numpy中axis的通俗理解(个人学习见解)

import numpy as np
a = np.arange(12).reshape(2,3,2)
print a
# 个人理解,a将1~11分成两个数组,分别是三行两列(0~5一组,另一组6~11)
# output
[[[ 0  1]
  [ 2  3]
  [ 4  5]]

 [[ 6  7]
  [ 8  9]
  [10 11]]]

然而,numpy中的axis=0可以理解为最外层括号,每向内增加一层括号,axis加1;

# 这里sum(axis=0)其实就是第一个括号里面的两个数组相加
b = a.sum(axis=0)
# output
[[ 6, 8], 
[10, 12], 
[14, 16]]

# 这里sum(axis=1)其实就是第二个括号内的数相加
c = a.sum(axis=1)
# output
[[ 6, 9],
 [24, 27]]

# 这里sum(axis=2)其实就是第三个括号内的数相加
d = a.sum(axis=2)
# output
[[ 1, 5, 9], 
 [13, 17, 21]]

不多说,直接上图:

猜你喜欢

转载自blog.csdn.net/CDaron/article/details/111340630