Java调用PMML模型

生成PMML模型

具体见我的上一篇博客Python XGBoost保存模型PMML

Java调用PMML模型

Java基本的运行环境就不说了,大家如果能看到这篇文章,基本上就都掌握了Java运行环境。
首先maven导入需要的jar包

  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>4.11</version>
      <scope>test</scope>
    </dependency>
    <dependency>
      <groupId>org.jpmml</groupId>
      <artifactId>pmml-evaluator</artifactId>
      <version>1.4.1</version>
    </dependency>
    <dependency>
      <groupId>org.jpmml</groupId>
      <artifactId>pmml-evaluator-extension</artifactId>
      <version>1.4.1</version>
    </dependency>
  </dependencies>

导入jar包后,将下面代码复制到代码处

package sso.passport;
/**
 * function:java实现调用pmml文件
 * datatime:2020-07-10 16:09
 **/

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import java.io.*;
import java.util.*;
import java.util.List;

public class Classification {
    public static void main(String[] args) throws Exception {
        //模型路径
        String pathxml = System.getProperty("user.dir") + "/model/xgboost.pmml";
        //传入模型特征数据
        Map<String, Double> map = new HashMap<String, Double>();
        map.put("x1", 5.1);
        map.put("x2", 3.5);
        map.put("x3", 0.4);
        map.put("x4", 0.2);
        //模型预测
        predictLrHeart(map, pathxml);
    }

    public static void predictLrHeart(Map<String, Double> irismap, String pathxml) throws Exception {
        PMML pmml;
        File file = new File(pathxml);
        InputStream inputStream = new FileInputStream(file);
        try (InputStream is = inputStream) {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);

            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
            ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
            Evaluator evaluator = (Evaluator) modelEvaluator;

            List<InputField> inputFields = evaluator.getInputFields();
            Map<FieldName, FieldValue> argements = new LinkedHashMap<>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getName();
                Object raeValue = irismap.get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(raeValue);
                argements.put(inputFieldName, inputFieldValue);
            }
            Map<FieldName, ?> results = evaluator.evaluate(argements);
            List<TargetField> targetFields = evaluator.getTargetFields();
            for (TargetField targetField : targetFields) {
                FieldName targetFieldName = targetField.getName();
                Object targetFieldValue = results.get(targetFieldName);
//                System.out.println("target: " + targetFieldName.getValue());
                System.out.println(targetFieldValue);
            }
        }
    }
}

本代码也是根据鸢尾花数据进行操作的,由于本人对java语言不甚了解,其中详细注释不好多说,但是一看就能明白。
大家运行如果报错的话,请看下一篇文章(可能有你需要的哦)。
(1)、如果您在阅读博客时遇到问题或者不理解的地方,可以联系我,互相交流、互相进步;
(2)、本人业余时间可以承接毕业设计和各种小项目,如系统构建、成立网站、数据挖掘、机器学习、深度学习等。有需要的加QQ:1143948594,备注“csdn项目”。

猜你喜欢

转载自blog.csdn.net/qq_32113189/article/details/107541890