以图搜图引擎 With Saprk

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lingzidong/article/details/86681560

最近搞了一个SX搜索图片相似度的系统,非常的简单。主要原理是这样的
1.用Phash算法计算出每二个图片的Phash值,存在CSV中
2.用Spark读入CSV,并且计算出要搜索的图片的Hash值
3.将这个值广播出去,然后求一个hamming距离的最大值
代码如下,在我的gayhub中也有limn2o4’s github

import cv2
import numpy as np
import phash
import os
import csv

def get_pHash(img):
    
    
    img_gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    img_resize = cv2.resize(img_gray,(64,64),interpolation=cv2.INTER_CUBIC)
    
    
    h,w = img_resize.shape[:2]
    img_float = np.zeros((h,w),np.float32)
    img_float[:h,:w] = img_resize
    img_dct = cv2.dct(cv2.dct(img_float))
    
    img_dct = cv2.resize(img_dct,(32,32),interpolation=cv2.INTER_CUBIC)

    num_list = img_dct.flatten()
    
    num_avg = sum(num_list)/len(num_list)
    bin_list = ['0' if i < num_avg else '1' for i in num_list]
    #print(''.join(bin_list))
    return ''.join(['%x' % int(''.join(bin_list[x:x+4]),2) for x in range(0,32*32,4)])



if __name__ == '__main__':
    
    
    root_path = '/home/limn2o4/Documents/jpg/'

    path_list = os.listdir(root_path)

    with open('img_data.csv','w') as f:
        csv_writter = csv.writer(f)
        for img_name in path_list :
            print(img_name)    
            img = cv2.imread(root_path+img_name)
            if img.all() == None:
                raise ValueError("wrong img data")
            csv_writter.writerow([img_name,get_pHash(img)])

下面的代码有点短,但是……
我想这才是spark的精髓,用最短的操作实现功能

import findspark

findspark.init()

from pyspark import SparkConf,SparkContext,sql
from phash import get_pHash
import csv
from io import StringIO
import cv2



if __name__ == '__main__':

    conf = SparkConf().setMaster('local[*]').setAppName("imgSearch")
    sc = SparkContext(conf=conf)
    sqlContext = sql.SQLContext(sc)
    img_data_df = sqlContext.read.csv('/home/limn2o4/Documents/Code/SparkImgSImg/img_data.csv')
    #转换成rdd,方便进行mapValues操作
    img_data = img_data_df.rdd.map(lambda p : (p._c0,p._c1))
    target_img = cv2.imread('/home/limn2o4/Documents/jpg/100503.jpg')
    target_hash = get_pHash(target_img)
    #print(target_hash)
    search_hash = sc.broadcast(target_hash)

    def hamming_distance(str1):
        assert(len(str1)==len(search_hash.value))
        return sum([ch1 != ch2 for ch1,ch2 in zip(str1,search_hash.value)])
    
    #print(img_data.take(3))
    dist_rdd = img_data.mapValues(hamming_distance)
    sort_rdd = dist_rdd.sortBy(lambda x : x[1],ascending=True)
    #print(sort_rdd.take(10))
    sort_rdd.saveAsTextFile('result')

参考:
《Spark快速大数据分析》
https://blog.csdn.net/sinat_26917383/article/details/70287521

猜你喜欢

转载自blog.csdn.net/lingzidong/article/details/86681560