赞
踩
个人小作业,虽说做的很差,也算是一个学习的转化;主要用于分类自己下载的壁纸
学期末需要一个学习成果的展示,高难度的自己做不来,模型也跑不动(电脑有点渣),刚好自己也有图片分类的需求,最后决定做了这个,确实也算做了一个自己用得到的小程序
需要自动加载指定目录所有图片,自行迁移至指定目录并存入不同的文件夹
│ colorUi.ui 正在使用的UI界面文件 │ fun.py 对于模型函数的初步封装,为PyQt界面提供支持 │ main.py 入口部分 │ model.py 模型的训练、加载 │ ui.py 正在使用的UI界面py文件 │ ui.ui 老的UI界面文件 │ utils.py 一些读取图片处理图片的函数 ├─fun_test 内含各类图片共100张,用于最后的功能测试 ├─make_data_set 用于处理制作数据集 ├─model 训练好的模型存储的路径 ├─test 内含处理好的数据集的测试集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组 ├─test_pic 测试集原始数据目录,路径下各种图片独占一个目录,用于通过make_data_set制作数据集,目录应与train_pic对应 │ ├─dongman 其中一个分类 │ ├─dongwu 其中一个分类 │ ├─fengjing 其中一个分类 │ ├─meinv 其中一个分类 │ └─youxi 其中一个分类 ├─train 内含处理好的数据集的训练集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组 └─train_pic ├─dongman 其中一个分类 ├─dongwu 其中一个分类 ├─fengjing 其中一个分类 ├─meinv 其中一个分类 └─youxi 其中一个分类
import json import os import cv2 import numpy import numpy as np import tensorflow as tf from tensorflow import keras from tqdm import tqdm from utils import img_resize def init_network(): """ 初始化神经网络,支持五种类型 :return: 模型 """ model = tf.keras.Sequential([ tf.keras.layers.Conv2D(filters=48, kernel_size=(3, 3), padding='same', activation='relu', strides=1, input_shape=(108, 192, 3)), tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), # 抑制过拟合 tf.keras.layers.Dropout(rate=0.6), tf.keras.layers.Conv2D(filters=24, kernel_size=( 3, 3), padding='same', activation='relu', strides=1), # 2*2池化取最大值 tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), # 抑制过拟合 tf.keras.layers.Dropout(rate=0.6), # 维度拉伸成1维 tf.keras.layers.Flatten(), # 第二层隐藏层,使用relu激活函数 tf.keras.layers.Dense(256, activation='relu'), # 抑制过拟合 tf.keras.layers.Dropout(rate=0.6), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dropout(rate=0.5), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dropout(rate=0.5), # 输出层 tf.keras.layers.Dense(5, activation='softmax') ]) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary() return model def getTrainData(): """ 获取训练集数据 :return: train_images, train_labels, class_names """ fp = open('./train/train.json', 'r', encoding='utf8') class_names = json.load(fp)['support'] fp.close() # 返回加载来的数据集 pic_train_images = numpy.load('./train/train_pic.npy') train_images = pic_train_images.reshape( pic_train_images.shape[0], 108, 192, 3) / 255.0 print(train_images.shape) train_labels = numpy.load('./train/train_labels.npy') print(numpy.load('./train/train_labels.npy').shape) return train_images, train_labels, class_names def getTestData(): """ 获取测试集包的数据 :return: train_images, train_labels, class_names """ fp = open('./test/test.json', 'r', encoding='utf8') class_names = json.load(fp)['support'] fp.close() # 返回加载来的数据集 pic_test_images = numpy.load('./test/test_pic.npy') test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0 print(test_images.shape) test_labels = numpy.load('./test/test_labels.npy') print(numpy.load('./test/test_labels.npy').shape) return test_images, test_labels, class_names def getTestImages(): """ 加载测试集1920*1080的壁纸 """ path = './test_pic' imgs = [] labels = [] k = 0 paths = os.listdir(path) paths.sort() for j in paths: pbar = tqdm(total=100) for i in os.listdir(path + '/' + j): pbar.update(100.0 / len(os.listdir(path + '/' + j))) pic_path = path + '/' + j + '/' + i # img = img_resize(cv2.imread(pic_path, cv2.IMREAD_GRAYSCALE)) img = img_resize(cv2.imread(pic_path)) if img.shape[0] != 108 or img.shape[1] != 192: os.remove(pic_path) continue imgs.append(img) labels.append(k) pbar.close() k = k + 1 pic_test_images = np.array(imgs) test_images = pic_test_images.reshape( pic_test_images.shape[0], 108, 192, 3) / 255.0 return test_images, np.array(labels) def getModel(train_mode=False): """ 获取模型 :param train_mode: 是否训练 :return: 模型 """ # 如果训练 if train_mode: # 初始化神经网络 model = init_network() # 加载数据集 train_images, train_labels, _ = getTrainData() test_images, test_labels, _ = getTestData() print(train_images.shape) print(train_labels.shape) print(test_images.shape) print(test_labels.shape) # 开始训练,训练二十次,显示日志信息 model.fit(train_images, keras.utils.to_categorical( train_labels), batch_size=128, epochs=100, verbose=2) # 评估模型,不输出预测结果 test_loss, test_acc = model.evaluate( test_images, keras.utils.to_categorical(test_labels), verbose=2) # 输出损失值 print('测试集损失:', test_loss) # 输出正确率 print('测试集正确率:', test_acc) # 保存模型 model.save('.\\model\\expll.h5') return model, test_loss, test_acc else: # 加载模型 model = tf.keras.models.load_model('.\\model\\780_3x3_1_3_100_expll.h5') # 打印模型信息 model.summary() test_images, test_labels, _ = getTestData() # 评估模型,不输出预测结果 test_loss, test_acc = model.evaluate( test_images, keras.utils.to_categorical(test_labels), verbose=2) # print([np.where(i == np.max(i))[0][0] for i in model.predict(test_images)]) return model, test_loss, test_acc # 训练模型 # if __name__ == '__main__': # model = getModel(True)
import json import os import shutil import numpy as np from PyQt5.QtCore import * import utils from model import getModel def getModelSupportTypes(data): """ 获取模型支持的分类 :return: """ temp = '' for i in data: temp = temp + ' ' + i return temp def getModelInfo(loss, acc): """ 获取模型信息 :return: 模型测试准确度 """ return '测试集损失:{:.3f}\n测试集准确率:{:.3f}%'.format(loss, acc * 100) class Service(QObject): signalRunTime = pyqtSignal(str, bool) model = None signalWorking = pyqtSignal(bool) loadModelStatus = False signalModelInfo = pyqtSignal(str) signalModelSupportTypes = pyqtSignal(str) def __init__(self): super().__init__() def predict(self, imgs: np.array): """ 预测 :param imgs: 预测图片集 :return: 预测结果 """ rs = self.model.predict(imgs) return [np.where(i == np.max(i))[0][0] for i in rs] def iniModel(self): """ 初始化加载模型 """ if self.loadModelStatus: self.signalRunTime.emit('模型加载中···', False) return self.loadModelStatus = True self.signalRunTime.emit('正在加载模型···', False) self.model, loss, acc = getModel() with open('model/model.json', 'r', encoding='utf8') as fp: info = json.load(fp) self.signalModelInfo.emit('方法:' + info['way'] + '\n' + getModelInfo(loss, acc)) self.signalModelSupportTypes.emit(getModelSupportTypes(info['support'])) self.signalRunTime.emit('模型加载完成', False) self.loadModelStatus = False def startRun(self, window): """ 开始进行分类 :param window: 窗口对象 """ if len(window.getFromPath()) == 0 or len(window.getTargetPath()) == 0: self.signalRunTime.emit('\n存在路径为空\n', False) self.signalWorking.emit(False) return list_path = [] self.signalRunTime.emit('\n检索中······\n', False) utils.getListDir(window.fromPath.toPlainText(), window.getRecursionPathStatus(), list_path, imageCallback=None, dirCallback=lambda x: self.signalRunTime.emit('检索检索到目录: {0}\n'.format(x), False)) self.signalRunTime.emit('检索完成,共计{0}张图片\n'.format(len(list_path)), False) if len(list_path) == 0: self.signalWorking.emit(False) return self.signalRunTime.emit('开始读取图片······', False) img = utils.get_data(list_path, lambda x: self.signalRunTime.emit('已加载: {0}\n'.format(x), False)) self.signalRunTime.emit('读取图片完成', False) self.signalRunTime.emit('维度信息:{0}'.format(img.shape), False) self.signalRunTime.emit('进行分类识别中······', False) rs = self.predict(img) self.signalRunTime.emit('分类识别完成\n***********\n识别结果:\n***********\n***********\n***********\n', False) with open('.\\model\\model.json', encoding='utf8') as fp: supportTypes = json.load(fp)['support'] outRunInfo = '\n' for i in zip(list_path, rs): outRunInfo = outRunInfo + '路径: {0}; 结果:{1}\n\n'.format(i[0], supportTypes[i[1]]) self.signalRunTime.emit(outRunInfo + '\n\n***********\n***********\n识别结果输出结束\n***********\n***********\n', False) targetPathRoot = window.getTargetPath() for i in supportTypes: if not os.path.exists(targetPathRoot + '/' + i): os.mkdir(targetPathRoot + '/' + i) self.signalRunTime.emit('\n\n开始进行分类迁移······', False) onlyMoveMax = window.getOnlyNumber() with open('.\\model\\model.json', encoding='utf8') as fp: supportTypes = json.load(fp)['support'] for j in range(0, int(len(list_path) * 1.0 / onlyMoveMax + 1)): for i in list(zip(list_path, rs))[onlyMoveMax * j:onlyMoveMax * (j + 1)]: try: self.signalRunTime.emit( '来源: {0}; 迁移至:{1}\n\n'.format(i[0], (targetPathRoot + '/' + supportTypes[i[1]])), False) shutil.move(i[0], targetPathRoot + '/' + supportTypes[i[1]]) except Exception as e: self.signalRunTime.emit( 'ERROR: {0}'.format(e, False)) self.signalRunTime.emit('\n\n迁移结束,任务完成\n\n', False) self.signalWorking.emit(False)
# -*- coding: utf-8 -*- import os import sys from concurrent.futures import ThreadPoolExecutor from PyQt5.QtWidgets import * import fun from ui import Ui_Form threadPool = ThreadPoolExecutor(max_workers=20) def openPath(callback): # 选择图片 path = QFileDialog.getExistingDirectory(None, "选择存储文件夹", os.getcwd()) if path == "": return 0 callback(path) class MainWindow(QWidget, Ui_Form): service = None img = None working = False def __init__(self, service_): super(MainWindow, self).__init__() self.service = service_ self.setupUi(self) def openFromPath(self): """ 选择来源路径 """ openPath(callback=lambda x: self.fromPath.setText(x)) def openTargetPath(self): """ 选择输出路径 """ openPath(callback=lambda x: self.targetPath.setText(x)) def outRuntimeInfo(self, data, refresh=True): """ 输出运行时 :param data: 日志 :param refresh: 追加或清空再输出 """ if refresh: self.runtimeInfor.setText(data) else: self.runtimeInfor.setText(self.runtimeInfor.toPlainText() + '\n' + data) self.runtimeInfor.moveCursor(self.runtimeInfor.textCursor().End) def getFromPath(self): """ 获取源路径 :return: 源路径 """ return self.fromPath.toPlainText() def getTargetPath(self): """ 获取输出路径 :return: 输出路径 """ return self.targetPath.toPlainText() def outSupportTypes(self, data): """ 输出模型支持的类型 :param data: 类型串 """ self.modelType.setText(data) def outModelInfo(self, data): """ 输出模型信息 :param data: 模型信息 """ self.modelInfor.setText(data) def getOnlyNumber(self): """ 单次处理图片数量 :return: 数量 """ return self.onlyNumber.value() def getRecursionPathStatus(self): """ 是否递归目录 """ return self.recursionPath.checkState() == 2 def startRun(self): """ 开始分类 """ if self.working: self.outRuntimeInfo('任务执行中', False) return try: threadPool.submit(service.startRun, self) except Exception as e: print(e) def setWorking(self, status): self.working = status if __name__ == '__main__': service = fun.Service() app = QApplication(sys.argv) # 初始化窗口 m = MainWindow(service) m.btu_selectFromPath.clicked.connect(m.openFromPath) m.btu_selectTargetPath.clicked.connect(m.openTargetPath) m.btu_startRun.clicked.connect(m.startRun) m.setWindowTitle('1920*1080壁纸分类') m.show() service.signalRunTime.connect(m.outRuntimeInfo) service.signalWorking.connect(m.setWorking) service.signalModelInfo.connect(m.outModelInfo) service.signalModelSupportTypes.connect(m.outSupportTypes) threadPool.submit(service.iniModel) sys.exit(app.exec_())
虽说很简单,或许显得很那么······没用,但是也是自己的一个小成果,也算是又做了一个对自己有用的工具吧!
项目文件所在地址,内含训练好的模型,目前支持五种:https://github.com/WindSnowLi/picture-classify
原文
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。