当前位置:   article > 正文

YoloV8训练源代码步骤以及推理源代码步骤解析_yolov8训练代码

yolov8训练代码

Train训练代码

  1. import cv2
  2. import numpy as np
  3. from PIL import Image
  4. from ultralytics import YOLO
  5. from torchvision import utils as vutils
  6. """训练"""
  7. # 加载模型
  8. model = YOLO("yolov8s-seg.pt") # 加载预训练模型(建议用于训练)
  9. # model = YOLO(r"D:\PC_DeepLearing\yolov8\runs\segment\train_resize2560\weights\best.pt") # 加载预训练模型(建议用于训练)
  10. # 使用模型
  11. model.train(data="turntable.yaml",cfg="default.yaml") # 训练模型

训练程序跳入

D:\PC_DeepLearing\yolov8\ultralytics\engine\model.py

然后在其中找到下面代码

        self.trainer.train()

训练程序跳入

D:\PC_DeepLearing\yolov8\ultralytics\engine\trainer.py

然后在其中找到下面代码

            self._do_train(world_size)

在其中我们可以看到

self._setup_train(world_size) #训练数据读取加载
  1. def _do_train(self, world_size=1):
  2. """Train completed, evaluate and plot if specified by arguments."""
  3. if world_size > 1:
  4. self._setup_ddp(world_size)
  5. self._setup_train(world_size) #训练数据读取加载
  6. self.epoch_time = None
  7. self.epoch_time_start = time.time()
  8. self.train_time_start = time.time()

接下来然后查看数据加载  setup_train(self, world_size):

  1. def _setup_train(self, world_size):
  2. """Builds dataloaders and optimizer on correct rank process."""
  3. # Model
  4. self.run_callbacks('on_pretrain_routine_start')
  5. ckpt = self.setup_model()
  6. self.model = self.model.to(self.device)
  7. self.set_model_attributes()
  8. # Freeze layers 冻结层 循环遍历查找
  9. freeze_list = self.args.freeze if isinstance(
  10. self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
  11. always_freeze_names = ['.dfl'] # always freeze these layers
  12. freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
  13. for k, v in self.model.named_parameters():
  14. # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
  15. if any(x in k for x in freeze_layer_names):
  16. LOGGER.info(f"Freezing layer '{k}'")
  17. v.requires_grad = False
  18. elif not v.requires_grad:
  19. LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
  20. 'See ultralytics.engine.trainer for customization of frozen layers.')
  21. v.requires_grad = True
  22. # Check AMP
  23. self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
  24. if self.amp and RANK in (-1, 0): # Single-GPU and DDP 判断是单GPU还是DDP模式训练
  25. callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
  26. self.amp = torch.tensor(check_amp(self.model), device=self.device)
  27. callbacks.default_callbacks = callbacks_backup # restore callbacks
  28. if RANK > -1 and world_size > 1: # DDP
  29. dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
  30. self.amp = bool(self.amp) # as boolean
  31. self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
  32. if world_size > 1:
  33. self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
  34. # Check imgsz 检查图像尺寸
  35. gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
  36. self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
  37. # Batch size batchsize
  38. if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
  39. self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
  40. # Dataloaders Dataloaders
  41. batch_size = self.batch_size // max(world_size, 1)
  42. self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
  43. if RANK in (-1, 0):
  44. self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
  45. self.validator = self.get_validator()
  46. metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
  47. self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
  48. self.ema = ModelEMA(self.model)
  49. if self.args.plots:
  50. self.plot_training_labels()
  51. # Optimizer 优化器 设置 如果是auto的话,会判断迭代次数 然后选择SGD或者ADMW
  52. self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
  53. weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
  54. iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
  55. self.optimizer = self.build_optimizer(model=self.model,
  56. name=self.args.optimizer,
  57. lr=self.args.lr0,
  58. momentum=self.args.momentum,
  59. decay=weight_decay,
  60. iterations=iterations)
  61. # Scheduler 调取器 应该是选择学习率的一个参数
  62. if self.args.cos_lr:
  63. self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
  64. else:
  65. self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
  66. self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
  67. self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
  68. self.resume_training(ckpt)
  69. self.scheduler.last_epoch = self.start_epoch - 1 # do not move
  70. self.run_callbacks('on_pretrain_routine_end')

