Why does Spark's Word2Vec return a vector?

Mehran :

Running the Spark's example for Word2Vec, I realized that it takes in an array of string and gives out a vector. My question is, shouldn't it return a matrix instead of a vector? I was expecting one vector per input word. But it returns one vector period!

Or maybe it should have accepted string, instead of an array of strings (one word) as input. Then, yeah sure, it could return one vector as output. But accepting an array of strings and returning one single vector does not make sense to me.

[UPDATE]

Per @Shaido's request, here's the code with my minor change to print the schema for the output:

public class JavaWord2VecExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                .builder()
                .appName("JavaWord2VecExample")
                .getOrCreate();

        // $example on$
        // Input data: Each row is a bag of words from a sentence or document.
        List<Row> data = Arrays.asList(
                RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
                RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
                RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
        );
        StructType schema = new StructType(new StructField[]{
                new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
        });
        Dataset<Row> documentDF = spark.createDataFrame(data, schema);

        // Learn a mapping from words to Vectors.
        Word2Vec word2Vec = new Word2Vec()
                .setInputCol("text")
                .setOutputCol("result")
                .setVectorSize(7)
                .setMinCount(0);

        Word2VecModel model = word2Vec.fit(documentDF);
        Dataset<Row> result = model.transform(documentDF);

        for (Row row : result.collectAsList()) {
            List<String> text = row.getList(0);
            System.out.println("Schema: " + row.schema());
            Vector vector = (Vector) row.get(1);
            System.out.println("Text: " + text + " => \nVector: " + vector + "\n");
        }
        // $example off$

        spark.stop();
    }
}

And it prints:

Schema: StructType(StructField(text,ArrayType(StringType,true),false), StructField(result,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true))
Text: [Hi, I, heard, about, Spark] => 
Vector: [-0.0033279924420639875,-0.0024428479373455048,0.01406305879354477,0.030621735751628878,0.00792500376701355,0.02839711122214794,-0.02286271695047617]

Schema: StructType(StructField(text,ArrayType(StringType,true),false), StructField(result,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true))
Text: [I, wish, Java, could, use, case, classes] => 
Vector: [-9.96453288410391E-4,-0.013741840076233658,0.013064394239336252,-0.01155538750546319,-0.010510949650779366,0.004538436819400106,-0.0036846946126648356]

Schema: StructType(StructField(text,ArrayType(StringType,true),false), StructField(result,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true))
Text: [Logistic, regression, models, are, neat] => 
Vector: [0.012510885251685977,-0.014472834207117558,0.002779599279165268,0.0022389178164303304,0.012743516173213721,-0.02409198731184006,0.017409833287820222]

Please correct me if I'm wrong, but the input is an array of strings and the output is a single vector. And I was expecting each word to be mapped into a vector.

desertnaut :

This is an attempt to justify the rationale of Spark here, and it should be read as a complement to the nice programming explanation already provided as an answer...

To start with, how exactly individual word embeddings should be combined is not in principle a feature of the Word2Vec model itself (which is about, well, individual words), but an issue of concern to "higher order" models, such as Sentence2Vec, Paragraph2Vec, Doc2Vec, Wikipedia2Vec etc (you could name a few more, I guess...).

Having said that, it turns out indeed that a very first approach in combining word vectors in order to get vector representations of larger pieces of text (phrases, sentences, tweets etc) is indeed to simply average the vector representations of the constituent words, as Spark ML does.

Starting from the practitioner community, we have:

How to concatenate word vectors to form sentence vector (SO answer):

There are at least three common ways to combine embedding vectors; (a) summing, (b) summing & averaging or (c) concatenating. [...] See gensim.models.doc2vec.Doc2Vec, dm_concat and dm_mean - it allows you to use any of those three options

Sentence2Vec : Evaluation of popular theories — Part I (Simple average of word vectors) (blog post):

So what’s first thing that comes to your mind when you have word vectors and need to calculate sentence vector.

Just average them?

Yes that’s what we are going to do here. enter image description here

Sentence2Vec (Github repo):

Word2Vec can help to find other words with similar semantic meaning. However, Word2Vec can only take 1 word each time, while a sentence consists of multiple words. To solve this, I write the Sentence2Vec, which is actually a wrapper to Word2Vec. To obtain the vector of a sentence, I simply get the averaged vector sum of each word in the sentence.

It certainly seems that, at least for practitioners, this simple averaging of the individual word vectors is far from unexpected.

An expected counter-argument here is that blog posts and SO answers are arguably not that credible sources; what about the researchers and the relevant scientific literature? Well, it turns out that this simple averaging is far from uncommon here, too:

From Distributed Representations of Sentences and Documents (Le & Mikolov, Google, ICML 2014):

enter image description here

From NILC-USP at SemEval-2017 Task 4: A Multi-view Ensemble for Twitter Sentiment analysis (SemEval 2017, section 2.1.2):

enter image description here


It should be clear by now that the particular design choice in Spark ML is far from arbitrary, or even uncommon; I have blogged about what certainly seem as absurd design choices in Spark ML (see Classification in Spark 2.0: “Input validation failed” and other wondrous tales), but it seems that this is not such a case...

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=86925&siteId=1