DIFM网络详解及复现

网络详解

https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247484784&idx=1&sn=9fa9ae9caa951a7781059b39fdde75dd&chksm=c337b8e9f44031ffef69decc7a5b11b7c4a6b4633084beda9e8eb9a7814bc16f29bcd8fe2f11#rd

网络结构代码

# coding:utf-8
# @Email: [email protected]
# @Time: 2022/8/1 3:36 下午
# @File: DIFM.py
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
from torchkeras import summary

from tools import *
from BaseModel import BaseModel

class FM(nn.Module):
    '''
    without linear term and bias
    '''
    def __init__(self):
        super(FM, self).__init__()

    def forward(self, inputs):
        # (batch_size, field_size, embedding_size)
        fm_input = inputs
        square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2)
        sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True)
        cross_term = square_of_sum - sum_of_square
        cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)

        # (batch_size, 1)
        return cross_term

class Linear_W(nn.Module):
    def __init__(self, dense_nums):
        super(Linear_W, self).__init__()
        self.dense_nums = dense_nums
        if dense_nums is not None or dense_nums != 0:
            self.weight = nn.Parameter(torch.Tensor(dense_nums, 1))
            torch.nn.init.normal_(self.weight, mean=0, std=0.0001)

    def forward(self, sparse_inputs, dense_inputs=None, sparse_feat_refine_weight=None):
        linear_logit = torch.zeros([sparse_inputs.shape[0], 1])

        sparse_logit = sparse_inputs * sparse_feat_refine_weight.unsqueeze(-1)

        sparse_logit = torch.sum(sparse_logit, dim=-1, keepdim=False)
        sparse_logit = torch.unsqueeze(torch.sum(sparse_logit, dim=-1, keepdim=False), dim=-1)

        linear_logit += sparse_logit

        if dense_inputs is not None:
            dense_logit = torch.matmul(dense_inputs, self.weight)
            linear_logit += dense_logit

        return linear_logit

class MultiheadAttention(nn.Module):
    def __init__(self, emb_dim, head_num, scaling=True, use_residual=True):
        super(MultiheadAttention, self).__init__()
        self.emb_dim = emb_dim
        self.head_num = head_num
        self.scaling = scaling
        self.use_residual = use_residual
        self.att_emb_size = emb_dim // head_num
        assert emb_dim % head_num == 0, "emb_dim must be divisible head_num"

        self.W_Q = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
        self.W_K = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
        self.W_V = nn.Parameter(torch.Tensor(emb_dim, emb_dim))

        if self.use_residual:
            self.W_R = nn.Parameter(torch.Tensor(emb_dim, emb_dim))

        # 初始化, 避免计算得到nan
        for weight in self.parameters():
            nn.init.xavier_uniform_(weight)

    def forward(self, inputs):

        '''1. 线性变换生成Q、K、V'''
        # dim: [batch_size, fields, emb_size]
        querys = torch.tensordot(inputs, self.W_Q, dims=([-1], [0]))
        keys = torch.tensordot(inputs, self.W_K, dims=([-1], [0]))
        values = torch.tensordot(inputs, self.W_V, dims=([-1], [0]))
        # # 等价于 querys = torch.matmul(inputs, self.W_Q)

        '''2. 分头'''
        # dim: [head_num, batch_size, fields, emb_size // head_num]
        querys = torch.stack(torch.split(querys, self.att_emb_size, dim=2))
        keys = torch.stack(torch.split(keys, self.att_emb_size, dim=2))
        values = torch.stack(torch.split(values, self.att_emb_size, dim=2))

        '''3. 缩放点积注意力'''
        # dim: [head_num, batch_size, fields, emb_size // head_num]
        inner_product = torch.matmul(querys, keys.transpose(-2, -1))
        # # 等价于 inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys)
        if self.scaling:
            inner_product /= self.att_emb_size ** 0.5
        # Softmax归一化权重
        attn_w = F.softmax(inner_product, dim=-1)
        # 加权求和, attention结果与V相乘,得到多头注意力结果
        results = torch.matmul(attn_w, values)

        '''4. 拼接多头空间'''
        # dim: [batch_size, fields, emb_size]
        results = torch.cat(torch.split(results, 1, ), dim=-1)
        results = torch.squeeze(results, dim=0)

        # 跳跃连接
        if self.use_residual:
            results = results + torch.tensordot(inputs, self.W_R, dims=([-1], [0]))

        results = F.relu(results)
        # results = F.tanh(results)

        return results

class DNN(nn.Module):
    def __init__(self, input_dim, dnn_hidden_units, use_bn, dropout):
        super(DNN, self).__init__()
        fc_layers = []
        self.input_dim = input_dim
        for hidden in dnn_hidden_units:
            fc_layers.append(nn.Linear(input_dim, hidden))
            if use_bn: fc_layers.append(nn.BatchNorm1d(hidden))
            fc_layers.append(nn.ReLU())
            fc_layers.append(nn.Dropout(p=dropout))
            input_dim = hidden
        self.fc_layers = nn.Sequential(*fc_layers)

        for name, tensor in self.fc_layers.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=0.0001)

    def forward(self, inputs):
        dnn_output = self.fc_layers(inputs.view(-1, self.input_dim))
        return dnn_output

