pytorch-张量-张量的计算-统计相关的计算

3统计相关的计算

p y t o r c h pytorch pytorch中 包含了一些基础的统计计算的功能,可以很方便的获取张量中的均值,标准差,最大值,最小值及位置等。

import torch

# 一维张量的最大值和最小值
a = torch.tensor([12., 34, 25, 11, 67, 32, 29, 30, 99, 55, 23, 44])
print("最大值:", a.max())
print("最大值位置:", a.argmax())
print("最小值", a.min())
print("最小值位置", a.argmin())

b = a.reshape(3, 4)
print("2-D张量:", b)
# 最大值及位置(每行)
print("最大值:", b.max(dim=1))
print("最大值位置:", b.argmax(dim=1))
# 最小值及位置(每列)
print("最小值", b.min(dim=0))
print("最小值位置", b.argmin(dim=0))

# torch.sort()可以对一维张量进行排序,或者对高维张量在指定的维度进行排序,输出排序结果,还会输出对应的值在原始位置的索引
print(torch.sort(a))
# 按照降序排序
print(torch.sort(a, descending=True))
# 对二维张量进行排序
bsort, bsort_id = torch.sort(b)
print("b sort", bsort)
print("b sort_id", bsort_id)
print("b argsort:", torch.argsort(b))

# torch.topk()根据指定的k值,计算出张量中取值大小为第k大的数值与数值所在的位置
# torch.kthvalue()根据指定的数值k,计算出张量中取值大小为第k小的数值与数值所在的位置

# 获取张量前几大的数值
print(torch.topk(a, 4))
# 获取2-D张量每列前几大的数值
btop2, btop2_id = torch.topk(b, 2, dim=0)
print("b 每列 top2", btop2)
print("b 每列 top2 位置", btop2_id)

# 获取张量第k小的数值和位置
print(torch.kthvalue(a, 3))

# 获取2-D张量第k小的数值和位置
print(torch.kthvalue(b, 3, dim=1))
# 获取2-D张量第k小的数值和位置
bkth, bkth_id = torch.kthvalue(b, 3, dim=1, keepdim=True)
print(bkth)

# torch.mean()根据指定的维度计算均值
# 计算每行的平均值
print(torch.mean(b, dim=1, keepdim=True))
# 计算每列的均值
print(torch.mean(b, dim=0, keepdim=True))
# torch.sum()根据指定的维度求和
# 计算每行的和
print(torch.sum(b, dim=1, keepdim=True))
# 计算每列的和
print(torch.sum(b, dim=0, keepdim=True))
# torch.cumsum()根据指定的维度计算累加和
# 按照行计算累加和
print(torch.cumsum(b, dim=1))
# 按照列计算累加和
print(torch.cumsum(b, dim=0))
# torch.median()根据指定的维度计算中位数
# 计算每行的中位数
print(torch.median(b, dim=1, keepdim=True))
# 计算每列的中位数
print(torch.median(b, dim=0, keepdim=True))
# 按照行计算乘积
print(torch.prod(b, dim=1, keepdim=True))
# 按照列计算乘积
print(torch.prod(b, dim=1, keepdim=True))
# 按照行计算累乘积
print(torch.cumprod(b, dim=1))
# 按照列计算累乘积
print(torch.cumprod(b, dim=0))
# torch.std()计算张量的标准差
print(torch.std(a))


猜你喜欢

转载自blog.csdn.net/weixin_45955630/article/details/111668304