torch.max(action_value, 1)[1].data.numpy()[0] 是什么意思

torch.max(action_value, 1)表示取action_value里每行的最大值

torch.max(action_value, 1)[1]表示最大值对应的下标

.data.numpy()[0]表示将将Variable转换成tensor

action_value = self.eval_net.forward(x)


action = torch.max(action_value, 1)[1].data.numpy()[0]


print("<choose_action> action_value=", action_value, "torch.max(action_value, 1)=",torch.max(action_value, 1),"torch.max(action_value, 1)[1]=",torch.max(action_value, 1)[1], "action=", action)


 <choose_action> action_value= tensor([[-0.2394, -0.3109, -0.3330, -0.0376]], grad_fn=<AddmmBackward0>) torch.max(action_value, 1)= torch.return_types.max(
values=tensor([-0.0376], grad_fn=<MaxBackward0>),
indices=tensor([3])) torch.max(action_value, 1)[1]= tensor([3]) action= 3

参考:torch.max() - 知乎a0 = torch.max(a, dim) 其中a为一个tensor dim的值为 0/1,分别代表索引每列/行最大值返回的值包含两个数据(values, indices) 分别代表最大值的值和所在的索引 一般我们只需要里面的索引,而对最大值的值不感兴…https://zhuanlan.zhihu.com/p/468861622

猜你喜欢

转载自blog.csdn.net/u013288190/article/details/128337069