当前位置:   article > 正文

TensorFlow之数据读取_tensorflow.data

tensorflow.data

小数据量读取

这仅用于可以完全加载到存储器中的小的数据集有两种方法:

  • 存储在常数中。
  • 存储在变量中,初始化后,永远不要改变它的值。

使用常数更简单一些,但是会使用更多的内存,因为常数会内联的存储在数据流图数据结构中,这个结构体可能会被复制几次。

  1. training_data = ...
  2. training_labels = ...
  3. with tf.Session() as sess:
  4. input_data = tf.constant(training_data)
  5. input_labels = tf.constant(training_labels)

要改为使用变量的方式,您就需要在数据流图建立后初始化这个变量。

  1. training_data = ...
  2. training_labels = ...
  3. with tf.Session() as sess:
  4. data_initializer = tf.placeholder(dtype=training_data.dtype,
  5. shape=training_data.shape)
  6. label_initializer = tf.placeholder(dtype=training_labels.dtype,
  7. shape=training_labels.shape)
  8. input_data = tf.Variable(data_initalizer, trainable=False, collections=[])
  9. input_labels = tf.Variable(label_initalizer, trainable=False, collections=[])
  10. ...
  11. sess.run(input_data.initializer,
  12. feed_dict={data_initializer: training_data})
  13. sess.run(input_labels.initializer,
  14. feed_dict={label_initializer: training_lables})

设定trainable=False可以防止该变量被数据流图的GraphKeys.TRAINABLE_VARIABLES收集,这样我们就不会在训练的时候尝试更新它的值;设定collections=[]可以防止GraphKeys.VARIABLES收集后做为保存和恢复的中断点。设定这些标志,是为了减少额外的开销。

文件读取流程

 先看下文件读取以及读取数据处理成张量结果的过程:

一般数据文件格式有文本、excel和图片数据。那么TensorFlow都有对应的解析函数,除了这几种。还有TensorFlow指定的文件格式。

TensorFlow还提供了一种内置文件格式TFRecord,二进制数据和训练类别标签数据存储在同一文件。模型训练前图像等文本信息转换为TFRecord格式。TFRecord文件是protobuf格式。数据不压缩,可快速加载到内存。TFRecords文件包含 tf.train.Example protobuf,需要将Example填充到协议缓冲区,将协议缓冲区序列化为字符串,然后使用该文件将该字符串写入TFRecords文件。在图像操作我们会介绍整个过程以及详细参数。

 数据读取的实现

 

文件队列生成函数

  • tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)

产生指定文件张量

文件阅读器类

  • class tf.TextLineReader

阅读文本文件逗号分隔值(CSV)格式

  • tf.FixedLengthRecordReader

要读取每个记录是固定数量字节的二进制文件

  • tf.TFRecordReader

读取TfRecords文件

解码

由于从文件中读取的是字符串,需要函数去解析这些字符串到张量

  • tf.decode_csv(records,record_defaults,field_delim = None,name = None)将CSV转换为张量,与tf.TextLineReader搭配使用

  • tf.decode_raw(bytes,out_type,little_endian = None,name = None) 将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用

生成文件队列

将文件名列表交给tf.train.string_input_producer函数。string_input_producer来生成一个先入先出的队列,文件阅读器会需要它们来取数据。string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数,QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,如果shuffle=True的话,会对文件名进行乱序处理。一过程是比较均匀的,因此它可以产生均衡的文件名队列。

