Pseudo-document-based Topic Model(基于伪文档的主题模型)的理解以及源码解读

本文作者:合肥工业大学 管理学院 钱洋 email:[email protected] 内容可能有不到之处,欢迎交流。

未经本人允许禁止转载

论文来源

Zuo Y, Wu J, Zhang H, et al. Topic modeling of short texts: A pseudo-document view[C]//Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining. ACM, 2016: 2105-2114.

来自于16年,计算机顶会KDD的文章。作者是北航的学者。

论文简介


主题模型的底层原理是基于共现,但是对于短文本来说,这种共现是很稀疏的,这将导致模型学习的效果不好。当然,有很多种方法来处理短文本主题学习。作者这篇文章提供了一种伪文档策略。
下面我们来看看模型的概率图:


这里写图片描述


(a)图是基本的PTM,(b)图引入了稀疏性先验,即Spike and Slab prior该先验在很多主题模型都使用过,具体可以看我之前的一些博客分享。这里使用的目的是实现伪文档主题分布的稀疏性。
模型的生成过程如下:


这里写图片描述


引入稀疏性,只是改了右半边的生成方式,如下图所示:

模型推理


首先,抽取文档所属的伪文档,如下图所示,该公式是跟对包含稀疏性的SPTM,如果是PTM则简单的改动一下就行。


这里写图片描述


再抽取文档单词所属的主题,如下图所示:


这里写图片描述


