Introductory Guide to Deep Learning in 2023 (16) - JAX and TPU Acceleration

Introductory Guide to Deep Learning in 2023 (16) - JAX and TPU Acceleration

In the previous section, we introduced the principle of human-instructed reinforcement learning, one of the core algorithms of ChatGPT. I know that everyone didn't understand it, because it requires a lot of knowledge reserves. But it doesn't matter, the large model cannot be trained in a day, and it is impossible to align it in a day. We have plenty of time to lay down the basics first.

The reason why the intensive learning part of the previous section was not discussed is that we are worried that everyone will forget all the knowledge of mathematics, and they have not learned programming in mathematics class. In this section, we will introduce two basic tools, one can be said to be the NumPy library that every Python deep learning framework must not bypass, and the other is the NumPy library JAX developed by Google, which can be considered as the GPU and TPU version.

The purpose of learning these two frameworks is to make up for math lessons, especially math programming. This is also the first time TPU has appeared in our tutorial section. Of course, the GPU can also be used.

matrix

The core function of NumPy is the support of multidimensional matrices.

We can install NumPy through the method, and then introduce the NumPy library pip install numpythrough the method in Python . However, NumPy cannot support GPU and TPU acceleration, which is not very practical for the calculations we will deal with in the future, so we introduce the JAX library here.import numpy as np

For JAX installation documentation, see JAX official documentation

We have used CUDA for GPU acceleration many times before, here we might as well take a look at the acceleration effect of TPU.
Only Google has TPU, and we can only buy TPU cloud services, but we can use Google Colab to use TPU.

On Colab, the JAX and TPU runtimes are already installed. We can activate the TPU by running the following code:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

Let's see how many TPU devices are available:

print(jax.device_count())
print(jax.local_device_count())
print(jax.devices())

The output is as follows:

8
8
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Explain that we have 8 TPU devices available.

Below we will use jax.numpy instead of numpy.

The most important feature of NumPy is the support for multidimensional matrices. We can np.arraycreate a multidimensional matrix by

Let's start with a one-dimensional vector:

import jax.numpy as jnp
a1 = jnp.array([1,2,3])
print(a1)

We can then use the 2D array to create a matrix:

a2 = jnp.array([[1,2],[0,4]])
print(a2)

The matrix can be assigned an initial value uniformly. The zeros function creates a matrix of all 0s, the ones function creates a matrix of all 1s, and the full function creates a matrix of all values.

For example, to assign 0 values ​​to a matrix with 10 rows and 10 columns, we can write:

a3 = jnp.zeros((10,10))
print(a3)

A matrix of all 1s:

a4 = jnp.ones((10,10))

Full endowment 100:

a5 = jnp.full((10,10),100)

In addition, we can also generate a sequence through the linspace function. The first parameter of the linpsace function is the start value of the sequence, the second parameter is the end value of the sequence, and the third parameter is the length of the sequence. For example, we can generate a sequence from 1 to 100 with a length of 100:

a7 = jnp.linspace(1,100,100) # 从1到100,生成100个数
a7.reshape(10,10)
print(a7)

Finally, the way JAX generates random values ​​for matrices is different from NumPy, and there is no such package as jnp.random. We can use jax.random to generate random values. JAX's random number generation functions all require an explicit random state as the first parameter. This state consists of two unsigned 32-bit integers, called a key. Using a key doesn't modify it, so reusing the same key will give the same result. If you need a new random number, you can use jax.random.split() to generate a new subkey.

from jax import random
key = random.PRNGKey(0) # a random key
key, subkey = random.split(key) # split a key into two subkeys
a8 = random.uniform(subkey,shape=(10,10)) # a random number using subkey
print(a8)

Norm

Norm (Norm) is a mathematical concept used to measure the "size" of a vector in a vector space. The norm needs to satisfy the following properties:

  • Non-negativity: All vectors have a norm greater than or equal to zero, except for zero vectors.
  • Homogeneity: For any real number λ and any vector v, there is ||λv|| = |λ| ||v||.
  • Triangle inequality: For any vector u and v, there is ||u + v|| ≤ ||u|| + ||v||.
    In practical applications, norms are usually used to measure the size of vectors or matrices. For example, in machine learning, norms are often used for the calculation of regularization terms.

Common norms are:

  • L0 norm: the number of non-zero elements in the vector.
  • L1 norm: The sum of the absolute values ​​of each element in the vector, also known as the Manhattan distance.
  • L2 norm: the square sum of each element in the vector and then the square root, also known as the Euclidean distance.
  • Infinity norm: the maximum value of the absolute value of each element in the vector.
    It should be noted that the L0 norm is not strictly a norm because it violates homogeneity. But in machine learning, the L0 norm is often used to measure the number of non-zero elements in a vector, so it is also called "pseudo-norm".

