连续型特征做embedding代码示例

为什么要将连续特征也做emb?

● 一方面连续特征emb后能更充分的与其它特征进行交叉;
● 另一方面可以使得学习更加充分,避免数值微小的变化带来预测结果的剧烈变化。

在这里插入图片描述
实现思路:

  • 对于连续值做归一化;
  • 然后新增列用以做label encoder;
  • 对编码的tensor做emb;
  • 取出连续值tensor,然后相乘;
ddd = pd.DataFrame({
    
    'x1': [0.001, 0.002, 0.003], 'x2': [0.1, 0.1, 0.2]})

'''将编码后的值拼回去'''
dense_cols = ['x1', 'x2']
dense_cols_enc = [c + '_enc' for c in dense_cols]
for i in range(len(dense_cols)):
    enc = LabelEncoder()
    ddd[dense_cols_enc[i]] = enc.fit_transform(ddd[dense_cols[i]].values).copy()
print(ddd)

'''计算fields'''
dense_fields = ddd[dense_cols_enc].max().values + 1
dense_fields = dense_fields.astype(np.int32)
offsets = np.array((0, *np.cumsum(dense_fields)[:-1]), dtype=np.longlong)
print(dense_fields, offsets)

'''用编码后的做emb'''
tensor = torch.tensor(ddd.values)
emb_tensor = tensor[:, -2:] + tensor.new_tensor(offsets).unsqueeze(0)
emb_tensor = emb_tensor.long()

embedding = nn.Embedding(sum(dense_fields) + 1, embedding_dim=4)
torch.nn.init.xavier_uniform_(embedding.weight.data)
dense_emb = embedding(emb_tensor)

print('---', dense_emb.shape)
print(dense_emb.data)

# print(embedding.weight.shape)
# print(embedding.weight.data)
# print(embedding.weight.data[1])

'''取出原来的数值特征并增加维度用于相乘'''
dense_tensor = torch.unsqueeze(tensor[:, :2], dim=-1)
print('---', dense_tensor.shape)
print(dense_tensor)

dense_emb = dense_emb * dense_tensor
print(dense_emb)
      x1   x2  x1_enc  x2_enc
0  0.001  0.1       0       0
1  0.002  0.1       1       0
2  0.003  0.2       2       1

[3 2] [0 3]

--- torch.Size([3, 2, 4])
tensor([[[-0.1498, -0.5054,  0.0211, -0.2746],
         [ 0.0133,  0.3257, -0.2117, -0.0956]],

        [[-0.1296, -0.4524,  0.5334,  0.0894],
         [ 0.0133,  0.3257, -0.2117, -0.0956]],

        [[ 0.5597,  0.3630, -0.7686, -0.1408],
         [ 0.6840, -0.5328,  0.0422, -0.6365]]])
         
--- torch.Size([3, 2, 1])
tensor([[[0.0010],
         [0.1000]],

        [[0.0020],
         [0.1000]],

        [[0.0030],
         [0.2000]]], dtype=torch.float64)
         
tensor([[[-1.4985e-04, -5.0542e-04,  2.1051e-05, -2.7457e-04],
         [ 1.3284e-03,  3.2572e-02, -2.1174e-02, -9.5578e-03]],

        [[-2.5924e-04, -9.0472e-04,  1.0668e-03,  1.7884e-04],
         [ 1.3284e-03,  3.2572e-02, -2.1174e-02, -9.5578e-03]],

        [[ 1.6790e-03,  1.0891e-03, -2.3059e-03, -4.2229e-04],
         [ 1.3679e-01, -1.0656e-01,  8.4448e-03, -1.2731e-01]]],
       dtype=torch.float64, grad_fn=<MulBackward0>)

reference:
https://www.zhihu.com/question/352399723/answer/869939360

有关offsets可以看:
https://blog.csdn.net/qq_42363032/article/details/125928623?spm=1001.2014.3001.5501

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/125999994