当前位置:   article > 正文

使用自己数据及进行PointNet++分类网络训练_使用pointnet++的分割网络来训练二分类数据集要注意萨满

使用pointnet++的分割网络来训练二分类数据集要注意萨满

目录

一、概述

二、数据集读取

三、运行代码

四、问题解决

五、总结


一、概述

使用自己数据及进行PointNet++分类网络训练,这里我选中悉尼大学开放的自动驾驶数据集进行测试。

二、数据集读取

常用数据集

点云数据集_爱学习的小菜鸡的博客-CSDN博客_点云数据集

选用悉尼大学开放的自动驾驶数据集进行训练

方式一:安装snark软件

math-deg2rad.exe -h

方式二:使用pyton脚本读取

  1. # -*- coding: utf-8 -*-
  2. """ Simple example for loading object binary data. """
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. names = ['t','intensity','id',
  6. 'x','y','z',
  7. 'azimuth','range','pid']
  8. formats = ['int64', 'uint8', 'uint8',
  9. 'float32', 'float32', 'float32',
  10. 'float32', 'float32', 'int32']
  11. binType = np.dtype( dict(names=names, formats=formats) )
  12. data = np.fromfile('objects/excavator.0.10974.bin', binType)
  13. # 3D points, one per row
  14. #按垂直方向(行顺序)堆叠数组构成一个新的数组
  15. P = np.vstack([ data['x'], data['y'], data['z'] ]).T
  16. print(P)
  17. #visulize the pointcloud
  18. x = []
  19. y = []
  20. z = []
  21. for point in P:
  22. x.append(point[0])
  23. y.append(point[1])
  24. z.append(point[2])
  25. x = np.array(x, dtype=float)
  26. y = np.array(y, dtype=float)
  27. z = np.array(z, dtype=float)
  28. ax = plt.subplot(projection='3d')
  29. ax.set_title('image_show')
  30. ax.scatter(x, y, z, c='r')
  31. ax.set_xlabel('X')
  32. ax.set_ylabel('Y')
  33. ax.set_zlabel('Z')
  34. plt.show()

  • 制作分类数据集

数据集格式要求,参照modelnet10、modelnet40分类数据格式:

如下图:

说明:

allclass:所有类别说明

filelist:列出所有文件所在路径

testlist:测试数据

trainlist:训练数据

三、运行代码

代码目录:

 新增

