【debug】torch weight norm device

m = weight_norm(nn.Linear(20, 40), name='weight').cuda()
print(m.weight.device) # on cpu
inputs # a tensor on cuda
outputs = m(inputs)

不会报错,按常识模型和数据要放在同一个设备上才行,其实是weight_norm运算的时候用的不是m.weight,而是m.weight_g和m.weight_v.

print(m.weight_g.device) # on cuda
print(m.weight_v.device) # on cuda

猜你喜欢

转载自blog.csdn.net/weixin_42262721/article/details/128891493
今日推荐