这个QueueRunner工作线程是独立于文件阅读器的线程,因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的 read 方法。阅读器的read方法会输出一个键来表征输入的文件和其中纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

  1. # 读取CSV格式文件
  2. # 1、构建文件队列
  3. # 2、构建读取器,读取内容
  4. # 3、解码内容
  5. # 4、现读取一个内容,如果有需要,就批处理内容
  6. import tensorflow as tf
  7. import os
  8. def readcsv_decode(filelist):
  9. """
  10. 读取并解析文件内容
  11. :param filelist: 文件列表
  12. :return: None
  13. """
  14. # 把文件目录和文件名合并
  15. flist = [os.path.join("./csvdata/",file) for file in filelist]
  16. # 构建文件队列
  17. file_queue = tf.train.string_input_producer(flist,shuffle=False)
  18. # 构建阅读器,读取文件内容
  19. reader = tf.TextLineReader()
  20. key,value = reader.read(file_queue)
  21. record_defaults = [["null"],["null"]] # [[0],[0],[0],[0]]
  22. # 解码内容,按行解析,返回的是每行的列数据
  23. example,label = tf.decode_csv(value,record_defaults=record_defaults)
  24. # 通过tf.train.batch来批处理数据
  25. example_batch,label_batch = tf.train.batch([example,label],batch_size=9,num_threads=1,capacity=9)
  26. with tf.Session() as sess:
  27. # 线程协调员
  28. coord = tf.train.Coordinator()
  29. # 启动工作线程
  30. threads = tf.train.start_queue_runners(sess,coord=coord)
  31. # 这种方法不可取
  32. # for i in range(9):
  33. # print(sess.run([example,label]))
  34. # 打印批处理的数据
  35. print(sess.run([example_batch,label_batch]))
  36. coord.request_stop()
  37. coord.join(threads)
  38. return None
  39. if __name__=="__main__":
  40. filename_list = os.listdir("./csvdata")
  41. readcsv_decode(filename_list)

每次read的执行都会从文件中读取一行内容,注意,(这与后面的图片和TfRecords读取不一样),decode_csv操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。在调用run或者eval去执行read之前,你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

TensorFlow读取csv文件

 

  1. import tensorflow as tf
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "0"#指定GPU
  4. """
  5. csv文件读取
  6. 1、先找到文件。构造一个列表
  7. 2、构造文件的队列
  8. 3、构造阅读器,读取队列内容(行)
  9. 4、解码内容
  10. 5、批处理(多个样本)
  11. """
  12. #批处理大小跟队列,数据没关系,只决定这批次的取多少数据
  13. def csv_read(filelist):
  14. """
  15. 读取csv文件
  16. """
  17. #1、构造文件队列
  18. file_queue = tf.train.string_input_producer(filelist,shuffle = True)
  19. #2、构造csv阅读器,以行读取
  20. reader = tf.TextLineReader()
  21. key,value = reader.read(file_queue)
  22. print (value)
  23. #3、对每一行内容解码
  24. #record_defaults:指定每一个样本的每一列的类型,指定默认值,
  25. records = [["None"],["None"]]#是字符串格式
  26. example,label = tf.decode_csv(value,record_defaults = records)
  27. #想要读取多个数据,就要批处理
  28. example_batch,label_batch = tf.train.batch([example,label],batch_size = 20,num_threads = 1,capacity=9)
  29. return example_batch,label_batch
  30. if __name__ == "__main__":
  31. #1、找到文件,放入列表 路径+名字
  32. file_name = os.listdir("./data/")#返回文件名列表
  33. #路径+名字
  34. filelist = [os.path.join("./data/", file) for file in file_name]
  35. example_batch,label_batch = csv_read(filelist)
  36. #开启会话
  37. with tf.Session() as sess:
  38. #定义一个线程协调器
  39. coord = tf.train.Coordinator()
  40. #开启读取文件的线程
  41. threads = tf.train.start_queue_runners(sess,coord = coord)
  42. #打印读取的内容
  43. print (sess.run([example_batch,label_batch]))
  44. #回收子线程
  45. coord.request_stop()
  46. coord.join(threads)