train_ownpointcloud_cls.py

  1. """
  2. Author: Benny
  3. Date: Nov 2019
  4. """
  5. import os
  6. import sys
  7. import torch
  8. import numpy as np
  9. import datetime
  10. import logging
  11. import provider
  12. import importlib
  13. import shutil
  14. import argparse
  15. from pathlib import Path
  16. from tqdm import tqdm
  17. from data_utils.OwnPointCloudDataLoader import OwnPointCloudDataLoader
  18. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  19. ROOT_DIR = BASE_DIR
  20. sys.path.append(os.path.join(ROOT_DIR, 'models'))
  21. def parse_args():
  22. '''PARAMETERS'''
  23. parser = argparse.ArgumentParser('training')
  24. parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
  25. parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
  26. parser.add_argument('--batch_size', type=int, default=20, help='batch size in training')
  27. parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
  28. parser.add_argument('--num_category', default=4, type=int, choices=[4,10, 40], help='training on ModelNet10/40')
  29. parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
  30. parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
  31. parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
  32. parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
  33. parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
  34. parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
  35. parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
  36. parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
  37. parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
  38. return parser.parse_args()
  39. def inplace_relu(m):
  40. classname = m.__class__.__name__
  41. if classname.find('ReLU') != -1:
  42. m.inplace=True
  43. def test(model, loader, num_class=40):
  44. mean_correct = []
  45. class_acc = np.zeros((num_class, 3))
  46. classifier = model.eval()
  47. for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
  48. if not args.use_cpu:
  49. points, target = points.cuda(), target.cuda()
  50. points = points.transpose(2, 1)
  51. pred, _ = classifier(points)
  52. pred_choice = pred.data.max(1)[1]
  53. for cat in np.unique(target.cpu()):
  54. classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
  55. class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
  56. class_acc[cat, 1] += 1
  57. correct = pred_choice.eq(target.long().data).cpu().sum()
  58. mean_correct.append(correct.item() / float(points.size()[0]))
  59. class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
  60. class_acc = np.mean(class_acc[:, 2])
  61. instance_acc = np.mean(mean_correct)
  62. return instance_acc, class_acc
  63. def main(args):
  64. def log_string(str):
  65. logger.info(str)
  66. print(str)
  67. '''HYPER PARAMETER'''
  68. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
  69. '''CREATE DIR'''
  70. timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
  71. exp_dir = Path('./log/')
  72. exp_dir.mkdir(exist_ok=True)
  73. exp_dir = exp_dir.joinpath('classification')
  74. exp_dir.mkdir(exist_ok=True)
  75. if args.log_dir is None:
  76. exp_dir = exp_dir.joinpath(timestr)
  77. else:
  78. exp_dir = exp_dir.joinpath(args.log_dir)
  79. exp_dir.mkdir(exist_ok=True)
  80. checkpoints_dir = exp_dir.joinpath('checkpoints/')
  81. checkpoints_dir.mkdir(exist_ok=True)
  82. log_dir = exp_dir.joinpath('logs/')
  83. log_dir.mkdir(exist_ok=True)
  84. '''LOG'''
  85. args = parse_args()
  86. logger = logging.getLogger("Model")
  87. logger.setLevel(logging.INFO)
  88. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  89. file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
  90. file_handler.setLevel(logging.INFO)
  91. file_handler.setFormatter(formatter)
  92. logger.addHandler(file_handler)
  93. log_string('PARAMETER ...')
  94. log_string(args)
  95. '''DATA LOADING'''
  96. log_string('Load dataset ...')
  97. data_path = 'data/owndata/'
  98. train_dataset = OwnPointCloudDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
  99. test_dataset = OwnPointCloudDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
  100. trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
  101. testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)
  102. '''MODEL LOADING'''
  103. num_class = args.num_category
  104. model = importlib.import_module(args.model)
  105. shutil.copy('./models/%s.py' % args.model, str(exp_dir))
  106. shutil.copy('models/pointnet2_utils.py', str(exp_dir))
  107. shutil.copy('./train_classification.py', str(exp_dir))
  108. classifier = model.get_model(num_class, normal_channel=args.use_normals)
  109. criterion = model.get_loss()
  110. classifier.apply(inplace_relu)
  111. if not args.use_cpu:
  112. classifier = classifier.cuda()
  113. criterion = criterion.cuda()
  114. try:
  115. checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
  116. start_epoch = checkpoint['epoch']
  117. classifier.load_state_dict(checkpoint['model_state_dict'])
  118. log_string('Use pretrain model')
  119. except:
  120. log_string('No existing model, starting training from scratch...')
  121. start_epoch = 0
  122. if args.optimizer == 'Adam':
  123. optimizer = torch.optim.Adam(
  124. classifier.parameters(),
  125. lr=args.learning_rate,
  126. betas=(0.9, 0.999),
  127. eps=1e-08,
  128. weight_decay=args.decay_rate
  129. )
  130. else:
  131. optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
  132. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
  133. global_epoch = 0
  134. global_step = 0
  135. best_instance_acc = 0.0
  136. best_class_acc = 0.0
  137. '''TRANING'''
  138. logger.info('Start training...')
  139. for epoch in range(start_epoch, args.epoch):
  140. log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
  141. mean_correct = []
  142. classifier = classifier.train()
  143. scheduler.step()
  144. #print('================batch_id===============')
  145. for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
  146. optimizer.zero_grad()
  147. points = points.data.numpy()
  148. #print(points)
  149. points = provider.random_point_dropout(points)
  150. points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
  151. points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
  152. points = torch.Tensor(points)
  153. points = points.transpose(2, 1)
  154. if not args.use_cpu:
  155. points, target = points.cuda(), target.cuda()
  156. pred, trans_feat = classifier(points)
  157. loss = criterion(pred, target.long(), trans_feat)
  158. pred_choice = pred.data.max(1)[1]
  159. correct = pred_choice.eq(target.long().data).cpu().sum()
  160. mean_correct.append(correct.item() / float(points.size()[0]))
  161. loss.backward()
  162. optimizer.step()
  163. global_step += 1
  164. train_instance_acc = np.mean(mean_correct)
  165. log_string('Train Instance Accuracy: %f' % train_instance_acc)
  166. with torch.no_grad():
  167. instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
  168. if (instance_acc >= best_instance_acc):
  169. best_instance_acc = instance_acc
  170. best_epoch = epoch + 1
  171. if (class_acc >= best_class_acc):
  172. best_class_acc = class_acc
  173. log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
  174. log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
  175. if (instance_acc >= best_instance_acc):
  176. logger.info('Save model...')
  177. savepath = str(checkpoints_dir) + '/best_model.pth'
  178. log_string('Saving at %s' % savepath)
  179. state = {
  180. 'epoch': best_epoch,
  181. 'instance_acc': instance_acc,
  182. 'class_acc': class_acc,
  183. 'model_state_dict': classifier.state_dict(),
  184. 'optimizer_state_dict': optimizer.state_dict(),
  185. }
  186. torch.save(state, savepath)
  187. global_epoch += 1
  188. logger.info('End of training...')
  189. if __name__ == '__main__':
  190. args = parse_args()
  191. main(args)

