赞
踩
下面是全部的代码:
- import os
- import torch
- import numpy as np
- import scipy.misc as m
- from PIL import Image
- from torch.utils import data
- from dataloaders.utils import recursive_glob, decode_segmap
- from mypath import Path
-
-
- class CityscapesSegmentation(data.Dataset):
-
- def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):
-
- self.root = root
- self.split = split
- self.transform = transform
- self.files = {}
- self.n_classes = 19
-
- self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
- self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
-
- self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
-
- self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16
- self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19
- self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
- 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
- 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
- 'motorcycle', 'bicycle'] # 20
-
- self.ignore_index = 255
- self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))
-
- if not self.files[split]:
- raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
-
- print("Found %d %s images" % (len(self.files[split]), split))
-
- def __len__(self):
- return len(self.files[self.split])
-
- def __getitem__(self, index):
-
- img_path = self.files[self.split][index].rstrip()
- lbl_path = os.path.join(self.annotations_base,
- img_path.split(os.sep)[-2], # os.sep=='/' get city name
- os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
-
- _img = Image.open(img_path).convert('RGB')
- _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
- _tmp = self.encode_segmap(_tmp)
- _target = Image.fromarray(_tmp)
-
- sample = {'image': _img, 'label': _target}
-
- if self.transform: # to do Data transformation or Data enhancement and convert torch
- sample = self.transform(sample)
- return sample
-
- def encode_segmap(self, mask): # to change original image pixel value to 0-18 and 255 according class id
- # Put all void classes to zero
- for _voidc in self.void_classes:
- mask[mask == _voidc] = self.ignore_index # no need class and unto set 255 (white)
- for _validc in self.valid_classes:
- mask[mask == _validc] = self.class_map[_validc] # 19 classes encode from 0 to 18
- return mask
-
-
- if __name__ == '__main__':
- from dataloaders import custom_transforms as tr
- from dataloaders.utils import decode_segmap
- from torch.utils.data import DataLoader
- from torchvision import transforms
- import matplotlib.pyplot as plt # to show image
-
- composed_transforms_tr = transforms.Compose([
- tr.RandomHorizontalFlip(),
- tr.RandomScale((0.5, 0.75)),
- tr.RandomCrop((512, 1024)),
- tr.RandomRotate(5),
- tr.ToTensor()])
-
- cityscapes_train = CityscapesSegmentation(split='train',
- transform=composed_transforms_tr)
- dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
-
- for ii, sample in enumerate(dataloader):
- for jj in range(sample["image"].size()[0]):
- img = sample['image'].numpy() # from torch convert to numpy n x c x h x w
- gt = sample['label'].numpy() # from torch convert to numpy n x c x h x w
- tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=c x h x w
- tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w
- segmap = decode_segmap(tmp, dataset='cityscapes')
- img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x c
- plt.figure()
- plt.title('display')
- plt.subplot(211)
- plt.imshow(img_tmp)
- plt.subplot(212)
- plt.imshow(segmap)
-
- if ii == 1:
- break
- plt.show(block=True)
-

下面怎么读取图片的 可以参考:https://blog.csdn.net/zz2230633069/article/details/84640867
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
转换的为:
- composed_transforms_tr = transforms.Compose([
- tr.RandomHorizontalFlip(),
- tr.RandomScale((0.5, 0.75)),
- tr.RandomCrop((512, 1024)),
- tr.RandomRotate(5),
- tr.ToTensor()])
上面关于图像变换或者说增强的实现代码如下:
上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)
- class RandomHorizontalFlip(object):
- def __call__(self, sample):
- img = sample['image']
- mask = sample['label']
- if random.random() < 0.5:
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
- mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
-
- return {'image': img,
- 'label': mask}
-
-
- class RandomScale(object):
- def __init__(self, limit):
- self.limit = limit
-
- def __call__(self, sample):
- img = sample['image']
- mask = sample['label']
- assert img.size == mask.size
-
- scale = random.uniform(self.limit[0], self.limit[1])
- w = int(scale * img.size[0])
- h = int(scale * img.size[1])
-
- img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
-
- return {'image': img, 'label': mask}
-
-
- class RandomCrop(object):
- def __init__(self, size, padding=0):
- if isinstance(size, numbers.Number):
- self.size = (int(size), int(size))
- else:
- self.size = size # h, w
- self.padding = padding
-
- def __call__(self, sample):
- img, mask = sample['image'], sample['label']
-
- if self.padding > 0:
- img = ImageOps.expand(img, border=self.padding, fill=0)
- mask = ImageOps.expand(mask, border=self.padding, fill=0)
-
- assert img.size == mask.size
- w, h = img.size
- th, tw = self.size # target size
- if w == tw and h == th:
- return {'image': img,
- 'label': mask}
- if w < tw or h < th:
- img = img.resize((tw, th), Image.BILINEAR)
- mask = mask.resize((tw, th), Image.NEAREST)
- return {'image': img,
- 'label': mask}
-
- x1 = random.randint(0, w - tw)
- y1 = random.randint(0, h - th)
- img = img.crop((x1, y1, x1 + tw, y1 + th))
- mask = mask.crop((x1, y1, x1 + tw, y1 + th))
-
- return {'image': img,
- 'label': mask}
-
-
- class RandomRotate(object):
- def __init__(self, degree):
- self.degree = degree
-
- def __call__(self, sample):
- img = sample['image']
- mask = sample['label']
- rotate_degree = random.random() * 2 * self.degree - self.degree
- img = img.rotate(rotate_degree, Image.BILINEAR)
- mask = mask.rotate(rotate_degree, Image.NEAREST)
-
- return {'image': img,
- 'label': mask}
-
- class ToTensor(object):
- """Convert ndarrays in sample to Tensors."""
-
- def __call__(self, sample):
- # swap color axis because
- # numpy image: H x W x C
- # torch image: C X H X W
- img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
- mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
- mask[mask == 255] = 0 #
-
- img = torch.from_numpy(img).float()
- mask = torch.from_numpy(mask).float()
-
-
- return {'image': img,
- 'label': mask}

