当前位置:   article > 正文

Tensorflow (1): 读取数据的三种方式及tfrecord的使用_placeholder 使用tfrecord

placeholder 使用tfrecord

参考: https://blog.csdn.net/lujiandong1/article/details/53376802
https://blog.csdn.net/happyhorizion/article/details/77894055

读取数据的三种方式

Preloaded data: 预加载数据

import tensorflow as tf  
# 设计Graph  
x1 = tf.constant([2, 3, 4])  
x2 = tf.constant([4, 0, 1])  
y = tf.add(x1, x2)  
# 打开一个session --> 计算y  
with tf.Session() as sess:  
    print sess.run(y) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

Feeding: Python产生数据,再把数据喂给后端

import tensorflow as tf  
# 设计Graph  
x1 = tf.placeholder(tf.int16)  
x2 = tf.placeholder(tf.int16)  
y = tf.add(x1, x2)  
# 用Python产生数据  
li1 = [2, 3, 4]  
li2 = [4, 0, 1]  
# 打开一个session --> 喂数据 --> 计算y  
with tf.Session() as sess:  
    print sess.run(y, feed_dict={x1: li1, x2: li2})  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

Reading from file: 从文件中直接读取

aaa

#-*- coding:utf-8 -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义Reader
reader = tf.TextLineReader()
# key, value会分别得到文件名及行数和文件内容:['A.csv:3', 'Alpha3,A3']
key, value = reader.read(filename_queue)
# 定义Decoder
# sess.run([example, label])
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)
# 运行Graph
with tf.Session() as sess:
    coord = tf.train.Coordinator()  #创建一个协调器,管理线程
    threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。
    for i in range(10):
        train_1, label_1 = sess.run([example_batch, label_batch])
        print train_1, label_1
    coord.request_stop()
    coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

也可以使用shuffle_batch来实现,其中shuffle_batch不使用时数据是按照顺序的。
batch和batch_join的区别: 一般来说,单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch);而对于多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有对应的tf.train.shuffle_batch_join使用),与多个reader对应

#-*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['null'], ['null']]
#定义了多种解码器,每个解码器跟一个reader相连
example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                  for _ in range(2)]  # Reader设置为2
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
example_batch, label_batch = tf.train.shuffle_batch_join(
      example_list, batch_size=5, capacity=200,
                       min_after_dequeue=10)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        e_val,l_val = sess.run([example_batch,label_batch])
        print e_val,l_val
    coord.request_stop()
    coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

使用多个reader可以并行读取数据,提高效率
使用join可以保证实例与标签对应
也可以在代码中设置epochs

tfrecords

除了使用csv或者其他格式的数据,推荐使用tf内定标准格式——tfrecords

制作tfrecords

import tensorflow as tf
import numpy as np

# create tfrecord_writer
tfrecords_filename = 'test_tf.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for i in range(100):
    img_raw = np.random.random_integers(0, 255, size=(7, 30))
    img_raw = img_raw.tostring()
    example = tf.train.Example(features = tf.train.Features(
        feature = {
            'label' : tf.train.Feature(int64_list = tf.train.Int64List(value = [i])),
            'img_raw' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
        }
    ))
    writer.write(example.SerializeToString())

writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

由于存在于内存中的对象都是暂时的,无法长期驻存,为了把对象的状态保持下来,这时需要把对象写入到磁盘或者其他介质中,这个过程就叫做序列化。不序列化则无法保存

解析tfrecors

#encoding: utf-8
import tensorflow as tf
import numpy as np
from PIL import Image


if __name__=='__main__':
    tfrecords_filename = "test_tf.tfrecords"
    filename_queue = tf.train.string_input_producer([tfrecords_filename],) #读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })  #取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'],tf.int64)
    image = tf.reshape(image, [7,30])
    label = tf.cast(features['label'], tf.int64)
    image, label = tf.train.shuffle_batch([image, label],
                                               batch_size= 2, capacity=200, min_after_dequeue=10)

    with tf.Session() as sess: #开始一个会话
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        coord=tf.train.Coordinator()
        threads= tf.train.start_queue_runners(coord=coord)
        for i in range(20):
            example, l = sess.run([image,label])#在会话中取出image和label
            img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
            img.save('./'+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
            print(example, l)

        coord.request_stop()
        coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

Note: 这里使用shuffle_batch_size就出问题了,暂不清楚为什么。这样的话,以上代码不知道是否数据与标签统一,看晚上的做法都是这样做的
对于以上问题,经过实验发现,并不会导致数据与标签不一致

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/249791
推荐阅读
相关标签
  

闽ICP备14008679号