2019/2/28:
torch.Tensor
是默认的tensor类型(torch.FloatTensor
)的简称。也就是float、float32。
pytorch的权重weight,若未手动修改,默认格式一定是float32,不会随输入的格式而自适应更改。
问题在于,在pytorch中,要进行wx+b,w与x的数据格式要一致!!!
最后,如果要进行反向传播更新参数,那么权重weight这类参数的数据格式一定不能是整型int!!!
2019/3/1:
torch.nn.CrossEntropyLoss()用于N分类,要求:
网络最后的全连接层输出N个特征图,且不需要加激活层。
标签 target 应为 LongTensor类型(BCELoss要求为DoubleTensor)。
如果使用 torch.utils.data.TensorDataset 与 torch.utils.data.DataLoader,那么 0<= label[index]<=N-1。标签为1D,不需要one-hot。