基于Kubernetes的机器学习微服务系统设计系列——(六)特征选择微服务

  特征选择微服务主要实现如下特征选择算法:Document Frequency(DF)、Information Gain(IG)、(χ2)Chi-Square Test(CHI)、Mutual Information(MI)、Matrix Projection(MP)。

特征选择类图

  特征选择类图如图所示:

特征选择微服务类图

部分实现代码

特征选择Action类

package com.robin.feature.action;

import com.robin.feature.corpus.CorpusManager;
import com.robin.feature.AbstractFeature;
import com.robin.feature.FeatureFactory;
import com.robin.feature.FeatureFactory.FeatureMethod;
import com.robin.loader.MircoServiceAction;
import com.robin.log.RobinLogger;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;

/**
 * <DT><B>描述:</B></DT>
 * <DD>特征选择Action类</DD>
 *
 * 适配Jersey服务器资源调用
 *
 * @version Version1.0
 * @author Robin
 * @version <I> Date:2018-04-01</I>
 * @author  <I> E-mail:[email protected]</I>
 */
public class FeatureSelectAction implements MircoServiceAction {

    private static final Logger LOGGER = RobinLogger.getLogger();

    /**
     * Action状态码
     */
    public enum StatusCode {
        OK,
        JSON_ERR,
        KIND_ERR,
        VERSION_ERR,
        TRAIN_SCALE_ERR,
        METHOD_ERR,
        TEXTS_NULL,
    }

    /**
     * Action状态内部类
     */
    private class ActionStatus {

        StatusCode statusCode;
        String msg;

    }

    /**
     * 获取返回错误状态JSONObject
     *
     * @param actionStatus
     * @return JSONObject
     */
    private JSONObject getErrorJson(ActionStatus actionStatus) {
        JSONObject errJson = new JSONObject();
        try {
            errJson.put("status", actionStatus.statusCode.toString());
            errJson.put("msg", actionStatus.msg);
        } catch (JSONException ex) {
            LOGGER.log(Level.SEVERE, ex.getMessage());
        }
        return errJson;
    }

    /**
     * 检查JSON输入对象具体项
     *
     * @param jsonObj
     * @param key
     * @param valueSet
     * @param errStatusCode
     * @return ActionStatus
     */
    private ActionStatus checkJSONObjectTerm(JSONObject jsonObj,
            String key,
            HashSet<String> valueSet,
            StatusCode errStatusCode) {
        ActionStatus actionStatus = new ActionStatus();

        try {
            if (!jsonObj.isNull(key)) {
                String value = jsonObj.getString(key);
                if (!valueSet.contains(value)) {
                    actionStatus.msg = "The value [" + value + "] of " + key + " is error.";
                    actionStatus.statusCode = errStatusCode;
                    return actionStatus;
                }
            } else {
                actionStatus.msg = "The input parameter is missing " + key + ".";
                actionStatus.statusCode = errStatusCode;
                return actionStatus;
            }

        } catch (JSONException ex) {
            LOGGER.log(Level.SEVERE, ex.getMessage());
        }

        actionStatus.statusCode = StatusCode.OK;
        return actionStatus;
    }

    /**
     * 检查JSON输入对象
     *
     * @param jsonObj
     * @return ActionStatus
     */
    private ActionStatus checkInputJSONObject(JSONObject jsonObj) {
        ActionStatus actionStatus = new ActionStatus();
        ActionStatus retActionStatus;

        HashSet<String> valueSet = new HashSet();

        valueSet.add("feature");
        retActionStatus = checkJSONObjectTerm(jsonObj, "kind", valueSet, StatusCode.KIND_ERR);
        if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
            return retActionStatus;
        }

