当前位置:   article > 正文

图像分类----自构建数据集_图像自分类 csdn

图像自分类 csdn

图像分类----自构建数据集

代码流程

1.引入库

import os
import cv2
import time
import shutil
import urllib3
import requests
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
urllib3.disable_warnings()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

2.利用爬虫收集图片数据

cookies = {
            'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621',
            'BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B',
            'PSTM': '1646905430',
            'BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
            'BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598',
            'H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061',
            'BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
            'H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
            'BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
            'H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
            'indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D',
            'cleanHistoryStatus': '0',
            'BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
            'BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm',
            'BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm',
            'ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==',
            'delPer': '0',
            'PSINO': '2',
            'BA_HECTOR': '8h24a024042g05alup1h3g0aq0q',
            }

headers = {
            'Connection': 'keep-alive',
            'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"',
            'Accept': 'text/plain, */*; q=0.01',
            'X-Requested-With': 'XMLHttpRequest',
            'sec-ch-ua-mobile': '?0',
            'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36',
            'sec-ch-ua-platform': '"macOS"',
            'Sec-Fetch-Site': 'same-origin',
            'Sec-Fetch-Mode': 'cors',
            'Sec-Fetch-Dest': 'empty',
            'Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=&copyright=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97',
            'Accept-Language': 'zh-CN,zh;q=0.9',
            }
if not os.path.exists('dataset'):
    os.makedirs('dataset')
    print('新建 dataset 文件夹')
def craw_single_class(keyword, DOWNLOAD_NUM=200):
    if os.path.exists('dataset/' + keyword[1]):
        print('文件夹 dataset/{} 已存在,之后直接将爬取到的图片保存至该文件夹中'.format(keyword))
    else:
        os.makedirs('dataset/{}'.format(keyword[1]))
        print('新建文件夹:dataset/{}'.format(keyword[1]))
    count = 1
    with tqdm(total=DOWNLOAD_NUM, position=0, leave=True) as pbar:
        # 爬取第几张
        num = 0
        # 是否继续爬取
        FLAG = True
        while FLAG:
            page = 30 * count
            params = (
                ('tn', 'resultjson_com'),
                ('logid', '12508239107856075440'),
                ('ipn', 'rj'),
                ('ct', '201326592'),
                ('is', ''),
                ('fp', 'result'),
                ('fr', ''),
                ('word', f'{keyword[0]}'),
                ('queryWord', f'{keyword[0]}'),
                ('cl', '2'),
                ('lm', '-1'),
                ('ie', 'utf-8'),
                ('oe', 'utf-8'),
                ('adpicid', ''),
                ('st', '-1'),
                ('z', ''),
                ('ic', ''),
                ('hd', ''),
                ('latest', ''),
                ('copyright', ''),
                ('s', ''),
                ('se', ''),
                ('tab', ''),
                ('width', ''),
                ('height', ''),
                ('face', '0'),
                ('istype', '2'),
                ('qc', ''),
                ('nc', '1'),
                ('expermode', ''),
                ('nojc', ''),
                ('isAsync', ''),
                ('pn', f'{page}'),
                ('rn', '30'),
                ('gsm', '1e'),
                ('1647838001666', ''),
            )
            response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params,
                                    cookies=cookies)
            if response.status_code == 200:
                try:
                    json_data = response.json().get("data")
                    if json_data:
                        for x in json_data:
                            type = x.get("type")
                            if type not in ["gif"]:
                                img = x.get("thumbURL")
                                fromPageTitleEnc = x.get("fromPageTitleEnc")
                                try:
                                    resp = requests.get(url=img, verify=False)
                                    time.sleep(1)
                                    # print(f"链接 {img}")
                                    # 保存文件名
                                    # file_save_path = f'dataset/{keyword}/{num}-{fromPageTitleEnc}.{type}'
                                    file_save_path = f'dataset/{keyword[1]}/{num}.{type}'
                                    with open(file_save_path, 'wb') as f:
                                        f.write(resp.content)
                                        f.flush()
                                        # print('第 {} 张图像 {} 爬取完成'.format(num, fromPageTitleEnc))
                                        num += 1
                                        pbar.update(1)  # 进度条更新
                                    # 爬取数量达到要求
                                    if num > DOWNLOAD_NUM:
                                        FLAG = False
                                        print('{} 张图像爬取完毕'.format(num))
                                        break
                                except Exception:
                                    pass
                except:
                    pass
            else:
                break
            count += 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127

3.删除异常格式的图片文件

