mindspore模型训练和模型导出为onnx

mindspore是华为深度学习框架,网址为:https://www.mindspore.cn/

本代码主要参考快速入门的代码,加了模型导出为onnx

mindspore在模型搭建上基本上的语法和pytorch差不多

只是分为了网络和模型,模型主要拿来训练和预测,而网络就是单纯的网络,网络可以拿来导出模型文件,但是预测只能使用模型

训练代码如下:

# -*- coding: utf-8 -*-
import os
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor, Model,export,load_checkpoint
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
import numpy as np
import mindspore.dataset as ds
from mindspore.train.callback import Callback


# https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip
train_data_path = "./datasets/MNIST_Data/train"
test_data_path = "./datasets/MNIST_Data/test"
mnist_path = "./datasets/MNIST_Data"
model_path = "./models/ckpt/"

#定义数据集
def create_dataset(data_path, batch_size=128, repeat_size=1,
                   num_parallel_workers=1):
    """ 
    create dataset for train or test
    
    Args:
        data_path (str): Data path
        batch_size (int): The number of data records in each group
        repeat_size (int): The number of replicated data records
        num_parallel_workers (int): The number of parallel workers
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define some parameters needed for data enhancement and rough justification
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to a dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    
    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds

# custom callback function
class StepLossAccInfo(Callback):
    def __init__(self, model, eval_dataset, step_loss, steps_eval):
        self.model = model
        self.eval_dataset = eval_dataset
        self.step_loss = step_loss
        self.steps_eval = steps_eval
        
    def step_end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
        self.step_loss["loss_value"].append(str(cb_params.net_outputs))
        self.step_loss["step"].append(str(cur_step))
        if cur_step % 125 == 0:
            acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
            self.steps_eval["step"].append(cur_step)
            self.steps_eval["acc"].append(acc["Accuracy"])

#定义网络            
class  mnist(nn.Cell):
     def __init__(self, num_class=10):
        super(mnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')
        self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        
     def construct(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))        
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x) 
        return x    
                
network = mnist()
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

#定义模型
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )
# save the network model and parameters for subsequence fine-tuning
config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)
# group layers into an object with training and evaluation features
ckpoint_cb = ModelCheckpoint(prefix="mnist", directory=model_path, config=config_ck)
eval_dataset = create_dataset("./datasets/MNIST_Data/test")

step_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# collect the steps,loss and accuracy information
step_loss_acc_info = StepLossAccInfo(model , eval_dataset, step_loss, steps_eval)

repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(1, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False)

#测试
ds_test = create_dataset(test_data_path).create_dict_iterator()
data = next(ds_test)
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
print(model.predict(Tensor(data['image'])))
print(images.shape)
#导出格式为onnx文件
load_checkpoint("models\\ckpt\\mnist-1_1875.ckpt", net=network)
print("load****")
input =np.ones([1, 1, 32, 32]).astype(np.float32)
export(network, Tensor(input), file_name='minist', file_format='ONNX')

注意,数据集需要自己提前下载:https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip

运行后就可以导出模型

运行结果:

可以打开onnx文件看看网络结构:

 

验证一下数据结果:

onnx:

import onnxruntime
import numpy as np

x=np.ones([1, 1, 32, 32]).astype(np.float32)
session = onnxruntime.InferenceSession("minist.onnx")
inputs = {session.get_inputs()[0].name: x}
outs = session.run(None, inputs)
print('onnx result is:',outs)

运行结果:

 mindspore:

import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Tensor, Model
import numpy as np

class  mnist(nn.Cell):
     def __init__(self, num_class=10):
        super(mnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(8, 12, 5, pad_mode='valid')
        self.fc1 = nn.Dense(300 , 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 60, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(60, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        
     def construct(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))        
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x) 
        return x
    
network = mnist()
ckpt_path = "models\\ckpt\\mnist-1_1875.ckpt"
trained_ckpt = load_checkpoint(ckpt_path)
load_param_into_net(network, trained_ckpt)
input=np.ones([1, 1, 32, 32]).astype(np.float32)
model = Model(network, metrics={'acc'}, eval_network=network)
print(model.predict(Tensor(input)))

运行结果:

运行结果一致,说明导出文件正确。 

猜你喜欢

转载自blog.csdn.net/zhou_438/article/details/114024403