当前位置:   article > 正文

(笔记)yolov5-5中train.py一些简单注释_import test # import test.py to get map after each

import test # import test.py to get map after each epoch from models.experim

train.py里面加了很多额外的功能,使得整体看起来比较复杂,其实核心部分主要就是

读取数据集,

加载模型,

训练中损失的计算。

这里简单的将train.py按每部分的功能进行了一些注释。

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from copy import deepcopy
  8. from pathlib import Path
  9. from threading import Thread
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  30. from utils.google_utils import attempt_download
  31. from utils.loss import ComputeLoss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
  34. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  35. logger = logging.getLogger(__name__)
  36. def train(hyp, opt, device, tb_writer=None):
  37. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) #读取hyp超参数文件
  38. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  39. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  40. '''
  41. 创建目录,设置模型、txt等保存的路径
  42. save_dir = Path(opt.save_dir) ,save_dir就是根据opt参数里面所设置的路径而生成的目录
  43. '''
  44. # Directories
  45. wdir = save_dir / 'weights'
  46. wdir.mkdir(parents=True, exist_ok=True) # make dir
  47. last = wdir / 'last.pt'
  48. best = wdir / 'best.pt'
  49. results_file = save_dir / 'results.txt'
  50. '''
  51. 将本次运行的超参数(hyp),和选项操作(opt)给保存成yaml格式,
  52. 保存在了每次训练得到的exp文件中,这两个yaml显示了我们本次训练所选择的超参数和opt参数,opt参数是train代码下面那一堆参数选择
  53. '''
  54. # Save run settings
  55. with open(save_dir / 'hyp.yaml', 'w') as f:
  56. yaml.dump(hyp, f, sort_keys=False)
  57. with open(save_dir / 'opt.yaml', 'w') as f:
  58. yaml.dump(vars(opt), f, sort_keys=False)
  59. '''
  60. 配置:画图开关,cuda,种子,读取数据集相关的yaml文件
  61. '''
  62. # Configure
  63. plots = not opt.evolve # create plots
  64. cuda = device.type != 'cpu'
  65. init_seeds(2 + rank)
  66. with open(opt.data,encoding='utf-8') as f:
  67. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  68. is_coco = opt.data.endswith('coco.yaml')
  69. '''
  70. 加载相关日志功能:如logger,wandb
  71. '''
  72. # Logging- Doing this before checking the dataset. Might update data_dict
  73. loggers = {'wandb': None} # loggers dict
  74. if rank in [-1, 0]:
  75. opt.hyp = hyp # add hyperparameters
  76. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  77. wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
  78. loggers['wandb'] = wandb_logger.wandb
  79. data_dict = wandb_logger.data_dict
  80. if wandb_logger.wandb:
  81. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
  82. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  83. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  84. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  85. '''
  86. 加载模型
  87. '''
  88. # Model
  89. pretrained = weights.endswith('.pt')
  90. if pretrained:
  91. with torch_distributed_zero_first(rank):
  92. attempt_download(weights) # download if not found locally
  93. ckpt = torch.load(weights, map_location=device) # load checkpoint
  94. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  95. exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
  96. state_dict = ckpt['model'].float().state_dict() # to FP32
  97. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  98. model.load_state_dict(state_dict, strict=False) # load
  99. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  100. else:
  101. model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  102. with torch_distributed_zero_first(rank):
  103. check_dataset(data_dict) # check
  104. train_path = data_dict['train']
  105. test_path = data_dict['val']
  106. '''
  107. 冰冻一些层,使得这些层在反向传播的时候不再更新权重,需要冻结的层,可以写在freeze列表中
  108. '''
  109. # Freeze
  110. freeze = [] # parameter names to freeze (full or partial)
  111. for k, v in model.named_parameters():
  112. v.requires_grad = True # train all layers
  113. if any(x in k for x in freeze):
  114. print('freezing %s' % k)
  115. v.requires_grad = False
  116. '''
  117. nbs为名义批次,比如实际批次为16,那么64/16=4,每4次迭代,才进行一次反向传播更新权重,可以节约显存.
  118. '''
  119. # Optimizer
  120. nbs = 64 # nominal batch size
  121. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  122. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  123. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  124. '''
  125. 设置优化器,权重weight使用了正则化,偏置bias则不使用正则化
  126. '''
  127. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  128. for k, v in model.named_modules():
  129. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  130. pg2.append(v.bias) # biases
  131. if isinstance(v, nn.BatchNorm2d):
  132. pg0.append(v.weight) # no decay
  133. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  134. pg1.append(v.weight) # apply decay
  135. if opt.adam:
  136. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  137. else:
  138. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  139. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  140. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  141. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  142. del pg0, pg1, pg2
  143. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  144. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  145. '''
  146. 设置学习率策略:两者可供选择,线性学习率和余弦退火学习率
  147. '''
  148. if opt.linear_lr:
  149. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  150. else:
  151. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] 余弦退火方式
  152. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  153. # plot_lr_scheduler(optimizer, scheduler, epochs)
  154. '''
  155. 设置ema(指数移动平均):目的是为了收敛的曲线更加平滑
  156. '''
  157. # EMA
  158. ema = ModelEMA(model) if rank in [-1, 0] else None
  159. '''
  160. 继续接着训练,需要加载优化器,ema模型,训练结果txt,周期
  161. '''
  162. # Resume
  163. start_epoch, best_fitness = 0, 0.0
  164. if pretrained:
  165. # Optimizer
  166. if ckpt['optimizer'] is not None:
  167. optimizer.load_state_dict(ckpt['optimizer'])
  168. best_fitness = ckpt['best_fitness']
  169. # EMA
  170. if ema and ckpt.get('ema'):
  171. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  172. ema.updates = ckpt['updates']
  173. # Results
  174. if ckpt.get('training_results') is not None:
  175. results_file.write_text(ckpt['training_results']) # write results.txt
  176. # Epochs
  177. start_epoch = ckpt['epoch'] + 1
  178. if opt.resume:
  179. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  180. if epochs < start_epoch:
  181. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  182. (weights, ckpt['epoch'], epochs))
  183. epochs += ckpt['epoch'] # finetune additional epochs
  184. del ckpt, state_dict
  185. '''
  186. 模型默认的下采样倍率model.stride: [8,16,32]
  187. gs代表模型下采样的最大步长: 后续为了保证输入模型的图片宽高是最大步长的整数倍
  188. nl代表模型输出的尺度,默认为3个尺度, 分别下采样8倍,16倍,32倍. nl=3
  189. imgsz, imgsz_test代表训练和测试的图片大小,比如opt.img_size=[640,480],那么训练图片的最大边为640,测试图片最大边为480
  190. 如果opt.img_size=[640],那么自动补成[640,640]
  191. 当然比如这边imgsz是640,那么训练的图片是640*640吗,不一定,具体看你怎么设置,默认是padding成正方形进行训练的.
  192. '''
  193. # Image sizes
  194. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  195. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  196. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  197. '''
  198. 多卡训练
  199. '''
  200. # DP mode
  201. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  202. model = torch.nn.DataParallel(model)
  203. # SyncBatchNorm
  204. if opt.sync_bn and cuda and rank != -1:
  205. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  206. logger.info('Using SyncBatchNorm()')
  207. '''
  208. 加载数据集
  209. '''
  210. # Trainloader
  211. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  212. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  213. world_size=opt.world_size, workers=opt.workers,
  214. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  215. '''
  216. 检验加载的数据集是否正确: 利用数据集中的最大类别<nc
  217. '''
  218. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  219. nb = len(dataloader) # number of batches
  220. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  221. # Process 0
  222. if rank in [-1, 0]:
  223. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  224. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  225. world_size=opt.world_size, workers=opt.workers,
  226. pad=0.5, prefix=colorstr('val: '))[0]
  227. if not opt.resume:
  228. labels = np.concatenate(dataset.labels, 0)
  229. c = torch.tensor(labels[:, 0]) # classes
  230. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  231. # model._initialize_biases(cf.to(device))
  232. if plots:
  233. plot_labels(labels, names, save_dir, loggers)
  234. if tb_writer:
  235. tb_writer.add_histogram('classes', c, 0)
  236. '''
  237. Yolov5原本在模型配置文件(如yolov5l.py)中有默认的anchors,这些anchors是基于COCO数据集在640×640图像大小下锚定框的尺寸。
  238. Yolov5会自动按照新的数据集的labels自动学习anchors的尺寸。采用 k 均值和遗传学习算法对自定义数据集进行分析,获得适合自定义数据集中对象边界框预测的预设锚定框。
  239. 训练一开始会先计算Best Possible Recall (BPR),当BPR < 0.98时,再在kmean_anchors函数中进行k 均值和遗传学习算法更新anchors。
  240. check_anchors函数
  241. '''
  242. # Anchors
  243. if not opt.noautoanchor:
  244. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  245. model.half().float() # pre-reduce anchor precision
  246. # DDP mode
  247. if cuda and rank != -1:
  248. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
  249. # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
  250. find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
  251. '''
  252. 模型参数的一些调整
  253. '''
  254. # Model parameters
  255. hyp['box'] *= 3. / nl # scale to layers
  256. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  257. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  258. hyp['label_smoothing'] = opt.label_smoothing
  259. model.nc = nc # attach number of classes to model
  260. model.hyp = hyp # attach hyperparameters to model
  261. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  262. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  263. model.names = names
  264. '''
  265. 开始训练
  266. '''
  267. # Start training
  268. t0 = time.time()
  269. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  270. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  271. maps = np.zeros(nc) # mAP per class
  272. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  273. scheduler.last_epoch = start_epoch - 1 # do not move
  274. scaler = amp.GradScaler(enabled=cuda)
  275. compute_loss = ComputeLoss(model) # init loss class
  276. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  277. f'Using {dataloader.num_workers} dataloader workers\n'
  278. f'Logging results to {save_dir}\n'
  279. f'Starting training for {epochs} epochs...')
  280. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  281. model.train()
  282. # Update image weights (optional)
  283. if opt.image_weights:
  284. # Generate indices
  285. if rank in [-1, 0]:
  286. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  287. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  288. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  289. # Broadcast if DDP
  290. if rank != -1:
  291. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  292. dist.broadcast(indices, 0)
  293. if rank != 0:
  294. dataset.indices = indices.cpu().numpy()
  295. # Update mosaic border
  296. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  297. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  298. mloss = torch.zeros(4, device=device) # mean losses
  299. if rank != -1:
  300. dataloader.sampler.set_epoch(epoch)
  301. pbar = enumerate(dataloader)
  302. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  303. if rank in [-1, 0]:
  304. pbar = tqdm(pbar, total=nb) # progress bar
  305. optimizer.zero_grad() #梯度清零
  306. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  307. # ni用来记录当前的迭代次数,如果小于nw(warm up需要的迭代次数),就进行wam uo
  308. ni = i + nb * epoch # number integrated batches (since train start)
  309. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  310. # Warmup
  311. if ni <= nw:
  312. xi = [0, nw] # x interp
  313. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  314. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  315. for j, x in enumerate(optimizer.param_groups):
  316. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  317. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  318. if 'momentum' in x:
  319. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  320. '''
  321. 对图片尺寸进行变换,多尺度训练,在opt参数里面可以选择开启或关闭
  322. '''
  323. # Multi-scale
  324. if opt.multi_scale:
  325. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  326. sf = sz / max(imgs.shape[2:]) # scale factor
  327. if sf != 1:
  328. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  329. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  330. # Forward
  331. with amp.autocast(enabled=cuda):
  332. pred = model(imgs) # forward
  333. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  334. if rank != -1:
  335. loss *= opt.world_size # gradient averaged between devices in DDP mode
  336. if opt.quad:
  337. loss *= 4.
  338. # Backward,选择优化器之前的步骤就是调用loss进行反向传播
  339. scaler.scale(loss).backward()
  340. # Optimize
  341. if ni % accumulate == 0:
  342. scaler.step(optimizer) # optimizer.step
  343. scaler.update()
  344. optimizer.zero_grad() #梯度清零
  345. if ema:
  346. ema.update(model)
  347. # Print
  348. if rank in [-1, 0]:
  349. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  350. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  351. s = ('%10s' * 2 + '%10.4g' * 6) % (
  352. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  353. pbar.set_description(s)
  354. # Plot 画图
  355. if plots and ni < 3:
  356. f = save_dir / f'train_batch{ni}.jpg' # filename
  357. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  358. # if tb_writer:
  359. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  360. # tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
  361. elif plots and ni == 10 and wandb_logger.wandb:
  362. wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
  363. save_dir.glob('train*.jpg') if x.exists()]})
  364. # end batch ------------------------------------------------------------------------------------------------
  365. # end epoch ----------------------------------------------------------------------------------------------------
  366. # Scheduler
  367. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  368. scheduler.step()
  369. # DDP process 0 or single-GPU
  370. if rank in [-1, 0]:
  371. # mAP
  372. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  373. final_epoch = epoch + 1 == epochs
  374. if not opt.notest or final_epoch: # Calculate mAP
  375. wandb_logger.current_epoch = epoch + 1
  376. results, maps, times = test.test(data_dict,
  377. batch_size=batch_size * 2,
  378. imgsz=imgsz_test,
  379. model=ema.ema,
  380. single_cls=opt.single_cls,
  381. dataloader=testloader,
  382. save_dir=save_dir,
  383. verbose=nc < 50 and final_epoch,
  384. plots=plots and final_epoch,
  385. wandb_logger=wandb_logger,
  386. compute_loss=compute_loss,
  387. is_coco=is_coco)
  388. # Write
  389. with open(results_file, 'a') as f:
  390. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  391. if len(opt.name) and opt.bucket:
  392. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  393. # Log
  394. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  395. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  396. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  397. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  398. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  399. if tb_writer:
  400. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  401. if wandb_logger.wandb:
  402. wandb_logger.log({tag: x}) # W&B
  403. # Update best mAP
  404. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  405. if fi > best_fitness:
  406. best_fitness = fi
  407. wandb_logger.end_epoch(best_result=best_fitness == fi)
  408. # Save model
  409. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  410. ckpt = {'epoch': epoch,
  411. 'best_fitness': best_fitness,
  412. 'training_results': results_file.read_text(),
  413. 'model': deepcopy(model.module if is_parallel(model) else model).half(),
  414. 'ema': deepcopy(ema.ema).half(),
  415. 'updates': ema.updates,
  416. 'optimizer': optimizer.state_dict(),
  417. 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
  418. # Save last, best and delete
  419. torch.save(ckpt, last)
  420. if best_fitness == fi:
  421. torch.save(ckpt, best)
  422. if wandb_logger.wandb:
  423. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  424. wandb_logger.log_model(
  425. last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  426. del ckpt
  427. # end epoch ----------------------------------------------------------------------------------------------------
  428. # end training
  429. if rank in [-1, 0]:
  430. # Plots
  431. if plots:
  432. plot_results(save_dir=save_dir) # save as results.png
  433. if wandb_logger.wandb:
  434. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  435. wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
  436. if (save_dir / f).exists()]})
  437. # Test best.pt
  438. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  439. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  440. for m in (last, best) if best.exists() else (last): # speed, mAP tests
  441. results, _, _ = test.test(opt.data,
  442. batch_size=batch_size * 2,
  443. imgsz=imgsz_test,
  444. conf_thres=0.001,
  445. iou_thres=0.7,
  446. model=attempt_load(m, device).half(),
  447. single_cls=opt.single_cls,
  448. dataloader=testloader,
  449. save_dir=save_dir,
  450. save_json=True,
  451. plots=False,
  452. is_coco=is_coco)
  453. # Strip optimizers
  454. final = best if best.exists() else last # final model
  455. for f in last, best:
  456. if f.exists():
  457. strip_optimizer(f) # strip optimizers
  458. if opt.bucket:
  459. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  460. if wandb_logger.wandb and not opt.evolve: # Log the stripped model
  461. wandb_logger.wandb.log_artifact(str(final), type='model',
  462. name='run_' + wandb_logger.wandb_run.id + '_model',
  463. aliases=['last', 'best', 'stripped'])
  464. wandb_logger.finish_run()
  465. else:
  466. dist.destroy_process_group()
  467. torch.cuda.empty_cache()
  468. return results
  469. if __name__ == '__main__':
  470. parser = argparse.ArgumentParser()
  471. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') #选择用来训练的网络模型路径,当default为空时,就是没有预训练模型,从头开始训练
  472. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path') #网络模型的配置参数,地址在models/hub中的yaml文件
  473. parser.add_argument('--data', type=str, default='data/new.yaml', help='data.yaml path') #数据集地址
  474. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path') #超参数,data里面的hyp.xx.yaml两个二选一
  475. parser.add_argument('--epochs', type=int, default=300) #设置训练多少轮
  476. parser.add_argument('--batch-size', type=int, default=32, help='total batch size for all GPUs') #设置batch_size,每次送入网络多少张图片
  477. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes') #设置训练和预测时候的图片尺寸大小(保持一致)
  478. parser.add_argument('--rect', action='store_true', help='rectangular training') #矩阵训练方式,默认关闭
  479. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') #继续上次中止的训练,在填入default="上次训练的权重位置"即可接着训练
  480. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') #是否只保存最后一次训练的权重,不用设置
  481. parser.add_argument('--notest', action='store_true', help='only test final epoch') #是否只在最后一轮进行测试,不用设置
  482. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') #是否禁止采用锚框,不用设置
  483. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') #超参数进化,超参数的调优,默认关闭
  484. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') #不用管
  485. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') #是否把图片缓存用于更好的训练中,默认关闭
  486. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') #对上一轮测试效果不是很好的图片
  487. #在下一轮中对这些图片加一些相关的权重,着重训练,默认关闭
  488. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') #选择GPU还是CPU不用设置系统会自动选择
  489. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') #对图片尺寸进行变换,多尺度训练,默认关闭
  490. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') #为的单类别还是多类别,默认多类别
  491. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') #Adam优化器,默认是不用的,用的是随机梯度下降
  492. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') #不用看
  493. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') #不用看
  494. parser.add_argument('--workers', type=int, default=4, help='maximum number of dataloader workers') #可以先将workers改为0,训练之后没什么问题再调大
  495. parser.add_argument('--project', default='runs/train', help='save to project/name') #训练过的权重文件保存路径
  496. parser.add_argument('--entity', default=None, help='W&B entity') #不用管
  497. parser.add_argument('--name', default='exp', help='save to project/name') #存放训练好的权重文件的文件名
  498. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') #如果设置为True,每次训练的结果就不会每次新建一个exp而是在一个文件,没啥用
  499. parser.add_argument('--quad', action='store_true', help='quad dataloader') #是否选择quad dataloader这种取数据方式,当训练尺寸>640时效果更好
  500. #在640尺寸上效果没有默认dataloder好,默认不用
  501. parser.add_argument('--linear-lr', action='store_true', help='linear LR') #一种学习率优化方式,默认的是余弦退火方式
  502. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon') #标签平滑可以设置为0.01、0.005,防止分类算法中过拟合的情况产生
  503. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  504. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  505. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') #不用管
  506. # parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') 作者还没实现
  507. opt = parser.parse_args()
  508. # Set DDP variables
  509. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  510. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  511. set_logging(opt.global_rank)
  512. if opt.global_rank in [-1, 0]:
  513. check_git_status()
  514. check_requirements()
  515. # Resume
  516. wandb_run = check_wandb_resume(opt)
  517. if opt.resume and not wandb_run: # resume an interrupted run
  518. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  519. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  520. apriori = opt.global_rank, opt.local_rank
  521. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  522. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  523. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  524. logger.info('Resuming training from %s' % ckpt)
  525. else:
  526. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  527. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  528. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  529. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  530. opt.name = 'evolve' if opt.evolve else opt.name
  531. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  532. # DDP mode
  533. opt.total_batch_size = opt.batch_size
  534. device = select_device(opt.device, batch_size=opt.batch_size)
  535. if opt.local_rank != -1:
  536. assert torch.cuda.device_count() > opt.local_rank
  537. torch.cuda.set_device(opt.local_rank)
  538. device = torch.device('cuda', opt.local_rank)
  539. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  540. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  541. opt.batch_size = opt.total_batch_size // opt.world_size
  542. # Hyperparameters
  543. with open(opt.hyp) as f:
  544. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  545. # Train 训练模式
  546. logger.info(opt)
  547. if not opt.evolve:
  548. tb_writer = None # init loggers
  549. if opt.global_rank in [-1, 0]:
  550. prefix = colorstr('tensorboard: ')
  551. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  552. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  553. train(hyp, opt, device, tb_writer)
  554. # Evolve hyperparameters (optional) 进化超参数
  555. else:
  556. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  557. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  558. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  559. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  560. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  561. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  562. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  563. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  564. 'box': (1, 0.02, 0.2), # box loss gain
  565. 'cls': (1, 0.2, 4.0), # cls loss gain
  566. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  567. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  568. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  569. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  570. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  571. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  572. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  573. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  574. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  575. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  576. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  577. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  578. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  579. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  580. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  581. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  582. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  583. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  584. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  585. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  586. opt.notest, opt.nosave = True, True # only test/save final epoch
  587. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  588. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  589. if opt.bucket:
  590. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  591. for _ in range(300): # generations to evolve
  592. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  593. # Select parent(s)
  594. parent = 'single' # parent selection method: 'single' or 'weighted'
  595. x = np.loadtxt('evolve.txt', ndmin=2)
  596. n = min(5, len(x)) # number of previous results to consider
  597. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  598. w = fitness(x) - fitness(x).min() # weights
  599. if parent == 'single' or len(x) == 1:
  600. # x = x[random.randint(0, n - 1)] # random selection
  601. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  602. elif parent == 'weighted':
  603. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  604. # Mutate
  605. mp, s = 0.8, 0.2 # mutation probability, sigma
  606. npr = np.random
  607. npr.seed(int(time.time()))
  608. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  609. ng = len(meta)
  610. v = np.ones(ng)
  611. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  612. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  613. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  614. hyp[k] = float(x[i + 7] * v[i]) # mutate
  615. # Constrain to limits
  616. for k, v in meta.items():
  617. hyp[k] = max(hyp[k], v[1]) # lower limit
  618. hyp[k] = min(hyp[k], v[2]) # upper limit
  619. hyp[k] = round(hyp[k], 5) # significant digits
  620. # Train mutation
  621. results = train(hyp.copy(), opt, device)
  622. # Write mutation results
  623. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  624. # Plot results
  625. plot_evolution(yaml_file)
  626. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  627. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/53696
推荐阅读
相关标签
  

闽ICP备14008679号