当前位置:   article > 正文

Python循环产生批量数据batch_python batch

python batch

Python循环产生批量数据batch

 

目录

Python循环产生批量数据batch

一、Python循环产生批量数据batch

二、TensorFlow循环产生批量数据batch 

(1) tf.train.slice_input_producer

(2) tf.train.batch和tf.train.shuffle_batch

(3) TF循环产生批量数据batch 的完整例子

三、更加实用的方法:数据巨大的情况


一、Python循环产生批量数据batch

   在机器学习中,经常需要产生一个batch的数据用于训练模型,比如tensorflow的接口tf.train.batch就可以实现数据批量读取的操作。本博客将不依赖TensorFlow,实现一个类似于tensorflow接口tf.train.batch的方法,循环产生批量数据batch。实现的代码和测试的代码如下:

   TXT文本如下,格式:图片名 label1 label2 ,注意label可以多个

  1. 1.jpg 1 11
  2. 2.jpg 2 12
  3. 3.jpg 3 13
  4. 4.jpg 4 14
  5. 5.jpg 5 15
  6. 6.jpg 6 16
  7. 7.jpg 7 17
  8. 8.jpg 8 18

    要想产生batch数据,关键是要用到Python的关键字yield,实现一个batch一个batch的返回数据,代码实现主要有两个方法:

  1. def get_data_batch(inputs, batch_size=None, shuffle=False):
  2. '''
  3. 循环产生批量数据batch
  4. :param inputs: list数据
  5. :param batch_size: batch大小
  6. :param shuffle: 是否打乱inputs数据
  7. :return: 返回一个batch数据
  8. '''
  1. def get_next_batch(batch):
  2. return batch.__next__()

    使用时,将数据传到 get_data_batch( )方法,然后使用get_next_batch( )获得一个batch数据,完整的Python代码如下:

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: create_batch_data
  4. @File : create_batch_data.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2017-10-27 18:20:15
  8. """
  9. import math
  10. import random
  11. import os
  12. import glob
  13. import numpy as np
  14. def get_data_batch(inputs, batch_size=None, shuffle=False):
  15. '''
  16. 循环产生批量数据batch
  17. :param inputs: list类型数据,多个list,请[list0,list1,...]
  18. :param batch_size: batch大小
  19. :param shuffle: 是否打乱inputs数据
  20. :return: 返回一个batch数据
  21. '''
  22. rows = len(inputs[0])
  23. indices = list(range(rows))
  24. # 如果输入是list,则需要转为list
  25. if shuffle:
  26. random.seed(100)
  27. random.shuffle(indices)
  28. while True:
  29. batch_indices = np.asarray(indices[0:batch_size]) # 产生一个batch的index
  30. indices = indices[batch_size:] + indices[:batch_size] # 循环移位,以便产生下一个batch
  31. batch_data = []
  32. for data in inputs:
  33. data = np.asarray(data)
  34. temp_data=data[batch_indices] #使用下标查找,必须是ndarray类型类型
  35. batch_data.append(temp_data.tolist())
  36. yield batch_data
  37. def get_data_batch2(inputs, batch_size=None, shuffle=False):
  38. '''
  39. 循环产生批量数据batch
  40. :param inputs: list类型数据,多个list,请[list0,list1,...]
  41. :param batch_size: batch大小
  42. :param shuffle: 是否打乱inputs数据
  43. :return: 返回一个batch数据
  44. '''
  45. # rows,cols=inputs.shape
  46. rows = len(inputs[0])
  47. indices = list(range(rows))
  48. if shuffle:
  49. random.seed(100)
  50. random.shuffle(indices)
  51. while True:
  52. batch_indices = indices[0:batch_size] # 产生一个batch的index
  53. indices = indices[batch_size:] + indices[:batch_size] # 循环移位,以便产生下一个batch
  54. batch_data = []
  55. for data in inputs:
  56. temp_data = find_list(batch_indices, data)
  57. batch_data.append(temp_data)
  58. yield batch_data
  59. def get_data_batch_one(inputs, batch_size=None, shuffle=False):
  60. '''
  61. 产生批量数据batch,非循环迭代
  62. 迭代次数由:iter_nums= math.ceil(sample_nums / batch_size)
  63. :param inputs: list类型数据,多个list,请[list0,list1,...]
  64. :param batch_size: batch大小
  65. :param shuffle: 是否打乱inputs数据
  66. :return: 返回一个batch数据
  67. '''
  68. # rows,cols=inputs.shape
  69. rows = len(inputs[0])
  70. indices = list(range(rows))
  71. if shuffle:
  72. random.seed(100)
  73. random.shuffle(indices)
  74. while True:
  75. batch_data = []
  76. cur_nums = len(indices)
  77. batch_size = np.where(cur_nums > batch_size, batch_size, cur_nums)
  78. batch_indices = indices[0:batch_size] # 产生一个batch的index
  79. indices = indices[batch_size:]
  80. # indices = indices[batch_size:] + indices[:batch_size] # 循环移位,以便产生下一个batch
  81. for data in inputs:
  82. temp_data = find_list(batch_indices, data)
  83. batch_data.append(temp_data)
  84. yield batch_data
  85. def find_list(indices, data):
  86. out = []
  87. for i in indices:
  88. out = out + [data[i]]
  89. return out
  90. def get_list_batch(inputs, batch_size=None, shuffle=False):
  91. '''
  92. 循环产生batch数据
  93. :param inputs: list数据
  94. :param batch_size: batch大小
  95. :param shuffle: 是否打乱inputs数据
  96. :return: 返回一个batch数据
  97. '''
  98. if shuffle:
  99. random.shuffle(inputs)
  100. while True:
  101. batch_inouts = inputs[0:batch_size]
  102. inputs = inputs[batch_size:] + inputs[:batch_size] # 循环移位,以便产生下一个batch
  103. yield batch_inouts
  104. def load_file_list(text_dir):
  105. text_dir = os.path.join(text_dir, '*.txt')
  106. text_list = glob.glob(text_dir)
  107. return text_list
  108. def get_next_batch(batch):
  109. return batch.__next__()
  110. def load_image_labels(finename):
  111. '''
  112. 载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
  113. :param test_files:
  114. :return:
  115. '''
  116. images_list = []
  117. labels_list = []
  118. with open(finename) as f:
  119. lines = f.readlines()
  120. for line in lines:
  121. # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
  122. content = line.rstrip().split(' ')
  123. name = content[0]
  124. labels = []
  125. for value in content[1:]:
  126. labels.append(float(value))
  127. images_list.append(name)
  128. labels_list.append(labels)
  129. return images_list, labels_list
  130. if __name__ == '__main__':
  131. filename = './training_data/test.txt'
  132. images_list, labels_list = load_image_labels(filename)
  133. # 若输入为np.arange数组,则需要tolist()为list类型,如:
  134. # images_list = np.reshape(np.arange(8*3), (8,3))
  135. # labels_list = np.reshape(np.arange(8*3), (8,3))
  136. # images_list=images_list.tolist()
  137. # labels_list=labels_list.tolist()
  138. iter = 5 # 迭代3次,每次输出一个batch个
  139. # batch = get_data_batch([images_list, labels_list], batch_size=3, shuffle=False)
  140. batch = get_data_batch2(inputs=[images_list,labels_list], batch_size=5, shuffle=True)
  141. for i in range(iter):
  142. print('**************************')
  143. batch_images, batch_labels = get_next_batch(batch)
  144. print('batch_images:{}'.format(batch_images))
  145. print('batch_labels:{}'.format(batch_labels))

   运行输出结果为:

**************************
batch_images:['1.jpg', '2.jpg', '3.jpg']
batch_labels:[[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]]
**************************
batch_images:['4.jpg', '5.jpg', '6.jpg']
batch_labels:[[4.0, 14.0], [5.0, 15.0], [6.0, 16.0]]
**************************
batch_images:['7.jpg', '8.jpg', '1.jpg']
batch_labels:[[7.0, 17.0], [8.0, 18.0], [1.0, 11.0]]
**************************
batch_images:['2.jpg', '3.jpg', '4.jpg']
batch_labels:[[2.0, 12.0], [3.0, 13.0], [4.0, 14.0]]
**************************
batch_images:['5.jpg', '6.jpg', '7.jpg']
batch_labels:[[5.0, 15.0], [6.0, 16.0], [7.0, 17.0]]

Process finished with exit code 0

二、TensorFlow循环产生批量数据batch 

    使用TensorFlow实现产生批量数据batch,需要几个接口,

(1) tf.train.slice_input_producer

tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

  1. slice_input_producer(tensor_list,
  2. num_epochs=None,
  3. shuffle=True,
  4. seed=None,
  5. capacity=32,
  6. shared_name=None,
  7. name=None)
  8. # 第一个参数
  9. # tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。
  10. # 第二个参数num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置
  11. # num_epochs = None, 生成器可以无限次遍历tensor列表,如果设置为
  12. # num_epochs = N,生成器只能遍历tensor列表N次。
  13. # 第三个参数shuffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle = True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用
  14. # tf.train.batch函数就可以了;
  15. # 如果shuffle = False, 就需要在批处理时候使用
  16. # tf.train.shuffle_batch函数打乱样本。
  17. # 第四个参数seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle = True的情况下才有用。
  18. # 第五个参数capacity:设置tensor列表的容量。
  19. # 第六个参数shared_name:可选参数,如果设置一个‘shared_name’,则在不同的上下文环境(Session)中可以通过这个名字共享生成的tensor。
  20. # 第七个参数name:可选,设置操作的名称

    tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。

    例子:

  1. import tensorflow as tf
  2. images = ['img1', 'img2', 'img3', 'img4', 'img5']
  3. labels= [1,2,3,4,5]
  4. epoch_num=8
  5. f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=False)
  6. with tf.Session() as sess:
  7. sess.run(tf.global_variables_initializer())
  8. coord = tf.train.Coordinator()
  9. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  10. for i in range(epoch_num):
  11. k = sess.run(f)
  12. print '************************'
  13. print (i,k)
  14. coord.request_stop()
  15. coord.join(threads)

(2) tf.train.batch和tf.train.shuffle_batch

    tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到文件队列,作为训练一个batch的数据,等待tensor出队执行计算。

  1. tf.train.batch(tensors,
  2. batch_size,
  3. num_threads=1,
  4. capacity=32,
  5. enqueue_many=False,
  6. shapes=None,
  7. dynamic_pad=False,
  8. allow_smaller_final_batch=False,
  9. shared_name=None,
  10. name=None)
  11. # 第一个参数tensors:tensor序列或tensor字典,可以是含有单个样本的序列;
  12. # 第二个参数batch_size: 生成的batch的大小;
  13. # 第三个参数num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
  14. # 第四个参数capacity: 定义生成的tensor序列的最大容量;
  15. # 第五个参数enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
  16. # 第六个参数shapes: 可选参数,默认是推测出的传入的tensor的形状;
  17. # 第七个参数dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
  18. # 第八个参数allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
  19. # 第九个参数shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  20. # 第十个参数name: 操作的名称;

    如果tf.train.batch的第一个参数 tensors 传入的是tenor列表或者字典,返回的是tensor列表或字典,如果传入的是只含有一个元素的列表,返回的是单个的tensor,而不是一个列表。

    与tf.train.batch函数相对的还有一个tf.train.shuffle_batch函数,两个函数作用一样,都是生成一定数量的tensor,组成训练一个batch需要的数据集,区别是tf.train.shuffle_batch会打乱样本顺序。

(3) TF循环产生批量数据batch 的完整例子

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: LSTM
  4. @File : tf_create_batch_data.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2018-10-28 17:50:24
  8. """
  9. import tensorflow as tf
  10. def get_data_batch(inputs,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
  11. '''
  12. :param inputs: 输入数据,可以是多个list
  13. :param batch_size:
  14. :param labels_nums:标签个数
  15. :param one_hot:是否将labels转为one_hot的形式
  16. :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
  17. :return:返回batch的images和labels
  18. '''
  19. # 生成队列
  20. inputs_que= tf.train.slice_input_producer(inputs, shuffle=shuffle)
  21. min_after_dequeue = 200
  22. capacity = min_after_dequeue + 3 * batch_size # 保证capacity必须大于min_after_dequeue参数值
  23. if shuffle:
  24. out_batch = tf.train.shuffle_batch(inputs_que,
  25. batch_size=batch_size,
  26. capacity=capacity,
  27. min_after_dequeue=min_after_dequeue,
  28. num_threads=num_threads)
  29. else:
  30. out_batch = tf.train.batch(inputs_que,
  31. batch_size=batch_size,
  32. capacity=capacity,
  33. num_threads=num_threads)
  34. return out_batch
  35. def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False):
  36. '''
  37. :param images:图像
  38. :param labels:标签
  39. :param batch_size:
  40. :param labels_nums:标签个数
  41. :param one_hot:是否将labels转为one_hot的形式
  42. :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
  43. :return:返回batch的images和labels
  44. '''
  45. images_que, labels_que= tf.train.slice_input_producer([images,labels], shuffle=shuffle)
  46. min_after_dequeue = 200
  47. capacity = min_after_dequeue + 3 * batch_size # 保证capacity必须大于min_after_dequeue参数值
  48. if shuffle:
  49. images_batch, labels_batch = tf.train.shuffle_batch([images_que, labels_que],
  50. batch_size=batch_size,
  51. capacity=capacity,
  52. min_after_dequeue=min_after_dequeue)
  53. else:
  54. images_batch, labels_batch = tf.train.batch([images_que, labels_que],
  55. batch_size=batch_size,
  56. capacity=capacity)
  57. if one_hot:
  58. labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
  59. return images_batch,labels_batch
  60. def load_image_labels(finename):
  61. '''
  62. 载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
  63. :param test_files:
  64. :return:
  65. '''
  66. images_list=[]
  67. labels_list=[]
  68. with open(finename) as f:
  69. lines = f.readlines()
  70. for line in lines:
  71. #rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
  72. content=line.rstrip().split(' ')
  73. name=content[0]
  74. labels=[]
  75. for value in content[1:]:
  76. labels.append(float(value))
  77. images_list.append(name)
  78. labels_list.append(labels)
  79. return images_list,labels_list
  80. if __name__ == '__main__':
  81. filename='./training_data/train.txt'
  82. # 输入数据可以是list,也可以是np.array
  83. images_list, labels_list=load_image_labels(filename)
  84. # np.arange数组如:
  85. # images_list = np.reshape(np.arange(8*3), (8,3))
  86. # labels_list = np.reshape(np.arange(8*3), (8,3))
  87. iter = 5 # 迭代5次,每次输出一个batch个
  88. # batch_images, batch_labels = get_data_batch( inputs=[images_list, labels_list],batch_size=3,labels_nums=2,one_hot=False,shuffle=False,num_threads=1)
  89. # 或者
  90. batch_images, batch_labels = get_batch_images(images_list,labels_list,batch_size=3,labels_nums=2,one_hot=False,shuffle=False)
  91. with tf.Session() as sess: # 开始一个会话
  92. sess.run(tf.global_variables_initializer())
  93. coord = tf.train.Coordinator()
  94. threads = tf.train.start_queue_runners(coord=coord)
  95. for i in range(iter):
  96. # 在会话中取出images和labels
  97. images, labels = sess.run([batch_images, batch_labels] )
  98. print('**************************')
  99. print('batch_images:{}'.format(images ))
  100. print('batch_labels:{}'.format(labels))
  101. # 停止所有线程
  102. coord.request_stop()
  103. coord.join(threads)

    运行输出结果:

**************************
batch_images:[b'1.jpg' b'2.jpg' b'3.jpg']
batch_labels:[[ 1. 11.] [ 2. 12.][ 3. 13.]]
**************************
batch_images:[b'4.jpg' b'5.jpg' b'6.jpg']
batch_labels:[[ 4. 14.] [ 5. 15.][ 6. 16.]]
**************************
batch_images:[b'7.jpg' b'8.jpg' b'1.jpg']
batch_labels:[[ 7. 17.][ 8. 18.][ 1. 11.]]
**************************
batch_images:[b'2.jpg' b'3.jpg' b'4.jpg']
batch_labels:[[ 2. 12.] [ 3. 13.][ 4. 14.]]
**************************
batch_images:[b'5.jpg' b'6.jpg' b'7.jpg']
batch_labels:[[ 5. 15.][ 6. 16.][ 7. 17.]]

三、更加实用的方法:数据巨大的情况

    当数据量很大很大时,超过2T的数据时,我们不可能把所以数据都保存为一个文件,也不可能把数据都加载到内存。为了避免内存耗尽的情况,最简单的思路是:把数据分割成多个文件保存到硬盘(每个文件不超过2G),训练时,按batch大小逐个加载文件,再获取一个batch的训练数据。这种方法,也可以用TensorFlow TFRecord格式,利用队列方法读取文件,然后再产生一个batch数据,可以参考:《Tensorflow生成自己的图片数据集TFrecords》:https://blog.csdn.net/guyuealian/article/details/80857228

   但TensorFlow TFRecord格式存储的内容,有很多限制, 这里将实现一种类似于TensorFlow TFRecord的方法,但存储的内容没有限制,你可稍微修改保存不同的数据,基本思路是:

  1. 数据产生:利用numpy,把数据分割成多个*.npy文件保存到硬盘(每个文件不超过1G),当然你可以用其他Python工具保存其他文件格式,只要你能读取文件即可
  2. 获得训练数据:获取所有文件*.npy的列表,逐个读取文件的数据,并根据batch的大小,循环返回数据

     完整代码如下:

    这里将数据保存为data1.npy,data2.npy,data3.npy,其中*.npy文件的数据保存是二维矩阵Mat:第一列为样本的labels,剩余的列为样本的数据,

  1. indexMat1:
  2. [[0 0 5]
  3. [1 1 6]
  4. [2 2 7]
  5. [3 3 8]
  6. [4 4 9]]
  7. indexMat2:
  8. [[ 5 15 20]
  9. [ 6 16 21]
  10. [ 7 17 22]
  11. [ 8 18 23]
  12. [ 9 19 24]]
  13. indexMat3:
  14. [[10 30 35]
  15. [11 31 36]
  16. [12 32 37]
  17. [13 33 38]
  18. [14 34 39]]
  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: nlp-learning-tutorials
  4. @File : create_batch_data.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2018-11-08 09:29:19
  8. """
  9. import math
  10. import random
  11. import os
  12. import glob
  13. import numpy as np
  14. from sklearn import preprocessing
  15. def get_data_batch(file_list,labels_nums,batch_size=None, shuffle=False,one_hot=False):
  16. '''
  17. 加载*.npy文件的数据,循环产生批量数据batch,其中*.npy文件的数据保存是二维矩阵Mat:
  18. 二维矩阵Mat:第一列为样本的labels,剩余的列为样本的数据,
  19. np.concatenate([label,data], axis=1)
  20. :param file_list: *.npy文件路径,type->list->[file0.npy,file1.npy,....]
  21. :param labels_nums: labels种类数
  22. :param batch_size: batch大小
  23. :param shuffle: 是否打乱数据,PS:只能打乱一个batch的数据,不同batch的数据不会干扰
  24. :param one_hot: 是否独热编码
  25. :return: 返回一个batch数据
  26. '''
  27. height = 0
  28. indexMat_labels = None
  29. i = 0
  30. while True:
  31. while height < batch_size:
  32. i = i%len(file_list)
  33. tempFile = file_list[i]
  34. tempMat_labels = np.load(tempFile)
  35. if indexMat_labels is None:
  36. indexMat_labels = tempMat_labels
  37. else:
  38. indexMat_labels = np.concatenate([indexMat_labels, tempMat_labels], 0)
  39. i=i+1
  40. height = indexMat_labels.shape[0]
  41. indices = list(range(height))
  42. batch_indices = np.asarray(indices[0:batch_size]) # 产生一个batch的index
  43. if shuffle:
  44. random.seed(100)
  45. random.shuffle(batch_indices)
  46. batch_indexMat_labels = indexMat_labels[batch_indices] # 使用下标查找,必须是ndarray类型类型
  47. indexMat_labels=np.delete(indexMat_labels,batch_indices,axis=0)
  48. height = indexMat_labels.shape[0]
  49. # 将数据分割成indexMat和labels
  50. batch_labels=batch_indexMat_labels[:,0] # 第一列是labels
  51. batch_indexMat=batch_indexMat_labels[:,1:] # 其余是indexMat
  52. # 是否进行独热编码
  53. if one_hot:
  54. batch_labels = batch_labels.reshape(len(batch_labels), 1)
  55. onehot_encoder = preprocessing.OneHotEncoder(sparse=False,categories=[range(labels_nums)])
  56. batch_labels = onehot_encoder.fit_transform(batch_labels)
  57. yield batch_indexMat,batch_labels
  58. def get_next_batch(batch):
  59. return batch.__next__()
  60. def get_file_list(file_dir,postfix):
  61. '''
  62. 获得后缀名为postfix所有文件列表
  63. :param file_dir:
  64. :param postfix:
  65. :return:
  66. '''
  67. file_dir=os.path.join(file_dir,postfix)
  68. file_list=glob.glob(file_dir)
  69. return file_list
  70. def create_test_data(out_dir):
  71. '''
  72. 产生测试数据
  73. :return:
  74. '''
  75. data1 = np.arange(0, 10)
  76. data1 = np.transpose(data1.reshape([2, 5]))
  77. label1 = np.arange(0, 5)
  78. label1 = label1.reshape([5, 1])
  79. path1 = os.path.join(out_dir,'data1.npy')
  80. indexMat1 = np.concatenate([label1, data1], axis=1) # 矩阵拼接,第一列为labels
  81. np.save(path1, indexMat1)
  82. data2 = np.arange(15, 25)
  83. data2 = np.transpose(data2.reshape([2, 5]))
  84. label2 = np.arange(5, 10)
  85. label2 = label2.reshape([5, 1])
  86. path2 = os.path.join(out_dir,'data2.npy')
  87. indexMat2 = np.concatenate([label2, data2], axis=1)
  88. np.save(path2, indexMat2)
  89. data3 = np.arange(30, 40)
  90. data3 = np.transpose(data3.reshape([2, 5]))
  91. label3 = np.arange(10, 15)
  92. label3 = label3.reshape([5, 1])
  93. path3 = os.path.join(out_dir,'data3.npy')
  94. indexMat3 = np.concatenate([label3, data3], axis=1)
  95. np.save(path3, indexMat3)
  96. print('indexMat1:\n{}'.format(indexMat1))
  97. print('indexMat2:\n{}'.format(indexMat2))
  98. print('indexMat3:\n{}'.format(indexMat3))
  99. if __name__ == '__main__':
  100. out_dir='./output'
  101. create_test_data(out_dir)
  102. file_list=get_file_list(file_dir=out_dir, postfix='*.npy')
  103. iter = 3 # 迭代3次,每次输出一个batch个
  104. batch = get_data_batch(file_list, labels_nums=15,batch_size=8, shuffle=False,one_hot=False)
  105. for i in range(iter):
  106. print('**************************')
  107. batch_data, batch_label = get_next_batch(batch)
  108. print('batch_images:\n{}'.format(batch_data))
  109. print('batch_labels:\n{}'.format(batch_label))

运行结果: 

  1. **************************
  2. batch_images:
  3. [[ 0 5]
  4. [ 1 6]
  5. [ 2 7]
  6. [ 3 8]
  7. [ 4 9]
  8. [15 20]
  9. [16 21]
  10. [17 22]]
  11. batch_labels:
  12. [0 1 2 3 4 5 6 7]
  13. **************************
  14. batch_images:
  15. [[18 23]
  16. [19 24]
  17. [30 35]
  18. [31 36]
  19. [32 37]
  20. [33 38]
  21. [34 39]
  22. [ 0 5]]
  23. batch_labels:
  24. [ 8 9 10 11 12 13 14 0]
  25. **************************
  26. batch_images:
  27. [[ 1 6]
  28. [ 2 7]
  29. [ 3 8]
  30. [ 4 9]
  31. [15 20]
  32. [16 21]
  33. [17 22]
  34. [18 23]]
  35. batch_labels:
  36. [1 2 3 4 5 6 7 8]

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/927172
推荐阅读
相关标签
  

闽ICP备14008679号