RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2'

解决方法是将用到的数据进行类型转换:

原来:

prediction, h_state = rnn(x, h_state)

加上类型转换:

x = torch.tensor(x, dtype=torch.float32)
prediction, h_state = rnn(x, h_state)
发布了41 篇原创文章 · 获赞 13 · 访问量 6692

猜你喜欢

转载自blog.csdn.net/comli_cn/article/details/104609123
今日推荐