        valueSet.clear();
        valueSet.add("v1");
        retActionStatus = checkJSONObjectTerm(jsonObj, "version", valueSet, StatusCode.VERSION_ERR);
        if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
            return retActionStatus;
        }

        try {
            double trainScale = jsonObj.getJSONObject("metadata").getJSONObject("feature").getDouble("trainScale");
            if ((trainScale >= 1.0) || (trainScale <= 0)) {
                actionStatus.statusCode = StatusCode.TRAIN_SCALE_ERR;
                actionStatus.msg = "The input train_scale [" + trainScale + "] is error.";
                return actionStatus;
            }

            valueSet.clear();
            valueSet.add("DF");
            valueSet.add("CHI");
            valueSet.add("MP");
            valueSet.add("IG");
            valueSet.add("MI");

            JSONArray methods = jsonObj.getJSONObject("metadata").getJSONObject("feature").getJSONArray("method");
            for (int i = 0; i < methods.length(); i++) {
                String method = methods.getString(i);
                if (!valueSet.contains(method)) {
                    actionStatus.statusCode = StatusCode.METHOD_ERR;
                    actionStatus.msg = "The input method [" + method + "] is error.";
                    return actionStatus;
                }
            }
        } catch (JSONException ex) {
            LOGGER.log(Level.SEVERE, ex.getMessage());
        }

        actionStatus.statusCode = StatusCode.OK;
        return actionStatus;
    }

    /**
     * 覆盖抽象类中的具体action方法<BR>
     * 实现特征选择具体处理事物
     *
     * @param obj
     * @return Object
     */
    @Override
    public Object action(Object obj) {
        ActionStatus actionStatus = new ActionStatus();
        ActionStatus retActionStatus;

        if (!(obj instanceof JSONObject)) {
            actionStatus.msg = "The action arguments is not JSONObject.";
            LOGGER.log(Level.SEVERE, actionStatus.msg);
            actionStatus.statusCode = StatusCode.JSON_ERR;
            return this.getErrorJson(actionStatus);
        }

        JSONObject corpusJson = (JSONObject) obj;
        retActionStatus = this.checkInputJSONObject(corpusJson);
        if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
            LOGGER.log(Level.SEVERE, retActionStatus.msg);
            return this.getErrorJson(retActionStatus);
        }

        try {
            long beginTime = System.currentTimeMillis();
            JSONObject texts = corpusJson.getJSONObject("texts");
            if (null == texts) {
                actionStatus.statusCode = StatusCode.TEXTS_NULL;
                actionStatus.msg = "The input texts is null.";
                LOGGER.log(Level.SEVERE, actionStatus.msg);
                return this.getErrorJson(actionStatus);
            }

            //生成训练集和测试集
            CorpusManager.divide(corpusJson);
            JSONObject testSetJson = (JSONObject) corpusJson.remove("testSet");
            JSONObject trainSetJson = (JSONObject) corpusJson.remove("trainSet");

            JSONObject metadataFeatureJson = corpusJson.getJSONObject("metadata").getJSONObject("feature");
            Boolean globalFeature = metadataFeatureJson.getBoolean("globalFeature");
            int globalDimension = metadataFeatureJson.getInt("globalDimension");
            Boolean localFeature = metadataFeatureJson.getBoolean("localFeature");
            int localDimension = metadataFeatureJson.getInt("localDimension");

            JSONObject featureSelectJson = new JSONObject();
            JSONObject globalFeatureJson = new JSONObject();
            JSONObject localFeatureJson = new JSONObject();

            //特征选择
            JSONArray methodArr = metadataFeatureJson.getJSONArray("method");
            for (int i = 0; i < methodArr.length(); i++) {
                String selectMethod = methodArr.getString(i);
                AbstractFeature selecter = FeatureFactory.creatInstance(trainSetJson, FeatureMethod.valueOf(selectMethod));
                if (true == globalFeature) {
                    List<Map.Entry<Integer, Double>> featureList = selecter.selectGlobalFeature(globalDimension);
                    JSONArray featureArr = new JSONArray();
                    featureList.forEach((entry) -> {
                        featureArr.put(entry.getKey());
                    });
                    globalFeatureJson.put(selectMethod, featureArr);
                }
                if (true == localFeature) {
                    Map<String, List<Map.Entry<Integer, Double>>> labelsMap = selecter.selectLocalFeature(localDimension);
                    JSONObject labelFeatureJson = new JSONObject();
                    Iterator<String> labelsIt = labelsMap.keySet().iterator();
                    while (labelsIt.hasNext()) {
                        String label = labelsIt.next();
                        JSONArray labelFeatureArr = new JSONArray();
                        List<Map.Entry<Integer, Double>> localFeatureList = labelsMap.get(label);
                        localFeatureList.forEach((entry) -> {
                            labelFeatureArr.put(entry.getKey());
                        });
                        labelFeatureJson.put(label, labelFeatureArr);
                    }
                    localFeatureJson.put(selectMethod, labelFeatureJson);
                }
            }
            featureSelectJson.put("globalFeature", globalFeatureJson);
            featureSelectJson.put("localFeature", localFeatureJson);
            corpusJson.put("featureSelect", featureSelectJson);
            corpusJson.put("trainSet", trainSetJson);
            corpusJson.put("testSet", testSetJson);

            JSONObject preMetadataJson = corpusJson.getJSONObject("metadata").getJSONObject("feature");
            long endTime = System.currentTimeMillis();
            int spendTime = (int) (endTime - beginTime);
            preMetadataJson.put("spendTime", spendTime);
        } catch (JSONException ex) {
            LOGGER.log(Level.SEVERE, ex.getMessage());
        }

        JSONObject rsp = new JSONObject();
        try {
            rsp.put("status", "OK");
            rsp.put("result", corpusJson);
        } catch (JSONException ex) {
            LOGGER.log(Level.SEVERE, ex.getMessage());
        }
        return rsp;
    }
}

特征选择抽象类

package com.robin.feature;

