链接
主要作用:按照dim和index的指示,替换src中的某些元素
接口
scatter_(dim,index,src)
- dim指定替换维度
- index.ndim=src.ndim,但是一般index.shape=(N,1),src.shape=(N,M)
实现onehot
index = torch.randint(0,10,(100,1))
buf = torch.zeros((100,10))
buf.scatter_(dim=1,index=index,value=1)
链接
主要作用:按照dim和index的指示,替换src中的某些元素
scatter_(dim,index,src)
实现onehot
index = torch.randint(0,10,(100,1))
buf = torch.zeros((100,10))
buf.scatter_(dim=1,index=index,value=1)