geo_teaser代码阅读

版权声明:旨在学习交流,共同进步 https://blog.csdn.net/u013735511/article/details/80158854

geo_teaser是论文Geo-Teaser: Geo-Temporal Sequential Embedding Rank
for Point-of-interest Recommendation
中的代码。该论文研究的是基于时空序列的POI推荐方式,论文具体内容不做介绍,这里研读下geo_teaser的相关代码。

数据格式

以gowalla为例:http://snap.stanford.edu/data/loc-gowalla.html
这里写图片描述

程序数据集包装

先看一个项的封装,如下,一个项里面包括词(user or poi)、计数(poi个数计数)、路径、哈夫曼编码、item_set(poi集合)

class VocabItem:
    def __init__(self, word):
        self.word = word
        self.count = 0
        self.path = None # Path (list of indices) from the root to the word (leaf)
        self.code = None # Huffman encoding
        #self.user_set={}
        self.item_set={}  //以poi为key值,rating为value
       # self.combination={}

再看看整个数据集的封装,首先从文件中读取了用户id(user)、poi信息(poi)、时间(time_str)、经纬度(lat,lon)。

class Vocab:
   def __init__(self, fi, pair, comb, min_count,percentage):
   # def __init__(self,percentage):

        vocab_items = {}       //一个hashmap,即python所谓的字典
        vocab_4table=[]        //一个list
        user_hash={}           //一个hashmap
        vocab_hash = {}        //一个hashmap
        poi_track={}           //一个hashmap
        poi_time_track = {}    //一个hashmap
        poi_list=[]            //一个list
        loc_geo = {}  # 经纬度  //一个hashmap
        seq_t = {}             //一个hashmap
        rating_count = 0       
        # fi = open(u'E:\\论文\\poi\\poi数据集\\gowalla数据集\\1.txt', 'r')
        fi=open(fi,'r')
        poi_per_track=[]   # 每个poi 的追踪  //一个list
        pre_user=-1
        date='-1'
        user_count = 0

        for line in fi:
            # line=line[:-1]
            # print line
            # if user_count == 10:
            #    break
            line=line.strip()
            # tokens = line.split(',')
            tokens=line.split(' ')
            # if len(tokens)==1:

            user_count += 1
            user=tokens[0]
            user_hash[user]=user_hash.get(user, int(len(user_hash)))
            user=user_hash[user]
            # print user
                # print "u", user
                # line=fi.next()
                # print line

            token=tokens[-1]   # poi 编号
            time_str=tokens[1].split('T')[0]
            lat = float(tokens[2])   # 经度
            lon = float(tokens[3])   # 纬度
            wday = time.strptime(time_str, '%Y-%m-%d').tm_wday   # 年月日 到星期几
            wday = 0 if wday < 5 else 1  # 工作日为0,周末是1
            rating=1
            if date!=time_str or user!=pre_user:  # 用户不是前一个用户 星期几不相同
                if len(poi_per_track)!=0:
                    poi_track[len(poi_list)]=pre_user
                    poi_time_track[len(poi_list)] = wday   # 周几
                    poi_list.append(poi_per_track)

                    # print poi_per_track
                pre_user=user
                date=time_str
                poi_per_track=[]
            if token not in vocab_hash: # poi编号
                vocab_hash[token] = vocab_hash.get(token,int(len(vocab_hash))) # poi编号: 行数
                vocab_4table.append(VocabItem(token))  # 列表里的数组
            token=vocab_hash[token]   # 行数
            # print token
            if token not in loc_geo:
                loc_geo[token] = (lat,lon)    # poi行数: 经纬度
            vocab_4table[token].count+=1
            # print vocab_4table[token].count
            poi_per_track.append(token)  # 行数
            vocab_items[user]=vocab_items.get(user,VocabItem(user)) # 用户:?
            vocab_items[user].item_set[token]=rating    #  ?
            rating_count += 1
            if rating_count % 10000 == 0:
                sys.stdout.write("\rReading ratings %d" % rating_count)
                sys.stdout.flush()
            #if rating_count>10000:
            #    break
            # Add special tokens <bol> (beginning of line) and <eol> (end of line)
        fi.close()
        sys.stdout.write("%s reading completed\n" % fi)
        self.vocab_items = vocab_items         # List of VocabItem objects
        self.vocab_hash = vocab_hash           # Mapping from each token to its index in vocab
        self.rating_count = rating_count           # Total number of words in train file
        self.user_count=len(user_hash.keys())  # 键值总数
        self.item_count=len(vocab_hash.keys())
        self.poi_track=poi_track
        self.poi_list=poi_list
        self.poi_time_track = poi_time_track
        self.vocab_4table=vocab_4table
        self.loc_geo = loc_geo
        self.test_data = {}
         # Add special token <unk> (unknown),
        # merge words occurring less than min_count into <unk>, and
        # sort vocab in descending order by frequency in train file
        print('num_poi_track: ', len(self.poi_track))
        print('num_poi_time_track', len(self.poi_time_track))