import com.robin.container.MapSort;
import com.robin.feature.corpus.CorpusManager;
import com.robin.log.RobinLogger;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import java.util.logging.Level;
import java.util.logging.Logger;
import org.codehaus.jettison.json.JSONObject;

/**
 * <DT><B>描述:</B></DT>
 * <DD>特征选择抽象类</DD>
 *
 * @version Version1.0
 * @author Robin
 * @version <I> Date:2018-04-05</I>
 * @author  <I> E-mail:[email protected]</I>
 */
public abstract class AbstractFeature {

    /**
     * 日志记录器
     */
    protected static final Logger LOGGER = RobinLogger.getLogger();
    /**
     * 训练集所有词的集合
     */
    protected Set<Integer> globalTermsSet;
    /**
     * 保存每个训练类别的词-文档频数 <类标签,<词编码,频数>>
     */
    protected HashMap<String, HashMap<Integer, Integer>> everyClassDFMap;

    //训练集JSON对象
    protected JSONObject trainSetJson;

    //全局特征-特征值集合
    protected HashMap<Integer, Double> globalFeatureValueMap;

    //局部特征-特征值集合
    protected HashMap<String, HashMap<Integer, Double>> allLocalFeatureValueMap;

    /**
     * 特征选择抽象类构造方法
     *
     * @param trainSetJson
     */
    public AbstractFeature(JSONObject trainSetJson) {
        this.trainSetJson = trainSetJson;
        this.allLocalFeatureValueMap = new HashMap<>();
        initEveryClassDFMap();
    }

    /**
     * 获取非重复总词数
     *
     * @return 非重复总词数
     */
    public int getAllTermTotal() {
        if (globalTermsSet != null) {
            return globalTermsSet.size();
        }
        return 0;
    }

    /**
     * 获取全局特征总数
     *
     * @return 全局特征总数
     */
    public int getGlobalFeatureSize() {
        if (null == globalFeatureValueMap) {
            return globalFeatureValueMap.size();
        }
        return 0;
    }

    /**
     * 计算全局特征值
     *
     * @return HashMap
     */
    protected abstract HashMap<Integer, Double> computeGlobalFeatureValue();

    /**
     * 计算局部特征值
     *
     * @param label 类标签
     * @return HashMap
     */
    protected abstract HashMap<Integer, Double> computeLocalFeatureValue(String label);

    /**
     * 全局选取 dimension 维特征
     *
     * @param dimension
     * @return List
     */
    public List<Map.Entry<Integer, Double>> selectGlobalFeature(int dimension) {
        if (null == globalFeatureValueMap) {
            // 计算全局特征的量化值
            globalFeatureValueMap = this.computeGlobalFeatureValue();
        }
        List<Map.Entry<Integer, Double>> featureList = new MapSort<Integer, Double>().descendSortByValue(globalFeatureValueMap);
        for (int i = featureList.size() - 1; dimension <= i; i--) {
            featureList.remove(i);
        }
        return featureList;
    }

    /**
     * 局部选取 dimension 维特征
     *
     * @param dimension
     * @return Map
     */
    public Map<String, List<Map.Entry<Integer, Double>>> selectLocalFeature(int dimension) {
        Map<String, List<Map.Entry<Integer, Double>>> localFeatuerListMap = new HashMap<>();
        // 计算每一个类别的所有词的特征量化值
        Iterator<String> labelsIt = this.trainSetJson.keys();
        while (labelsIt.hasNext()) {
            String label = labelsIt.next();

            HashMap<Integer, Double> localMPMap = allLocalFeatureValueMap.get(label);
            if (null == localMPMap) {
                localMPMap = this.computeLocalFeatureValue(label);
                allLocalFeatureValueMap.put(label, localMPMap);
            }
            List<Map.Entry<Integer, Double>> localFeatuerList = new MapSort<Integer, Double>().descendSortByValue(localMPMap);
            for (int i = localFeatuerList.size() - 1; dimension <= i; i--) {
                localFeatuerList.remove(i);
            }
            localFeatuerListMap.put(label, localFeatuerList);
        }
        return localFeatuerListMap;
    }

    /**
     * 初始化每个训练类别的词-文档频数 Map
     */
    protected final void initEveryClassDFMap() {
        this.everyClassDFMap = new HashMap<>();
        this.globalTermsSet = new HashSet<>();

        Iterator<String> labelsIt = this.trainSetJson.keys();
        while (labelsIt.hasNext()) {
            String label = labelsIt.next();
            HashMap<Integer, Integer> termDFMap = this.getTermDFMap(label);
            this.everyClassDFMap.put(label, termDFMap);
        }
    }

