DL4J model training Word2Vec

       At present, deep learning has been paid more and more attention, and deep learning frameworks are emerging one after another. For example, Google's TensorFlow, which is developed based on Python, may not be so convenient for many programmers who do not know enough about Python. Here Talk about a Java-based deep learning framework - DL4J. This blog mainly introduces the implementation of training Word2Vec based on the DL4J model at the code level. Let's take a look~


【Code】

package com.xzw.dl4j;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.EndingPreProcessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
/**
 *
 * @author xzw
 *
 */
public class Word2VecTest {
	@SuppressWarnings("deprecation")
	public static void main(String[] args) throws IOException {
		System.out.println("Load data...");
		File file = new File("C://Users//Machenike//Desktop//zzz//raw_sentences.txt");
		SentenceIterator iterator = new LineSentenceIterator(file);
		iterator.setPreProcessor(new SentencePreProcessor() {
			
			private static final long serialVersionUID = 1L;

			@Override
			public String preProcess(String sentence) {
				// TODO Auto-generated method stub
				return sentence.toLowerCase();
			}
		});
		System.out.println("Tokenize data...");
		final EndingPreProcessor preProcessor = new EndingPreProcessor();
		TokenizerFactory tokenizer = new DefaultTokenizerFactory();
		tokenizer.setTokenPreProcessor(new TokenPreProcess() {
			
			@Override
			public String preProcess(String token) {
				// TODO Auto-generated method stub
				token = token.toLowerCase();
				String base = preProcessor.preProcess(token);
				base = base.replaceAll("\\d", "d");
				return base;
			}
		});
		
		System.out.println("Build model...");
		int batchSize = 1000;
		int iterations = 3;
		int layerSize = 150;
		Word2Vec vec = new Word2Vec.Builder()
			.batchSize(batchSize)
			.minWordFrequency(5)
			.useAdaGrad(false)
			.layerSize(layerSize)
			.iterations(iterations)
			.learningRate(0.025)
			.minLearningRate(1e-3)
			.negativeSample(10)
			.iterate(iterator)
			.tokenizerFactory(tokenizer)
			.build();
		//train
		System.out.println("Learning...");
		vec.fit();
		//model save
		System.out.println("Save model...");
		WordVectorSerializer.writeWordVectors(vec, "C://Users//Machenike//Desktop//zzz//words.txt");
		System.out.println("Evaluate model...");
		String word1 = "people";
		String word2 = "money";
		double similarity = vec.similarity(word1, word2);
		System.out.println(String.format("The similarity between %s and %s is %f", word1, word2, similarity));
		String word = "day";
		int ranking = 10;
		Collection<String> similarTop10 = vec.wordsNearest(word, ranking);
		System.out.println(String.format("Similar word to %s is %s", word, similarTop10));
	}

}

【Data set used】


[Saved Word2Vec model]


【operation result】








Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325938655&siteId=291194637