torch.repeat()

When writing transformer code, I saw the function torch.repeat(). Regarding its input and output specifications, I thought of a memorable method. The following three pieces of code can help understand.

a = torch.arange(512)
b = a.repeat(1,32)
print(b.shape)
# b:(1,32*512)
a = torch.ones(32,100)
b = a.repeat(1,2,3)
# b:(1,2*32,3*100)
a = torch.ones(32,100)
b = a.repeat(10)
# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

Then in the step of transformer definition position encoding, the code:

positions = torch.arange(inputs.size(1)).repeat(inputs.size(0), 1) + 1

You can get the specification of position as: (batch_size (that is, inputs.size(0)) * seq_len (that is, inputs.size(1))), which is expressed as: for each sample (sentence), there are seq_len words , Which is seq_len positions. Each position has its position code (d_model dimension) (you must write a blog when you fully master the transformer)

Guess you like

Origin blog.csdn.net/jokerxsy/article/details/106736026