TensorFlow读取图片文件 

  1. import tensorflow as tf
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "0"#指定GPU
  4. #tf.image.resize_images(images,size) 对图片进行缩放
  5. def images_read(filelist):
  6. """
  7. 读取图片文件并转换成张量
  8. """
  9. #1、构造文件队列
  10. file_queue = tf.train.string_input_producer(filelist,shuffle = True)
  11. #2、构造图片阅读器,默认按照一张一张读取
  12. reader = tf.WholeFileReader()
  13. key,value = reader.read(file_queue)
  14. print (value)
  15. #3、对读取的内容容解码
  16. image = tf.image.decode_png(value)
  17. #4、处理图片大小,统一大小
  18. image_resize = tf.image.resize_images(image,[200,200])
  19. #把样本的形状固定
  20. image_resize.set_shape([200,200,1])
  21. print (image_resize)
  22. #5、想要读取多个数据,就要批处理
  23. image_batch = tf.train.batch([image_resize],batch_size = 20,num_threads = 1,capacity=20)
  24. print (image_batch)
  25. return image_batch
  26. if __name__ == "__main__":
  27. #1、找到文件,放入列表 路径+名字
  28. file_name = os.listdir("./data/BSD68/")#返回文件名列表
  29. #路径+名字
  30. filelist = [os.path.join("./data/BSD68/", file) for file in file_name]
  31. image_batch = images_read(filelist)
  32. #开启会话
  33. with tf.Session() as sess:
  34. #定义一个线程协调器
  35. coord = tf.train.Coordinator()
  36. #开启读取文件的线程
  37. threads = tf.train.start_queue_runners(sess,coord = coord)
  38. #打印读取的内容
  39. print (sess.run(image_batch))
  40. #回收子线程
  41. coord.request_stop()
  42. coord.join(threads)

TensorFlow读取二进制文件

  1. import tensorflow as tf
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "0"#指定GPU
  4. #读取二进制文件
  5. #定义命令行参数
  6. flags = tf.app.flags
  7. tf.app.flags.DEFINE_string("cifar_dir","./cifar-10-binary/","文件目录")
  8. FLAGS = flags.FLAGS
  9. class CifarRead(object):
  10. """
  11. 完成读取二进制文件,写进tfrecords,读取tfrecords
  12. """
  13. def __init__(self,filelist):
  14. #文件列表
  15. self.file_list = filelist
  16. #定义读取图片的属性
  17. self.height = 32
  18. self.width = 32
  19. self.channel = 3
  20. #二进制每张图片存储的字节
  21. self.label_bytes = 1
  22. self.image_bytes = self.height *self.width*self.channel
  23. self.bytes = self.label_bytes + self.image_bytes
  24. def read_and_decode(self):
  25. #1、构造文件队列
  26. file_queue = tf.train.string_input_producer(self.file_list)
  27. #2、构造文件读取器,读取内容,每个样本的字节数
  28. reader = tf.FixedLengthRecordReader(self.bytes)
  29. key,value = reader.read(file_queue)
  30. print (value)
  31. #3、解码内容,二进制文件解码,标签值和特征值在一起
  32. label_image = tf.decode_raw(value,tf.uint8)
  33. #4、分割出图片和标签数据,切出特征值和目标值
  34. label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)#切片并转换类型
  35. image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
  36. #print (label,image)
  37. #5、可以对图片的特征数据进行形状改变
  38. image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
  39. #6、批处理
  40. image_batch,label_batch = tf.train.batch([image_reshape,label],batch_size= 10,num_threads = 1,capacity = 10)
  41. return image_batch,label_batch
  42. if __name__ == "__main__":
  43. #1、找到文件,放入列表 路径+名字
  44. file_name = os.listdir(FLAGS.cifar_dir)#返回文件名列表
  45. #路径+名字
  46. filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
  47. image_batch = CifarRead(filelist)
  48. image_batch,label_batch = image_batch.read_and_decode()
  49. #开启会话
  50. with tf.Session() as sess:
  51. #定义一个线程协调器
  52. coord = tf.train.Coordinator()
  53. #开启读取文件的线程
  54. threads = tf.train.start_queue_runners(sess,coord = coord)
  55. #打印读取的内容
  56. print (sess.run([image_batch,label_batch]))
  57. #回收子线程
  58. coord.request_stop()
  59. coord.join(threads)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/249806
推荐阅读
相关标签
  

闽ICP备14008679号