模型pmml方式存储跨语言调用

pmml提供了一种轻便的跨语言调用,比如说我用python训练一个模型,想要在java里面调,很多时候还需要重构底层的逻辑,造就诸多不方便,pmml正为此产生,不需要再重构底层逻辑,在python中ligbtgbm、xgboost、tensorflow训练的模型都可以通过这种方式在java或者其他语言里面调,在性能方面可能并不是一种很高效的方式,但是可以作为一种参考。

主要参考:
https://blog.csdn.net/hopeztm/article/details/78321700
 https://henning.kropponline.de/2015/09/06/jpmml-example-random-forest/

https://github.com/jpmml/jpmml-evaluator


首先是安装sklearn2pmml包:

pip install git+https://github.com/jpmml/sklearn2pmml.git


先看下python里面训练一个随机森林模型,保存为pmml文件,数据用的鸾尾花数据,网上到处都有,直接看代码:

from sklearn_pandas import DataFrameMapper
import pandas as pd
from sklearn import tree
from sklearn2pmml import PMMLPipeline
from sklearn2pmml import sklearn2pmml

iris_df = pd.read_csv("xml/iris.csv")
clf = tree.DecisionTreeClassifier()
print(iris_df.columns)
mapper = DataFrameMapper([
    (['sepal_length'], None),
    (['sepal_width'], None),
    (['petal_length'], None),
    (['petal_width'], None)
])
pipeline = PMMLPipeline([('mapper', mapper), ("classifier", clf)])
pipeline.fit(iris_df[iris_df.columns.difference(["species"])],iris_df["species"])
sklearn2pmml(pipeline,"./xml/IrisClassificationTree.pmml",with_repr = True)
在java里面调:

看下pom.xml依赖:

<dependency>
			<groupId>org.jpmml</groupId>
			<artifactId>pmml-evaluator</artifactId>
			<version>1.3.6</version>
		</dependency>
		<dependency>
			<groupId>org.jpmml</groupId>
			<artifactId>pmml-evaluator-extension</artifactId>
			<version>1.3.6</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-model -->
		<dependency>
			<groupId>org.jpmml</groupId>
			<artifactId>pmml-model</artifactId>
			<version>1.3.6</version>
		</dependency>

java代码:

package com.meituan.test;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;

public class PMMLPrediction {
	
	public static void main(String[] args) throws Exception {
		String  pathxml="/Users/shuubiasahi/Documents/python/credit-tfgan/xml/IrisClassificationTree.pmml";
		Map<String, Double>  map=new HashMap<String, Double>();
		map.put("sepal_length", 5.1);
		map.put("sepal_width", 3.5);
		map.put("petal_length", 1.4);
		map.put("petal_width", 0.2);	
		predictLrHeart(map, pathxml);
	}
	
	public static void predictLrHeart(Map<String, Double> irismap,String  pathxml)throws Exception {

		PMML pmml;
		// 模型导入
		File file = new File(pathxml);
		InputStream inputStream = new FileInputStream(file);
		try (InputStream is = inputStream) {
			pmml = org.jpmml.model.PMMLUtil.unmarshal(is);

			ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
					.newInstance();
			ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
					.newModelEvaluator(pmml);
			Evaluator evaluator = (Evaluator) modelEvaluator;

			List<InputField> inputFields = evaluator.getInputFields();
			Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
			for (InputField inputField : inputFields) {
				FieldName inputFieldName = inputField.getName();
				Object rawValue = irismap
						.get(inputFieldName.getValue());
				FieldValue inputFieldValue = inputField.prepare(rawValue);
				arguments.put(inputFieldName, inputFieldValue);
			}

			Map<FieldName, ?> results = evaluator.evaluate(arguments);
			List<TargetField> targetFields = evaluator.getTargetFields();
			//对于分类问题等有多个输出。
			for (TargetField targetField : targetFields) {
				FieldName targetFieldName = targetField.getName();
				Object targetFieldValue = results.get(targetFieldName);
				System.out.println("target: " + targetFieldName.getValue()
						+ " value: " + targetFieldValue);
			}
		}
	}
}

结果:
target: species value: NodeScoreDistribution{result=0, probability_entries=[0=1.0, 1=0.0, 2=0.0], entityId=2, confidence_entries=[]}

猜你喜欢

转载自blog.csdn.net/luoyexuge/article/details/80079569