当前位置:   article > 正文

pytorch 训练数据以及测试 全部代码(9)---deeplab v3+ 对Cityscapes数据的处理_class cityscapessegmentation(data.dataset):

class cityscapessegmentation(data.dataset):

 下面是全部的代码:

  1. import os
  2. import torch
  3. import numpy as np
  4. import scipy.misc as m
  5. from PIL import Image
  6. from torch.utils import data
  7. from dataloaders.utils import recursive_glob, decode_segmap
  8. from mypath import Path
  9. class CityscapesSegmentation(data.Dataset):
  10. def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):
  11. self.root = root
  12. self.split = split
  13. self.transform = transform
  14. self.files = {}
  15. self.n_classes = 19
  16. self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
  17. self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
  18. self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
  19. self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16
  20. self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19
  21. self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
  22. 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
  23. 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
  24. 'motorcycle', 'bicycle'] # 20
  25. self.ignore_index = 255
  26. self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))
  27. if not self.files[split]:
  28. raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
  29. print("Found %d %s images" % (len(self.files[split]), split))
  30. def __len__(self):
  31. return len(self.files[self.split])
  32. def __getitem__(self, index):
  33. img_path = self.files[self.split][index].rstrip()
  34. lbl_path = os.path.join(self.annotations_base,
  35. img_path.split(os.sep)[-2], # os.sep=='/' get city name
  36. os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
  37. _img = Image.open(img_path).convert('RGB')
  38. _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
  39. _tmp = self.encode_segmap(_tmp)
  40. _target = Image.fromarray(_tmp)
  41. sample = {'image': _img, 'label': _target}
  42. if self.transform: # to do Data transformation or Data enhancement and convert torch
  43. sample = self.transform(sample)
  44. return sample
  45. def encode_segmap(self, mask): # to change original image pixel value to 0-18 and 255 according class id
  46. # Put all void classes to zero
  47. for _voidc in self.void_classes:
  48. mask[mask == _voidc] = self.ignore_index # no need class and unto set 255 (white)
  49. for _validc in self.valid_classes:
  50. mask[mask == _validc] = self.class_map[_validc] # 19 classes encode from 0 to 18
  51. return mask
  52. if __name__ == '__main__':
  53. from dataloaders import custom_transforms as tr
  54. from dataloaders.utils import decode_segmap
  55. from torch.utils.data import DataLoader
  56. from torchvision import transforms
  57. import matplotlib.pyplot as plt # to show image
  58. composed_transforms_tr = transforms.Compose([
  59. tr.RandomHorizontalFlip(),
  60. tr.RandomScale((0.5, 0.75)),
  61. tr.RandomCrop((512, 1024)),
  62. tr.RandomRotate(5),
  63. tr.ToTensor()])
  64. cityscapes_train = CityscapesSegmentation(split='train',
  65. transform=composed_transforms_tr)
  66. dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
  67. for ii, sample in enumerate(dataloader):
  68. for jj in range(sample["image"].size()[0]):
  69. img = sample['image'].numpy() # from torch convert to numpy n x c x h x w
  70. gt = sample['label'].numpy() # from torch convert to numpy n x c x h x w
  71. tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=c x h x w
  72. tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w
  73. segmap = decode_segmap(tmp, dataset='cityscapes')
  74. img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x c
  75. plt.figure()
  76. plt.title('display')
  77. plt.subplot(211)
  78. plt.imshow(img_tmp)
  79. plt.subplot(212)
  80. plt.imshow(segmap)
  81. if ii == 1:
  82. break
  83. plt.show(block=True)

下面怎么读取图片的 可以参考:https://blog.csdn.net/zz2230633069/article/details/84640867

self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')

转换的为:

  1. composed_transforms_tr = transforms.Compose([
  2. tr.RandomHorizontalFlip(),
  3. tr.RandomScale((0.5, 0.75)),
  4. tr.RandomCrop((512, 1024)),
  5. tr.RandomRotate(5),
  6. tr.ToTensor()])

上面关于图像变换或者说增强的实现代码如下:

上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)

  1. class RandomHorizontalFlip(object):
  2. def __call__(self, sample):
  3. img = sample['image']
  4. mask = sample['label']
  5. if random.random() < 0.5:
  6. img = img.transpose(Image.FLIP_LEFT_RIGHT)
  7. mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
  8. return {'image': img,
  9. 'label': mask}
  10. class RandomScale(object):
  11. def __init__(self, limit):
  12. self.limit = limit
  13. def __call__(self, sample):
  14. img = sample['image']
  15. mask = sample['label']
  16. assert img.size == mask.size
  17. scale = random.uniform(self.limit[0], self.limit[1])
  18. w = int(scale * img.size[0])
  19. h = int(scale * img.size[1])
  20. img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
  21. return {'image': img, 'label': mask}
  22. class RandomCrop(object):
  23. def __init__(self, size, padding=0):
  24. if isinstance(size, numbers.Number):
  25. self.size = (int(size), int(size))
  26. else:
  27. self.size = size # h, w
  28. self.padding = padding
  29. def __call__(self, sample):
  30. img, mask = sample['image'], sample['label']
  31. if self.padding > 0:
  32. img = ImageOps.expand(img, border=self.padding, fill=0)
  33. mask = ImageOps.expand(mask, border=self.padding, fill=0)
  34. assert img.size == mask.size
  35. w, h = img.size
  36. th, tw = self.size # target size
  37. if w == tw and h == th:
  38. return {'image': img,
  39. 'label': mask}
  40. if w < tw or h < th:
  41. img = img.resize((tw, th), Image.BILINEAR)
  42. mask = mask.resize((tw, th), Image.NEAREST)
  43. return {'image': img,
  44. 'label': mask}
  45. x1 = random.randint(0, w - tw)
  46. y1 = random.randint(0, h - th)
  47. img = img.crop((x1, y1, x1 + tw, y1 + th))
  48. mask = mask.crop((x1, y1, x1 + tw, y1 + th))
  49. return {'image': img,
  50. 'label': mask}
  51. class RandomRotate(object):
  52. def __init__(self, degree):
  53. self.degree = degree
  54. def __call__(self, sample):
  55. img = sample['image']
  56. mask = sample['label']
  57. rotate_degree = random.random() * 2 * self.degree - self.degree
  58. img = img.rotate(rotate_degree, Image.BILINEAR)
  59. mask = mask.rotate(rotate_degree, Image.NEAREST)
  60. return {'image': img,
  61. 'label': mask}
  62. class ToTensor(object):
  63. """Convert ndarrays in sample to Tensors."""
  64. def __call__(self, sample):
  65. # swap color axis because
  66. # numpy image: H x W x C
  67. # torch image: C X H X W
  68. img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
  69. mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
  70. mask[mask == 255] = 0 #
  71. img = torch.from_numpy(img).float()
  72. mask = torch.from_numpy(mask).float()
  73. return {'image': img,
  74. 'label': mask}

