赞
踩
本文章从源码入手,详细的剖析 MMDetection 的训练细节。理解它是如何实现通用,灵活的训练的。
MMDetection 支持单机单卡、多卡训练。而且对于训练部分的代码和模型、数据集的代码的耦合性极低。真正实现了支持各种训练模式、各种模型和数据集的通用训练模板。
tools/train.py
是单机单卡训练
时需要运行的文件,其主要使用方法为:
python tools/train.py ${CONFIG_FILE} [optional arguments]
如果为单机多卡
,需要运行 tools/dist_train.sh
,注意此文件不支持多机多卡训练:
./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
可选参数如下:
=========== optional arguments ===========
# --work-dir 存储日志和模型的目录
# --resume-from 加载 checkpoint 的目录
# --no-validate 是否在训练的时候进行验证
# 互斥组:
# --gpus 使用的 GPU 数量
# --gpu_ids 使用指定 GPU 的 id
# --seed 随机数种子
# --deterministic 是否设置 cudnn 为确定性行为
# --options 其他参数
# --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
# none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
我们来看一下 dist_train.sh 里面究竟是什么?
可以看出 dist_train.sh 的本质就是使用 torch.distributed.launch(这是分布式的辅助启动工具) 运行 tools/train.py。
torch.distributed.launch 需要使用 python -m 来运行,-m 是把一个模块当做脚本来运行的一个参数。一般情况下可以给 torch.distributed.launch 传如下的参数
--nproc_per_node:表示每台机器的 GPU 数量
--nnodes:表示机器的数量
--node_rank:机器的排名,如果为 0 代表是 master 节点(机器)。
--master_addr:master 节点的 IP 地址
--master_port:master 节点开放的端口号
可以看到,因为 dist_train.sh 只支持单机多卡训练,所以只传参了 --nproc_per_node(GPU 个数)和 --master_port(开放的端口号)
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
(一)从命令行和配置文件获取参数配置
(二)构建模型
# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
(三)构建数据集
# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]
(四)训练模型
# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
所以对于 train.py
来说,首先从命令行和配置文件读取配置,然后分别用 build_detector
、build_dataset
构建模型和数据集,最后将模型和数据集传入 train_detector
进行训练。
下面我们来看一下源码:
import argparse import copy import os import os.path as osp import time import warnings import mmcv import torch # Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式 from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist from mmcv.utils import get_git_hash from mmdet import __version__ from mmdet.apis import set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector from mmdet.utils import collect_env, get_root_logger # python tools/train.py ${CONFIG_FILE} [optional arguments] # =========== optional arguments =========== # --work-dir 存储日志和模型的目录 # --resume-from 加载 checkpoint 的目录 # --no-validate 是否在训练的时候进行验证 # 互斥组: # --gpus 使用的 GPU 数量 # --gpu_ids 使用指定 GPU 的 id # --seed 随机数种子 # --deterministic 是否设置 cudnn 为确定性行为 # --options 其他参数 # --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi'] # none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。 # --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。 def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') # action: store (默认, 表示保存参数) # action: store_true, store_false (如果指定参数, 则为 True, False) parser.add_argument( '--no-validate', action='store_true', help='whether not to evaluate the checkpoint during training') # --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 --------- group_gpus = parser.add_mutually_exclusive_group() group_gpus.add_argument( '--gpus', type=int, help='number of gpus to use ' '(only applicable to non-distributed training)') # 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id # 参数结果:[0, 1, 2, 3] # nargs = '*':参数个数可以设置0个或n个 # nargs = '+':参数个数可以设置1个或n个 # nargs = '?':参数个数可以设置0个或1个 group_gpus.add_argument( '--gpu-ids', type=int, nargs='+', help='ids of gpus to use ' '(only applicable to non-distributed training)') # ------------------------------------------------------------------------ parser.add_argument('--seed', type=int, default=None, help='random seed') parser.add_argument( '--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') # 其他参数: 可以使用 --options a=1,2,3 指定其他参数 # 参数结果: {'a': [1, 2, 3]} parser.add_argument( '--options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file (deprecate), ' 'change to --cfg-options instead.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') # 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') # 本地进程编号,此参数 torch.distributed.launch 会自动传入。 parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() # 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( '--options and --cfg-options cannot be both ' 'specified, --options is deprecated in favor of --cfg-options') if args.options: warnings.warn('--options is deprecated in favor of --cfg-options') args.cfg_options = args.options return args def main(): args = parse_args() cfg = Config.fromfile(args.config) # 从文件读取配置 # 从命令行读取额外的配置 if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # set cudnn_benchmark,设置True 可以加速输入大小固定的模型. 如:SSD300 if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename # work_dir 的优先程度为: 命令行 > 配置文件 if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录 elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None # os.path.basename(path) 返回文件名 # os.path.splitext(path) 分割路径, 返回路径名和文件扩展名的元组 cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) # 是否继续上次的训练 if args.resume_from is not None: cfg.resume_from = args.resume_from # gpu id if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) # init distributed env first, since logger depends on the dist info. # 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none. if args.launcher == 'none': distributed = False # launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’ else: distributed = True # 初始化 dist 里面会调用 init_process_group init_dist(args.launcher, **cfg.dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info meta['config'] = cfg.pretty_text # log some basic info logger.info(f'Distributed training: {distributed}') logger.info(f'Config:\n{cfg.pretty_text}') # 设置随机化种子 if args.seed is not None: logger.info(f'Set random seed to {args.seed}, ' f'deterministic: {args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed meta['exp_name'] = osp.basename(args.config) # 构建模型: 需要传入 cfg.model, cfg.train_cfg, cfg.test_cfg model = build_detector( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) model.init_weights() # 构建数据集: 需要传入 cfg.data.train datasets = [build_dataset(cfg.data.train)] # workflow 代表流程: # [('train', 2), ('val', 1)] 就代表,训练两个 epoch 验证一个 epoch if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmdet_version=__version__ + get_git_hash()[:7], CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES # 训练检测器, 传入:模型, 数据集, config 等 train_detector( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta) if __name__ == '__main__': main()
我们分别来看看与 train.py 相关的核心函数:
1、init_dist:
此函数负责调用 init_process_group,完成分布式的初始化
。在运行 dist_train.py
训练时,默认传递的 launcher 是 ‘pytorch’。所以此函数会进一步调用 _init_dist_pytorch 来完成初始化。
因为 torch.distributed 可以采用单进程控制多 GPU,也可以一个进程控制一个 GPU
。一个进程控制一个 GPU 是目前Pytorch中,无论是单节点还是多节点,进行数据并行训练最快的方式。在 mmdet 中也是这么实现的。既然是单个进程控制单个 GPU,那么我么就需要绑定当前进程控制的是哪个 GPU。可以理解为在使用 torch.distributed.launch 运行 py 文件时。 它会多次调用 py 文件,每个 py 文件控制一个 GPU。并向每个 py 文件传参 --local_rank。(local_rank 是在这台机器上的本地进程编号)这样对于每个 py 文件,都能拿到传入的本地进程编号,我们只需要把当前进程绑定到指定的 GPU 即可。
在 _init_dist_pytorch 中就会设置当前进程控制的默认 GPU(torch.cuda.set_device),再使用 dist.init_process_group 初始化,初始化的方式为默认的 env://,即环境变量的方式。使用 env:// 方式初始化就需要用 torch.distributed.launch 运行 py 文件,torch.distributed.launch 会根据传入的参数设置环境变量,并运行 py 文件。
# Copyright (c) Open-MMLab. All rights reserved. import functools import os import subprocess import torch import torch.distributed as dist import torch.multiprocessing as mp from mmcv.utils import TORCH_VERSION def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') # 默认进到这里 if launcher == 'pytorch': # 调用下面的 _init_dist_pytorch 函数来初始化。 _init_dist_pytorch(backend, **kwargs) elif launcher == 'mpi': _init_dist_mpi(backend, **kwargs) elif launcher == 'slurm': _init_dist_slurm(backend, **kwargs) else: raise ValueError(f'Invalid launcher type: {launcher}') def _init_dist_pytorch(backend, **kwargs): # rank 是所有进程的总编号,算上本地进程,从 0 开始。 rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() # 这里也可以使用命令行传递来的 --local_rank。 torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs)
2、set_random_seed:
此函数会对 python、numpy、torch 都设置随机数种子。
保持随机数种子相同时,卷积的结果在CPU上相同,在GPU上仍然不相同。这是因为,cudnn卷积行为的不确定性。使用 torch.backends.cudnn.deterministic = True 可以解决。
cuDNN 使用非确定性算法,并且可以使用 torch.backends.cudnn.enabled = False 来进行禁用。如果设置为 torch.backends.cudnn.enabled = True,说明设置为使用非确定性算法(即会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题)
一般来讲,应该遵循以下准则:
def set_random_seed(seed, deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# manual_seed_all 是为所有 GPU 都设置随机数种子。
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
3、get_root_logger:
get_root_logger 调用 get_logger 函数获取 logger 对象。
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
return logger
这里实现的 get_logger 函数非常灵活,如果传入相同的 log 的 name,会返回配置相同的 log。传入以点分割的日志名称的子模块,也会返回相同的 log。如:a 和 a.b 会返回相同的 log。
如果传入 log_file 会保存 log 的输出到 log_file 指定的路径,如果不传入 log_file,不保存日志的输出。只在控制台输出。
下面我们来分析一下源码:
import logging import torch.distributed as dist # 记录是否创建过 name 对应的 log,如果创建过设置为 True logger_initialized = {} def get_logger(name, log_file=None, log_level=logging.INFO): # 获取 log 对象。 logger = logging.getLogger(name) # 如果已经创建过,直接返回 if name in logger_initialized: return logger # 如果是创建过的以 ‘.’ 分割的子模块,也直接返回 for logger_name in logger_initialized: if name.startswith(logger_name): return logger stream_handler = logging.StreamHandler() handlers = [stream_handler] # 获取当前的 rank(总进程编号) if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: rank = 0 # 只有 rank 0(master 节点的 local_rank 为 0 的进程)的主机才保存日志 if rank == 0 and log_file is not None: file_handler = logging.FileHandler(log_file, 'w') handlers.append(file_handler) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') for handler in handlers: handler.setFormatter(formatter) handler.setLevel(log_level) logger.addHandler(handler) # 对于非 rank 为 0 的进程,只有 error 以上的信息才会显示 if rank == 0: logger.setLevel(log_level) else: logger.setLevel(logging.ERROR) # 将 log name 对应的值设为 True,表示创建过。 logger_initialized[name] = True return logger
因为在 train.py 中主要调用:构建模型(build_detector),构建数据集(build_dataset),训练模型(train_detector)的函数,我们下来分别看看这三个函数的源码。
build_detector
函数将配置文件config中的:model、train_cfg 和 test_cfg
三部分传入参数。
下面以 faster_rcnn_r50_fpn_1x_coco.py
配置文件来举例:
具体在faster_rcnn_r50_fpn.py
文件中
model
model = dict( type='FasterRCNN', pretrained='torchvision://resnet50', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch'), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), rpn_head=dict( type='RPNHead', in_channels=256, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0)), roi_head=dict( type='StandardRoIHead', bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=1.0))))
train_cfg
train_cfg = dict( rpn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, match_low_quality=True, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), allowed_border=-1, pos_weight=-1, debug=False), rpn_proposal=dict( nms_across_levels=False, nms_pre=2000, nms_post=1000, max_num=1000, nms_thr=0.7, min_bbox_size=0), rcnn=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), pos_weight=-1, debug=False))
test_cfg
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
)
运行时会将上面的三个值作为参数传入 build_detector
函数,build_detector 函数会调用 build
函数,build 函数调用 build_from_cfg
函数构建检测器对象。其中 train_cfg
和 test_cfg
作为默认参数用于构建 detector 对象。
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
# 调用 build_from_cfg 用来根据 config 字典构建 registry 里面的对象
return build_from_cfg(cfg, registry, default_args)
def build_detector(cfg, train_cfg=None, test_cfg=None):
# 调用 build 函数,传入 cfg, registry 对象,
# 把 train_cfg 和 test_cfg 作为默认字典传入
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
build_from_cfg
在 mmcv/utils/registery.py
中。其中参数 cfg 字典中的 type 键所对应的值表示需要创建的对象的类型。build_from_cfg 会自动在 Registry 注册的类中找到需要创建的类,并传入默认参数实例化。
def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict. Args: cfg (dict): Config dict. It should at least contain the key "type". registry (:obj:`Registry`): The registry to search the type from. default_args (dict, optional): Default initialization arguments. Returns: object: The constructed object. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'type' not in cfg: raise KeyError( f'the cfg dict must contain the key "type", but got {cfg}') if not isinstance(registry, Registry): raise TypeError('registry must be an mmcv.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') args = cfg.copy() # 获取 type 对应的值 obj_type = args.pop('type') if is_str(obj_type): # 获取需要创建的对象 obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') # 如果 default_args 不是 None,传入默认值再实例化。 if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args)
那么什么是 registry?
registry 就是注册类,将一个字符串和类关联起来。如果索引字符串就会获得类。Registry 是注册所需要的类,可以用它来注册类。我们可以使用如下的方式来注册类。
backbones = Registry('backbone')
@backbones.register_module()
class ResNet:
pass
backbones = Registry('backbone')
@backbones.register_module(name='mnet')
class MobileNet:
pass
backbones = Registry('backbone')
class ResNet:
pass
backbones.register_module(ResNet)
下面是 Registry 类的代码,它的内部维护了一个已经注册的类的字典 ——_module_dict。每当注册一个类就在字典里添加一个字符串(默认为类名)与类的映射。register_module 方法,利用装饰器将类名和类添加到 _module_dict 中。对于注册的模块可以通过 build_from_cfg 来构建。
import inspect import warnings from functools import partial from .misc import is_str class Registry: """A registry to map strings to classes. Args: name (str): Registry name. """ def __init__(self, name): self._name = name # 已经注册的类的字典 self._module_dict = dict() def __len__(self): return len(self._module_dict) def __contains__(self, key): return self.get(key) is not None def __repr__(self): format_str = self.__class__.__name__ + \ f'(name={self._name}, ' \ f'items={self._module_dict})' return format_str @property def name(self): return self._name @property def module_dict(self): return self._module_dict def get(self, key): """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ return self._module_dict.get(key, None) def _register_module(self, module_class, module_name=None, force=False): if not inspect.isclass(module_class): raise TypeError('module must be a class, ' f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__ if not force and module_name in self._module_dict: raise KeyError(f'{module_name} is already registered ' f'in {self.name}') self._module_dict[module_name] = module_class def deprecated_register_module(self, cls=None, force=False): warnings.warn( 'The old API of register_module(module, force=False) ' 'is deprecated and will be removed, please use the new API ' 'register_module(name=None, force=False, module=None) instead.') if cls is None: return partial(self.deprecated_register_module, force=force) self._register_module(cls, force=force) return cls def register_module(self, name=None, force=False, module=None): """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module() >>> class ResNet: >>> pass >>> backbones = Registry('backbone') >>> @backbones.register_module(name='mnet') >>> class MobileNet: >>> pass >>> backbones = Registry('backbone') >>> class ResNet: >>> pass >>> backbones.register_module(ResNet) Args: name (str | None): The module name to be registered. If not specified, the class name will be used. force (bool, optional): Whether to override an existing class with the same name. Default: False. module (type): Module class to be registered. """ if not isinstance(force, bool): raise TypeError(f'force must be a boolean, but got {type(force)}') # NOTE: This is a walkaround to be compatible with the old api, # while it may introduce unexpected bugs. if isinstance(name, type): return self.deprecated_register_module(name, force=force) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: self._register_module( module_class=module, module_name=name, force=force) return module # raise the error ahead of time if not (name is None or isinstance(name, str)): raise TypeError(f'name must be a str, but got {type(name)}') # use it as a decorator: @x.register_module() def _register(cls): self._register_module( module_class=cls, module_name=name, force=force) return cls return _register
build_dataset
也类似,通过调用 build_from_cfg
创建。
def build_dataset(cfg, default_args=None): from .dataset_wrappers import (ConcatDataset, RepeatDataset, ClassBalancedDataset) if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) elif cfg['type'] == 'ClassBalancedDataset': dataset = ClassBalancedDataset( build_dataset(cfg['dataset'], default_args), cfg['oversample_thr']) elif isinstance(cfg.get('ann_file'), (list, tuple)): dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset
train_detector
的主要流程为:
(一)构建 data loaders:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
(二)构建分布式处理对象:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
(三)构建优化器:
optimizer = build_optimizer(model, cfg.optimizer)
(四)创建 EpochBasedRunner 并进行训练:
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
我们来看一下源码:
def train_detector(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): # 获取 logger logger = get_root_logger(cfg.log_level) # ==================== 构建 data loaders ==================== dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] # 获得 samples_per_gpu if 'imgs_per_gpu' in cfg.data: logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' 'Please use "samples_per_gpu" instead') if 'samples_per_gpu' in cfg.data: logger.warning( f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' f'={cfg.data.imgs_per_gpu} is used in this experiments') else: logger.warning( 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' f'{cfg.data.imgs_per_gpu} in this experiments') cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu data_loaders = [ build_dataloader( ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, # cfg.gpus will be ignored if distributed len(cfg.gpu_ids), dist=distributed, seed=cfg.seed) for ds in dataset ] # ==================== 构建分布式处理对象 ===================== # 如果是多卡会进入此 if if distributed: find_unused_parameters = cfg.get('find_unused_parameters', False) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) # 单卡进入 else: model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) # ====================== 构建优化器 ========================== optimizer = build_optimizer(model, cfg.optimizer) # ============= 创建 EpochBasedRunner 并进行训练 ============== runner = EpochBasedRunner( model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp # fp16 setting fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: optimizer_config = Fp16OptimizerHook( **cfg.optimizer_config, **fp16_cfg, distributed=distributed) elif distributed and 'type' not in cfg.optimizer_config: optimizer_config = OptimizerHook(**cfg.optimizer_config) else: optimizer_config = cfg.optimizer_config # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) if distributed: runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
在本篇文章(一)中,主要讲解了,train.py 中的主要流程,train.py 中的重要的函数以及函数的具体实现。但是 train_detector 只讲了流程,并没有拆开详细讲解。在下一小结中我们会详细讲解 train_detector 的每一步究竟做了什么。
参考:
https://zhuanlan.zhihu.com/p/163747610
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。