When writing transformer code, I saw the function torch.repeat(). Regarding its input and output specifications, I thought of a memorable method. The following three pieces of code can help understand.
a = torch.arange(512)
b = a.repeat(1,32)
print(b.shape)
# b:(1,32*512)
a = torch.ones(32,100)
b = a.repeat(1,2,3)
# b:(1,2*32,3*100)
a = torch.ones(32,100)
b = a.repeat(10)
# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
Then in the step of transformer definition position encoding, the code:
positions = torch.arange(inputs.size(1)).repeat(inputs.size(0), 1) + 1
You can get the specification of position as: (batch_size (that is, inputs.size(0)) * seq_len (that is, inputs.size(1))), which is expressed as: for each sample (sentence), there are seq_len words , Which is seq_len positions. Each position has its position code (d_model dimension) (you must write a blog when you fully master the transformer)