直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。

与此相对的标签图也是从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w  )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。

对上面的两个tensor的重新变成图像的代码如下:

  1. for ii, sample in enumerate(dataloader):
  2. for jj in range(sample["image"].size()[0]):
  3. img = sample['image'].numpy() # from torch convert to numpy n x 3 x h x w
  4. gt = sample['label'].numpy() # from torch convert to numpy n x 1 x h x w
  5. tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=1 x h x w
  6. tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction
  7. segmap = decode_segmap(tmp, dataset='cityscapes')
  8. img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x 3
  9. plt.figure()
  10. plt.title('display')
  11. plt.subplot(211)
  12. plt.imshow(img_tmp)
  13. plt.subplot(212)
  14. plt.imshow(segmap)
  15. if ii == 1:
  16. break
  17. plt.show(block=True)

里面的标签图(h x w)解码代码如下:

只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面

  1. segmap = decode_segmap(tmp, dataset='cityscapes') # tmp.shape=h x w
  2. def decode_segmap(label_mask, dataset, plot=False):
  3. """Decode segmentation class labels into a color image
  4. Args:
  5. label_mask (np.ndarray): an (M,N) array of integer values denoting
  6. the class label at each spatial location.
  7. plot (bool, optional): whether to show the resulting color image
  8. in a figure.
  9. Returns:
  10. (np.ndarray, optional): the resulting decoded color image.
  11. """
  12. if dataset == 'pascal':
  13. n_classes = 21
  14. label_colours = get_pascal_labels()
  15. elif dataset == 'cityscapes':
  16. n_classes = 19
  17. label_colours = get_cityscapes_labels()
  18. else:
  19. raise NotImplementedError
  20. r = label_mask.copy() # h x w
  21. g = label_mask.copy() # h x w
  22. b = label_mask.copy() # h x w
  23. for ll in range(0, n_classes):
  24. r[label_mask == ll] = label_colours[ll, 0]
  25. g[label_mask == ll] = label_colours[ll, 1]
  26. b[label_mask == ll] = label_colours[ll, 2]
  27. rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化
  28. rgb[:, :, 0] = r / 255.0
  29. rgb[:, :, 1] = g / 255.0
  30. rgb[:, :, 2] = b / 255.0
  31. if plot:
  32. plt.imshow(rgb)
  33. plt.show()
  34. else:
  35. return rgb

下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表https://blog.csdn.net/zz2230633069/article/details/84591532

  1. def get_cityscapes_labels():
  2. return np.array([
  3. # [ 0, 0, 0],
  4. [128, 64, 128],
  5. [244, 35, 232],
  6. [70, 70, 70],
  7. [102, 102, 156],
  8. [190, 153, 153],
  9. [153, 153, 153],
  10. [250, 170, 30],
  11. [220, 220, 0],
  12. [107, 142, 35],
  13. [152, 251, 152],
  14. [0, 130, 180],
  15. [220, 20, 60],
  16. [255, 0, 0],
  17. [0, 0, 142],
  18. [0, 0, 70],
  19. [0, 60, 100],
  20. [0, 80, 100],
  21. [0, 0, 230],
  22. [119, 11, 32]])
  23. def get_pascal_labels():
  24. """Load the mapping that associates pascal classes with label colors
  25. Returns:
  26. np.ndarray with dimensions (21, 3)
  27. """
  28. return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
  29. [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
  30. [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
  31. [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
  32. [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
  33. [0, 64, 128]])

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/362080
推荐阅读
相关标签
  

闽ICP备14008679号