Why is NumPy so fast

NumPy is the basic package for scientific computing with Python and is the foundation of Python's data science foundation. It has the following six characteristics:

  1. Powerful N-dimensional arrays. NumPy's vectorization, indexing, and broadcasting concepts are fast and general, and are the de facto standard for array computing today.
  2. Numerical calculation tool. NumPy provides a comprehensive set of mathematical functions, random number generators, linear algebra routines, Fourier transforms, and more.
  3. Interoperable. NumPy supports a wide range of hardware and computing platforms, and works well with distributed, GPU, and sparse array libraries.
  4. high performance. At its core, NumPy is well-optimized C code. Enjoy the flexibility of Python and the speed of compiled code.
  5. Easy to use. NumPy's high-level syntax makes it easily accessible and productive for programmers of any background or experience level.
  6. open source. Distributed under the liberal BSD license, NumPy is openly developed and maintained on GitHub by a vibrant, responsive, and diverse community.

The above is from NumPy’s official website. Personally, I think NumPy has the following two characteristics. Let’s understand it together:

  1. high performance
  2. API is simple and easy to use

High-performance NumPy

High performance is the primary requirement of scientific computing, a large amount of data, a large number of loops, fast word first, after all, no one wants to run a model for several days. Let's look at a simple example first: use Python's list comprehension to output an array of length 1000, each element is the square of a natural number, the code is as follows:

L = range(1000)
%timeit [i**2 for i in L]
复制代码

Here's how using NumPy:

a = np.arange(1000)
%timeit a**2
复制代码

The time consumption of the two in the test results is as follows:

cycle test time consuming
list comprehension loop 437 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
numpy loop 1.97 µs ± 19.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

You can see the test results that NumPy is hundreds of times faster than the list comprehension.

Why is NumPy so fast? The first reason is that the core of NumPy is well-optimized C code with the speed of compiled code. The second is the data structure design and algorithm of NumPy, which is not introduced in many articles. This is what I want to focus on.

NumPy提供的最重要的数据结构是一个称为NumPy数组的强大对象,它有下面两个特点:

  1. 数组长度固定
  2. 仅支持同类型数据元素

我们知道Python的list,是动态的并且可以存放任意类型的元素, 比如:

>>> a = ["a", 9, 9.0, True]
>>> a.append("b")
>>> a
['a', 9, 9.0, True, 'b']
复制代码

a数组包含了4个元素,数据类型各不相同,我们还可以使用append方法往a中添加元素。实际数组有长度(length)和容量(capacity)两个概念,我这里借用一下go-slice的图示意一下:

容量和长度

  • 长度表示当前数组内元素个数
  • 容量表示当前数组最多可以存储多少个元素,超过了则需要重新申请内存区域
  • 一般扩容申请都会翻倍。比如上图是6个方格,已经使用了4个,再添加3个,这时候会直接再申请6个,而不是3个。

所以我们可以设想一下,使用列表推导式的时候,经过多次的内存申请,效率就低了下来。而NumPy数组是长度固定的,一次申请到位,自然效率会高不少。

如果大家做过协议处理,一定理解定长和不定长协议。定长协议中,每个协议长度相同,计算起来非常快捷,直接当前位置+固定长度就可以获取下一个协议位置;而不定长协议,还需要解析当前协议长度,判断当前协议的长度,才可以得知下一个协议的位置。

我们把数组在内存中的存储相信成协议的字节流,这就一致了。NumPy仅支持同类型的数据元素,就是定长协议的解析,效率很高。

# 定长
+---+---+---+---+---+---+
|   |   |   |   |   |   |
+---+---+---+---+---+---+

# 不定长
+--+----+------+--+-----+
|  |    |      |  |     |
+--+----+------+--+-----+
复制代码
  • 字符图不好理解的话,大家可以把定长元素想象为高铁车厢,不定长元素想象成汽车,那么春运的铁路运输和公路运输效率就一目了然

需要注意,NumPy中数组也可以放不同 Python类型 元素,但是它们都会(长度)向上对齐到 NumPy数据类型 ,下面的U32就是NumPy的数据类型:

np.array(["a", 9, 9.0])
array(['a', '9', '9.0'], dtype='<U32')
复制代码

NumPy支持矩阵运算,这也是NumPy高性能原因所在。

我们先复习一下矩阵的哈达玛积(Hadamard product),使用符号A⊙B表示:

  • 两个矩阵对应位置的元素逐一相乘

那么a**2的运算使用矩阵的方式就是这样:

+-----------+     +-----------+     +--------------+
|1,2,3,...,n|  ⊙  |1,2,3,...,n|  =  |1,4,9,...,N*N |
+-----------+     +-----------+     +--------------+
复制代码
  • 这里二维变成一维,计算法则是一样的

使用矩阵后就可以进行并发处理了,这和大数据中的map-reduce模型类似。我们可以这样理解它,普通的列表推导式:

L = range(1000)
for i in L:
    i**2        
复制代码

需要经历1000次循环,并且只能够在CPU的单核上逐次执行。根据矩阵运算公式,对应A_i位置的元素,只需要和B_i位置的元素相乘,和其它的999个数都无关,那么我们可以将整个大的运算拆分成1000个小运算,分到多个CPU核上并发执行,计算出每个位置元素后,再汇总即可。

所以NumPy采用良好的数据结构+高效的算法,性能自然上去了。

简单易用的NumPy

如果仅关注性能那么Fortran, matlab, R语言也足够了,或者直接使用C语言。Python语言足够简洁和灵活,使用它包装的API简单易用,又使NumPy的开发效率带来很大提升。我们可以通过下面几个小的例子来了解NumPy的这个特点。

首先是NumPy的数组切片非常强大,如图:

  • 红色切片取第0行的,第3-第4个元素(左闭右开)
  • 绿色切片取第4行和第4列后面的元素
  • 蓝色切片取第3列
  • 紫色切片按照2的步进取元素

然后是NumPy支持数学运算:

>>> import numpy as np
>>> a = np.arange(5)
>>> a
array([0, 1, 2, 3, 4])
>>> np.sin(a)
array([ 0.        ,  0.84147098,  0.90929743,  0.14112001, -0.7568025 ])
复制代码

同样的计算,我们在python中大概这样实现:

import math
>>> import math
>>> items = [0, 1, 2, 3, 4]
>>> list(map(lambda x: math.sin(x), items))
[0.0, 0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282]
复制代码

对比可见,NumPy包装的API更便捷。

最后NumPy还提供了很多统计函数, 比如:

>>> x = np.array([1, 2, 3, 4])
>>> np.sum(x)
10
>>> x.sum()
10
>>> x.max()
4
>>> x.min()
1
>>> x.mean()
2.5
复制代码
  • np.sum(x)x.sum() 是API的两种写法,我们更常使用后面的方法

更多的API使用,可以阅读参考链接中的用户指南和参考手册

小结

NumPy由于高性能和简单易用,是Python进行科学计算的基石。本文从NumPy的数据结构和算法实现上,探讨了其高性能的原理,并简单介绍了部分API,希望能吸引你学习它的兴趣。

参考链接

Guess you like

Origin juejin.im/post/7188897810287165495