Let's start by calculating the L1 norm of a one-dimensional vector. Don't be intimidated by the name of the L1 norm, it is actually the sum of absolute values:

norm10_1 = jnp.linalg.norm(a10,ord=1)
print(norm10_1)

As expected, the result is 6.

Next, let's look at the L2 norm, which is the Euclidean distance, that is, the square and the square root:

a10 = jnp.array([1, 2, 3])
norm10 = jnp.linalg.norm(a10)
print(norm10)

According to the definition of L2 norm, we can calculate it manually: norm10 = jnp.sort(1 + 2 2 + 3 3) = 3.7416573.

We can see that the value of norm10 above is the same as our manual calculation.

Let's calculate the infinite norm, which is actually the maximum value:

norm10_inf = jnp.linalg.norm(a10, ord = jnp.inf)
print(norm10_inf)

The result is 3.

Let's do a big consolidation:

a10 = jnp.linspace(1,100,100) # 从1到100,生成100个数
n10 = jnp.linalg.norm(a10,ord=2)
print(n10)

This result is 581.67865.

inverse matrix

A square matrix with 1s on the diagonal and all 0s on the other is called an identity matrix. In NumPy and JAX, we use the eye function to generate the identity matrix.

Since it is a square matrix, there is no need for two values ​​​​for rows and columns, and only one value is required. This value is the number of rows and columns of the matrix. By assigning this value to the first parameter of the eye function, an identity matrix can be generated.

Next, let's review how matrix multiplication is calculated.

For each row of matrix A, we need to multiply with each column of matrix B. "Multiply" here means taking a row of A and a column of B, multiplying their corresponding elements, and adding those products. This sum is the element at the corresponding position in the resulting matrix.

As an example, suppose we have two 2x2 matrices A and B:

A = 1 2     B = 4 5
    3 4         6 7

We can compute the product of matrix A and matrix B like this:

(1*4 + 2*6) (1*5 + 2*7)     16 19
(3*4 + 4*6) (3*5 + 4*7) =  34 43

Let's use JAX to calculate:

ma1 = jnp.array([[1,2],[3,4]])
ma2 = jnp.array([[4,5],[6,7]])
ma3 = jnp.dot(ma1,ma2)
print(ma1)
print(ma2)
print(ma3)

The output is:

[[1 2]
 [3 4]]
[[4 5]
 [6 7]]
[[16 19]
 [36 43]]

If A*B=I, I is the identity matrix, then we call B the inverse matrix of A.

We can use the inv function to calculate the inverse of a matrix.

ma1 = jnp.array([[1,2],[3,4]])
inv1 = jnp.linalg.inv(ma1)
print(inv1)

The output result is:

[[-2.0000002   1.0000001 ]
 [ 1.5000001  -0.50000006]]

Derivatives and Gradients

A derivative is the rate of change of a function at a point and is used to describe the rate of change of the function at that point. The derivative can represent the slope of the function at that point, that is, how steep the function is at that point.

Gradient is a vector indicating that the directional derivative of the function at that point takes its maximum value along that direction. The gradient can represent the direction in which the function changes fastest and with the largest rate of change at that point. In a univariate real-valued function, the gradient can be simply understood as a derivative.

As a framework that supports deep learning, JAX's support for gradients is given priority. We can use the jax.grad function to calculate the gradient. For a function of one variable, the gradient is the derivative. We can use the following code to calculate the gradient of the sin function at x=1.0:

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sin(x)

# 计算 f 在 x=1.0 处的梯度
grad_f = jax.grad(f)
print(grad_f(1.0))

If we go in the direction of the gradient every time, then we can find the extremum of the function. This method of advancing in the gradient direction is the gradient descent method. The gradient descent method is a commonly used optimization algorithm. Its core idea is: if the gradient value of a function at a certain point is positive, then the function will decrease the fastest along the gradient direction at that point; if a function at a certain point The gradient value of is negative, then the function rises fastest along the gradient direction at this point. Therefore, we can find the extremum of the function by continuously advancing in the direction of the gradient.

So, what does gradient descent do? We can use gradient descent to find the minimum value of a function. We can use the following code to solve the function f ( x ) = x 2 f(x)=x^2f(x)=xMinimum value of 2 :

import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

grad_f = jax.grad(f)

x = 2.0  # 初始点
learning_rate = 0.1  # 学习率
num_steps = 100  # 迭代步数

for i in range(num_steps):
    grad = grad_f(x)  # 计算梯度
    x = x - learning_rate * grad  # 按负梯度方向更新 x

print(x)  # 打印最终的 x 值,应接近 0(函数的最小值)

