模型持久化方法(pickle和PMML)

pickle方式

pickle方式应用场景:在python环境中训练模型,然后使用pickle将模型持久化为一个模型文件,然后就可以在python环境加载持久化后的模型文件对新数据进行预测。

1、安装pickle模块

pip install pickle

2、代码示例

pickle_demo.py模块完成了模型的训练,持久化和模型加载代码如下:

"""
pickle方式模型持久化
"""

import pickle
from sklearn import linear_model as lm
from src.utils import read_data
import os


def train_and_save_model(data, model_path):
    model = lm.LinearRegression()
    model.fit(data[["x"]], data["y"])
    pickle.dump(model, open(model_path, "wb"))

    return model


def load_model(model_path):
    model = pickle.load(open(model_path, "rb"))
    print(model.coef_)
    return model


if __name__ == "__main__":
    model_path = os.path.dirname(os.path.abspath(__file__)) + "/data/liner_model.pickle"
    train_and_save_model(read_data("simple_example.csv"), model_path)
    load_model(model_path)

使用到自定义模块utis.py代码如下:

import os
import pandas as pd


def read_data(file_name):
    """
    使用pandas读取数据
    """
    home_path = os.path.dirname(os.path.abspath(__file__))
    # Windows下的存储路径与Linux并不相同
    if os.name == "nt":
        data_path = "%s\\data\\%s" % (home_path, file_name)
    else:
        data_path = "%s/data/%s" % (home_path,file_name)

    return pd.read_csv(data_path)

simple.csv数据文件内容如下:

x,y
10,7.7
10,9.87
11,11.18
12,10.43
13,12.36
14,14.15
15,15.73
16,16.4
17,18.86
18,16.13
19,18.21
20,18.37
21,22.61
22,19.83
23,22.67
24,22.7
25,25.16
26,25.55
27,28.21
28,28.12

程序成功运行后,会生成src/data/liner_model.pickle文件,内容如图:

PMML方式

PMML方式是标准的一种方式,应用场景:主要应用在不同的环境,比如在python环境中训练模型,然后在Java环境调用模型。

1、安装sklearn2pmml

执行安装命令pip install sklearn2pmml,如图:

2、代码示例

pmml_demo.py模块完成模型训练和持久化,这里选用的是最简单的线性回归,代码如下:

import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from src.utils import read_data
from sklearn import linear_model as lm
import os


def save_model(data, model_path):
    pipeline = PMMLPipeline([("regression", lm.LinearRegression())])

    pipeline.fit(data[["x"]], data["y"])
    pmml.sklearn2pmml(pipeline, model_path, with_repr=True)


if __name__ == "__main__":
    data = read_data("simple_example.csv")
    model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/data/liner_model.pmml"
    save_model(data, model_path)

 生成一个src/data/liner_model.pmml文件,其实就是xml文件,内容如下:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
	<Header>
		<Application name="JPMML-SkLearn" version="1.5.32"/>
		<Timestamp>2020-03-11T02:31:05Z</Timestamp>
	</Header>
	<MiningBuildTask>
		<Extension>PMMLPipeline(steps=[('regression', LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False))])</Extension>
	</MiningBuildTask>
	<DataDictionary>
		<DataField name="y" optype="continuous" dataType="double"/>
		<DataField name="x" optype="continuous" dataType="double"/>
	</DataDictionary>
	<RegressionModel functionName="regression">
		<MiningSchema>
			<MiningField name="y" usageType="target"/>
			<MiningField name="x"/>
		</MiningSchema>
		<RegressionTable intercept="-0.9495378313625551">
			<NumericPredictor name="x" coefficient="1.0329669989952859"/>
		</RegressionTable>
	</RegressionModel>
</PMML>

下面通过Java代码来调用这个持久化的模型,其本质就是解析这个xml文件,java工程的pom.xml中需要引入解析模型的依赖包,如下:

<dependency>
	<groupId>org.jpmml</groupId>
	<artifactId>pmml-evaluator</artifactId>
	<version>1.4.1</version>
</dependency>

读取模型文件并做预测的java代码如下: 


import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.xml.bind.JAXBException;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
import org.xml.sax.SAXException;

/**
 * @Description: TODO
 * @author leboop
 * @date 2020年3月11日
 */
public class PMMLDemo {

	private static Evaluator loadPmml(String modelPath) {
		PMML pmml = new PMML();
		InputStream inputStream = null;
		try {
			inputStream = new FileInputStream(modelPath);
		} catch (IOException e) {
			e.printStackTrace();
		}
		try {
			pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
		} catch (SAXException e1) {
			e1.printStackTrace();
		} catch (JAXBException e1) {
			e1.printStackTrace();
		} finally {
			// 关闭输入流
			try {
				inputStream.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
		ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
		Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
		pmml = null;

		return evaluator;
	}

	private static Object predict(Evaluator evaluator, int fearture) {
		// 输入数据
		Map<String, Integer> data = new HashMap<String, Integer>();
		data.put("x", fearture);
		// 模型的输入特征名称,模型是线性回归y=ax+b,特征名称就是x
		List<InputField> inputFields = evaluator.getInputFields();
		// 输出 [InputField{name=x, dataType=DOUBLE, opType=CONTINUOUS}]
		System.out.println(inputFields);
		// 过模型的原始特征,从画像中获取数据,作为模型输入
		Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
		// 遍历模型特征名称
		for (InputField inputField : inputFields) {
			FieldName inputFieldName = inputField.getName();
			// 输出x
			System.out.println(inputFieldName);
			// 从输入数据中获取对应特征名称的输入值,这里特征名称是x,输入值fearture=20,所以rawValue是Integer
			Object rawValue = data.get(inputFieldName.getValue());
			FieldValue inputFieldValue = inputField.prepare(rawValue);
			// 输出:ContinuousDouble{opType=CONTINUOUS, dataType=DOUBLE, value=20.0}
			System.out.println(inputFieldValue);
			// 组装输入的特征名称和对应的数值
			arguments.put(inputFieldName, inputFieldValue);
		}
		// 预测结果并输出map对象:{y=19.70980214854316}
		Map<FieldName, ?> results = evaluator.evaluate(arguments);
		System.out.println(results);
		// 模型的结果参数,输出:TargetField{name=y, dataType=DOUBLE, opType=CONTINUOUS}
		List<TargetField> targetFields = evaluator.getTargetFields();
		FieldName targetFieldName = targetFields.get(0).getName();
		// 获取结果对应的值,也就是y对应的值
		Object targetFieldValue = results.get(targetFieldName);
		// 输出:target: y value: 19.70980214854316
		System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
		
		return targetFieldValue;
	}

	public static void main(String args[]) {
		String modelPath="G:\\pycharm_workspace\\machine_learning_study\\src\\data\\liner_model.pmml";
		Evaluator model = PMMLDemo.loadPmml(modelPath);
		Object result=PMMLDemo.predict(model, 20);
		//最后预测的结果是19.70980214854316
		System.out.println(result);
	}
}

程序输出结果:

[InputField{name=x, dataType=DOUBLE, opType=CONTINUOUS}]
x
ContinuousDouble{opType=CONTINUOUS, dataType=DOUBLE, value=20.0}
{y=19.70980214854316}
target: y value: 19.70980214854316
19.70980214854316
发布了89 篇原创文章 · 获赞 79 · 访问量 10万+

猜你喜欢

转载自blog.csdn.net/L_15156024189/article/details/104788886