点云处理:实现PointNet点云分类
作者:Zhihao Cao
日期:2022.5
摘要:本示例在于演示如何基于 PaddlePaddle 2.3.0 实现PointNet在ShapeNet数据集上进行点云分类处理。
一、环境设置
本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网安装。
import os import numpy as np import random import h5py import paddle import paddle.nn as nn import paddle.nn.functional as F print(paddle.__version__) 2.3.0 二、数据集
2.1 数据介绍
ShapeNet数据集是一个注释丰富且规模较大的 3D 形状数据集,由斯坦福大学、普林斯顿大学和芝加哥丰田技术学院于 2015 年联合发布。
ShapeNet数据集官方链接:https://vision.princeton.edu/projects/2014/3DShapeNets/
AIStudio链接:sharpnet数据集(经过整理)
ShapeNet数据集的储存格式是h5文件,该文件中key值分别为:
1、data:这一份数据中所有点的xyz坐标,
2、label:这一份数据所属类别,如airplane等,
3、pid:这一份数据中所有点所属的类型,如这一份数据属airplane类,则它包含的所有点的类型有机翼、机身等类型。
2.2 解压数据集
!unzip data/data70460/shapenet_part_seg_hdf5_data.zip !mv hdf5_data dataset 2.3 数据列表
ShapeNet数据集所有的数据文件。
train_list = [ "ply_data_train0.h5", "ply_data_train1.h5", "ply_data_train2.h5", "ply_data_train3.h5", "ply_data_train4.h5", "ply_data_train5.h5", ] test_list = ["ply_data_test0.h5", "ply_data_test1.h5"] val_list = ["ply_data_val0.h5"] 2.4 搭建数据生成器
说明:将ShapeNet数据集全部读入。
def make_data(mode="train", path="./dataset/", num_point=2048): datas = [] labels = [] if mode == "train": for file_list in train_list: f = h5py.File(os.path.join(path, file_list), "r") datas.extend(f["data"][:, :num_point, :]) labels.extend(f["label"]) f.close() elif mode == "test": for file_list in test_list: f = h5py.File(os.path.join(path, file_list), "r") datas.extend(f["data"][:, :num_point, :]) labels.extend(f["label"]) f.close() else: for file_list in val_list: f = h5py.File(os.path.join(path, file_list), "r") datas.extend(f["data"][:, :num_point, :]) labels.extend(f["label"]) f.close() return datas, labels 说明:通过继承paddle.io.Dataset来完成数据集的构造。
class PointDataset(paddle.io.Dataset): def __init__(self, datas, labels): super().__init__() self.datas = datas self.labels = labels def __getitem__(self, index): data = paddle.to_tensor(self.datas[index].T.astype("float32")) label = paddle.to_tensor(self.labels[index].astype("int64")) return data, label def __len__(self): return len(self.datas) 说明:使用飞桨框架提供的API:paddle.io.DataLoader完成数据的加载,使得按照Batchsize生成Mini-batch的数据。
# 数据导入 datas, labels = make_data(mode="train", num_point=2048) train_dataset = PointDataset(datas, labels) datas, labels = make_data(mode="val", num_point=2048) val_dataset = PointDataset(datas, labels) datas, labels = make_data(mode="test", num_point=2048) test_dataset = PointDataset(datas, labels) # 实例化数据读取器 train_loader = paddle.io.DataLoader( train_dataset, batch_size=128, shuffle=True, drop_last=False ) val_loader = paddle.io.DataLoader( val_dataset, batch_size=32, shuffle=False, drop_last=False ) test_loader = paddle.io.DataLoader( test_dataset, batch_size=128, shuffle=False, drop_last=False ) 三、定义网络
PointNet是斯坦福大学研究人员提出的一个点云处理网络,在这篇论文中,它提出了空间变换网络(T-Net)解决点云的旋转问题(注:因为考虑到某一物体的点云旋转后还是该物体,所以需要有一个网络结构去学习并解决这个旋转问题),并且提出了采取MaxPooling的方法极大程度上地提取点云全局特征。
3.1 定义网络结构
class PointNet(nn.Layer): def __init__(self, name_scope="PointNet_", num_classes=16, num_point=2048): super().__init__() self.input_transform_net = nn.Sequential( nn.Conv1D(3, 64, 1), nn.BatchNorm(64), nn.ReLU(), nn.Conv1D(64, 128, 1), nn.BatchNorm(128), nn.ReLU(), nn.Conv1D(128, 1024, 1), nn.BatchNorm(1024), nn.ReLU(), nn.MaxPool1D(num_point), ) self.input_fc = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear( 256, 9, weight_attr=paddle.ParamAttr( initializer=paddle.nn.initializer.Assign( paddle.zeros((256, 9)) ) ), bias_attr=paddle.ParamAttr( initializer=paddle.nn.initializer.Assign( paddle.reshape(paddle.eye(3), [-1]) ) ), ), ) self.mlp_1 = nn.Sequential( nn.Conv1D(3, 64, 1), nn.BatchNorm(64), nn.ReLU(), nn.Conv1D(64, 64, 1), nn.BatchNorm(64), nn.ReLU(), ) self.feature_transform_net = nn.Sequential( nn.Conv1D(64, 64, 1), nn.BatchNorm(64), nn.ReLU(), nn.Conv1D(64, 128, 1), nn.BatchNorm(128), nn.ReLU(), nn.Conv1D(128, 1024, 1), nn.BatchNorm(1024), nn.ReLU(), nn.MaxPool1D(num_point), ) self.feature_fc = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 64 * 64), ) self.mlp_2 = nn.Sequential( nn.Conv1D(64, 64, 1), nn.BatchNorm(64), nn.ReLU(), nn.Conv1D(64, 128, 1), nn.BatchNorm(128), nn.ReLU(), nn.Conv1D(128, 1024, 1), nn.BatchNorm(1024), nn.ReLU(), ) self.fc = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(p=0.7), nn.Linear(256, num_classes), nn.LogSoftmax(axis=-1), ) def forward(self, inputs): batchsize = inputs.shape[0] t_net = self.input_transform_net(inputs) t_net = paddle.squeeze(t_net, axis=-1) t_net = self.input_fc(t_net) t_net = paddle.reshape(t_net, [batchsize, 3, 3]) x = paddle.transpose(inputs, (0, 2, 1)) x = paddle.matmul(x, t_net) x = paddle.transpose(x, (0, 2, 1)) x = self.mlp_1(x) t_net = self.feature_transform_net(x) t_net = paddle.squeeze(t_net, axis=-1) t_net = self.feature_fc(t_net) t_net = paddle.reshape(t_net, [batchsize, 64, 64]) x = paddle.squeeze(x, axis=-1) x = paddle.transpose(x, (0, 2, 1)) x = paddle.matmul(x, t_net) x = paddle.transpose(x, (0, 2, 1)) x = self.mlp_2(x) x = paddle.max(x, axis=-1) x = paddle.squeeze(x, axis=-1) x = self.fc(x) return x 3.2 网络结构可视化
说明:使用飞桨API:paddle.summary完成模型结构可视化
pointnet = PointNet() paddle.summary(pointnet, (64, 3, 2048)) W0509 16:16:31.949033 135 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0509 16:16:31.957976 135 device_context.cc:465] device: 0, cuDNN Version: 7.6. --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =========================================================================== Conv1D-1 [[64, 3, 2048]] [64, 64, 2048] 256 BatchNorm-1 [[64, 64, 2048]] [64, 64, 2048] 256 ReLU-1 [[64, 64, 2048]] [64, 64, 2048] 0 Conv1D-2 [[64, 64, 2048]] [64, 128, 2048] 8,320 BatchNorm-2 [[64, 128, 2048]] [64, 128, 2048] 512 ReLU-2 [[64, 128, 2048]] [64, 128, 2048] 0 Conv1D-3 [[64, 128, 2048]] [64, 1024, 2048] 132,096 BatchNorm-3 [[64, 1024, 2048]] [64, 1024, 2048] 4,096 ReLU-3 [[64, 1024, 2048]] [64, 1024, 2048] 0 MaxPool1D-1 [[64, 1024, 2048]] [64, 1024, 1] 0 Linear-1 [[64, 1024]] [64, 512] 524,800 ReLU-4 [[64, 512]] [64, 512] 0 Linear-2 [[64, 512]] [64, 256] 131,328 ReLU-5 [[64, 256]] [64, 256] 0 Linear-3 [[64, 256]] [64, 9] 2,313 Conv1D-4 [[64, 3, 2048]] [64, 64, 2048] 256 BatchNorm-4 [[64, 64, 2048]] [64, 64, 2048] 256 ReLU-6 [[64, 64, 2048]] [64, 64, 2048] 0 Conv1D-5 [[64, 64, 2048]] [64, 64, 2048] 4,160 BatchNorm-5 [[64, 64, 2048]] [64, 64, 2048] 256 ReLU-7 [[64, 64, 2048]] [64, 64, 2048] 0 Conv1D-6 [[64, 64, 2048]] [64, 64, 2048] 4,160 BatchNorm-6 [[64, 64, 2048]] [64, 64, 2048] 256 ReLU-8 [[64, 64, 2048]] [64, 64, 2048] 0 Conv1D-7 [[64, 64, 2048]] [64, 128, 2048] 8,320 BatchNorm-7 [[64, 128, 2048]] [64, 128, 2048] 512 ReLU-9 [[64, 128, 2048]] [64, 128, 2048] 0 Conv1D-8 [[64, 128, 2048]] [64, 1024, 2048] 132,096 BatchNorm-8 [[64, 1024, 2048]] [64, 1024, 2048] 4,096 ReLU-10 [[64, 1024, 2048]] [64, 1024, 2048] 0 MaxPool1D-2 [[64, 1024, 2048]] [64, 1024, 1] 0 Linear-4 [[64, 1024]] [64, 512] 524,800 ReLU-11 [[64, 512]] [64, 512] 0 Linear-5 [[64, 512]] [64, 256] 131,328 ReLU-12 [[64, 256]] [64, 256] 0 Linear-6 [[64, 256]] [64, 4096] 1,052,672 Conv1D-9 [[64, 64, 2048]] [64, 64, 2048] 4,160 BatchNorm-9 [[64, 64, 2048]] [64, 64, 2048] 256 ReLU-13 [[64, 64, 2048]] [64, 64, 2048] 0 Conv1D-10 [[64, 64, 2048]] [64, 128, 2048] 8,320 BatchNorm-10 [[64, 128, 2048]] [64, 128, 2048] 512 ReLU-14 [[64, 128, 2048]] [64, 128, 2048] 0 Conv1D-11 [[64, 128, 2048]] [64, 1024, 2048] 132,096 BatchNorm-11 [[64, 1024, 2048]] [64, 1024, 2048] 4,096 ReLU-15 [[64, 1024, 2048]] [64, 1024, 2048] 0 Linear-7 [[64, 1024]] [64, 512] 524,800 ReLU-16 [[64, 512]] [64, 512] 0 Linear-8 [[64, 512]] [64, 256] 131,328 ReLU-17 [[64, 256]] [64, 256] 0 Dropout-1 [[64, 256]] [64, 256] 0 Linear-9 [[64, 256]] [64, 16] 4,112 LogSoftmax-1 [[64, 16]] [64, 16] 0 =========================================================================== Total params: 3,476,825 Trainable params: 3,461,721 Non-trainable params: 15,104 --------------------------------------------------------------------------- Input size (MB): 1.50 Forward/backward pass size (MB): 11333.40 Params size (MB): 13.26 Estimated Total Size (MB): 11348.16 --------------------------------------------------------------------------- {'total_params': 3476825, 'trainable_params': 3461721} 四、训练
说明:模型训练的时候,将会使用paddle.optimizer.Adam优化器来进行优化。使用F.nll_loss来计算损失值。
def train(): model = PointNet(num_classes=16, num_point=2048) model.train() optim = paddle.optimizer.Adam( parameters=model.parameters(), weight_decay=0.001 ) epoch_num = 10 for epoch in range(epoch_num): # train print( "===================================train===========================================" ) for batch_id, data in enumerate(train_loader()): inputs, labels = data predicts = model(inputs) loss = F.nll_loss(predicts, labels) acc = paddle.metric.accuracy(predicts, labels) if batch_id % 20 == 0: print( "train: epoch: {}, batch_id: {}, loss is: {}, accuracy is: {}".format( epoch, batch_id, loss.numpy(), acc.numpy() ) ) loss.backward() optim.step() optim.clear_grad() if epoch % 2 == 0: paddle.save(model.state_dict(), "./model/PointNet.pdparams") paddle.save(optim.state_dict(), "./model/PointNet.pdopt") # validation print( "===================================val===========================================" ) model.eval() accuracies = [] losses = [] for batch_id, data in enumerate(val_loader()): inputs, labels = data predicts = model(inputs) loss = F.nll_loss(predicts, labels) acc = paddle.metric.accuracy(predicts, labels) losses.append(loss.numpy()) accuracies.append(acc.numpy()) avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) print( "validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc) ) model.train() if __name__ == "__main__": train() ===================================train=========================================== train: epoch: 0, batch_id: 0, loss is: [8.135595], accuracy is: [0.046875] train: epoch: 0, batch_id: 20, loss is: [0.96110815], accuracy is: [0.7265625] train: epoch: 0, batch_id: 40, loss is: [0.77762437], accuracy is: [0.8046875] train: epoch: 0, batch_id: 60, loss is: [0.575164], accuracy is: [0.84375] train: epoch: 0, batch_id: 80, loss is: [0.60243726], accuracy is: [0.8359375] ===================================val=========================================== validation: loss is: 0.5027859807014465, accuracy is: 0.848895251750946 ===================================train=========================================== train: epoch: 1, batch_id: 0, loss is: [0.5886416], accuracy is: [0.8359375] train: epoch: 1, batch_id: 20, loss is: [0.59509534], accuracy is: [0.8515625] train: epoch: 1, batch_id: 40, loss is: [0.43501458], accuracy is: [0.875] train: epoch: 1, batch_id: 60, loss is: [0.5497817], accuracy is: [0.8515625] train: epoch: 1, batch_id: 80, loss is: [0.2889481], accuracy is: [0.8984375] ===================================val=========================================== validation: loss is: 0.2470872551202774, accuracy is: 0.9263771176338196 ===================================train=========================================== train: epoch: 2, batch_id: 0, loss is: [0.43095332], accuracy is: [0.8984375] train: epoch: 2, batch_id: 20, loss is: [0.42620662], accuracy is: [0.8984375] train: epoch: 2, batch_id: 40, loss is: [0.31073096], accuracy is: [0.8984375] train: epoch: 2, batch_id: 60, loss is: [0.21410619], accuracy is: [0.9375] train: epoch: 2, batch_id: 80, loss is: [0.23696409], accuracy is: [0.9296875] ===================================val=========================================== validation: loss is: 0.24663102626800537, accuracy is: 0.9278147220611572 ===================================train=========================================== train: epoch: 3, batch_id: 0, loss is: [0.1000444], accuracy is: [0.96875] train: epoch: 3, batch_id: 20, loss is: [0.2845613], accuracy is: [0.9296875] train: epoch: 3, batch_id: 40, loss is: [0.46592], accuracy is: [0.859375] train: epoch: 3, batch_id: 60, loss is: [0.3819336], accuracy is: [0.9140625] train: epoch: 3, batch_id: 80, loss is: [0.08518291], accuracy is: [0.9765625] ===================================val=========================================== validation: loss is: 0.17066480219364166, accuracy is: 0.9491525292396545 ===================================train=========================================== train: epoch: 4, batch_id: 0, loss is: [0.11713062], accuracy is: [0.9609375] train: epoch: 4, batch_id: 20, loss is: [0.1716559], accuracy is: [0.953125] train: epoch: 4, batch_id: 40, loss is: [0.15082854], accuracy is: [0.96875] train: epoch: 4, batch_id: 60, loss is: [0.2787561], accuracy is: [0.96875] train: epoch: 4, batch_id: 80, loss is: [0.11986132], accuracy is: [0.9609375] ===================================val=========================================== validation: loss is: 0.1389710158109665, accuracy is: 0.9608050584793091 ===================================train=========================================== train: epoch: 5, batch_id: 0, loss is: [0.17427993], accuracy is: [0.9453125] train: epoch: 5, batch_id: 20, loss is: [0.25355965], accuracy is: [0.9609375] train: epoch: 5, batch_id: 40, loss is: [0.18881711], accuracy is: [0.9609375] train: epoch: 5, batch_id: 60, loss is: [0.14433464], accuracy is: [0.953125] train: epoch: 5, batch_id: 80, loss is: [0.13028377], accuracy is: [0.96875] ===================================val=========================================== validation: loss is: 0.09753856807947159, accuracy is: 0.9671609997749329 ===================================train=========================================== train: epoch: 6, batch_id: 0, loss is: [0.12662013], accuracy is: [0.9765625] train: epoch: 6, batch_id: 20, loss is: [0.1309431], accuracy is: [0.9609375] train: epoch: 6, batch_id: 40, loss is: [0.29988244], accuracy is: [0.9453125] train: epoch: 6, batch_id: 60, loss is: [0.114668], accuracy is: [0.9609375] train: epoch: 6, batch_id: 80, loss is: [0.48784435], accuracy is: [0.9296875] ===================================val=========================================== validation: loss is: 0.16411711275577545, accuracy is: 0.9576271176338196 ===================================train=========================================== train: epoch: 7, batch_id: 0, loss is: [0.12558301], accuracy is: [0.9609375] train: epoch: 7, batch_id: 20, loss is: [0.1776012], accuracy is: [0.953125] train: epoch: 7, batch_id: 40, loss is: [0.12831621], accuracy is: [0.9609375] train: epoch: 7, batch_id: 60, loss is: [0.15245995], accuracy is: [0.953125] train: epoch: 7, batch_id: 80, loss is: [0.08825297], accuracy is: [0.9609375] ===================================val=========================================== validation: loss is: 0.06742173433303833, accuracy is: 0.9809321761131287 ===================================train=========================================== train: epoch: 8, batch_id: 0, loss is: [0.07868354], accuracy is: [0.96875] train: epoch: 8, batch_id: 20, loss is: [0.1875119], accuracy is: [0.96875] train: epoch: 8, batch_id: 40, loss is: [0.04444], accuracy is: [0.9921875] train: epoch: 8, batch_id: 60, loss is: [0.08977574], accuracy is: [0.9765625] train: epoch: 8, batch_id: 80, loss is: [0.13062863], accuracy is: [0.9765625] ===================================val=========================================== validation: loss is: 0.13399624824523926, accuracy is: 0.9661017060279846 ===================================train=========================================== train: epoch: 9, batch_id: 0, loss is: [0.14676869], accuracy is: [0.953125] train: epoch: 9, batch_id: 20, loss is: [0.16409941], accuracy is: [0.9609375] train: epoch: 9, batch_id: 40, loss is: [0.08795467], accuracy is: [0.96875] train: epoch: 9, batch_id: 60, loss is: [0.05970801], accuracy is: [0.984375] train: epoch: 9, batch_id: 80, loss is: [0.2631768], accuracy is: [0.9296875] ===================================val=========================================== validation: loss is: 0.11335306614637375, accuracy is: 0.9682203531265259 五、评估与测试
说明:通过model.load_dict的方式加载训练好的模型对测试集上的数据进行评估与测试。
def evaluation(): model = PointNet() model_state_dict = paddle.load("./model/PointNet.pdparams") model.load_dict(model_state_dict) model.eval() accuracies = [] losses = [] for batch_id, data in enumerate(test_loader()): inputs, labels = data predicts = model(inputs) loss = F.nll_loss(predicts, labels) acc = paddle.metric.accuracy(predicts, labels) losses.append(loss.numpy()) accuracies.append(acc.numpy()) avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) print("validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc)) if __name__ == "__main__": evaluation()