机器学习特征工程——给任意属性增加任意次方的全组合

在机器学习中,我们时常会碰到需要给属性增加字段的情况。譬如有x、y两个属性,当结果倾向于线性时,我们可以很简单的通过线性回归得到模型。但很多时候,线性(在数学上称为多元一次方程),线性是拟合不了结果的。

往往,我们就需要在给定的几个属性上,通过增加属性来尝试能否拟合。那么原本只有两列,x、y,我们增加2次方的属性后,就会变成x、y、x^2、x*y、y^2,变成了5个属性,根据以往经验,我们知道通过这5个属性是能拟合出曲线。

2次方时,我们还能很简单的写出来所有的组合形式,但是当5次方时,原本有4列时,我们该增加多少列,增加的列该怎么计算呢。这就有点麻烦了,譬如(x+y+z)^3展开后就是x^3+y^3+z^3+3xy^2+3xz^2+3x^2y+3yz^2+3x^2z+3y^2z+6xyz. 去掉系数后,就是我们需要追加的所有列了。我们这篇就是做一个程序,来通过给定的m列,n次方,来给出所有的组合形式。

譬如m为2,n也为2,那么我们给出结果组合:[{0,2}, {1,1}, {2,0}],代表追加3列,第一列是x^0 * y^2,第二列是x^1 * y^1,第三列是x^2 * y^0.

通过观察我们发现,我们需要做的是求这样的方程的所有解:X1+X2+X3+……+Xm = N。其中0<=X<=n。

那么解法就是,我们可以定义一个int[m],该数组共有m个元素,每个元素的取值范围在0到n之间,并且该数组的所有元素的和等于n即可。

直接看程序:

/**
 * @author wuweifeng wrote on 2018/6/4.
 */
public class LineAdder {
    private static int lines = 3;
    private static int power = 5;

    private static int[] resultArray;

    public static void main(String[] args) {
        resultArray = new int[lines];
        deal(0);
    }

    public static void deal(int m) {
        for (int i = 0; i <= power; i++) {
            resultArray[m] = i;
            if (m == lines - 1) {
                //如果找到一个解
                if (check()) {
                    print();
                    return;
                }
            } else {
                deal(m + 1);
            }
        }
    }

    /**
     * 判断是否符合结果
     *
     * @return 是否符合
     */
    private static boolean check() {
        int total = 0;
        for (int one : resultArray) {
            total += one;
        }
        return power == total;
    }

    private static void print() {
        for (int one : resultArray) {
            System.out.print(one);
        }
        System.out.print("\n");
    }
}    

结果是:

005
014
023
032
041
050
104
113
122
131
140
203
212
221
230
302
311
320
401
410
500
这就是有3列,并且希望求出5次方时的所有组合的答案。

下面我们将它优化一下,让他能处理文本,能处理一行一行的数据,直接把列追加在文本上。

直接上代码:

package ploy;

import java.util.ArrayList;
import java.util.List;

/**
 * @author wuweifeng wrote on 2018/6/4.
 */
public class LineAdder {
    private int lines = 3;
    private int power = 5;

    private List<int[]> resultList = new ArrayList<>();

    private int[] resultArray;

    public List<int[]> lineAdd(int lines, int power) {
        resultArray = new int[lines];
        this.lines = lines;
        this.power = power;
        deal(0);
        return resultList;
    }

    private void deal(int m) {
        for (int i = 0; i <= power; i++) {
            resultArray[m] = i;
            if (m == lines - 1) {
                //如果找到一个解
                if (check()) {
                    print();
                    return;
                }
            } else {
                deal(m + 1);
            }
        }
    }

    /**
     * 判断是否符合结果
     *
     * @return 是否符合
     */
    private boolean check() {
        int total = 0;
        for (int one : resultArray) {
            total += one;
        }
        return power == total;
    }

    private void print() {
        for (int one : resultArray) {
            System.out.print(one);

        }
        System.out.print("\n");
        int[] temp = new int[resultArray.length];
        System.arraycopy(resultArray, 0, temp, 0, resultArray.length);
        resultList.add(temp);
    }
}
package ploy;

import java.io.*;
import java.util.List;

/**
 * @author wuweifeng wrote on 2018/6/5.
 */