#        self.__sort(min_count)
#         print (self.vocab_hash)
        f=open('./gowalla_loc.hash','w')
        for x in self.vocab_hash:
            f.write(str(x)+' '+str(self.vocab_hash[x])+'\n')
        f.close()
        print(self.rating_count)   # 签到总数
        print(self.rating_count*percentage)
        self.split(percentage)
        #assert self.word_count == sum([t.count for t in self.vocab_items]), 'word_count and sum of t.count do not agree'
        print('Total user in training file: %d' % self.user_count)
        print('Total item in training file: %d' % self.item_count)
        print('Total rating in file: %d' % self.rating_count)
        print('Total POI tracking (day): %d' % len(self.poi_track.keys()))
        print(len(poi_list))
#        print 'Total raiting in testing set: %d' % self.rating_count-int(self.rating_count*percentage)
        #print 'Vocab size: %d' % len(self)

再看看对这些数据的处理方式。
1.user_hash是一个user到len的对应,感觉是一个简单的hash
2.poi_list是poi的序列信息,根据用户、时间(主要是工作日、非工作日之分)而确定的不同序列集合
3.poi_track和poi_time_track是按照poi_list的长度来索引用户、星期信息
4.vocab_hash 是一个poi到len的映射,和user_hash基本一致
5.vocab_4table是一个list,该list的元素为VocabItem
6.loc_geo是一个以poi为key值,经纬度为value的hashmap
7.poi_per_track记录了一个用户在某一个时段的poi序列信息
8.vocab_items是以user为key值,VocabItem为value,其中VocabItem的word是user信息
9.rating是一个标志信息,1为训练,11为测试

数据集的划分

输入:percentage是一个划分比,默认为0.8

该算法是通过随机数来选择一个user中的poi,并将该poi标记为测试数据。之后将测试数据从训练数据中移除。

   def split(self,percentage):
        cur_test=0
        test_case=(1-percentage)*self.rating_count   # 测试数
        print('Test case: ', test_case)
        #print test_case
        for user in self.vocab_items.keys():
            if len(self.vocab_items[user].item_set.keys())<5:    # 用户少于5个删除
                continue
            if cur_test>=test_case:   # 超过测试数结束
                break
            for item in self.vocab_items[user].item_set.keys():
                if cur_test<test_case and np.random.random()>percentage:    # random.random()  生成0-1 之间的随机数
                    cur_test+=1
                    self.vocab_items[user].item_set[item]+=10     # item +10
        seq_out=('./gowalla_loc.seq.out')
        f=open(seq_out,'w')
        for i in range(len(self.poi_track)):
            try:
                user=self.poi_track[i]        # 用户
                w_day = self.poi_time_track[i]   # 用户对应的时间
            except:
                print(self.poi_track[i])
                print(self.poi_time_track[i])
            if user not in self.test_data:     #  不在测试集就设为空
                self.test_data[user] = {}
                self.test_data[user][0] = []
                self.test_data[user][1] = []
            for item in self.poi_list[i]:
                f.write(str(item)+' ')
                if self.vocab_items[user].item_set[item]>8: # item > is testdata
                    self.test_data[user][w_day].append(item)  # test_day append item
                    self.poi_list[i].remove(item) # renmove from poilist
            f.write('\n')
        f.close()

训练过程

训练过程很漫长,主要是word2vec的训练,单词的维度默认为50。具体标注在代码中。