我们可以从

self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')

中跳入到

D:\PC_DeepLearing\yolov8\ultralytics\models\yolo\detect\train.py
  1. self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
  2. if RANK in (-1, 0):
  3. self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')

然后get_dataloader方法

  1. def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
  2. """Construct and return dataloader."""
  3. assert mode in ['train', 'val']
  4. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP 初始化dataset如果是cache只缓存一次
  5. dataset = self.build_dataset(dataset_path, mode, batch_size)
  6. shuffle = mode == 'train'
  7. if getattr(dataset, 'rect', False) and shuffle:
  8. LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
  9. shuffle = False
  10. workers = self.args.workers if mode == 'train' else self.args.workers * 2
  11. return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
  1. def build_dataset(self, img_path, mode='train', batch=None):
  2. """
  3. Build YOLO Dataset.
  4. Args:
  5. img_path (str): Path to the folder containing images.
  6. mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
  7. batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
  8. """
  9. gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
  10. return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
  1. def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
  2. """Build YOLO Dataset."""
  3. return YOLODataset(
  4. img_path=img_path,
  5. imgsz=cfg.imgsz,
  6. batch_size=batch,
  7. augment=mode == 'train', # augmentation
  8. hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
  9. rect=cfg.rect or rect, # rectangular batches
  10. cache=cfg.cache or None,
  11. single_cls=cfg.single_cls or False,
  12. stride=int(stride),
  13. pad=0.0 if mode == 'train' else 0.5,
  14. prefix=colorstr(f'{mode}: '),
  15. use_segments=cfg.task == 'segment',
  16. use_keypoints=cfg.task == 'pose',
  17. classes=cfg.classes,
  18. data=data,
  19. fraction=cfg.fraction if mode == 'train' else 1.0)

然后在

YOLODataset中继续跳入到 同时可以看到YOLODataset是继承BaseDataset的

D:\PC_DeepLearing\yolov8\ultralytics\data\base.py

  1. class YOLODataset(BaseDataset):
  2. """
  3. Dataset class for loading object detection and/or segmentation labels in YOLO format.
  4. Args:
  5. data (dict, optional): A dataset YAML dictionary. Defaults to None.
  6. use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
  7. use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
  8. Returns:
  9. (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
  10. """
  11. def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
  12. """Initializes the YOLODataset with optional configurations for segments and keypoints."""
  13. self.use_segments = use_segments
  14. self.use_keypoints = use_keypoints
  15. self.data = data
  16. assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
  17. super().__init__(*args, **kwargs)

然后在BaseDataset中主要看

  1. self.im_files = self.get_img_files(self.img_path)
  2. self.labels = self.get_labels()
  3. 以及
  4. self.transforms = self.build_transforms(hyp=hyp)

然后跳入

D:\PC_DeepLearing\yolov8\ultralytics\data\dataset.py

  1. def build_transforms(self, hyp=None):
  2. """Builds and appends transforms to the list."""
  3. if self.augment:
  4. hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
  5. hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
  6. transforms = v8_transforms(self, self.imgsz, hyp)
  7. else:
  8. transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
  9. transforms.append(
  10. Format(bbox_format='xywh',
  11. normalize=True,
  12. return_mask=self.use_segments,
  13. return_keypoint=self.use_keypoints,
  14. batch_idx=True,
  15. mask_ratio=hyp.mask_ratio,
  16. mask_overlap=hyp.overlap_mask))
  17. return transforms

然后一般跳入D:\PC_DeepLearing\yolov8\ultralytics\data\augment.py查看

