赞
踩
参考代码:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
安装及环境配置可参考作者介绍或者其他博客
训练时需要将数据集转换为coco格式的数据集,本人使用的数据集为visdrone数据集,转换过程如下:txt->XML->coco.json
import os from PIL import Image # 把下面的路径改成你自己的路径即可 root_dir = "./VisDrone2019-DET-train/" annotations_dir = root_dir+"annotations/" image_dir = root_dir + "images/" xml_dir = root_dir+"Annotations_XML/" # 下面的类别也换成你自己数据类别,也可适用于其他的数据集转换 class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others'] for filename in os.listdir(annotations_dir): fin = open(annotations_dir+filename, 'r') image_name = filename.split('.')[0] img = Image.open(image_dir+image_name+".jpg") # 若图像数据是“png”转换成“.png”即可 xml_name = xml_dir+image_name+'.xml' with open(xml_name, 'w') as fout: fout.write('<annotation>'+'\n') fout.write('\t'+'<folder>VOC2007</folder>'+'\n') fout.write('\t'+'<filename>'+image_name+'.jpg'+'</filename>'+'\n') fout.write('\t'+'<source>'+'\n') fout.write('\t\t'+'<database>'+'VisDrone2018 Database'+'</database>'+'\n') fout.write('\t\t'+'<annotation>'+'VisDrone2018'+'</annotation>'+'\n') fout.write('\t\t'+'<image>'+'flickr'+'</image>'+'\n') fout.write('\t\t'+'<flickrid>'+'Unspecified'+'</flickrid>'+'\n') fout.write('\t'+'</source>'+'\n') fout.write('\t'+'<owner>'+'\n') fout.write('\t\t'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'\n') fout.write('\t\t'+'<name>'+'Haipeng Zhang'+'</name>'+'\n') fout.write('\t'+'</owner>'+'\n') fout.write('\t'+'<size>'+'\n') fout.write('\t\t'+'<width>'+str(img.size[0])+'</width>'+'\n') fout.write('\t\t'+'<height>'+str(img.size[1])+'</height>'+'\n') fout.write('\t\t'+'<depth>'+'3'+'</depth>'+'\n') fout.write('\t'+'</size>'+'\n') fout.write('\t'+'<segmented>'+'0'+'</segmented>'+'\n') for line in fin.readlines(): line = line.split(',') fout.write('\t'+'<object>'+'\n') fout.write('\t\t'+'<name>'+class_name[int(line[5])]+'</name>'+'\n') fout.write('\t\t'+'<pose>'+'Unspecified'+'</pose>'+'\n') fout.write('\t\t'+'<truncated>'+line[6]+'</truncated>'+'\n') fout.write('\t\t'+'<difficult>'+str(int(line[7]))+'</difficult>'+'\n') fout.write('\t\t'+'<bndbox>'+'\n') fout.write('\t\t\t'+'<xmin>'+line[0]+'</xmin>'+'\n') fout.write('\t\t\t'+'<ymin>'+line[1]+'</ymin>'+'\n') # pay attention to this point!(0-based) fout.write('\t\t\t'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'\n') fout.write('\t\t\t'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'\n') fout.write('\t\t'+'</bndbox>'+'\n') fout.write('\t'+'</object>'+'\n') fin.close() fout.write('</annotation>')
# coding=utf-8 import xml.etree.ElementTree as ET import os import json voc_clses = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] categories = [] for iind, cat in enumerate(voc_clses): cate = {} cate['supercategory'] = cat cate['name'] = cat cate['id'] = iind categories.append(cate) def getimages(xmlname, id): sig_xml_box = [] tree = ET.parse(xmlname) root = tree.getroot() images = {} for i in root: # 遍历一级节点 if i.tag == 'filename': file_name = i.text # 0001.jpg # print('image name: ', file_name) images['file_name'] = file_name if i.tag == 'size': for j in i: if j.tag == 'width': width = j.text images['width'] = width if j.tag == 'height': height = j.text images['height'] = height if i.tag == 'object': for j in i: if j.tag == 'name': cls_name = j.text cat_id = voc_clses.index(cls_name) + 1 if j.tag == 'bndbox': bbox = [] xmin = 0 ymin = 0 xmax = 0 ymax = 0 for r in j: if r.tag == 'xmin': xmin = eval(r.text) if r.tag == 'ymin': ymin = eval(r.text) if r.tag == 'xmax': xmax = eval(r.text) if r.tag == 'ymax': ymax = eval(r.text) bbox.append(xmin) bbox.append(ymin) bbox.append(xmax - xmin) bbox.append(ymax - ymin) bbox.append(id) # 保存当前box对应的image_id bbox.append(cat_id) # anno area bbox.append((xmax - xmin) * (ymax - ymin) - 10.0) # bbox的ares # coco中的ares数值是 < w*h 的, 因为它其实是按segmentation的面积算的,所以我-10.0一下... sig_xml_box.append(bbox) # print('bbox', xmin, ymin, xmax - xmin, ymax - ymin, 'id', id, 'cls_id', cat_id) images['id'] = id # print ('sig_img_box', sig_xml_box) return images, sig_xml_box def txt2list(txtfile): f = open(txtfile) l = [] for line in f: l.append(line[:-1]) return l # voc2007xmls = 'anns' voc2007xmls = '/data2/chenjia/data/VOCdevkit/VOC2007/Annotations' # test_txt = 'voc2007/test.txt' test_txt = '/data2/chenjia/data/VOCdevkit/VOC2007/ImageSets/Main/test.txt' xml_names = txt2list(test_txt) xmls = [] bboxes = [] ann_js = {} for ind, xml_name in enumerate(xml_names): xmls.append(os.path.join(voc2007xmls, xml_name + '.xml')) json_name = 'annotations/instances_voc2007val.json' images = [] for i_index, xml_file in enumerate(xmls): image, sig_xml_bbox = getimages(xml_file, i_index) images.append(image) bboxes.extend(sig_xml_bbox) ann_js['images'] = images ann_js['categories'] = categories annotations = [] for box_ind, box in enumerate(bboxes): anno = {} anno['image_id'] = box[-3] anno['category_id'] = box[-2] anno['bbox'] = box[:-3] anno['id'] = box_ind anno['area'] = box[-1] anno['iscrowd'] = 0 annotations.append(anno) ann_js['annotations'] = annotations json.dump(ann_js, open(json_name, 'w'), indent=4) # indent=4 更加美观显示
将生成的json及图片按照一下结构放置,注意修改json文件名称:
修改projects下coco.yml内容,按照自己的数据库情况修改
project_name: visdrone2019 # also the folder name of the dataset that under data_path folder train_set: train2019 val_set: val2019 num_gpus: 1 # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. mean: [0.373, 0.378, 0.364] std: [0.191, 0.182, 0.194] # this is coco anchors, change it if necessary anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]' # must match your dataset's category_id. # category_id is one_indexed, # for example, index of 'car' here is 2, while category_id of is 3 obj_list: ["pedestrian","people","bicycle","car","van","truck","tricycle","awning-tricycle","bus","motor"]
python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10
–load_weights /path/to/your/weights/efficientdet-d2.pth
提前下载model文件,放置在文件夹中,建议d0,d1,d2(大了显存会溢出),如出现显存溢出情况,调整batch_size大小。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。