def train_process(pid):   # 训练模型
    #print pid
    # Set fi to point to the right chunk of training file
    //信息分片
    start = len(vocab.poi_list) / num_processes * pid
    end = len(vocab.poi_list) if pid == num_processes - 1 else len(vocab.poi_list) / num_processes * (pid + 1)
    #fi.seek(start)
    print 'Worker %d beginning training at %d, ending at %d' % (pid, start, end)
    current_item=start
    alpha = starting_alpha
    iter_num=0
    total_iter=1
    word_count = 0
    last_word_count = 0
    orig_start=start
    count = 0
    counttime = time.time()
    total = end - start
    //在当前的分片序列中训练
    while start< end:
        //训练信息打印
        if count % 5000 == 0:
            print("worker id %d, deal with %d items(total is %d), cost %d, deal percentage %f"%(pid, count, total, time.time()-counttime, count * 1.0/(total)))
            counttime = time.time()
        count+=1
        #sys.stdout.write("%d iter\n" % (iter_num))
        #if iter_num<total_iter and start==end-1:
        #    iter_num+=1
        #    start=orig_start

        #global_word_count.value+=()
        #line = fi.readline().strip()
        # Skip blank lines
        # if not line:
        #    continue
        #print line.split()
        # Init sent, a list of indices of words in line
        #sent = vocab.indices(['<bol>'] + line.split() + ['<eol>'])
        if len(vocab.poi_list[start])==0:
            start+=1
            continue
        #word_count+=1
        if  word_count % 2000 == 0:
                global_word_count.value += (word_count - last_word_count)
                last_word_count = word_count
                #print global_word_count.value
        #if global_word_count.value%10000==0:
#                auc, ndcg10,ndcg20=evaluate(vocab, syn_user,syn0)
#                p5,p10,r5,r10=tr_error(vocab,syn0,syn_user)
#                sys.stdout.write( "\nProcessing %d, AUC: %f, NDCG@10: %f, NDCG@20: %f" %(global_word_count.value, auc,ndcg10,ndcg20))
#                sys.stdout.write( "\nProcessing %d, %f %f %f %f" %(global_word_count.value,p5,p10,r5,r10))
#                sys.stdout.flush()
    #if current_error.value>last_error.value:
        #    break
        #last_error.value=current_error.value
        //获取当前poi序列
        sent=vocab.poi_list[start]
        //user信息
        c_user=vocab.poi_track[start]
        //week信息
        t_state = vocab.poi_time_track[start]
       # print c_user
        for sent_pos, token in enumerate(sent):  # 枚举
            //在设置的window之下,随机选择一个窗口大小。上下文信息为前后的信息
            current_win = randint(1, win+1)
            context_start = max(sent_pos - current_win, 0)
            context_end = min(sent_pos + current_win + 1, len(sent))
            //拼接上下文信息
            context = sent[context_start:sent_pos] + sent[sent_pos+1:context_end] # Turn into an iterator?
            #print 'context',(context)
            #print 'word', (sent[sent_pos])
            if alpha!=0:
                # CBOW,word2vec的训练
                if cbow:
                    # Compute neu1
                    neu1 = np.mean(np.array([syn0[c] for c in context]), axis=0)
                    assert len(neu1) == dim, 'neu1 and dim do not agree'

                    # Init neu1e with zeros
                    neu1e = np.zeros(dim)

                    # Compute neu1e and update syn1
                    if neg > 0:
                        classifiers = [(token, 1)] + [(target, 0) for target in table.sample(neg)]
                    else:
                        classifiers = zip(vocab[token].path, vocab[token].code)
                    for target, label in classifiers:
                       # print 'CBOW',target,label
                        z = np.dot(neu1, syn1[target])   # 矩阵乘积
                        p = sigmoid(z)
                        g = alpha * (label - p)
                        neu1e += g * syn1[target] # Error to backpropagate to syn0
                        syn1[target] += g * neu1  # Update syn1

                # Update syn0
                    for context_word in context:
                        syn0[context_word] += neu1e

                # Skip-gram,word2vec训练
                else:
                    //对每个词语分别计算
                    for context_word in context:
                        # Init neu1e with zeros
                        neu1e = np.zeros(dim)

                        # Compute neu1e and update syn1
                        if neg > 0:
                            classifiers = [(token, 1)] + [(target, 0) for target in table.sample(neg)]
                        else:
                            classifiers = zip(vocab[token].path, vocab[token].code)
                        for target, label in classifiers:
                            #print 'SG',target,label
                            z = np.dot(syn0[context_word], syn1[target])+np.dot(syn0[context_word],syn_t[t_state])
                            p = sigmoid(z)
                            g = alpha * (label - p)
                            neu1e += g * syn1[target]              # Error to backpropagate to syn0
                            syn1[target] += g * (syn0[context_word]) # Update syn1
                            syn_t[t_state] += g * (syn0[context_word])
                        # Update syn0
                        syn0[context_word] += neu1e
            //bayesian personalized ranking,计算neighboring POI和non-neighboring POI,利用地理距离来确定偏号关系
            #print 'bpr',len(vocab.vocab_items[c_user].item_set.keys())
            for x in range((len(vocab.vocab_items[c_user].item_set.keys()))):
                neighbor_item=np.random.choice(vocab.vocab_items[c_user].item_set.keys())
                non_neighbors = []
                while len(non_neighbors) < num_non_neighbors:
                    non_neighbor_item=randint(0,vocab.item_count)
                    while non_neighbor_item in vocab.vocab_items[c_user].item_set.keys():
                        non_neighbor_item=randint(0,vocab.item_count)
                    non_neighbors.append(non_neighbor_item)
                neighbor_set = []
                non_neighbor_set = []
                for item in non_neighbors:
                    temp_lat1 = vocab.loc_geo[neighbor_item][0]
                    temp_lon1 = vocab.loc_geo[neighbor_item][1]
                    temp_lat2 = vocab.loc_geo[item][0]
                    temp_lon2 = vocab.loc_geo[item][1]
                    if dis(temp_lon1,temp_lat1,temp_lon2,temp_lat2) < neighbor_threshold:
                        neighbor_set.append(item)
                    else:
                        non_neighbor_set.append(item)
                pair_set = []
                if len(neighbor_set) and len(non_neighbor_set):
                    for item_neighbor in neighbor_set:
                        for item_non_neighbor in non_neighbor_set:
                            pair_set.append((neighbor_item,item_neighbor))
                            pair_set.append((item_neighbor,item_non_neighbor))
                else:
                    for item in (neighbor_set+non_neighbor_set):
                        pair_set.append((neighbor_item,item))


                #print 'tt'
                #print syn_user[c_user]
                #print syn0[neighbor_item]
                //更新向量
                for item in pair_set:
                    neighbor_item = item[0]
                    non_neighbor_item = item[1]
                    p_e=np.dot(syn_user[c_user],syn0[neighbor_item])-np.dot(syn_user[c_user],syn0[non_neighbor_item])
    #            print 'pe', p_e,np.dot(syn_user[c_user],syn0[neighbor_item]),np.dot(syn_user[c_user],syn0[non_neighbor_item])
                    if p_e > 6:
                        bpr_e = 0
                    elif p_e < -6:
                        bpr_e = 1
                    else:
                        bpr_e=np.exp(-p_e)/(1+np.exp(-p_e))