    /**
     * 获取一个训练集类别的所有词及出现的文档数,<BR>
     * 使用-1号词代码保存类别下的文档数。<BR>
     * 由于词文档Map中使用-1号词代码记录一个文本总词数,<BR>
     * 所以这里直接自动统计含-1的文本数,即文本总数。
     *
     * @param label 类别标签
     * @return HashMap 训练集类别的所有词及出现的文档数
     */
    protected HashMap<Integer, Integer> getTermDFMap(String label) {
        // 一个类下词以及这个词出现的次数HashMap<词编号,文档数>
        HashMap<Integer, Integer> thisDFMap = new HashMap<>();
        // HashMap<文件ID,HashMap<词编号,词个数>>
        HashMap<String, HashMap<Integer, Integer>> tdmMap = CorpusManager.getTdmMap(this.trainSetJson, label);
        if (null == tdmMap) {
            LOGGER.severe("词文档矩阵Map为空或NULL!");
            return thisDFMap;
        }
        Set<String> textsIdSet = tdmMap.keySet();
        Iterator<String> textIdit = textsIdSet.iterator();
        while (textIdit.hasNext()) {
            String textId = textIdit.next();
            HashMap<Integer, Integer> textMap = tdmMap.get(textId);
            Set<Integer> termCodeSet = textMap.keySet();
            Iterator<Integer> it = termCodeSet.iterator();
            while (it.hasNext()) {
                Integer termCode = it.next();
                Integer num = 1;
                Integer thisNum = thisDFMap.get(termCode);
                if (null != thisNum) {
                    num += thisNum;
                }
                thisDFMap.put(termCode, num);
                globalTermsSet.add(termCode);// 为了节约时间,此行用于初始化所有词集合
            }
        }
        return thisDFMap;
    }

    /**
     * 获取特征词频集合
     *
     * @param label 类别标签
     * @return HashMap<Integer, Integer> 特征词频集合
     */
    protected HashMap<Integer, Integer> getTermTFMap(String label) {
        // HashMap<词编号,词频>
        HashMap<Integer, Integer> thisTFMap = new HashMap<>();
        // HashMap<文件ID,HashMap<词编号,词频>>
        HashMap<String, HashMap<Integer, Integer>> tdmMap = CorpusManager.getTdmMap(this.trainSetJson, label);
        if (null == tdmMap) {
            LOGGER.log(Level.SEVERE, "词文档矩阵Map为空或NULL!");
            return thisTFMap;
        }
        Set<String> textIdSet = tdmMap.keySet();
        Iterator<String> textsIt = textIdSet.iterator();
        while (textsIt.hasNext()) {
            String textId = textsIt.next();
            HashMap<Integer, Integer> textMap = tdmMap.get(textId);
            Set<Integer> termCodeSet = textMap.keySet();
            Iterator<Integer> it = termCodeSet.iterator();
            while (it.hasNext()) {
                Integer termCode = it.next();
                if (termCode == -1) {
                    continue;
                }
                Integer num = textMap.get(termCode);//词频较好
                Integer thisNum = thisTFMap.get(termCode);
                if (null != thisNum) {
                    num += thisNum;
                }
                // 本函数当前未使用,为啥没除以文本次数呢?
                thisTFMap.put(termCode, num);
                globalTermsSet.add(termCode);// 为了节约时间,此行用于初始化所有词集合
            }
        }
        return thisTFMap;
    }

    /**
     * 获得除当期处理类别的其他类综合词-文档频数Map<类名,<词编码,频数>>
     *
     * @param currLabel 当期处理类别
     * @return HashMap<Integer, Integer>其他文本类别综合词-文档频数Map<词编码,频数>
     */
    protected HashMap<Integer, Integer> getOtherClassDFMap(String currLabel) {
        // 其他文档类别的词-文档频数Map
        HashMap<Integer, Integer> otherClassDFMap = new HashMap<>();

        Iterator<String> labelsIt = this.trainSetJson.keys();
        while (labelsIt.hasNext()) {
            String label = labelsIt.next();
            if (!label.equals(currLabel)) {
                HashMap<Integer, Integer> otherDFMap = everyClassDFMap.get(label);
                Set<Integer> otherTermSet = otherDFMap.keySet();
                Iterator<Integer> it = otherTermSet.iterator();
                while (it.hasNext()) {
                    Integer termCode = it.next();
                    Integer docNum = otherDFMap.get(termCode);
                    Integer otherDocNum = otherClassDFMap.get(termCode);
                    if (null != otherDocNum) {
                        docNum += otherDocNum;
                    }
                    otherClassDFMap.put(termCode, docNum);
                }
            }
        }
        return otherClassDFMap;
    }
}

请求JSON

  特征选择服务请求的JSON格式如下,红框部分为特征选择配置参数。

响应JSON

  特征选择服务响应的JSON格式如下,红框部分为特征选择结构,为显示效果,DF删除大部分,其他特征折叠了。

猜你喜欢

转载自blog.csdn.net/xsdjj/article/details/83927523