pytorch tensor.scatter_()

链接
主要作用:按照dim和index的指示,替换src中的某些元素

接口

scatter_(dim,index,src)

  • dim指定替换维度
  • index.ndim=src.ndim,但是一般index.shape=(N,1),src.shape=(N,M)

实现onehot

index = torch.randint(0,10,(100,1))
buf = torch.zeros((100,10))
buf.scatter_(dim=1,index=index,value=1)

猜你喜欢

转载自blog.csdn.net/u010590593/article/details/121422881