KTV song recommendation - logistic regression - user gender prediction

foreword

In the last article, I wrote that the oldest algorithm in the recommendation system is called collaborative filtering. It is not difficult to use, but it is actually an algorithm that is very useful. Good algorithms solve different problems. Here we mainly talk about logistic regression, which is mainly used for scoring prediction. I don't have scoring data here so use gender instead. The example here is to use the song list to predict the user's gender.

What is logistic regression

There is a lot of data on logistic regression. I recommend everyone to watch the video of Mr. Li Hongyi on bilibili. Here I will only mention some points that need attention.

network structure

Logistic regression can be understood as a single-layer neural network, the network structure is shown in the figure:

Activation function selection

Logistic regression generally chooses sigmoid or softmax

  • The upper part of the figure is the binary logistic regression activation function is sigmoid
  • The lower part of the figure is a multiple logistic regression without an activation function directly connected to a softmax

Don't ask me what is sigmoid and what is softmax, ask is Baidu.

Loss function selection

There are three commonly used loss function logistic regression (in fact, there are many more than three, check the API yourself):

  • binary_crossentropy
  • categorical_crossentropy
  • sparse_categorical_crossentrop It is more suitable to use binary here, but I chose categorical_crossentropy here, because I am too lazy to change it, and I will do other functions later

Gradient Descent Selection

There are many ways of gradient descent. I choose stochastic gradient descent here. In fact, I think adam is more suitable for sgd, depending on everyone's mood. As for why

data preparation

The data this time is 10,000 KTV singing data, don't ask me where the data came from. The question is given by others.

X is the one-hot of the user's singing data

Y is the gender of the user one-hot

Here's the real tech

Code

  • The data is split into 80% training and 20% testing
  • Although there are only two types here, softmax is still used, which does not affect
  • The training tool is keras

data collection

What does the following code do, mainly two matrices.

One is onehot->song_hot_matrix where the user sings.

One is onehot->decades_hot_matrix of user gender. The code is not important, it mainly depends on the words.

import elasticsearch
import elasticsearch.helpers
import re
import numpy as np
import operator
import datetime


es_client = elasticsearch.Elasticsearch(hosts=["localhost:9200"])

def trim_song_name(song_name):
    """
    处理歌名,过滤掉无用内容和空白
    """
    song_name = song_name.strip()
    song_name = re.sub("【.*?】", "", song_name)
    song_name = re.sub("(.*?)", "", song_name)
    return song_name

def trim_address_name(address_name):
    """
    处理地址
    """
    return str(address_name).strip()

def get_data(size=0):
    """
    获取uid=>作品名list的字典
    """
    cur_size=0
    song_dic = {}
    user_address_dic = {}
    user_decades_dic = {}
    
    search_result = elasticsearch.helpers.scan(
        es_client, 
        index="ktv_user_info", 
        doc_type="ktv_works", 
        scroll="10m",
        query={
            "query":{
                "range": {
                    "birthday": {
                        "gt": 63072662400
                    }
                }
            }
        }
    )

    for hit_item in search_result:
        cur_size += 1
        if size>0 and cur_size>size:
            break
            
        user_info = hit_item["_source"]
        item = get_work_info(hit_item["_id"])
        if item is None:
            continue

        work_list = item['item_list']
        if len(work_list)<2:
            continue
        
        if user_info['gender']==0:
            continue
        if user_info['gender']==1:
            user_info['gender']="男"
        if user_info['gender']==2:
            user_info['gender']="女"
        
        song_dic[item['uid']] = [trim_song_name(item['songname']) for item in work_list]

        
        user_decades_dic[item['uid']] = user_info['gender']
        user_address_dic[item['uid']] = trim_address_name(user_info['address'])
        
    return (song_dic, user_address_dic, user_decades_dic)

def get_user_info(uid):
    """
    获取用户信息
    """
    ret = es_client.get(
        index="ktv_user_info", 
        doc_type="ktv_works", 
        id=uid
    )
    return ret['_source']

