在pytorch中取一个tensor的均值,然后该张量中的所有值与其对比!

Pytorch中的Tensor的shape是(B, C, W, H),对该tensor取均值并与所有值做对比代码如下:

C, H, W = tensor.shape[1], tensor.shape[2], tensor.shape[3]
for c in range(C):
	mean = torch.mean(x[0][c])
	for h in range(H):
		for w in range(W):
		if x[0][c][h][w] >= mean:
		x[0][c][h][w] = mean

猜你喜欢

转载自blog.csdn.net/qq_37760750/article/details/107413057
今日推荐