v8_transforms
  1. def v8_transforms(dataset, imgsz, hyp, stretch=False):
  2. """Convert images to a size suitable for YOLOv8 training."""
  3. pre_transform = Compose([
  4. Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
  5. CopyPaste(p=hyp.copy_paste),
  6. RandomPerspective(
  7. degrees=hyp.degrees,
  8. translate=hyp.translate,
  9. scale=hyp.scale,
  10. shear=hyp.shear,
  11. perspective=hyp.perspective,
  12. pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
  13. )])
  14. flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
  15. if dataset.use_keypoints:
  16. kpt_shape = dataset.data.get('kpt_shape', None)
  17. if len(flip_idx) == 0 and hyp.fliplr > 0.0:
  18. hyp.fliplr = 0.0
  19. LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
  20. elif flip_idx and (len(flip_idx) != kpt_shape[0]):
  21. raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
  22. return Compose([
  23. pre_transform,
  24. MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
  25. Albumentations(p=1.0),
  26. RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
  27. RandomFlip(direction='vertical', p=hyp.flipud),
  28. RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms

数据加载查看LetterBox

D:\PC_DeepLearing\yolov8\ultralytics\data\augment.py

YOLO中会直接对图像进行等比例缩放,然后再进行padding填充

这个地方一定要注意,因为在部署的时候前处理和后处理都会使用到

  1. class LetterBox:
  2. """Resize image and padding for detection, instance segmentation, pose."""
  3. def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
  4. """Initialize LetterBox object with specific parameters."""
  5. self.new_shape = new_shape
  6. self.auto = auto
  7. self.scaleFill = scaleFill
  8. self.scaleup = scaleup
  9. self.stride = stride
  10. self.center = center # Put the image in the middle or top-left
  11. def __call__(self, labels=None, image=None):
  12. """Return updated labels and image with added border."""
  13. if labels is None:
  14. labels = {}
  15. img = labels.get('img') if image is None else image
  16. shape = img.shape[:2] # current shape [height, width]
  17. new_shape = labels.pop('rect_shape', self.new_shape)
  18. if isinstance(new_shape, int):
  19. new_shape = (new_shape, new_shape)
  20. # Scale ratio (new / old)
  21. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  22. if not self.scaleup: # only scale down, do not scale up (for better val mAP)
  23. r = min(r, 1.0)
  24. # Compute padding
  25. ratio = r, r # width, height ratios
  26. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  27. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  28. if self.auto: # minimum rectangle
  29. dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
  30. elif self.scaleFill: # stretch
  31. dw, dh = 0.0, 0.0
  32. new_unpad = (new_shape[1], new_shape[0])
  33. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  34. if self.center:
  35. dw /= 2 # divide padding into 2 sides
  36. dh /= 2
  37. if shape[::-1] != new_unpad: # resize
  38. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  39. top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
  40. left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
  41. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  42. value=(114, 114, 114)) # add border
  43. if labels.get('ratio_pad'):
  44. labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
  45. if len(labels):
  46. labels = self._update_labels(labels, ratio, dw, dh)
  47. labels['img'] = img
  48. labels['resized_shape'] = new_shape
  49. return labels
  50. else:
  51. return img
  52. def _update_labels(self, labels, ratio, padw, padh):
  53. """Update labels."""
  54. labels['instances'].convert_bbox(format='xyxy')
  55. labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
  56. labels['instances'].scale(*ratio)
  57. labels['instances'].add_padding(padw, padh)
  58. return labels

Predict推理代码 以及后处理

  1. """预测推理"""
  2. model = YOLO(r"D:\PC_DeepLearing\yolov8\runs\segment\train_resize1280_1201\weights\best.onnx") # 加载预训练模型(建议用于训练)
  3. results = model.predict(data="turntable.yaml",
  4. cfg="default.yaml",
  5. source=r"C:\Users\1\Desktop\231130\C1\料盘定位\OK",
  6. save=True,
  7. save_conf=True,
  8. batch=1,
  9. imgsz=[1280,1280],
  10. iou=0.35,
  11. conf=0.1,
  12. rect=False
  13. ) # 在验证集上评估模型性能

先进入D:\PC_DeepLearing\yolov8\ultralytics\engine\model.py

  1. def predict(self, source=None, stream=False, predictor=None, **kwargs):
  2. """
  3. Perform prediction using the YOLO model.
  4. Args:
  5. source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
  6. Accepts all source types accepted by the YOLO model.
  7. stream (bool): Whether to stream the predictions or not. Defaults to False.
  8. predictor (BasePredictor): Customized predictor.
  9. **kwargs : Additional keyword arguments passed to the predictor.
  10. Check the 'configuration' section in the documentation for all available options.
  11. Returns:
  12. (List[ultralytics.engine.results.Results]): The prediction results.
  13. """
  14. if source is None:
  15. source = ASSETS
  16. LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
  17. is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
  18. x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
  19. custom = {'conf': 0.25, 'save': is_cli} # method defaults
  20. args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
  21. prompts = args.pop('prompts', None) # for SAM-type models
  22. if not self.predictor:
  23. self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
  24. self.predictor.setup_model(model=self.model, verbose=is_cli)
  25. else: # only update args if predictor is already setup
  26. self.predictor.args = get_cfg(self.predictor.args, args)
  27. if 'project' in args or 'name' in args:
  28. self.predictor.save_dir = get_save_dir(self.predictor.args)
  29. if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
  30. self.predictor.set_prompts(prompts)
  31. return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

在predictor中进行推理

D:\PC_DeepLearing\yolov8\ultralytics\engine\predictor.py

查看推理的流程步骤

  1. @smart_inference_mode()
  2. def stream_inference(self, source=None, model=None, *args, **kwargs):
  3. """Streams real-time inference on camera feed and saves results to file."""
  4. if self.args.verbose:
  5. LOGGER.info('')
  6. # Setup model
  7. if not self.model:
  8. self.setup_model(model)
  9. with self._lock: # for thread-safe inference
  10. # Setup source every time predict is called
  11. self.setup_source(source if source is not None else self.args.source)
  12. # Check if save_dir/ label file exists
  13. if self.args.save or self.args.save_txt:
  14. (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
  15. # Warmup model
  16. if not self.done_warmup:
  17. self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
  18. self.done_warmup = True
  19. self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
  20. self.run_callbacks('on_predict_start')
  21. for batch in self.dataset:
  22. self.run_callbacks('on_predict_batch_start')
  23. self.batch = batch
  24. path, im0s, vid_cap, s = batch
  25. # Preprocess
  26. with profilers[0]:
  27. im = self.preprocess(im0s)
  28. # Inference
  29. with profilers[1]:
  30. preds = self.inference(im, *args, **kwargs)
  31. # Postprocess
  32. with profilers[2]:
  33. self.results = self.postprocess(preds, im, im0s)
  34. self.run_callbacks('on_predict_postprocess_end')
  35. # Visualize, save, write results
  36. n = len(im0s)
  37. for i in range(n):
  38. self.seen += 1
  39. self.results[i].speed = {
  40. 'preprocess': profilers[0].dt * 1E3 / n,
  41. 'inference': profilers[1].dt * 1E3 / n,
  42. 'postprocess': profilers[2].dt * 1E3 / n}
  43. p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
  44. p = Path(p)
  45. if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
  46. s += self.write_results(i, self.results, (p, im, im0))
  47. if self.args.save or self.args.save_txt:
  48. self.results[i].save_dir = self.save_dir.__str__()
  49. if self.args.show and self.plotted_img is not None:
  50. self.show(p)
  51. if self.args.save and self.plotted_img is not None:
  52. self.save_preds(vid_cap, i, str(self.save_dir / p.name))
  53. self.run_callbacks('on_predict_batch_end')
  54. yield from self.results
  55. # Print time (inference-only)
  56. if self.args.verbose:
  57. LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
  58. # Release assets
  59. if isinstance(self.vid_writer[-1], cv2.VideoWriter):
  60. self.vid_writer[-1].release() # release final video writer
  61. # Print results
  62. if self.args.verbose and self.seen:
  63. t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
  64. LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
  65. f'{(1, 3, *im.shape[2:])}' % t)
  66. if self.args.save or self.args.save_txt or self.args.save_crop:
  67. nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
  68. s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
  69. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
  70. self.run_callbacks('on_predict_end')
  1. 图像会进行LetterBox
  2. 然后在进行归一化预处理
  3. D:\PC_DeepLearing\yolov8\ultralytics\engine\predictor.py
  4. def preprocess(self, im):
  5. """
  6. Prepares input image before inference.
  7. Args:
  8. im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
  9. """
  10. not_tensor = not isinstance(im, torch.Tensor)
  11. if not_tensor:
  12. im = np.stack(self.pre_transform(im))
  13. im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
  14. im = np.ascontiguousarray(im) # contiguous
  15. im = torch.from_numpy(im)
  16. im = im.to(self.device)
  17. im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
  18. if not_tensor:
  19. im /= 255 # 0 - 255 to 0.0 - 1.0
  20. return im

处理完成后后段处理

然后跳入D:\PC_DeepLearing\yolov8\ultralytics\models\yolo\segment\predict.py

  1. def postprocess(self, preds, img, orig_imgs):
  2. """Applies non-max suppression and processes detections for each image in an input batch."""
  3. p = ops.non_max_suppression(preds[0],
  4. self.args.conf,
  5. self.args.iou,
  6. agnostic=self.args.agnostic_nms,
  7. max_det=self.args.max_det,
  8. nc=len(self.model.names),
  9. classes=self.args.classes)
  10. if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
  11. orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
  12. results = []
  13. proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
  14. for i, pred in enumerate(p):
  15. orig_img = orig_imgs[i]
  16. img_path = self.batch[0][i]
  17. if not len(pred): # save empty boxes
  18. masks = None
  19. elif self.args.retina_masks:
  20. pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  21. masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
  22. else:
  23. masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
  24. pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  25. results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
  26. return results
其中

  1. def process_mask(protos, masks_in, bboxes, shape, upsample=False):
  2. """
  3. Apply masks to bounding boxes using the output of the mask head.
  4. Args:
  5. protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
  6. masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
  7. bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
  8. shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
  9. upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
  10. Returns:
  11. (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
  12. are the height and width of the input image. The mask is applied to the bounding boxes.
  13. """
  14. c, mh, mw = protos.shape # CHW
  15. ih, iw = shape
  16. masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
  17. downsampled_bboxes = bboxes.clone()
  18. downsampled_bboxes[:, 0] *= mw / iw
  19. downsampled_bboxes[:, 2] *= mw / iw
  20. downsampled_bboxes[:, 3] *= mh / ih
  21. downsampled_bboxes[:, 1] *= mh / ih
  22. masks = crop_mask(masks, downsampled_bboxes) # CHW
  23. if upsample:
  24. masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
  25. return masks.gt_(0.5)
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)