接着,抽取伪文档是否包含某主题,即伪文档主题选择器。该公式依据的是Wang等人的抽样方式,该文章是非参模型,且提供了详细的推导过程,大家可以学习。
C. Wang and D. M. Blei. Decoupling sparsity and smoothness in the discrete hierarchical dirichlet process. In Advances in neural information processing systems, pages 1982{1989. 2009.


这里写图片描述

源码解读


这里解读的源码是PTM模型,根据公式理解还是很简单的。

package main;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

public class PseudoDocTM implements Runnable {

    public int K1 = 1000;  //设置伪文档数量
    public int K2 = 100; //

    public int M;
    public int V;

    public double alpha1 = 0.1;
    public double alpha2 = 0.1;

    public double beta = 0.01;

    public int mp[]; //分配到每个伪文档文档的数量

    public int npk[][];  //伪文档l由主题k生成的单词数量
    public int npkSum[];  //伪文档对应的单词总数

    public int nkw[][]; //主题k对应的单词w的数量
    public int nkwSum[]; //主题k对应的单词总数

    public int zAssigns_1[];  //文档分配伪文档
    public int zAssigns_2[][]; //文档单词分配主题

    public int niters = 200; 
    public int saveStep = 1000; 
    public String inputPath="";
    public String outputPath="";

    public int innerSteps = 10;

    public List<List<Integer>> docs = new ArrayList<List<Integer>>(); //文档表示
    public HashMap<String, Integer> w2i = new HashMap<String, Integer>(); //词的编号
    public HashMap<Integer, String> i2w = new HashMap<Integer, String>(); //编号转化为词


    public PseudoDocTM(int P,int K,int iter,int innerStep,int saveStep,double alpha1,double alpha2,double beta,String inputPath,String outputPath){
        this.K1=P;
        this.K2=K;
        this.niters=iter;
        this.innerSteps= innerStep;
        this.saveStep =saveStep;
        this.alpha1=alpha1;
        this.alpha2= alpha2;
        this.beta = beta;
        this.inputPath=inputPath;
        this.outputPath=outputPath;
    }
    //加载语料
    public void loadTxts(String txtPath) {
        BufferedReader reader = IOUtil.getReader(txtPath, "UTF-8");

        String line;
        try {
            line = reader.readLine();
            while (line != null) {
                List<Integer> doc = new ArrayList<Integer>();

                String[] tokens = line.trim().split("\\s+");
                for (String token : tokens) {
                    if (!w2i.containsKey(token)) {
                        w2i.put(token, w2i.size());
                        i2w.put(w2i.get(token), token);
                    }
                    doc.add(w2i.get(token));
                }
                docs.add(doc);
                line = reader.readLine();
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }

        //文档数量
        M = docs.size();
        //语料词的数量
        V = w2i.size();

        return;
    }
    //初始化模型
    public void initModel() {

        mp = new int[K1];

        npk = new int[K1][K2];
        npkSum = new int[K1];

        nkw = new int[K2][V];
        nkwSum = new int[K2];

        zAssigns_1 = new int[M]; //文档所属的伪文档
        zAssigns_2 = new int[M][]; //文档每个单词所属的主题

        for (int m = 0; m != M; m++) {
            //文档单词的数量
            int N = docs.get(m).size();
            //初始化
            zAssigns_2[m] = new int[N];
            //随机分配文档所属的伪文档
            int z1 = (int) Math.floor(Math.random()*K1);
            zAssigns_1[m] = z1;

            mp[z1] ++; //伪文档对应的文本数量增加
            //对每个单词随机分配主题
            for (int n = 0; n != N; n++) {
                int w = docs.get(m).get(n);
                int z2 = (int) Math.floor(Math.random()*K2);

                npk[z1][z2] ++;
                npkSum[z1] ++;

                nkw[z2][w] ++;
                nkwSum[z2] ++;

                zAssigns_2[m][n] = z2;
            }
        }
    }
    //抽取文档所属的伪文档
    public void sampleZ1(int m) {
        int z1 = zAssigns_1[m];  //获取文档所属的伪文档
        int N = docs.get(m).size(); //获取文档单词的数量

        mp[z1] --; //移除该文档,伪文档z1对应的单词数量减少

        Map<Integer, Integer> k2Count = new HashMap<Integer, Integer>();
        for (int n = 0; n != N; n++){ //循环文档的每个单词
            int z2 = zAssigns_2[m][n]; //获取单词的主题分配
            if (k2Count.containsKey(z2)) { //计算每个主题包含该文档单词的总数量
                k2Count.put(z2, k2Count.get(z2)+1);
            } else {
                k2Count.put(z2, 1);
            }

            npk[z1][z2] --;
            npkSum[z1] --;
        }

        double k2Alpha2 = K2 * alpha2;   //分母的K*alpha

        double[] pTable = new double[K1];
        //循环每个伪文档
        for (int k = 0; k != K1; k++) {
            double expectTM = 1.0;
            int index = 0;
            //这里要计算单词的频次,进行连乘
            for (int z2 : k2Count.keySet()) {
                int c = k2Count.get(z2);
                for (int i = 0; i != c; i++) {
                    expectTM *= (npk[k][z2] + alpha2 + i) / (k2Alpha2 + npkSum[k] + index);
                    index ++;
                }
            }
            //基于公式计算概率
            pTable[k] = (mp[k] + alpha1) / (M + K1 * alpha1) * expectTM;
        }
        //轮盘赌选择
        for (int k = 1; k != K1; k++) { //这里注意k=1开始,不能k=0
            pTable[k] += pTable[k-1];
        }

        double r = Math.random() * pTable[K1-1];

        for (int k = 0; k != K1; k++) {
            if (pTable[k] > r) {
                z1 = k;
                break;
            }
        }
        //基于轮盘赌选择的伪文档,重新统计
        mp[z1] ++;
        for (int n =0; n != N; n++) {
            int z2 = zAssigns_2[m][n];
            npk[z1][z2] ++;
            npkSum[z1] ++;
        }

        zAssigns_1[m] = z1;
    }
    //抽取文档m第n个单词的主题
    public void sampleZ2(int m, int n) {

        int z1 = zAssigns_1[m]; //获取文档所属的伪文档
        int z2 = zAssigns_2[m][n]; //获取文档m第n个所属的主题
        int w = docs.get(m).get(n); //获取单词编号

        npk[z1][z2] --;  //统计伪文档z1、主题z2生成的单词数量
        npkSum[z1] --; //伪文档z1对应的总单词数量
        nkw[z2][w] --; //主题z2对应的单词w的数量
        nkwSum[z2] --; //主题z2中所有单词的数量

        double VBeta = V * beta; //分母中的V*beta
        double k2Alpha2 = K2 * alpha2; //分母中的 K*alpha

        double[] pTable = new double[K2];
        //基于公式计算-----这里和公式有差异,公式应该按照这里写,及主题词分母应该按照前面的表达
        for (int k = 0; k != K2; k++) {
            pTable[k] = (npk[z1][k] + alpha2) / (npkSum[z1] + k2Alpha2) *
                    (nkw[k][w] + beta) / (nkwSum[k] + VBeta);
        }
        //轮盘赌选择
        for (int k = 1; k != K2; k++) {
            pTable[k] += pTable[k-1];
        }

        double r = Math.random() * pTable[K2-1];

        for (int k = 0; k != K2; k++) {
            if (pTable[k] > r) {
                z2 = k;
                break;
            }
        }
        //重新统计相关词频
        npk[z1][z2] ++;
        npkSum[z1] ++;
        nkw[z2][w] ++;
        nkwSum[z2] ++;

        zAssigns_2[m][n] = z2;
        return;
    }

    public void estimate() {
        long start = 0;
        for (int iter = 0; iter != niters; iter++) {
            start = System.currentTimeMillis();
            System.out.println("PAM4ST Iteration: " + iter + " ...");
            if(iter%this.saveStep==0&&iter!=0&&iter!=this.niters-1){
                this.storeResult(iter);
            }
            //对每篇文档循环,将文档分配到伪文档
            for (int i = 0; i != innerSteps; i++) {
                for (int m = 0; m != M; m++) {
                    this.sampleZ1(m);
                }
            }
            //对每篇文档进行循环,抽取每个单词所属的主题
            for (int i = 0; i != innerSteps; i++) {
                for (int m = 0; m != M; m++) {
                    int N = docs.get(m).size();
                    for (int n = 0; n != N; n++) {
                        sampleZ2(m, n);
                    }
                }
            }
            System.out.println("cost time:"+(System.currentTimeMillis()-start));
        }
        return;
    }
    //计算伪文档的主题分布---相当于LDA的文档主题分布
    public double[][] computeThetaP() {
        double[][] theta = new double[K1][K2];
        for (int k1 = 0; k1 != K1; k1++) {
            for (int k2 = 0; k2 != K2; k2++) {
                theta[k1][k2] = (npk[k1][k2] + alpha2) / (npkSum[k1] + K2*alpha2);
            }
        }
        return theta;
    }

    public void saveThetaP(String path) throws IOException {
        BufferedWriter writer = IOUtil.getWriter(path);
        double[][] theta = this.computeThetaP();
        for (int k1 = 0; k1 != K1; k1++) {
            for (int k2 = 0; k2 != K2; k2++) {
                writer.append(theta[k1][k2]+" ");
            }
            writer.newLine();
        }
        writer.flush();
        writer.close();
    }

    public void saveZAssigns1(String path) throws IOException {
        BufferedWriter writer = IOUtil.getWriter(path);

        for (int m = 0; m != M; m++) {
            writer.append(zAssigns_1[m]+"\n");
        }

        writer.flush();
        writer.close();
    }
    //计算主题词分布
    public double[][] computePhi() {
        double[][] phi = new double[K2][V];
        for (int k = 0; k != K2; k++) {
            for (int v = 0; v != V; v++) {
                phi[k][v] = (nkw[k][v] + beta) / (nkwSum[k] + V*beta);
            }
        }
        return phi;
    }
    //排序算法
    public ArrayList<List<Entry<String, Double>>> sortedTopicWords(
            double[][] phi, int T) {
        ArrayList<List<Entry<String, Double>>> res = new ArrayList<List<Entry<String, Double>>>();
        for (int k = 0; k != T; k++) {
            HashMap<String, Double> term2weight = new HashMap<String, Double>();
            for (String term : w2i.keySet())
                term2weight.put(term, phi[k][w2i.get(term)]);

            List<Entry<String, Double>> pairs = new ArrayList<Entry<String, Double>>(
                    term2weight.entrySet());
            Collections.sort(pairs, new Comparator<Entry<String, Double>>() {
                public int compare(Entry<String, Double> o1,
                        Entry<String, Double> o2) {
                    return (o2.getValue().compareTo(o1.getValue()));
                }
            });
            res.add(pairs);
        }
        return res;
    }


    public void printTopics(String path,int top_n) throws IOException {
        BufferedWriter writer = IOUtil.getWriter(path);
        double[][] phi = computePhi();
        ArrayList<List<Entry<String, Double>>> pairsList = this
                .sortedTopicWords(phi, K2);
        for (int k = 0; k != K2; k++) {
            writer.write("Topic " + k + ":\n");
            for (int i = 0; i != top_n; i++) {
                writer.write(pairsList.get(k).get(i).getKey() + " "
                        + pairsList.get(k).get(i).getValue()+"\n");
            }
        }
        writer.close();
    }

    public void savePhi(String path) {
        BufferedWriter writer = IOUtil.getWriter(path, "utf-8");

        double[][] phi = computePhi();
        int K = phi.length;
        assert K > 0;
        int V = phi[0].length;

        try {
            for (int k = 0; k != K; k++) {
                for (int v = 0; v != V; v++) {
                    writer.append(phi[k][v]+" ");
                }
                writer.append("\n");
            }
            writer.flush();
            writer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return;
    }

    public void saveWordmap(String path) {
        BufferedWriter writer = IOUtil.getWriter(path, "utf-8");

        try {
            for (String word : w2i.keySet())
                writer.append(word + "\t" + w2i.get(word) + "\n");

            writer.flush();
            writer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return;
    }

    public void saveAssign(String path){
        BufferedWriter writer = IOUtil.getWriter(path, "utf-8");
        try {
            for(int i=0;i<zAssigns_2.length;i++){
                for(int j=0;j<zAssigns_2[i].length;j++){
                    writer.write(docs.get(i).get(j)+":"+zAssigns_2[i][j]+" ");
                }
                writer.write("\n");
            }
            writer.flush();
            writer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }

        return;
    }
    public void printModel(){
        System.out.println("\tK1 :"+this.K1+
                "\tK2 :"+this.K2+
                "\tniters :"+this.niters+
                "\tinnerSteps :"+this.innerSteps+
                "\tsaveStep :"+this.saveStep +
                "\talpha1 :"+this.alpha1+
                "\talpha2 :"+this.alpha2+
                "\tbeta :"+this.beta +
                "\tinputPath :"+this.inputPath+
                "\toutputPath :"+this.outputPath);
    }

    int[][] ndk;
    int[] ndkSum;

    public void convert_zassigns_to_arrays_theta(){
        ndk = new int[M][K2];
        ndkSum = new int[M];

        for (int m = 0; m != M; m++) {
            for (int n = 0; n != docs.get(m).size(); n++) {
                ndk[m][zAssigns_2[m][n]] ++;
                ndkSum[m] ++;
            }
        }
    }
    //计算文档主题分布
    public double[][] computeTheta() {
        convert_zassigns_to_arrays_theta();
        double[][] theta = new double[M][K2];
        for (int m = 0; m != M; m++) {
            for (int k = 0; k != K2; k++) {
                theta[m][k] = (ndk[m][k] + alpha2) / (ndkSum[m] + K2 * alpha2);
            }
        }
        return theta;
    }

    public void saveTheta(String path) {
        BufferedWriter writer = IOUtil.getWriter(path, "utf-8");

        double[][] theta = computeTheta();
        try {
            for (int m = 0; m != M; m++) {
                for (int k = 0; k != K2; k++) {
                    writer.append(theta[m][k]+" ");
                }
                writer.append("\n");
            }
            writer.flush();
            writer.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return;
    }

    public void storeResult(int times){
        String appendString="final";
        if(times!=0){
            appendString =times+"";
        }
        try {
            this.printTopics(outputPath+"/model-"+appendString+".twords",20);
            this.saveWordmap(outputPath+"/wordmap.txt");
            this.savePhi(outputPath+"/model-"+appendString+".phi");
            this.saveAssign(outputPath+"/model-"+appendString+".tassign");
            this.saveTheta(outputPath+"/model-"+appendString+".theta");
            this.saveThetaP(outputPath+"/model-"+appendString+".thetap");
            this.saveZAssigns1(outputPath+"/model-"+appendString+".assign1");
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
    public void run() {
        printModel();
        this.loadTxts(inputPath);//加载语料
        this.initModel(); //初始化模型
        this.estimate(); //估计
        this.storeResult(0); //保存结果

    }


    public static void PseudoDocTM(int P,int K,int iter,int innerStep,int saveStep,double alpha1,double alpha2,double beta,int threadNum,String path){
        File trainFile = new File(path);
        String parent_path = trainFile.getParentFile().getAbsolutePath();
        (new File(parent_path+"/PTM_with_case_"+P+"_"+K+"_"+iter+"_"+alpha1+"_"+alpha2+"_"+beta+"/")).mkdirs();
        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        (new PseudoDocTM(P,K,iter,innerStep,saveStep,alpha1,alpha2,beta,path,parent_path+"/PTM_with_case_"+P+"_"+K+"_"+iter+"_"+alpha1+"_"+alpha2+"_"+beta)).run();

    }
}

猜你喜欢

转载自blog.csdn.net/qy20115549/article/details/79877825
今日推荐