赞
踩
今晚又是在工位熬夜的一天,没办法,我实在是太菜了,只能熬夜学习了,说起来都是以前自己过得太轻松了,导致我现在不得不使劲补基础。
好了回归正题,为什么我要写这篇博客,最近我在github上找到了一个指纹识别的代码,不过这个代码里面用到了训练集(train),测试集(eval)和验证集(val)。一开始我想测试一下这个代码能不能跑,所以测试集和验证集我用的是一个数据集,结果可以跑啊,不过网络效果就是很不错,毕竟用的测试集就是验证集。所以我现在想以8:2的形式把训练集分为训练集和测试集,没有找到相关的代码,所以今晚就熬夜研究了一下,研究出来了,记录一下这个过程,顺便心疼一下我的头发。
本文感谢以下参考博客:
以及我之前写过的一个博客:
3. 基于python和Opencv将多张图片结合为一张图片的办法
先介绍一下我的情况啊,我用的是livdet数据集,这个数据集下面的分类是:传感器-真/假(假的话下面还会分一个材料),大概是这样的啊:
LivDet-LivDet_2009-Training-Biometrika-Alive
然后这个文件夹下面都是图片了,本次测试文件夹为tif文件。
看了参考-1大佬的代码,我也用了这个函数,不过我测试的时候,我发现好像这个东西只能分割数字本身,而不能分割图片,当然也有可能我的代码写的比较垃圾,没试出来,这里只介绍我个人思路。
opencv
numpy
torch
torchvision
glob
import torch
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
import glob
import numpy as np
import cv2
关于这一点,可以看一下参考-3当中的用法
def open_image(path1):
img_path = glob.glob(path1)
return np.array([cv2.imread(true_path,0) for true_path in img_path])
这里参数的意义是:
images:上一个函数打开过的图片数组,每张图片都在这个大数组里面
list1:训练集的索引,我将其保存在了这个列表里面
list2:验证集的索引,同上
代码相当简单,相信看一眼就明白了,毕竟老夫也不是什么厉害人,写不出漂亮代码
def generate_image(images,list1,list2):
a = len(list1)
b = len(list2)
res1 = []
res2 = []
for i in range(a):
res1.append(images[i])
for j in range(b):
res2.append(images[j])
return res1,res2
关于random_split的使用可以看一下参考-1,这个大佬说的很明白,我一开始看人家的,我以为十分制分割呢,结果我一开始就写的lengths=[8:2],然后疯狂报错,我才发现是写你要分割的具体数量,这里我一共有520张图片,分一下就是416张训练集,104张验证集。
all_data = open_image('LivDet/LivDet_2009/Traning/Biometrika/Alive/*') # 保存图片
num = range(len(all_data)) # 获得图片长度
train_data_num, val_data_num = random_split(dataset=num,
lengths=[416,104]) # 得到打乱过后的索引
print(list(train_data_num)) # 一会给你们看看打印出来的效果
print(list(val_data_num))
output_dir = '/dataset/livdet2009/train/Biometrika/live/' # 设置你要保存的路径
output_dirr = '/dataset/livdet2009/val/Biometrika/live/'
train_data,test_data = generate_image(all_data,list(train_data_num),list(val_data_num)) # 给图片的过程
for i,img in enumerate(train_data):
cv2.imwrite(output_dir+str(i)+'.tif',img) # 保存训练集
for j,imgg in enumerate(test_data):
cv2.imwrite(output_dirr+str(j)+'.tif',imgg) # 保存验证集
给你们看看print那两句话得到的效果:
可以看到啊,这里成功把索引打乱了,可以确保得到的图片具有随机性,使得网络效果更好
import torch from torch.utils.data import random_split from torchvision.datasets import ImageFolder import glob import numpy as np import cv2 def open_image(path1): img_path = glob.glob(path1) return np.array([cv2.imread(true_path,0) for true_path in img_path]) def generate_image(images,list1,list2): a = len(list1) b = len(list2) res1 = [] res2 = [] for i in range(a): res1.append(images[i]) for j in range(b): res2.append(images[j]) return res1,res2 all_data = open_image('LivDet/LivDet_2009/Traning/Biometrika/Alive/*') # 保存图片 num = range(len(all_data)) # 获得图片长度 train_data_num, val_data_num = random_split(dataset=num, lengths=[416,104]) # 得到打乱过后的索引 print(list(train_data_num)) # 一会给你们看看打印出来的效果 print(list(val_data_num)) output_dir = '/dataset/livdet2009/train/Biometrika/live/' # 设置你要保存的路径 output_dirr = '/dataset/livdet2009/val/Biometrika/live/' train_data,test_data = generate_image(all_data,list(train_data_num),list(val_data_num)) # 给图片的过程 for i,img in enumerate(train_data): cv2.imwrite(output_dir+str(i)+'.tif',img) # 保存训练集 for j,imgg in enumerate(test_data): cv2.imwrite(output_dirr+str(j)+'.tif',imgg) # 保存验证集
代码很好懂啊,随便看看就明白了,实在不明白私信我,我给你讲
防火防盗防诈骗
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。