#                print 'bpre', bpr_e
                    syn_user[c_user]+=beta*bpr_e*(syn0[neighbor_item]-syn0[non_neighbor_item])
                    syn0[neighbor_item]+=beta*bpr_e*(syn_user[c_user])
                    syn0[non_neighbor_item]+=beta*bpr_e*(-syn_user[c_user])
                word_count+=1

        #    print 'bpr finished'

        start+=1
        #word_count+=1
        #print start
    # Print progress info
      #  global_word_count.value += (word_count - last_word_count)
#    sys.stdout.write("\rAlpha: %f Progress: %d of %d (%.2f%%)" %
#                     (alpha, global_word_count.value, vocab.rating_count,
#                      float(global_word_count.value)/vocab.rating_count * 100))
#    sys.stdout.flush()
    #fi.close()

可以看到整个训练过程考虑到了poi序列信息、空间距离

预测过程

并行预测过程,具体代码分析见注释

def predict_parallel(pid):
    //确定分片位置
    start = vocab.user_count / num_processes * pid
    end = vocab.user_count if pid == num_processes - 1 else vocab.user_count / num_processes * (pid + 1)
    print "predict usr between %d to %d"%(start, end)
    c=0.0
    count = 0
    counttime = time.time()
    total = end - start
    while start< end:
        //打印信息
        if count % 5 == 0:
            print("worker id %d, deal with %d items(total is %d), cost %d, deal percentage %f"%(pid, count, total, time.time()-counttime, count * 1.0/(total)))
            counttime = time.time()
        count+=1
        user=start
        //测试数据
        if user in vocab.test_data:
            //两个时段的poi信息
            wday_0 = vocab.test_data[user][0]  # 时间为0
            wday_1 = vocab.test_data[user][1]  # 时间为1  分为两个时间序列进行推荐

            raw_rating_0 = {}
            raw_rating_1 = {}
            if len(wday_0):
                if len(wday_1):
                  #  test_case += 1
                    //确定比例信息
                    rate_0 = len(wday_0) * 1.0/(len(wday_0)+len(wday_1))
                    //对每个item进行预测,获取评分值,进而排序,取top-5和top-10
                    for item1 in range(vocab.item_count):
                        if item1 in vocab.vocab_items[user].item_set:
                            if vocab.vocab_items[user].item_set[item1] < 8:
                                continue
                        pred = 0.0
                        for i in range(len(syn0[item1])):
                            pred += syn0[item1][i]*syn_user[user][i] + syn_t[0][i]*syn_user[user][i]
                        raw_rating_0[item1] = pred
                        pred = 0.0
                        for i in range(len(syn0[item1])):
                            pred += syn0[item1][i]*syn_user[user][i] + syn_t[1][i]*syn_user[user][i]
                        raw_rating_1[item1] = pred
                    ranked_0 = OrderedDict(sorted(raw_rating_0.items(), key=lambda x:x[1], reverse=True)) # 排序
                    ranked_1 = OrderedDict(sorted(raw_rating_1.items(), key=lambda x:x[1], reverse=True))
                    top10_0 = ranked_0.keys()[:10]
                    top10_1 = ranked_1.keys()[:10]
                    top5_0 = ranked_0.keys()[:5]
                    top5_1 = ranked_1.keys()[:5]
                    a = int(rate_0*10)
                    b = 10 - a
                    top10 = top10_0[:a] + top10_1[:b]
                    a = int(rate_0*5)
                    b = 5 - a
                    top5 = top5_0[:a] + top5_1[:b]
                else:
                  #  test_case += 1
                    # rate_0 = len(wday_0) * 1.0/(len(wday_0)+len(wday_1))
                    for item1 in range(vocab.item_count):
                        if item1 in vocab.vocab_items[user].item_set:
                            if vocab.vocab_items[user].item_set[item1] < 8:
                                continue
                        pred = 0.0
                        for i in range(len(syn0[item1])):
                            pred += syn0[item1][i]*syn_user[user][i] + syn_t[0][i]*syn_user[user][i]
                        raw_rating_0[item1] = pred

                    ranked_0 = OrderedDict(sorted(raw_rating_0.items(), key=lambda x:x[1], reverse=True))
                    top10 = ranked_0.keys()[:10]
                    top5 = ranked_0.keys()[:5]
            else:
                if len(wday_1):
                  #  test_case += 1
                # rate_0 = len(wday_0) * 1.0/(len(wday_0)+len(wday_1))
                    for item1 in range(vocab.item_count):
                        if item1 in vocab.vocab_items[user].item_set:
                            if vocab.vocab_items[user].item_set[item1] < 8:
                                continue
                        pred = 0.0
                        for i in range(len(syn0[item1])):
                            pred += syn0[item1][i]*syn_user[user][i] + syn_t[1][i]*syn_user[user][i]
                        raw_rating_1[item1] = pred
                    ranked_1 = OrderedDict(sorted(raw_rating_1.items(), key=lambda x:x[1], reverse=True))
                    top10 = ranked_1.keys()[:10]
                    top5 = ranked_1.keys()[:5]
            //测试中的poi信息,作为基准数据
            test_pois = wday_0+wday_1
            if len(test_pois):
                g5=0.0
                g10=0.0
                for i in top5: # 在5个里面选
                    if i in wday_0+wday_1:
                        g5+=1.0
                for i in top10: # 在10个里面选
                    if i in wday_0+wday_1:
                        g10+=1.0
                p5=g5/5.0     # 求平均数
                p10=g10/10.0
                r5=g5/len(wday_0+wday_1)
                r10=g10/len(wday_0+wday_1)
                ar=[p5, p10, r5, r10]

                result[user]=ar
        #    print len(result[user])
                ofile1 = open('result1.txt', 'a')
                ofile1.write(str(user)+str(top5)+str(top10)+'\n')
                # print user
                # print top5
                # print top10
        start+=1

猜你喜欢

转载自blog.csdn.net/u013735511/article/details/80158854
GEO
今日推荐