The result of my run this time is 4.0740736e-10. In other words, we use the gradient descent method to solve the function f ( x ) = x 2 f(x)=x^2f(x)=xThe minimum value of 2 , the final x value is close to 0, which is the minimum value of the function.

Among them, the learning rate (or step size) is a positive number used to control the magnitude of each step update. The learning rate needs to be carefully selected. If it is too large, the algorithm may not converge, and if it is too small, the convergence speed may be too slow.

probability

After awakening some memories of linear algebra and advanced mathematics, let's finally review the theory of probability.

Let's start with tossing a coin. We know that assuming a coin is even, the number of heads will be close to half of the total number of tosses if enough times are tossed.

This kind of random experiment with only two possible outcomes, we give it a tall name called Bernoulli trial (Bernoulli trial).

Next, we will use JAX's Bernoulli distribution to simulate the process of tossing a coin.

import jax
import time
from jax import random

# 生成一个形状为 (10, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(10, ))
print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)

The mean function is used to calculate the average, also known as mathematical expectation.

The printed result may be 0.5, or 0.3, 0.8 and so on. This is because we only tossed the coin 10 times, which is so infrequently that the number of heads that landed was not necessarily close to half of the total.

This is the result of one of the 0.6:

[ True  True  True  True False False  True False False  True]
0.6

After running several times, it is not uncommon for 0.1 and 0.9 to appear:

[False False False False False False False False False  True]
0.1

When we change the shape to a larger number such as 100, 1000, 10000, the result is getting closer and closer to 0.5.

Let's review the two values ​​that represent the deviation:

  • Variance: Variance is a way of measuring how far a data point deviates from the mean. In other words, it describes the square of the average distance between the data points and the mean.
  • Standard Deviation: The standard deviation is the square root of the variance. Because the variance is squared on the basis of the mean deviation, its dimension (unit) is different from the original data. To solve this problem, we introduce the concept of standard deviation. The standard deviation has the same dimension as the original data, which is easier to interpret.

Both of these statistics reflect the degree of dispersion of the data distribution. The larger the variance and standard deviation, the more scattered the data points; conversely, the smaller the variance and standard deviation, the more concentrated the data points.

We can use JAX's var function to calculate the variance and the std function to calculate the standard deviation.

import jax
import time
from jax import random

# 生成一个形状为 (1000, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(1000, ))
#print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)
var_x = jnp.var(rand_matrix)
print(var_x)
std_x = jnp.std(rand_matrix)
print(std_x)

Finally, let's review the amount of information we talked about earlier. Let's think about a question, how to maximize the average amount of information in the Bernoulli distribution?

We first construct two special cases. For example, if p=0, then we will never get a head-up result. At this time, we know the result and the amount of information is 0. If p=1, then we will never get the result of tails up. At this time, we also know the result, and the amount of information is also 0.

If p=0.01, the average amount of information that can be brought to us is still not large, because basically we can blindly guess that the result is tails up, and the occasional fronts up result, although it brings a larger single The amount of information, but the probability of occurrence is too low, so the average amount of information is still not large.

And if p=0.5, we can't guess whether the result is head-up or back-up. At this time, the average amount of information we get is the largest.

Of course, this is only a qualitative analysis, we also need to give a quantitative formula:

H ( X ) = − ∑ x ∈ X p ( x ) log ⁡ 2 p ( x ) H(X) = - \sum_{x \in X} p(x) \log_2 p(x) H(X)=xXp(x)log2p(x)

import jax.numpy as jnp

# 计算离散型随机变量 X 的平均信息量
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.3, 0.7])
avg_info = avg_information(p)
print(avg_info)

We try several calculations and we can get that when p is 0.3, the average amount of information is 0.8812325; when p is 0.01, the average amount of information is 0.08079329; when p is 0.5, the average amount of information is 1.0, reaching the maximum.

If the calculation using the Python function is too slow, we can call the jit function of JAX to speed it up. We just need to add @jit in front of the function definition.

import jax.numpy as jnp
from jax import jit

# 计算离散型随机变量 X 的平均信息量
@jit
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.01, 0.99])
avg_info = avg_information(p)
print(avg_info)

summary

Above we have selected some knowledge points of linear algebra, advanced mathematics and probability theory to awaken everyone's memory. At the same time, we also introduced their implementation and acceleration on JAX.
Although our examples are unremarkable, they actually run on the TPU.

Although the large model provides a strong ability, we still need to spend enough time on the basic skills. Both the hardware and the framework are in the crescent part, but the evolution of the basic knowledge of mathematics is very slow, and the input-output ratio is very high. After having solid basic skills, the framework and new hardware can be learned while using.

Guess you like

Origin blog.csdn.net/lusing/article/details/131113520