public class TextDeal {
    public static void main(String[] args) throws IOException {
        new TextDeal().linePower("/Users/wuwf/Downloads/ml_data/1逻辑回归入门/data11.csv",
                "/Users/wuwf/Downloads/ml_data/1逻辑回归入门/data_new.csv", 2, 0,1);
    }

    /**
     * @param filePath
     *         文件的路径
     * @param outputPath
     *         输出文件的路径
     * @param power
     *         要做几次方
     * @param lineNums
     *         都有哪几列,需要power,不填默认所有列。从第0列开始
     */
    public void linePower(String filePath, String outputPath, Integer power, Integer... lineNums) throws IOException {
        BufferedReader reader = buildReader(filePath);
        BufferedWriter writer = buildWriter(outputPath);

        addCSVHeader(reader, writer, power, lineNums);

    }

    private Integer[] getLineNums(String[] lines, Integer... lineNums) {
        //为null,则是所有列
        if (lineNums == null) {
            lineNums = new Integer[lines.length];
            for (int i = 0; i < lines.length; i++) {
                lineNums[i] = i;
            }
        }
        return lineNums;
    }

    private List<int[]> getAddList(int power, Integer... lineNums) {
        LineAdder lineAdder = new LineAdder();
        //计算共需增加多少列
        return lineAdder.lineAdd(lineNums.length, power);
    }

    /**
     * 给header里增加相应的列名,都在第一行
     */
    private void addCSVHeader(BufferedReader reader, BufferedWriter writer, Integer power, Integer... lineNums)
            throws IOException {
        //读取第一行
        String header = reader.readLine();
        //所有的列名
        String[] lines = header.split(",");
        lineNums = getLineNums(lines, lineNums);

        //计算共需增加多少列
        List<int[]> list = getAddList(power, lineNums);

        String[] addLines = new String[list.size()];

        String[] needLines = new String[lineNums.length];
        for (int i = 0; i < lineNums.length; i++) {
            needLines[i] = lines[lineNums[i]];
        }
        //设置每一列的名字
        for (int i = 0; i < list.size(); i++) {
            int[] array = list.get(i);
            String s = "";
            for (int j = 0; j < array.length; j++) {
                s += needLines[j] + array[j];
            }
            addLines[i] = s;
        }

        for (String addLine : addLines) {
            header += "," + addLine;
        }
        //将新增的列,写入header文件
        writer.write(header);
        writer.newLine();
        writer.flush();

        String oneLine;

        while ((oneLine = reader.readLine()) != null) {
            addLines = new String[list.size()];
            lines = oneLine.split(",");

            needLines = new String[lineNums.length];
            for (int i = 0; i < lineNums.length; i++) {
                needLines[i] = lines[lineNums[i]];
            }

            //设置每一列的值
            for (int i = 0; i < list.size(); i++) {
                int[] array = list.get(i);
                double s = 1;
                for (int j = 0; j < array.length; j++) {
                    //譬如a,b,对应02时,该列就是a的0次方乘以b的2次方
                    s *= Math.pow(Double.valueOf(needLines[j]), array[j]);
                }
                addLines[i] = s + "";
            }
            for (String addLine : addLines) {
                oneLine += "," + addLine;
            }
            writer.write(oneLine);

            //写入相关文件
            writer.newLine();
        }

        //将新增的列,写入header文件
        writer.flush();
        //关闭流
        reader.close();
        writer.close();
    }

    private BufferedReader buildReader(String filePath) {
        try {
            return new BufferedReader(new FileReader(new File(filePath)));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

    private BufferedWriter buildWriter(String outputPath) {
        //写入相应的文件
        try {
            return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputPath), "utf-8"));
        } catch (UnsupportedEncodingException | FileNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

}

看效果:

扫描二维码关注公众号,回复: 1506798 查看本文章

假如csv文件是这样的

a,b
1,2
2,3

4,5

运行后,结果是

a,b,a0b2,a1b1,a2b0
1,2,4.0,2.0,1.0
2,3,9.0,6.0,4.0
4,5,25.0,20.0,16.0

可以看到已经完成了做2次方的展开。

这个类,可以完成任意次方的模拟及计算。

猜你喜欢

转载自blog.csdn.net/tianyaleixiaowu/article/details/80577786
今日推荐