数据挖掘 -- Apriori关联规则算法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l1832876815/article/details/89314435
1. 算法原理

Apriori关联规则算法的目的就是找出所有的频繁项集,所以需要定义一个评估标准找出频繁项集,即最小支持度。 首先从原始数据集中找出出现的所有项,对应数据集确定候选1项集,根据候选一项集每项在原始项集中的出现次数计算每一项的sup值。比较sup值 / 原始数据集数 的值与最小支持度,小于则舍去,计算出频繁一项集,然后对频繁一项集两项之间求补集,并按照一项集中求sup的方法求取候选二项集及频繁二项集。之后递归求取频繁n项集,当频繁项集项数只有一项时递归结束。得到最后的频繁项集。

2. 代码实现
import java.util.ArrayList;

/**
 * @Description 项集item
 * @Author Clxk
 * @Date 2019/4/15 10:57
 * @Version 1.0
 */
public class Data {

    private ArrayList<String> data = new ArrayList<>();

    private int cnt;

    public ArrayList<String> getData() {
        return data;
    }

    public void setData(ArrayList<String> data) {
        this.data = data;
    }

    public int getCnt() {
        return cnt;
    }

    public void setCnt(int cnt) {
        this.cnt = cnt;
    }

    @Override
    public boolean equals(Object obj) {
        Data rhs = (Data) obj;
        boolean eq = this.cnt == rhs.cnt;

        if(this.cnt == rhs.cnt) {
            for(int i = 0; i < data.size(); i++) {
                if(!data.get(i).equals(rhs.data.get(i))) {
                    eq = false;
                    break;
                }
            }
        }
        return eq;
    }
}

import java.util.*;

/**
 * @Description Apriori
 * @Author Clxk
 * @Date 2019/4/15 10:43
 * @Version 1.0
 */
public class Main {

    /**
     * 初始数据集最大值
     */
    private static final int MAXN = 1000;
    /**
     * 数据集长度、最小支持度
     */
    private static int datacnt = 0;
    private static double minsupport = 0;
    /**
     * 初始数据集
     */
    private static ArrayList<String> []data = new ArrayList[500];
    /**
     * 项集结构
     */
    private static ArrayList<Data> items = new ArrayList<>();


    public static void main(String[] args) {

        /**
         * 原始数据集读取
         */
        Scanner scanner = new Scanner(System.in);
        System.out.println("请输入数据集的大小: ");
        datacnt = scanner.nextInt();
        System.out.println("请输入最小支持度: ");
        minsupport = scanner.nextDouble();
        System.out.println("请输入原始数据集: ");
        String str;
        scanner.nextLine();
        for (int i = 0; i < datacnt; i++) {
            data[i] = new ArrayList<>();
            str = scanner.nextLine();
            String[] split = str.split("\\s");
            for (int j = 0; j < split.length; j++) {
                data[i].add(split[j]);
            }
        }

        /**
         * 数据集处理
         */
        solve(data);


    }

    /**
     * 数据集处理
     * @param data
     */
    public static void solve(ArrayList<String>[] data) {

        getFrequent(data, 1);
    }

    /**
     * 获取到频繁1项集
     * @param data
     */
    public static void getFrequentOne(ArrayList<String>[] data) {

        /**
         * 获取不重复集合
         */
        for(ArrayList<String> list : data) {
            if(list == null) break;
            for(String s: list) {
                Data dt = new Data();
                List<String> ls = new ArrayList<>();
                ls.add(s);
                dt.setData((ArrayList<String>) ls);
                int is_have = 0;
                for(int i = 0; i < items.size(); i++) {
                    Data d = items.get(i);
                    if(d.getData().equals(ls)) {
                        is_have = 1;
                        break;
                    }
                }
                if(is_have == 0) {
                    items.add(dt);
                }
            }
        }
    }

    /**
     * 输出候选n项集
     * @param n
     */
    public static void getCandidate(int n) {

        System.out.println("候选" + n + "项集为: ");
        outList();
    }

    /**
     * 输出频繁n项集
     * @param n
     */
    public static void getItems(int n) {
        for(int i = 0; i < items.size(); i++) {
            if((double)items.get(i).getCnt() / datacnt < minsupport) {
                items.remove(i);
                i--;
            }
        }
        System.out.println("频繁"+ n +"项集为: ");
        outList();
    }

    /**
     * 获取频繁n项集
     * @param data
     * @param n
     */
    public static void getFrequent(ArrayList<String>[] data, int n) {

        if(n == 1) {
            getFrequentOne(data);
        } else {
            ArrayList<Data> array = new ArrayList<>();
            for(int i = 0; i < items.size(); i++) {
                Set<String> set = new HashSet<>();
                ArrayList<String> data1 = items.get(i).getData();
                for(int j = i+1; j < items.size(); j++) {
                    set.clear();
                    ArrayList<String> data2 = items.get(j).getData();
                    for(int u = 0; u < Math.max(data1.size(), data2.size()); u++) {
                        if(data1.size() > u) set.add(data1.get(u));
                        if(data2.size() > u) set.add(data2.get(u));
                    }
                    if(set == null || set.size() !=  n) continue;
                    put2Items(array,set);
                }
            }
            items = (ArrayList<Data>) array.clone();
        }

        /**
         * 获取sup值
         */
        addSup(n);

        /**
         * 输出候选n项集
         */
        getCandidate(n);
        /**
         * 输出频繁n项集
         */
        getItems(n);

        if(items.size() > 1) {
            getFrequent(data, n+1);
        }

    }

    /**
     * 获取Sup值
     * @param n
     */
    public static void addSup(int n) {
        for(int i = 0; i < items.size(); i++) {
            ArrayList<String> list = items.get(i).getData();
            int cnt = 0;
            for(int j = 0; j < datacnt; j++) {
                int have = 1;
                ArrayList<String> cur = data[j];
                for(int u = 0; u < list.size(); u++) {
                    if(!cur.contains(list.get(u))) {
                        have = 0;
                        break;
                    }
                }
                if(have == 1) cnt++;
            }
            Data d = new Data();
            d.setData(list);
            d.setCnt(cnt);
            items.set(i, d);
        }
    }

    /**
     * 整理候选频繁项集,同项相加
     * @param array,set
     */
    public static void put2Items(ArrayList<Data> array, Set<String> set) {
        Data data = new Data();
        for(String s:set) {
            data.getData().add(s);
        }
        int is_have = 0;
        for(int i = 0; i < array.size(); i++) {
            if(array.get(i) == null) break;
            if(array.get(i).equals(data)) {
                is_have = 1;
                array.set(i, data);
                break;
            }
        }
        if(is_have == 0) {
            array.add(data);
        }
    }


    /**
     * 输出项集
     */
    public static void outList() {

        for(Data data : items) {
            System.out.println(Arrays.toString(data.getData().toArray()) + "   " + data.getCnt());
        }

    }
}

猜你喜欢

转载自blog.csdn.net/l1832876815/article/details/89314435
今日推荐