class DIFM(nn.Module):
    def __init__(self, sparse_fields, dense_nums, emb_dim=8, head_num=1, scaling=True, use_residual=True, dnn_hidden_units=(256, 128), use_bn=True, dropout=0.2,):
        super(DIFM, self).__init__()
        self.sparse_field_num = len(sparse_fields)
        self.offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong)
        self.dnn_input_dim = self.sparse_field_num * emb_dim

        # spare emb
        self.embedding = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=emb_dim)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

        # vector-wise part
        self.vector_wise_net = MultiheadAttention(emb_dim=emb_dim, head_num=head_num, scaling=scaling, use_residual=use_residual)
        # bit-wise part
        self.bit_wise_net = DNN(input_dim=self.dnn_input_dim, dnn_hidden_units=dnn_hidden_units, use_bn=use_bn, dropout=dropout)

        # P_vec、P_bit
        self.transform_matrix_P_vec = nn.Linear(self.dnn_input_dim, self.sparse_field_num, bias=False)
        self.transform_matrix_P_bit = nn.Linear(dnn_hidden_units[-1], self.sparse_field_num, bias=False)

        self.linear_model = Linear_W(dense_nums=dense_nums)

        self.fm = FM()

    def forward(self, inputs):
        dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]

        # spare emb
        sparse_inputs = sparse_inputs.long()
        sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.offsets).unsqueeze(0)
        spare_emb = self.embedding(sparse_inputs)   # (None, sparse_field_num, emb_dim)

        # Dual-FEN Layer
        ## vector-wise part
        vec_out = self.vector_wise_net(spare_emb)   # (None, sparse_field_num, emb_dim)
        vec_out = vec_out.reshape(vec_out.shape[0], -1)   # (None, sparse_field_num * emb_dim)
        ## bit-wise part
        # bit_out = self.bit_wise_net(spare_emb.view(-1, self.dnn_input_dim))   # (None, dnn_hidden_units[-1])
        bit_out = self.bit_wise_net(spare_emb.reshape(vec_out.shape[0], -1))   # (None, dnn_hidden_units[-1])

        # Combination Layer   m_s is the complete input-ware factor
        m_vec = self.transform_matrix_P_vec(vec_out)   # (None, sparse_field_num)
        m_bit = self.transform_matrix_P_bit(bit_out)   # (None, sparse_field_num)
        m_x = m_vec + m_bit   # [None, sparse_field_num]

        # Reweighting Layer
        ## w_{x,i} = m_{x,i} \times w_i
        logit = self.linear_model(spare_emb, dense_inputs, sparse_feat_refine_weight=m_x)   # (None, 1)
        ## v_{x,i} = m_{x,i} \times v_i
        refined_fm_input = spare_emb * m_x.unsqueeze(-1)   # (None, sparse_field_num * emb_dim)

        # FM Prediction Layer
        logit += self.fm(refined_fm_input)    # (None, 1)

        return torch.sigmoid(logit.squeeze(-1))

训练代码

def printlog(info):
    # nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    # print("%s " % nowtime + "----------"*11 + '---')
    print(str(info))