def get_work_info(uid):
    """
    获取用户信息
    """
    try:
        ret = es_client.get(
            index="ktv_works", 
            doc_type="ktv_works", 
            id=uid
        )
        return ret['_source']
    except Exception as ex:
        return None


def get_uniq_song_sort_list(song_dict):
    """
    合并重复歌曲并按歌曲名排序
    """
    return sorted(list(set(np.concatenate(list(song_dict.values())).tolist())))
    
from sklearn import preprocessing
%run label_encoder.ipynb

user_count = 4000
song_count = 0

# 获得用户唱歌数据
song_dict, user_address_dict, user_decades_dict  = get_data(user_count)

# 歌曲字典
song_label_encoder = LabelEncoder()
song_label_encoder.fit_dict(song_dict, "", True)
song_hot_matrix = song_label_encoder.encode_hot_dict(song_dict, True)

user_decades_encoder = LabelEncoder()
user_decades_encoder.fit_dict(user_decades_dict)
decades_hot_matrix = user_decades_encoder.encode_hot_dict(user_decades_dict, False)

song_hot_matrix

uid scrub brush Mahjong your answer
0 0 1 0
1 1 1 0
2 1 0 0
3 0 0 0

decades_hot_matrix

uid male Female
0 1 0
1 0 1
2 1 0
3 0 1

model training

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Embedding,Flatten
import matplotlib.pyplot as plt
from keras.utils import np_utils
from sklearn import datasets
from sklearn.model_selection import train_test_split

n_class=user_decades_encoder.get_class_count()
song_count=song_label_encoder.get_class_count()
print(n_class)
print(song_count)

# 拆分训练数据和测试数据
train_X,test_X, train_y, test_y = train_test_split(song_hot_matrix,
                                                   decades_hot_matrix,
                                                   test_size = 0.2,
                                                   random_state = 0)
train_count = np.shape(train_X)[0]
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=8, units=n_class))
model.add(Activation('softmax'))

# 选定loss函数和优化器
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

# 训练过程
print('Training -----------')
for step in range(train_count):
    scores = model.train_on_batch(train_X, train_y)
    if step % 50 == 0:
        print("训练样本 %d 个, 损失: %f, 准确率: %f" % (step, scores[0], scores[1]*100))
print('finish!')

Accuracy test set evaluation

After training the data, test it with 20% of the split data:

# 准确率评估
from sklearn.metrics import classification_report
scores = model.evaluate(test_X, test_y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
Y_test = np.argmax(test_y, axis=1)
y_pred = model.predict_classes(test_X)
print(classification_report(Y_test, y_pred))

output:

accuracy: 78.43%
              precision    recall  f1-score   support

           0       0.72      0.90      0.80       220
           1       0.88      0.68      0.77       239

    accuracy                           0.78       459
   macro avg       0.80      0.79      0.78       459
weighted avg       0.80      0.78      0.78       459

Manual testing

Then let the friends play together, um, the accuracy rate is 100%, perfect!

def pred(song_list=[]):
    blong_hot_matrix = song_label_encoder.encode_hot_dict({"bblong":song_list}, True)
    y_pred = model.predict_classes(blong_hot_matrix)
    return user_decades_encoder.decode_list(y_pred)

# # 男A
# print(pred(["一路向北", "暗香", "菊花台"]))
# # 男B
# print(pred(["不要说话", "平凡之路", "李白"]))
# # 女A
# print(pred(["知足", "被风吹过的夏天", "龙卷风"]))
# # 男C
# print(pred(["情人","再见","无赖","离人","你的样子"]))
# # 男D
# print(pred(["小情歌","我好想你","无与伦比的美丽"]))
# # 男E
# print(pred(["忐忑","最炫民族风","小苹果"]))
{{o.name}}
{{m.name}}

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=324107399&siteId=291194637