def remove_error_file(dataset_path):
    for fruit in tqdm(os.listdir(dataset_path)):
        for file in os.listdir(os.path.join(dataset_path, fruit)):
            file_path = os.path.join(dataset_path, fruit, file)
            img = cv2.imread(file_path)
            if img is None:
                print(file_path, '读取错误,删除')
                os.remove(file_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

4.删除非三通道的图像

def remove_not_threechanel_photo(dataset_path):
    for fruit in tqdm(os.listdir(dataset_path)):
        for file in os.listdir(os.path.join(dataset_path, fruit)):
            file_path = os.path.join(dataset_path, fruit, file)
            img = np.array(Image.open(file_path))
            try:
                channel = img.shape[2]
                if channel != 3:
                    print(file_path, '非三通道,删除')
                    os.remove(file_path)
            except:
                print(file_path, '非三通道,删除')
                os.remove(file_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

4.划分训练集与测试集

def split_dataset(dataset_path):
    dataset_name = dataset_path.split('_')[0]
    classes = os.listdir(dataset_path)
    os.mkdir(os.path.join(dataset_path, 'train'))
    os.mkdir(os.path.join(dataset_path, 'val'))
    # 在 train 和 test 文件夹中创建各类别子文件夹
    for fruit in classes:
        os.mkdir(os.path.join(dataset_path, 'train', fruit))
        os.mkdir(os.path.join(dataset_path, 'val', fruit))
    df = pd.DataFrame()
    print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))
    for fruit in classes:  # 遍历每个类别
        # 读取该类别的所有图像文件名
        old_dir = os.path.join(dataset_path, fruit)
        images_filename = os.listdir(old_dir)
        np.random.shuffle(images_filename)  # 随机打乱
        # 划分训练集和测试集
        testset_numer = int(len(images_filename) * test_frac)  # 测试集图像个数
        testset_images = images_filename[:testset_numer]  # 获取拟移动至 test 目录的测试集图像文件名
        trainset_images = images_filename[testset_numer:]  # 获取拟移动至 train 目录的训练集图像文件名
        # 移动图像至 test 目录
        for image in testset_images:
            old_img_path = os.path.join(dataset_path, fruit, image)  # 获取原始文件路径
            new_test_path = os.path.join(dataset_path, 'val', fruit, image)  # 获取 test 目录的新文件路径
            shutil.move(old_img_path, new_test_path)  # 移动文件
        # 移动图像至 train 目录
        for image in trainset_images:
            old_img_path = os.path.join(dataset_path, fruit, image)  # 获取原始文件路径
            new_train_path = os.path.join(dataset_path, 'train', fruit, image)  # 获取 train 目录的新文件路径
            shutil.move(old_img_path, new_train_path)  # 移动文件
        # 删除旧文件夹
        assert len(os.listdir(old_dir)) == 0  # 确保旧文件夹中的所有图像都被移动走
        shutil.rmtree(old_dir)  # 删除文件夹
        # 工整地输出每一类别的数据个数
        print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
        # 保存到表格中
        df_cache = pd.DataFrame({'class': fruit, 'trainset': len(trainset_images), 'testset': len(testset_images)}, index=range(len(classes)))
        df = pd.concat([df, df_cache], axis=0, ignore_index=True)
    # 重命名数据集文件夹
    shutil.move(dataset_path, dataset_name + '_split')
    # 数据集各类别数量统计表格,导出为 csv 文件
    df['total'] = df['trainset'] + df['testset']
    df.to_csv('数据量统计.csv', index=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

5.主函数

class_dict = {'黄瓜': 'Cucumber', '南瓜': 'Pumpkin', '冬瓜': 'Winter Melon', '木瓜': 'Papaya','苦瓜': 'Bitter melon',
                  '丝瓜': 'Loofah', '窝瓜': 'Nesting melon', '甜瓜': 'Melon','香瓜': 'Cantaloupe',
                  '白兰瓜': 'White Orchid Melon', '黄金瓜': 'Golden Melon', '西葫芦': 'Zucchini', '人参果': 'Ginseng fruit',
                  '羊角蜜': 'Horn Honey', '佛手瓜': 'Fritillary melon', '伊丽莎白瓜': 'Elizabethan melon'}
    dataset_path = 'dataset'
    for each in class_dict.items():
        craw_single_class(each, DOWNLOAD_NUM = 180)
        remove_error_file(dataset_path)
        remove_not_threechanel_photo(dataset_path)
    test_frac = 0.2  # 测试集比例
    np.random.seed(123)  # 随机数种子,便于复现
    split_dataset(dataset_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/262417?site
推荐阅读
相关标签
  

闽ICP备14008679号