class BaseModel():
    def __init__(self, net):
        super(BaseModel, self).__init__()
        self.net = net

    def fit(self, train_loader, val_loader, epochs, loss_function, optimizer, metric_name):
        start_time = time.time()
        print("\n" + "********** start training **********")

        columns = ["epoch", "loss", *metric_name, "val_loss"] + ['val_' + mn for mn in metric_name]
        dfhistory = pd.DataFrame(columns=columns)

        '''   训练   '''
        for epoch in range(1, epochs + 1):
            printlog("Epoch {0} / {1}".format(epoch, epochs))
            step_start = time.time()
            step_num = 0

            train_loss = []
            train_pred_probs, train_y, train_pre = [], [], []
            self.net.train()
            for batch, (x, y) in enumerate(train_loader):
                step_num += 1

                # 梯度清零
                optimizer.zero_grad()

                # 正向传播求损失
                pred_probs = self.net(x)
                loss = loss_function(pred_probs, y.float().detach())
                # loss = loss_function(pred, y)

                # 反向传播求梯度
                loss.backward()
                optimizer.step()

                train_loss.append(loss.item())
                train_pred_probs.extend(pred_probs.tolist())
                train_y.extend(y.tolist())
                train_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))

            '''   验证   '''
            val_loss = []
            val_pred_probs, val_y, val_pre = [], [], []
            self.net.eval()
            # 不参与梯度计算
            with torch.no_grad():
                for batch, (x, y) in enumerate(val_loader):
                    pred_probs = self.net(x)
                    loss = loss_function(pred_probs, y.float().detach())
                    val_loss.append(loss.item())
                    val_pred_probs.extend(pred_probs.tolist())
                    val_y.extend(y.tolist())
                    val_pre.extend(
                        torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))

            '''  一次epoch结束 记录日志   '''
            epoch_loss, epoch_val_loss = np.mean(train_loss), np.mean(val_loss)
            train_auc = roc_auc_score(y_true=train_y, y_score=train_pred_probs)
            train_acc = accuracy_score(y_true=train_y, y_pred=train_pre)
            val_auc = roc_auc_score(y_true=val_y, y_score=val_pred_probs)
            val_acc = accuracy_score(y_true=val_y, y_pred=val_pre)

            dfhistory.loc[epoch - 1] = (epoch, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc)

            step_end = time.time()

            print("step_num: %s - %.1fs - loss: %.5f   accuracy: %.5f   auc: %.5f - val_loss: %.5f   val_accuracy: %.5f   val_auc: %.5f"
                % (step_num, (step_end - step_start) % 60, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc))

        end_time = time.time()
        print('********** end of training run time: {:.0f}分 {:.0f}秒 **********'.format((end_time - start_time) // 60,
                                                                                       (end_time - start_time) % 60))
        print()
        return dfhistory

    def evaluate(self, val_X, val_y):
        val_X = torch.tensor(val_X).float()

        pred_probs = self.net(val_X).data
        pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))

        precision = np.around(metrics.precision_score(val_y, pred), 4)
        recall = np.around(metrics.recall_score(val_y, pred), 4)
        accuracy = np.around(metrics.accuracy_score(val_y, pred), 4)
        f1 = np.around(metrics.f1_score(val_y, pred), 4)
        auc = np.around(metrics.roc_auc_score(val_y, pred_probs), 4)
        loss = np.around(metrics.log_loss(val_y, pred), 4)

        acc_condition, precision_condition, recall_condition = self.accDealWith2(val_y, pred)

        return precision, recall, accuracy, f1, auc, loss, acc_condition, precision_condition, recall_condition

    def predict(self, x):
        pred_probs = self.net(torch.tensor(x).float()).data
        print(pred_probs)
        pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
        print(pred)

    def plot_metric(self, dfhistory, metric):
        train_metrics = dfhistory[metric]
        val_metrics = dfhistory['val_' + metric]
        epochs = range(1, len(train_metrics) + 1)
        plt.plot(epochs, train_metrics, 'bo--')
        plt.plot(epochs, val_metrics, 'ro-')
        plt.title('Training and validation ' + metric)
        plt.xlabel("Epochs")
        plt.ylabel(metric)
        plt.legend(["train_" + metric, 'val_' + metric])
        plt.show()

    def accDealWith2(self, y_test, y_pre):
        lenall = len(y_test)
        if type(y_test) != list:
            y_test = y_test.flatten()
        pos = 0
        pre = 0
        rec = 0
        precisoinlen = 0
        recallLen = 0

        for i in range(lenall):
            # 准确率
            if y_test[i] == y_pre[i]:
                pos += 1
            # 精确率
            if y_pre[i] == 1:
                pre += 1
                if y_test[i] == 1:
                    precisoinlen += 1
            # 召回率
            if y_test[i] == 1:
                rec += 1
                if y_pre[i] == 1:
                    recallLen += 1

        acc_condition = '预测对的:{},总样本:{}'.format(pos, lenall)
        if pre != 0:
            precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen,
                                                                            np.around(precisoinlen / pre, 4))
        else:
            precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen, 0.0)

        if rec != 0:
            recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, np.around(recallLen / rec, 4))
        else:
            recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, 0.0)

        return acc_condition, precision_condition, recall_condition

main

if __name__ == '__main__':
    print()
    data = pd.read_csv('./data/criteo_sampled_data_test.csv')

    # I1-I13:总共 13 列数值型特征
    # C1-C26:共有 26 列类别型特征
    dense_cols = ['I' + str(i) for i in range(1, 14)]
    sparse_cols = ['C' + str(i) for i in range(1, 27)]
    stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')

    data_X = data[dense_cols + sparse_cols]
    data_y = data['label']

    sparse_fields = data_X[sparse_cols].max().values + 1
    sparse_fields = sparse_fields.astype(np.int32)
    print(sparse_fields)
    dense_fields_num = 13

    tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
    train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
    print(train_X.shape)
    print(val_X.shape)

    train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
    val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())

    train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
    val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)

    net = DIFM(sparse_fields=sparse_fields, dense_nums=dense_fields_num, emb_dim=8, head_num=2, scaling=True, use_residual=True, dnn_hidden_units=(256, 128), use_bn=True, dropout=0.2)

    loss_function = nn.BCELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    base = BaseModel(net=net)
    dfhistory = base.fit(train_loader=train_loader, val_loader=val_loader,
                         epochs=3, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])

    summary(model=net, input_data=torch.tensor(val_X.values).float())

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/126105169