先看box的操作

  1. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
  2. """
  3. Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
  4. (img1_shape) to the shape of a different image (img0_shape).
  5. Args:
  6. img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
  7. boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
  8. img0_shape (tuple): the shape of the target image, in the format of (height, width).
  9. ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
  10. calculated based on the size difference between the two images.
  11. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  12. rescaling.
  13. Returns:
  14. boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
  15. """
  16. if ratio_pad is None: # calculate from img0_shape
  17. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  18. pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
  19. (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding 计算WH的padding填充值
  20. else:
  21. gain = ratio_pad[0][0]
  22. pad = ratio_pad[1]
  23. if padding:
  24. boxes[..., [0, 2]] -= pad[0] # x padding 获取的box坐标-pad填充值
  25. boxes[..., [1, 3]] -= pad[1] # y padding 获取的box坐标-pad填充值
  26. boxes[..., :4] /= gain #坐标/缩放系数 返回在原图上原始坐标
  27. clip_boxes(boxes, img0_shape)
  28. return boxes

再看mask的操作

  1. def process_mask(protos, masks_in, bboxes, shape, upsample=False):
  2. """
  3. Apply masks to bounding boxes using the output of the mask head.
  4. Args:
  5. protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
  6. masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
  7. bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
  8. shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
  9. upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
  10. Returns:
  11. (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
  12. are the height and width of the input image. The mask is applied to the bounding boxes.
  13. """
  14. c, mh, mw = protos.shape # CHW
  15. ih, iw = shape
  16. masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
  17. downsampled_bboxes = bboxes.clone()
  18. downsampled_bboxes[:, 0] *= mw / iw
  19. downsampled_bboxes[:, 2] *= mw / iw
  20. downsampled_bboxes[:, 3] *= mh / ih
  21. downsampled_bboxes[:, 1] *= mh / ih
  22. masks = crop_mask(masks, downsampled_bboxes) # CHW
  23. if upsample:
  24. masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
  25. return masks.gt_(0.5)

此时如果输入图像的尺寸和模型的尺寸不一致的情况,我们还需要看保存的步骤

如果图像时Rect的  即长宽不一致的情况一定要注意

可以再D:\PC_DeepLearing\yolov8\ultralytics\engine\predictor.py

  1. def write_results(self, idx, results, batch):
  2. """Write inference results to a file or directory."""
  3. p, im, _ = batch
  4. log_string = ''
  5. if len(im.shape) == 3:
  6. im = im[None] # expand for batch dim
  7. if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
  8. log_string += f'{idx}: '
  9. frame = self.dataset.count
  10. else:
  11. frame = getattr(self.dataset, 'frame', 0)
  12. self.data_path = p
  13. self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
  14. log_string += '%gx%g ' % im.shape[2:] # print string
  15. result = results[idx]
  16. log_string += result.verbose()
  17. if self.args.save or self.args.show: # Add bbox to image
  18. plot_args = {
  19. 'line_width': self.args.line_width,
  20. 'boxes': self.args.boxes,
  21. 'conf': self.args.show_conf,
  22. 'labels': self.args.show_labels}
  23. if not self.args.retina_masks:
  24. plot_args['im_gpu'] = im[idx]
  25. self.plotted_img = result.plot(**plot_args)
  26. # Write
  27. if self.args.save_txt:
  28. result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
  29. if self.args.save_crop:
  30. result.save_crop(save_dir=self.save_dir / 'crops',
  31. file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
  32. return log_string

self.plotted_img = result.plot(**plot_args)中可以看到保存图像的步骤        

D:\PC_DeepLearing\yolov8\ultralytics\engine\results.py

在result.plot(**plot_args中查看)

然后跳入到

  1. # Plot Segment results
  2. if pred_masks and show_masks:
  3. if im_gpu is None:
  4. img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
  5. im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
  6. 2, 0, 1).flip(0).contiguous() / 255
  7. idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
  8. annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
  9. # Plot Detect results
  10. if pred_boxes and show_boxes:
  11. for d in reversed(pred_boxes):
  12. c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
  13. name = ('' if id is None else f'id:{id} ') + names[c]
  14. label = (f'{name} {conf:.2f}' if conf else name) if labels else None
  15. annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
  16. # Plot Classify results
  17. if pred_probs is not None and show_probs:
  18. text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
  19. x = round(self.orig_shape[0] * 0.03)
  20. annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors

在D:\PC_DeepLearing\yolov8\ultralytics\utils\plotting.py中查看

  1. def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  2. """
  3. Plot masks on image.
  4. Args:
  5. masks (tensor): Predicted masks on cuda, shape: [n, h, w]
  6. colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
  7. im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
  8. alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
  9. retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
  10. """
  11. if self.pil:
  12. # Convert to numpy first
  13. self.im = np.asarray(self.im).copy()
  14. if len(masks) == 0:
  15. self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
  16. if im_gpu.device != masks.device:
  17. im_gpu = im_gpu.to(masks.device)
  18. colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
  19. colors = colors[:, None, None] # shape(n,1,1,3)
  20. masks = masks.unsqueeze(3) # shape(n,h,w,1)
  21. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  22. inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  23. mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
  24. im_gpu = im_gpu.flip(dims=[0]) # flip channel
  25. im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
  26. im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
  27. im_mask = (im_gpu * 255)
  28. im_mask_np = im_mask.byte().cpu().numpy()
  29. self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
  30. if self.pil:
  31. # Convert im back to PIL and update draw
  32. self.fromarray(self.im)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/749423
推荐阅读
相关标签
  

闽ICP备14008679号