【Torch-RecHub学习】DeepFM

1. 本文简述

github.com/datawhalech… DataWhale团队发布的开源推荐系统工具包。目前实现了常用的LR、FM、MLP等经典模型的组件,但仍缺少应用GCN的组件。

本文学习使用torch-rechub工具包实现经典的DeepFM模型。

2. DeepFM

DeepFM.png

  • 2017年由哈工深和华为诺亚实验室提出
  • 使用Wide&Deep框架,对Wide部分改造,解决CTR点击率预测问题
  • 创新:FM(低阶特征交叉建模) + DNN(高阶特征交叉建模)
  • 创新:提高了模型wide侧提取信息的能力,不用手工特征工程

arxiv.org/pdf/1703.04…

2.1 FM部分

目的:学习一阶和二阶特征交叉

FM组件.png 上图为FM组件,用于学习一阶和二阶的低阶特征交叉

  • Addition:所有的输入相加;线性交叉linear interactions,作为一阶交叉order-1
  • Inner Product:该单元的输出是两个输入向量的乘积;成对交叉Pairwise,作为二阶交叉order-2
  • Sigmoid:CTR预测用sigmoid作为输出函数
  • 最下层蓝色箭头:Embedding,需要学习的隐向量latent vector
  • 黑色箭头:Normal Connection,待学习的权重连接
  • 红色箭头:默认权重为1的连接

FM公式.png 上式为因子分解机FM给出的公式

2.2 Deep部分

Deep部分.png 上图为Deep部分,是一个前馈神经网络DNN,用于学习高阶特征交叉

Deep公式.png 上式为Deep部分公式

2.3 补充

  1. FM和Deep共享同一个Embedding层 从原始特征中同时学习低阶和高阶的特征交叉,不需要人工的特征工程作为输入
  2. 设计的特征输入结构 统一了Embedding的长度,把FM中的隐向量作为网络中的权重被学习和被用于把Input vector压缩为Embedding vector

2.2 代码实现

class DeepFM(torch.nn.Module):

    def __init__(self, deep_features, fm_features, mlp_params):
        super(DeepFM, self).__init__()
        self.deep_features = deep_features
        self.fm_features = fm_features
        self.deep_dims = sum([fea.embed_dim for fea in deep_features])
        self.fm_dims = sum([fea.embed_dim for fea in fm_features])
        self.linear = LR(self.fm_dims)  # 像公式说的FM是LR的二阶拓展
        self.fm = FM(reduce_sum=True)  # 二阶交叉部分
        self.embedding = EmbeddingLayer(deep_features + fm_features)
        self.mlp = MLP(self.deep_dims, **mlp_params)  # Deep部分是多层MLP

    def forward(self, x):
        input_deep = self.embedding(x, self.deep_features, squeeze_dim=True)  # [batch_size, deep_dims]
        input_fm = self.embedding(x, self.fm_features, squeeze_dim=False)  # [batch_size, num_fields, embed_dim]

        y_linear = self.linear(input_fm.flatten(start_dim=1))  # 摊平为1维后送入Addition部分
        y_fm = self.fm(input_fm)  # 计算二阶交叉部分
        y_deep = self.mlp(input_deep)  # [batch_size, 1]
        y = y_linear + y_fm + y_deep  # 三个部分的输出相加
        return torch.sigmoid(y.squeeze(1))
复制代码

猜你喜欢

转载自juejin.im/post/7110780639435030565