RELATEED CONSULTING
相关咨询
选择下列产品马上在线沟通
服务时间:8:30-17:00
你可能遇到了下面的问题
关闭右侧工具栏

新闻中心

这里有您想知道的互联网营销解决方案
使用PointNet++测试分类自己的数据集并可视化-创新互联

我这里PointNet++的代码用的是pytorch版本的,链接为 https://github.com/yanx27/Pointnet2_pytorch

公司专注于为企业提供成都网站制作、网站设计、微信公众号开发、电子商务商城网站建设微信平台小程序开发,软件按需策划设计等一站式互联网企业服务。凭借多年丰富的经验,我们会仔细了解各客户的需求而做出多方面的分析、设计、整合,为客户设计出具风格及创意性的商业解决方案,成都创新互联更提供一系列网站制作和网站推广的服务。

将自己的数据集格式修改为和modelnet40_normal_resampled数据集格式一样。

 由于源码中测试脚本只是输出了测试数据集的分类精确度,且测试数据集同样的是有标签的,没有模型验证脚本,由于个人实验需要,希望当模型训练完成后能用自己的无标签数据输入后输出类别去检测模型的分类效果,因此根据模型测试脚本,修改了一下代码,可以实现输入一个无标签的数据,从而输出分类结果以及可视化,从而更直观的验证模型训练的准确度。 

代码如下,其中可视化部分参考这位博主的文章 pointconv pytorch modelnet40 点云分类结果可视化_对象被抛出的博客-博客_modelnet40可视化

from data_utils.ModelNetDataLoader_my import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import logging
from tqdm import tqdm
import sys
import importlib
import matplotlib.pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

def pc_normalize(pc):  #点云数据归一化
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('Testing')
    parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size in training')
    parser.add_argument('--num_category', default=10, type=int, choices=[10, 40],  help='training on ModelNet10/40')
    parser.add_argument('--num_point', type=int, default=10000, help='Point Number')
    parser.add_argument('--log_dir', type=str, default='pointnet2_cls_msg', help='Experiment root')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
    parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
    return parser.parse_args()
#加载数据集
dataset='/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/evalset/aaa_1.txt'
pcdataset = np.loadtxt(dataset, delimiter=' ').astype(np.float32)#数据读取
point_set = pcdataset[0:10000, :] #我的输入数据设置为原始数据中10000个点
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) #归一化数据
point_set = point_set[:, 0:3] 
point_set = point_set.transpose(1,0)#将数据由N*C转换为C*N
#print(point_set.shape)
point_set = point_set.reshape(1, 3, 10000)
n_points = point_set
point_set = torch.as_tensor(point_set)#需要将数据格式变为张量,不然会报错
point_set = point_set.cuda()
#print(point_set.shape)
#print(point_set.shape)
#分类测试函数
def test(model,point_set, num_class=10, vote_num=1):
    #mean_correct = []
    classifier = model.eval()
    class_acc = np.zeros((num_class, 3))
    vote_pool = torch.zeros(1, 10).cuda()
    for _ in range(vote_num):
        pred, _ = classifier(point_set)
        print(pred)
        vote_pool += pred
    pred = vote_pool / vote_num
    # 对预测结果每行取大值得到分类
    pred_choice = pred.data.max(1)[1]
    print(pred_choice)
    #可视化
    file_dir = '/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/visualizer'
    save_name_prefix = 'pred'
    draw(n_points[:, 0, :], n_points[:, 1, :], n_points[:, 2, :], save_name_prefix, file_dir, color=pred_choice)
    return pred_choice
#定义可视化函数
def draw(x, y, z, name, file_dir, color=None):
    """
    绘制单个样本的三维点图
    """
    if color is None:
        for i in range(len(x)):
            ax = plt.subplot(projection='3d')  # 创建一个三维的绘图工程
            save_name = name + '-{}.png'.format(i)
            save_name = os.path.join(file_dir,save_name)
            ax.scatter(x[i], y[i], z[i],s=0.1, c='r')
            ax.set_zlabel('Z')  # 坐标轴
            ax.set_ylabel('Y')
            ax.set_xlabel('X')
            plt.draw()
            plt.savefig(save_name)
            # plt.show()
    else:
        colors = ['red', 'blue', 'green', 'yellow', 'orange', 'tan', 'orangered', 'lightgreen', 'coral', 'aqua']
        for i in range(len(x)):
            ax = plt.subplot(projection='3d')  # 创建一个三维的绘图工程
            save_name = name + '-{}-{}.png'.format(i, color[i])
            save_name = os.path.join(file_dir,save_name)
            ax.scatter(x[i], y[i], z[i],s=0.1, c=colors[color[i]])
            ax.set_zlabel('Z')  # 坐标轴
            ax.set_ylabel('Y')
            ax.set_xlabel('X')
            plt.draw()
            plt.savefig(save_name)
            # plt.show()

def main(args):
    def log_string(str):
        logger.info(str)
        print(str)
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    experiment_dir = '/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/log/classification/' + args.log_dir
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    '''
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''
    num_class = args.num_category
    #选择模型
    model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
    model = importlib.import_module(model_name)
    
    classifier = model.get_model(num_class, normal_channel=args.use_normals)
    if not args.use_cpu:
        classifier = classifier.cuda()
    #选择训练好的.pth文件
    checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    classifier.load_state_dict(checkpoint['model_state_dict'])
    #预测分类
    with torch.no_grad():
         pred_choice = test(classifier.eval(), point_set, vote_num=args.num_votes, num_class=num_class)
         #log_string('pred_choice: %f' % (pred_choice))

if __name__ == '__main__':
    args = parse_args()
    main(args)

根据自己的数据格式修改自己对应的参数以及数据集路径运行即可

分类输出结果:

输出为分类的数据类别3

可视化结果保存在visualizer文件下,可视化结果:

你是否还在寻找稳定的海外服务器提供商?创新互联www.cdcxhl.cn海外机房具备T级流量清洗系统配攻击溯源,准确流量调度确保服务器高可用性,企业级服务器适合批量采购,新人活动首月15元起,快前往官网查看详情吧


网页名称:使用PointNet++测试分类自己的数据集并可视化-创新互联
网页地址:http://scpingwu.com/article/ijhej.html