scala 余弦相似度

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer


object test423_cosvec {
  def main(args: Array[String]): Unit = {
    val str1 = "听说菠萝就是凤梨"
    val str2 = "凤梨肯定不会是菠萝"

    val result=textCosine(str1,str2)
    println("两句话的余弦距离: "+result)


  }

  /**
    * 向量的模长
    * @param vec
    */
  def module(vec:Vector[Double]): Double ={
   // math.sqrt( vec.map(x=>x*x).sum )
    math.sqrt(vec.map(math.pow(_,2)).sum)
  }

  /**
    * 求两个向量的内积
    * @param v1
    * @param v2
    */
  def innerProduct(v1:Vector[Double],v2:Vector[Double]): Double ={
    val listBuffer=ListBuffer[Double]()
    for(i<- 0 until v1.length; j<- 0 until v2.length;if i==j){
      if(i==j){
        listBuffer.append( v1(i)*v2(j) )
      }
    }
    listBuffer.sum
  }

  /**
    * 求两个向量的余弦值
    * @param v1
    * @param v2
    */
  def cosvec(v1:Vector[Double],v2:Vector[Double]):Double ={
    val cos=innerProduct(v1,v2) / (module(v1)* module(v2))
    if (cos <= 1) cos else 1.0
  }

  def textCosine(str1:String,str2:String):Double={
    val set=mutable.Set[Char]() //统计两句话所有的字
    str1.foreach(set +=_)
    str2.foreach(set +=_)
    println(set)
    val ints1: Vector[Double] = set.toList.sorted.map(ch => {
      str1.count(s => s == ch).toDouble
    }).toVector
    println("===ints1: "+ints1)
    val ints2: Vector[Double] = set.toList.sorted.map(ch => {
      str2.count(s => s == ch).toDouble
    }).toVector
    println("===ints2: "+ints2)
    cosvec(ints1,ints2)

  }

}

发布了53 篇原创文章 · 获赞 40 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/u012761191/article/details/105311343
今日推荐