Pytorch学习系列之一 : Numpy基础操作

import numpy as np
import torch

定义numpy数组

a = np.array([1, 2, 3, 4, 5, 6])
b = np.array([8, 7, 6, 5, 4, 3])
print(a.shape, b.shape)
print(a)

reshape改变维度

aa = np.reshape(a, (3, 2))
bb = np.reshape(b, (1, 1, 1, 6))
print(aa, aa.shape)
print(bb, bb.shape)

压缩成一维度

b1 = np.squeeze(bb)
print(b1, b1.shape)

argmax用法

index = np.argmax(bb)
print(“find max:”,index, bb, bb[0][0][0][index])

ff = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])

print(np.argmax(ff, axis=0))#每一列最大索引

print(np.argmax(ff, axis=1))#每一行最大索引

维度倒置

aaa = aa.transpose((1, 0))
print(aaa, aaa.shape)
a2 = np.reshape(aa, -1)
print(a2, a2.shape)

m1 = np.zeros((6, 6), dtype=np.uint8)

生成等差数列

m2 = np.linspace(6, 10, 100)
print(m1, m2)

猜你喜欢

转载自blog.csdn.net/thequitesunshine007/article/details/118250979