赞
踩
复现了一下PointNet++,拿自己做的数据集试试感觉效果挺好,于是开始读读代码,打打基础,顺便找点灵感和思路。代码中有些注释是基于我自己制作的数据集进行解释的,我的数据集仿照的是ShapeNet数据集的格式,在我的数据集中,只有一种物体:book(书),book有两个部分:background(背景)和seam(书缝),分别对应0和1。
ShapeNetDataLoader.py作用就是把n个点云数据转换成一个数组,数组有n项,每项包含点的信息,点的大类别(book),点的小类别(background,seam)。
# *_*coding:utf-8 *_* import os import json import warnings import numpy as np from torch.utils.data import Dataset warnings.filterwarnings('ignore') def pc_normalize(pc): centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) pc = pc / m return pc class PartNormalDataset(Dataset): def __init__(self,root = './data/book_seam_dataset', npoints=50000, split='train', class_choice=None, normal_channel=False): self.npoints = npoints self.root = root self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') self.cat = {} self.normal_channel = normal_channel with open(self.catfile, 'r') as f: for line in f: ls = line.strip().split() self.cat[ls[0]] = ls[1] self.cat = {k: v for k, v in self.cat.items()} # {'book': '12345678'} self.classes_original = dict(zip(self.cat, range(len(self.cat)))) # {'book': 0} if not class_choice is None: self.cat = {k:v for k,v in self.cat.items() if k in class_choice} # print(self.cat) self.meta = {} with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) # {'1', '2', ...} with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) for item in self.cat: # item:'book' # print('category', item) self.meta[item] = [] dir_point = os.path.join(self.root, self.cat[item]) fns = sorted(os.listdir(dir_point)) # print(fns[0][0:-4]) if split == 'trainval': # 取训练集+验证集 fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] # fn[0:-4]就是‘1.txt’里面的‘1’, fns:['1.txt', '10.txt', ...] elif split == 'train': fns = [fn for fn in fns if fn[0:-4] in train_ids] elif split == 'val': fns = [fn for fn in fns if fn[0:-4] in val_ids] elif split == 'test': fns = [fn for fn in fns if fn[0:-4] in test_ids] else: print('Unknown split: %s. Exiting..' % (split)) exit(-1) # print(os.path.basename(fns)) for fn in fns: token = (os.path.splitext(os.path.basename(fn))[0]) # os.path.basename删除目录名,保留文件名, token:'1' self.meta[item].append(os.path.join(dir_point, token + '.txt')) # {'book': ['data/book_seam_datas...5678/1.txt','...',...} self.datapath = [] for item in self.cat: for fn in self.meta[item]: self.datapath.append((item, fn)) self.classes = {} for i in self.cat.keys(): self.classes[i] = self.classes_original[i] # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels self.seg_classes = {'book': [0, 1]} # for cat in sorted(self.seg_classes.keys()): # print(cat, self.seg_classes[cat]) self.cache = {} # from index to (point_set, cls, seg) tuple self.cache_size = 20000 # 缓存点的数据,采样点最多不能超过缓存点数量最大值(20000) def __getitem__(self, index): if index in self.cache: point_set, cls, seg = self.cache[index] else: fn = self.datapath[index] # ('book', 'data/book_seam_datas...5678/5.txt') cat = self.datapath[index][0] # 'book' cls = self.classes[cat] # [0] cls = np.array([cls]).astype(np.int32) data = np.loadtxt(fn[1]).astype(np.float32) if not self.normal_channel: point_set = data[:, 0:3] else: point_set = data[:, 0:6] seg = data[:, -1].astype(np.int32) if len(self.cache) < self.cache_size: self.cache[index] = (point_set, cls, seg) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) choice = np.random.choice(len(seg), self.npoints, replace=True) # resample point_set = point_set[choice, :] seg = seg[choice] return point_set, cls, seg # 点的信息,点的大类别(book),点的小类别(background,seam) def __len__(self): return len(self.datapath)
pointnet2_part_seg_msg.py是整个网络的整体框架,通过调用pointnet_utils.py中的自定义网络模型进行一层一层的搭建,因此看不到网络的具体细节。
pointnet2_utils.py中定义了各类网络模型以及pointnet++的关键算法。
import torch import torch.nn as nn import torch.nn.functional as F from time import time import numpy as np def timeit(tag, t): print("{}: {}s".format(tag, time() - t)) return time() def pc_normalize(pc): l = pc.shape[0] centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst B:batchsize, N:第一组点个数, M:第二组点个数, C:输入点通道数(xyz.C=3) Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] batchsize个[N,M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # permute:转换维度 dist += torch.sum(src ** 2, -1).view(B, N, 1) # view:按维度填充 dist += torch.sum(dst ** 2, -1).view(B, 1, M) # 数组广播机制,右边的式子复制N组后与dist叠加 return dist def index_points(points, idx): # i按照输入的点云数据和索引返回由索引的点云数据。 """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) #view_shape=[B,S] view_shape[1:] = [1] * (len(view_shape) - 1) #[1] * (len(view_shape) - 1) -> [1],即view_shape=[B,1] repeat_shape = list(idx.shape) #repeat_shape=[B,S] repeat_shape[0] = 1 #repeat_shape=[1,S] #.view(view_shape)=.view(B,1) #.repeat(repeat_shape)=.view(1,S) #batch_indices的维度[B,S] batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint): ''' FPS的逻辑如下: 假设一共有n个点,整个点集为N = {f1, f2,…,fn}, 目标是选取n1个起始点做为下一步的中心点: 随机选取一个点fi为起始点,并写入起始点集 B = {fi}; 选取剩余n-1个点计算和fi点的距离,选择最远点fj写入起始点集B={fi,fj}; 选取剩余n-2个点计算和点集B中每个点的距离, 将最短的那个距离作为该点到点集的距离, 这样得到n-2个到点集的距离,选取最远的那个点写入起始点B = {fi, fj ,fk},同时剩下n-3个点, 如果n1=3 则到此选择完毕; 如果n1 > 3则重复上面步骤直到选取n1个起始点为止. ''' """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape # B:BatchSize, N:ndataset(点云中点的个数), C:dimension centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # 提取得到中心点的集合 distance = torch.ones(B, N).to(device) * 1e10 # 记录某个样本中所有点到某一个点的距离,先取很大 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # 当前最远的点,随机初始化,范围为0~N,初始化B个,对应到每个样本都随机有一个初始最远点,B列的行向量 batch_indices = torch.arange(B, dtype=torch.long).to(device) # batch的索引,0~(B-1)的数组 for i in range(npoint): centroids[:, i] = farthest # 第i个最远点 centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) # 取出最远点xyz坐标 dist = torch.sum((xyz - centroid) ** 2, -1) # 计算距离,-1代表行求和 mask = dist < distance # 一个bool值的张量数组 distance[mask] = dist[mask] # True的会留下,False删除 farthest = torch.max(distance, -1)[1] # 返回一个张量,第一项是最大值,第二项是索引,-1代表列索引 return centroids def query_ball_point(radius, nsample, xyz, new_xyz): ''' ''' """ Input: radius: local region radius # radius为半径,new_xyz为中心,取nsample个点 nsample: max sample number in local region xyz: all points, [B, N, 3] # 所有点 new_xyz: query points, [B, S, 3] # farthest_point_sample得到S个中心点, new_xyz为中心点xyz Return: group_idx: grouped points index, [B, S, nsample] # nsameple个点的索引 """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # torch.arange得到索引,view转换为三维,repeat使其复制成[B,S,N] sqrdists = square_distance(new_xyz, xyz) # 计算中心点与所有点之间的欧几里德距离 group_idx[sqrdists > radius ** 2] = N # 大于半径的点设置成N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] # 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点. 0代表输出值,1代表索引 # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可 # group_first: [B, S, nsample], 实际就是把group_idx中的第一个点的值复制到[B, S, nsample]的维度,便利于后面的替换 # 这里要用view是因为group_idx[:, :, 0]取出之后的tensor相当于二维Tensor,因此需要用view变成三维tensor group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) # 找到group_idx中值等于N的点,会输出0,1构成的三维Tensor,维度为[B,S,nsample] mask = group_idx == N # 将这些点的值替换为第一个点的值 group_idx[mask] = group_first[mask] return group_idx def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): """ Input: npoint: radius: nsample: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, 3+D] """ B, N, C = xyz.shape S = npoint fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 中心点 new_xyz = index_points(xyz, fps_idx) # 中心点位置 idx = query_ball_point(radius, nsample, xyz, new_xyz) # 球查询得到点的索引 grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] # 球查询点的位置 grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) # 计算与中心点距离 if points is not None: grouped_points = index_points(points, idx) new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] C=3,D为点的特征维度(位置、法向、颜色) else: new_points = grouped_xyz_norm if returnfps: return new_xyz, new_points, grouped_xyz, fps_idx else: return new_xyz, new_points def sample_and_group_all(xyz, points): """ Input: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, 1, 3] new_points: sampled points data, [B, 1, N, 3+D] """ device = xyz.device B, N, C = xyz.shape new_xyz = torch.zeros(B, 1, C).to(device) grouped_xyz = xyz.view(B, 1, N, C) #new_xyz代表中心点,用原点表示 if points is not None: new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) else: new_points = grouped_xyz return new_xyz, new_points class PointNetSetAbstraction(nn.Module): def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): super(PointNetSetAbstraction, self).__init__() self.npoint = npoint self.radius = radius self.nsample = nsample self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) # MLP就相当于是1x1卷积 self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel self.group_all = group_all def forward(self, xyz, points): """ N是输入点的数量,C是坐标维度(C=3),D是特征维度(除坐标维度以外的其他特征维度) S是输出点的数量,C是坐标维度,D'是新的特征维度 Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0, 2, 1) # [B, N, 3] if points is not None: points = points.permute(0, 2, 1) if self.group_all: new_xyz, new_points = sample_and_group_all(xyz, points) else: new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) # new_xyz: sampled points position data, [B, npoint, C] # new_points: sampled points data, [B, npoint, nsample, C+D] new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] # pytorch的通道顺序是NCHW # N - Batch # C - Channel # H - Height # W - Width # 对[3+D, nsample]的维度上做逐像素的卷积,结果相当于对单个C+D维度做1d的卷积 for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) new_points = torch.max(new_points, 2)[0] new_xyz = new_xyz.permute(0, 2, 1) return new_xyz, new_points class PointNetSetAbstractionMsg(nn.Module): def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): super(PointNetSetAbstractionMsg, self).__init__() self.npoint = npoint self.radius_list = radius_list self.nsample_list = nsample_list self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() for i in range(len(mlp_list)): convs = nn.ModuleList() bns = nn.ModuleList() last_channel = in_channel + 3 for out_channel in mlp_list[i]: convs.append(nn.Conv2d(last_channel, out_channel, 1)) bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel self.conv_blocks.append(convs) self.bn_blocks.append(bns) def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) B, N, C = xyz.shape S = self.npoint new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) new_points_list = [] # 针对多个radius和nsample取点 for i, radius in enumerate(self.radius_list): K = self.nsample_list[i] group_idx = query_ball_point(radius, K, xyz, new_xyz) grouped_xyz = index_points(xyz, group_idx) grouped_xyz -= new_xyz.view(B, S, 1, C) if points is not None: grouped_points = index_points(points, group_idx) grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) else: grouped_points = grouped_xyz grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] for j in range(len(self.conv_blocks[i])): conv = self.conv_blocks[i][j] bn = self.bn_blocks[i][j] grouped_points = F.relu(bn(conv(grouped_points))) new_points = torch.max(grouped_points, 2)[0] # [B, D', S] new_points_list.append(new_points) new_xyz = new_xyz.permute(0, 2, 1) new_points_concat = torch.cat(new_points_list, dim=1) return new_xyz, new_points_concat class PointNetFeaturePropagation(nn.Module): def __init__(self, in_channel, mlp): super(PointNetFeaturePropagation, self).__init__() self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm1d(out_channel)) last_channel = out_channel def forward(self, xyz1, xyz2, points1, points2): """ Input: xyz1: input points position data, [B, C, N] # 所有点 xyz2: sampled input points position data, [B, C, S] # 采样点 points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ " 将B C N 转换为B N C 然后利用插值将高维点云数目S 插值到低维点云数目N (N大于S)" " xyz1 低维点云 数量为N xyz2 高维点云 数量为S" xyz1 = xyz1.permute(0, 2, 1) xyz2 = xyz2.permute(0, 2, 1) points2 = points2.permute(0, 2, 1) B, N, C = xyz1.shape _, S, _ = xyz2.shape "如果最后只有一个点,就将S直复制N份后与与低维信息进行拼接" if S == 1: interpolated_points = points2.repeat(1, N, 1) else: dists = square_distance(xyz1, xyz2) # [B,N,S] dists, idx = dists.sort(dim=-1) # 找到距离最近的三个邻居 dists, idx = dists[:, :, :3], idx[:, :, :3] # [B,N,3],N个点与这S个距离最近的前三个点的索引 dist_recip = 1.0 / (dists + 1e-8) # 求距离的倒数 2,512,3 对应论文中的 Wi(x) norm = torch.sum(dist_recip, dim=2, keepdim=True) # 也就是将距离最近的三个邻居的加起来 此时对应论文中公式的分母部分 weight = dist_recip / norm """ 这里的weight是计算权重 dist_recip中存放的是三个邻居的距离 norm中存放是距离的和 两者相除就是每个距离占总和的比重 也就是weight """ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) # 点乘 if points1 is not None: points1 = points1.permute(0, 2, 1) new_points = torch.cat([points1, interpolated_points], dim=-1) else: new_points = interpolated_points new_points = new_points.permute(0, 2, 1) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) return new_points
总的来讲,这个文件主要实现的是两个网络结构:PointNetSetAbstraction和PointNetFeaturePropagation,
PointNetSetAbstractionMsg只是PointNetSetAbstraction使用多个采样半径后叠加的结果。
文中的注释参考了博主weixin_42707080的PointNet++系列文章和正在学习的浅语的文章《PointNet++上采样(Feature Propagation)》。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。