直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。
与此相对的标签图也是从
PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。
对上面的两个tensor的重新变成图像的代码如下:
- for ii, sample in enumerate(dataloader):
- for jj in range(sample["image"].size()[0]):
- img = sample['image'].numpy() # from torch convert to numpy n x 3 x h x w
- gt = sample['label'].numpy() # from torch convert to numpy n x 1 x h x w
- tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=1 x h x w
- tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction
- segmap = decode_segmap(tmp, dataset='cityscapes')
- img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x 3
- plt.figure()
- plt.title('display')
- plt.subplot(211)
- plt.imshow(img_tmp)
- plt.subplot(212)
- plt.imshow(segmap)
-
- if ii == 1:
- break
- plt.show(block=True)

里面的标签图(h x w)解码代码如下:
只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面
- segmap = decode_segmap(tmp, dataset='cityscapes') # tmp.shape=h x w
- def decode_segmap(label_mask, dataset, plot=False):
- """Decode segmentation class labels into a color image
- Args:
- label_mask (np.ndarray): an (M,N) array of integer values denoting
- the class label at each spatial location.
- plot (bool, optional): whether to show the resulting color image
- in a figure.
- Returns:
- (np.ndarray, optional): the resulting decoded color image.
- """
- if dataset == 'pascal':
- n_classes = 21
- label_colours = get_pascal_labels()
- elif dataset == 'cityscapes':
- n_classes = 19
- label_colours = get_cityscapes_labels()
- else:
- raise NotImplementedError
-
- r = label_mask.copy() # h x w
- g = label_mask.copy() # h x w
- b = label_mask.copy() # h x w
- for ll in range(0, n_classes):
- r[label_mask == ll] = label_colours[ll, 0]
- g[label_mask == ll] = label_colours[ll, 1]
- b[label_mask == ll] = label_colours[ll, 2]
- rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化
- rgb[:, :, 0] = r / 255.0
- rgb[:, :, 1] = g / 255.0
- rgb[:, :, 2] = b / 255.0
- if plot:
- plt.imshow(rgb)
- plt.show()
- else:
- return rgb

下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表https://blog.csdn.net/zz2230633069/article/details/84591532:
- def get_cityscapes_labels():
- return np.array([
- # [ 0, 0, 0],
- [128, 64, 128],
- [244, 35, 232],
- [70, 70, 70],
- [102, 102, 156],
- [190, 153, 153],
- [153, 153, 153],
- [250, 170, 30],
- [220, 220, 0],
- [107, 142, 35],
- [152, 251, 152],
- [0, 130, 180],
- [220, 20, 60],
- [255, 0, 0],
- [0, 0, 142],
- [0, 0, 70],
- [0, 60, 100],
- [0, 80, 100],
- [0, 0, 230],
- [119, 11, 32]])
- def get_pascal_labels():
- """Load the mapping that associates pascal classes with label colors
- Returns:
- np.ndarray with dimensions (21, 3)
- """
- return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
- [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
- [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
- [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
- [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
- [0, 64, 128]])

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。