新增OwnPointCloudDataLoader.py

  1. '''
  2. @author: Xu Yan
  3. @file: ModelNet.py
  4. @time: 2021/3/19 15:51
  5. '''
  6. import os
  7. import numpy as np
  8. import warnings
  9. import pickle
  10. from tqdm import tqdm
  11. from torch.utils.data import Dataset
  12. warnings.filterwarnings('ignore')
  13. def pc_normalize(pc):
  14. centroid = np.mean(pc, axis=0)
  15. pc = pc - centroid
  16. m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
  17. pc = pc / m
  18. return pc
  19. def farthest_point_sample(point, npoint):
  20. """
  21. Input:
  22. xyz: pointcloud data, [N, D]
  23. npoint: number of samples
  24. Return:
  25. centroids: sampled pointcloud index, [npoint, D]
  26. """
  27. N, D = point.shape
  28. xyz = point[:, :3]
  29. centroids = np.zeros((npoint,))
  30. distance = np.ones((N,)) * 1e10
  31. farthest = np.random.randint(0, N)
  32. for i in range(npoint):
  33. centroids[i] = farthest
  34. centroid = xyz[farthest, :]
  35. dist = np.sum((xyz - centroid) ** 2, -1)
  36. mask = dist < distance
  37. distance[mask] = dist[mask]
  38. farthest = np.argmax(distance, -1)
  39. point = point[centroids.astype(np.int32)]
  40. return point
  41. class OwnPointCloudDataLoader(Dataset):
  42. def __init__(self, root, args, split='train', process_data=False):
  43. self.root = root
  44. self.npoints = args.num_point
  45. self.process_data = process_data
  46. self.uniform = args.use_uniform_sample
  47. self.use_normals = args.use_normals
  48. self.num_category = args.num_category
  49. if self.num_category == 10:
  50. self.catfile = os.path.join(self.root, 'allclass.txt')
  51. else:
  52. self.catfile = os.path.join(self.root, 'allclass.txt')
  53. self.cat = [line.rstrip() for line in open(self.catfile)]
  54. #print('cat')
  55. #print(self.cat)
  56. self.classes = dict(zip(self.cat, range(len(self.cat))))
  57. shape_ids = {}
  58. if self.num_category == 10:
  59. shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'trainlist.txt'))]
  60. shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'testlist.txt'))]
  61. else:
  62. shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'trainlist.txt'))]
  63. shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'testlist.txt'))]
  64. print('shape_ids:')
  65. print(shape_ids['train'] )
  66. assert (split == 'train' or split == 'test')
  67. #origin format : airplane_0627
  68. #new format : pedestrian.9.6994
  69. shape_names = [''.join(x.split('.')[0]) for x in shape_ids[split]]
  70. print('shapenames:')
  71. print(shape_names)
  72. self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
  73. in range(len(shape_ids[split]))]
  74. print('The size of %s data is %d' % (split, len(self.datapath)))
  75. if self.uniform:
  76. self.save_path = os.path.join(root,
  77. 'ownpointcloud%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
  78. else:
  79. self.save_path = os.path.join(root, 'ownpointcloud%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
  80. if self.process_data:
  81. if not os.path.exists(self.save_path):
  82. print('Processing data %s (only running in the first time)...' % self.save_path)
  83. self.list_of_points = [None] * len(self.datapath)
  84. self.list_of_labels = [None] * len(self.datapath)
  85. for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
  86. fn = self.datapath[index]
  87. cls = self.classes[self.datapath[index][0]]
  88. cls = np.array([cls]).astype(np.int32)
  89. point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
  90. if self.uniform:
  91. point_set = farthest_point_sample(point_set, self.npoints)
  92. else:
  93. point_set = point_set[0:self.npoints, :]
  94. self.list_of_points[index] = point_set
  95. self.list_of_labels[index] = cls
  96. with open(self.save_path, 'wb') as f:
  97. pickle.dump([self.list_of_points, self.list_of_labels], f)
  98. else:
  99. print('Load processed data from %s...' % self.save_path)
  100. with open(self.save_path, 'rb') as f:
  101. self.list_of_points, self.list_of_labels = pickle.load(f)
  102. def __len__(self):
  103. return len(self.datapath)
  104. def _get_item(self, index):
  105. if self.process_data:
  106. point_set, label = self.list_of_points[index], self.list_of_labels[index]
  107. else:
  108. fn = self.datapath[index]
  109. #print(fn)
  110. cls = self.classes[self.datapath[index][0]]
  111. #print(cls)
  112. label = np.array([cls]).astype(np.int32)
  113. point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
  114. if self.uniform:
  115. point_set = farthest_point_sample(point_set, self.npoints)
  116. else:
  117. point_set = point_set[0:self.npoints, :]
  118. point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
  119. if not self.use_normals:
  120. point_set = point_set[:, 0:3]
  121. #print('pintset')
  122. #print(point_set.size())
  123. #print(label)
  124. #print(point_set)
  125. return point_set, label[0]
  126. def __getitem__(self, index):
  127. return self._get_item(index)
  128. if __name__ == '__main__':
  129. import torch
  130. data = OwnPointCloudDataLoader('/data/owndata/', split='train')
  131. DataLoader = torch.utils.data.DataLoader(data, batch_size=4, shuffle=True)
  132. for point, label in DataLoader:
  133. print(point.shape)
  134. print(label.shape)

运行截图:

四、问题解决

解决问题过程出现了好几处,都是比较容易解决的问题,这里记录这个我感觉比较棘手的。

1、 RuntimeError: stack expects each tensor to be equal size, but got [124, 3] at entry 0 and [162, 3] at entry 1

这个问题是daloaded读取时候,modelnet默认·每个txt都是10000个点,因此自己数据数据量不统一,就报错了。

解决办法:保证数据到torch.stack里边张量的size一直

1、保证txt点云点的个数一致

2、实现getitem时候,张量进行补全,代码在这

五、总结

记录几个在制作数据集时候需要注意的问题:

1、目录结构要保持一致,文件名不用一致,因为代码里边能够改

2、txt内容要一致,不管是格式还是数量(比如上边的问题),因为作者的数据集比较标准,用自己数据时候可能情况比较复杂,当前也能够在代码角度进行解决。

3、建议使用小批量数据进行测试,测试没有问题后,在开始用脚本进行数据集的输入,不然制作的数据集出问题,反复改代码,也是麻烦。还可能带来新的错误,毕竟有时候情况复杂,考虑不周。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/360303
推荐阅读
相关标签
  

闽ICP备14008679号