Xgboost存储为pmml以及在java中调用

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Katherine_hsr/article/details/84951324

首先,我们构建一个xgboost模型并存储为PMML形式,使用到的包是sklearn2pmml,可以将sklearn中的模型保存为PMML的形式

import pandas as pd
from xgboost.sklearn import XGBClassifier
from sklearn2pmml import PMMLPipeline
from sklearn_pandas import DataFrameMapper
from sklearn2pmml import sklearn2pmml


df = pd.read_excel('/Users/huoshirui/Desktop/xyworking/pythonData/dataClean/kexin_data_huoshirui.xlsx')
df = df.drop(columns=['mbl_no'])

clf = XGBClassifier(
 learning_rate=0.01,
 n_estimators=1000,
 max_depth=4,
 min_child_weight=1,
 gamma=0.0001,
 subsample=0.3,
 colsample_bytree=0.8,
 colsample_bylevel=0.7,
 objective='binary:logistic',
 nthread=-1,
 scale_pos_weight=1,
 seed=666)

mapper = DataFrameMapper([
    (['kx_output_riskscore'], None),
    (['kx_new_risk_0'], None),
    (['kx_new_risk_1'], None),
    (['kx_new_risk_2'], None),
    (['kx_new_risk_3'], None),
    (['kx_new_risk_4'], None),
    (['kx_new_risk_5'], None),
    (['kx_new_risk_6'], None),
    (['kx_new_risk_7'], None),
    (['kx_new_risk_8'], None),
    (['kx_new_risk_11'], None),
    (['kx_new_risk_12'], None),
    (['kx_new_risk_13'], None),
    (['kx_new_risk_14'], None),
    (['kx_new_risk_15'], None),
    (['kx_new_risk_sumList'], None),
    (['kx_new_is_riskList'], None)
])


pipeline = PMMLPipeline([('mapper', mapper), ("classifier", clf)])

pipeline.fit(df[df.columns.difference(["target"])],df["target"])
# 存储为PMML形式
sklearn2pmml(pipeline,"/Users/huoshirui/Desktop/test/PMML/xgboost.pmml",with_repr = True)

然后我们可以在当前文件夹中得到xgboost.pmml文件,可以使用编辑器直接打开查看,部分内容如下:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
	<Header>
		<Application name="JPMML-SkLearn" version="1.5.8"/>
		<Timestamp>2018-12-05T02:44:50Z</Timestamp>
	</Header>
	<MiningBuildTask>
		<Extension>PMMLPipeline(steps=[('mapper', DataFrameMapper(default=False, df_out=False,
        features=[(['kx_output_riskscore'], None), (['kx_new_risk_0'], None), (['kx_new_risk_1'], None), (['kx_new_risk_2'], None), (['kx_new_risk_3'], None), (['kx_new_risk_4'], None), (['kx_new_risk_5'], None), (['kx_new_risk_6'], None), (['kx_new_risk_7'], None), (['kx_new_risk_8'], None), (['kx_new_risk_11'], None), (['kx_new_risk_12'], None), (['kx_new_risk_13'], None), (['kx_new_risk_14'], None), (['kx_new_risk_15'], None), (['kx_new_risk_sumList'], None), (['kx_new_is_riskList'], None)],
        input_df=False, sparse=False)),
       ('classifier', XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=0.7,
       colsample_bytree=0.8, gamma=0.0001, learning_rate=0.01,
       max_delta_step=0, max_depth=4, min_child_weight=1, missing=None,
       n_estimators=1000, n_jobs=1, nthread=-1,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=666, silent=True,
       subsample=0.3))])</Extension>
	</MiningBuildTask>
	<DataDictionary>
		<DataField name="target" optype="categorical" dataType="integer">
			<Value value="0"/>
			<Value value="1"/>
		</DataField>
		<DataField name="kx_output_riskscore" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_0" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_1" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_2" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_3" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_4" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_5" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_6" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_7" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_8" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_11" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_12" optype="continuous" dataType="double"/>
		<DataField name="kx_new_risk_sumList" optype="continuous" dataType="double"/>
		<DataField name="kx_new_is_riskList" optype="continuous" dataType="double"/>
	</DataDictionary>
	<TransformationDictionary>

有了PMML模型文件,我们就可以写JAVA代码来读取加载这个模型并做预测了。
java代码如下:

package com.seeyon.apps.outerspace.util;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;

public class jpmml {
	
    
    public static void main(String[] args) throws Exception {
		String  pathxml="/Users/huoshirui/Desktop/test/PMML/xgboost.pmml";
		Map<String, Double>  map=new HashMap<String, Double>();
		map.put("kx_output_riskscore", 400D);
		map.put("kx_new_risk_0", 0D);
		map.put("kx_new_risk_1", 1D);
		map.put("kx_new_risk_2", 1D);
		map.put("kx_new_risk_3", 1D);
		map.put("kx_new_risk_4", 0D);
		map.put("kx_new_risk_5", 0D);
		map.put("kx_new_risk_6", 0D);
		map.put("kx_new_risk_7", 0D);
		map.put("kx_new_risk_8", 0D);
		map.put("kx_new_risk_11", 0D);
		map.put("kx_new_risk_12", 0D);
		map.put("kx_new_risk_13", 0D);
		map.put("kx_new_risk_14", 0D);
		map.put("kx_new_risk_15", 0D);
		map.put("kx_new_risk_sumList", 2D);
		map.put("kx_new_is_riskList", 1D);
		predictLrHeart(map, pathxml);
	}
	
	public static void predictLrHeart(Map<String, Double> kxmap,String  pathxml)throws Exception {
 
		PMML pmml;
		
		File file = new File(pathxml);
		InputStream inputStream = new FileInputStream(file);
		try (InputStream is = inputStream) {
			pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
 
			ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
					.newInstance();
			ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
					.newModelEvaluator(pmml);
			Evaluator evaluator = (Evaluator) modelEvaluator;
 
			List<InputField> inputFields = evaluator.getInputFields();
			
			Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
			for (InputField inputField : inputFields) {
				FieldName inputFieldName = inputField.getName();
				Object rawValue = kxmap
						.get(inputFieldName.getValue());
				FieldValue inputFieldValue = inputField.prepare(rawValue);
				arguments.put(inputFieldName, inputFieldValue);
			}
 
			Map<FieldName, ?> results = evaluator.evaluate(arguments);
			List<TargetField> targetFields = evaluator.getTargetFields();
		
			for (TargetField targetField : targetFields) {
				FieldName targetFieldName = targetField.getName();
				Object targetFieldValue = results.get(targetFieldName);
				System.out.println("target: " + targetFieldName.getValue()
						+ " value: " + targetFieldValue);
			}
		}catch (Exception e) {
			inputStream.close();
		}
	}
}

编译运行后结果:
target: target value: ProbabilityDistribution{result=1, probability_entries=[1=0.69272673, 0=0.30727327]}

猜你喜欢

转载自blog.csdn.net/Katherine_hsr/article/details/84951324