[Conceptos básicos del aprendizaje automático] El modelo de regresión lineal predice la relación entre la cantidad de me gusta de videos y las colecciones de la estación B

El modelo de regresión lineal predice la relación entre la cantidad de me gusta de videos y las colecciones de la estación B (Hua Nong Brothers)


Prefacio

Los modelos de regresión lineal se pueden utilizar para predecir la tendencia de los datos. Mediante el entrenamiento en el conjunto de datos existente, se puede obtener una función lineal Y = w * X + b , y el valor subsiguiente se puede predecir a través de esta función lineal.


1. Modelo de regresión lineal

La regresión lineal se basa en el supuesto de una correlación lineal entre el valor objetivo X y el valor propio Y. El modelo lineal
Función lineal
se resuelve a través de un conjunto de datos conocido . El método de solución específico es construir una función de pérdida, haciendo que el valor de la función de pérdida sea cada vez mayor Cuanto menor sea hasta que se cumpla el requisito de precisión o el número de iteraciones. La función de pérdida puede entenderse como la diferencia entre el valor predicho y el valor real del proceso de cálculo del tipo obtenido, de modo que la brecha es menor, el modelo con el valor real como. Definición de función de pérdida:
Función de pérdida
Para minimizar la función de pérdida, minimice la pérdida (w, b) . Al presentar el algoritmo de descenso de gradiente , la velocidad de descenso a lo largo de la dirección del gradiente es la más rápida. Actualice w y b en cada iteración hasta que se cumplan los requisitos.
Descenso de gradiente
Calcule las derivadas parciales de la pérdida (w, b) con respecto a w y b, respectivamente (puede importar y = w * x + b en la función de pérdida):
w derivada parcial
b adelanto parcial

Dos, obtén datos

Al rastrear la información del video en BILIBILI, este artículo obtuvo la información del video de " Hua Nong Brothers ". Puede consultar en el blog todos los detalles de rastreo del video B de la estación UP . Tome la cantidad de me gusta y las colecciones de los videos y establezca un modelo de regresión lineal para predecir su relación. El diagrama de dispersión de los me gusta de videos (eje x) y las colecciones (eje y) es el siguiente:
Gráfico de dispersión
debido a la concentración deficiente de datos, los datos deben normalizarse. Este artículo utiliza los valores máximo y mínimo para normalizar.

Tres, entrenamiento modelo

Después del entrenamiento, w = 0.7229486928307687 b = 0.20322045504258518 El modelo entrenado es el siguiente:
Modelo de entrenamiento

Cuarto, el código

# 线型回归模型预测B站视频点赞量与收藏量的关系(华农兄弟)
import json
import numpy as np
import time
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

class LR(object):
    def __init__(self, max_iterator = 1000, learn_rate = 0.01):
        self.max_iterator = max_iterator
        self.learn_rate = learn_rate
        self.w = np.random.normal(1, 0.1)
        self.b = np.random.normal(1, 0.1)

    def cal_day(self, release_date, now_date):
        # 计算天数
        start_time = time.mktime(time.strptime(release_date.split(' ')[0], '%Y-%m-%d'))
        end_time = time.mktime(time.strptime(now_date.split(' ')[0], '%Y-%m-%d'))
        return int((end_time - start_time)/(24*60*60))


    def load_data(self, url):
        with open(url, 'r', encoding='utf-8') as f:
            data_dect = json.load(f)
        # print(data_dect)
        
        # 视频播放数量以及发布距离现在的天数
        watched_number_list = []
        time_list = []
        dm_number_list = []
        liked_list = []
        collected_list = []
        for sample in data_dect:
            # 去掉坏点
            if sample['watched'] != '':
                watched_number_list.append([float(sample['watched'])]) #观看数量
                liked_list.append([float(sample['liked'])])	#点赞数
                collected_list.append([float(sample['collected'])])	#收藏数
                dm_number_list.append([float(sample['bullet_comments'])])	#弹幕数
                time_list.append([float(self.cal_day(sample['date'], sample['now_date']))]) #视频发布距离现在时间

        return np.array(time_list), np.array(watched_number_list), np.array(liked_list), np.array(collected_list), np.array(dm_number_list)

    def train_set_normalize(self, train_set):
        data_range = np.max(train_set) - np.min(train_set)
        return (train_set - np.min(train_set)) / data_range



    def cal_gradient(self, x, y):
    	# 计算梯度
        # print(x, y)
        dw = np.mean((x * self.w + self.b - y) * x)
        db = np.mean(self.b + x * self.w - y)
        return dw, db
    
    
    def train(self, x, y):
        # 训练模型,使用梯度下降
        train_w = []
        train_b = []
        for i in range(self.max_iterator):
            print(self.w, self.b)
            train_w.append(self.w)
            train_b.append(self.b)
            i += 1
            # 计算梯度值,向着梯度下降的方向
            dw, db = self.cal_gradient(x, y)
            self.w -= self.learn_rate*dw
            self.b -= self.learn_rate*db
        return train_w, train_b

    def predict(self, x):
        # 预测
        return x * self.w + self.b
    
    def myplot(self, x, y, train_w, train_b):
        
        plt.pause(2)
        plt.ion()
        # 动态绘图
        for i in range(0, self.max_iterator, 30):
            
            plt.clf()
            # 原始散点图
            plt.scatter(x, y, marker = 'o',color = 'yellow', s = 40)
            plt.xlabel('liked')
            plt.ylabel('collected')
            plt.plot(x, train_w[i] * x  + train_b[i], c='red')
            plt.title('step: %d learning-rate: %.2f function: y=%.2f * x + %.2f' %(i, self.learn_rate, train_w[i], train_b[i]))
            plt.pause(0.5) 
            
        plt.show()
        plt.ioff()
        plt.pause(200)

        
  


lr = LR()
time_list, watched_number_list, liked_list, collected_list, dm_number_list = lr.load_data(r'2020\Crawl\Bilibili\Item1\data\video_detial.json')
# 需要对数据进行归一化处理

tw, tb = lr.train(lr.train_set_normalize(liked_list), lr.train_set_normalize(collected_list))
lr.myplot(lr.train_set_normalize(liked_list), lr.train_set_normalize(collected_list), tw, tb)

referencias

  1. https://www.cnblogs.com/geo-will/p/10468253.html

Supongo que te gusta

Origin blog.csdn.net/qq_37753409/article/details/109004339
Recomendado
Clasificación