Guía introductoria al aprendizaje profundo en 2023 (16) - Aceleración de JAX y TPU

Guía introductoria al aprendizaje profundo en 2023 (16) - Aceleración de JAX y TPU

En la sección anterior, presentamos el principio del aprendizaje por refuerzo instruido por humanos, uno de los algoritmos centrales de ChatGPT. Sé que no todos lo entendieron, porque requiere muchas reservas de conocimiento. Pero no importa, el modelo grande no se puede entrenar en un día y es imposible alinearlo en un día. Tenemos mucho tiempo para establecer lo básico primero.

La razón por la que no se discutió la parte de aprendizaje intensivo de la sección anterior es que nos preocupa que todos olviden todos los conocimientos de matemáticas y no hayan aprendido a programar en la clase de matemáticas. En esta sección, presentaremos dos herramientas básicas, se puede decir que una es la biblioteca NumPy que todo marco de aprendizaje profundo de Python no debe pasar por alto, y la otra es la biblioteca NumPy JAX desarrollada por Google, que se puede considerar como la GPU y Versión TPU.

El propósito de aprender estos dos marcos es compensar las lecciones de matemáticas, especialmente la programación matemática. Esta es también la primera vez que TPU aparece en nuestra sección de tutoriales. Por supuesto, también se puede utilizar la GPU.

matriz

La función central de NumPy es el soporte de matrices multidimensionales.

Podemos instalar NumPy a través del método y luego introducir la biblioteca NumPy pip install numpya través del método en Python . Sin embargo, NumPy no puede admitir la aceleración de GPU y TPU, lo que no es muy práctico para los cálculos que trataremos en el futuro, por lo que presentamos aquí la biblioteca JAX.import numpy as np

Para la documentación de instalación de JAX, consulte la documentación oficial de JAX

Hemos usado CUDA para la aceleración de GPU muchas veces antes, aquí también podríamos echar un vistazo al efecto de aceleración de TPU.
Solo Google tiene TPU, y solo podemos comprar servicios en la nube de TPU, pero podemos usar Google Colab para usar TPU.

En Colab, los tiempos de ejecución de JAX y TPU ya están instalados. Podemos activar la TPU ejecutando el siguiente código:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

Veamos cuántos dispositivos de TPU hay disponibles:

print(jax.device_count())
print(jax.local_device_count())
print(jax.devices())

La salida es la siguiente:

8
8
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Explique que tenemos 8 dispositivos TPU disponibles.

A continuación, usaremos jax.numpy en lugar de numpy.

La característica más importante de NumPy es el soporte para matrices multidimensionales. Podemos np.arraycrear una matriz multidimensional por

Comencemos con un vector unidimensional:

import jax.numpy as jnp
a1 = jnp.array([1,2,3])
print(a1)

Luego podemos usar la matriz 2D para crear una matriz:

a2 = jnp.array([[1,2],[0,4]])
print(a2)

A la matriz se le puede asignar un valor inicial uniformemente. La función de ceros crea una matriz de todos los 0, la función de unos crea una matriz de todos los 1 y la función completa crea una matriz de todos los valores.

Por ejemplo, para asignar 0 valores a una matriz de 10 filas y 10 columnas, podemos escribir:

a3 = jnp.zeros((10,10))
print(a3)

Una matriz de todos los 1:

a4 = jnp.ones((10,10))

Dotación completa 100:

a5 = jnp.full((10,10),100)

Además, también podemos generar una secuencia a través de la función linspace. El primer parámetro de la función linpsace es el valor inicial de la secuencia, el segundo parámetro es el valor final de la secuencia y el tercer parámetro es la longitud de la secuencia. Por ejemplo, podemos generar una secuencia del 1 al 100 con una longitud de 100:

a7 = jnp.linspace(1,100,100) # 从1到100,生成100个数
a7.reshape(10,10)
print(a7)

Finalmente, la forma en que JAX genera valores aleatorios para matrices es diferente de NumPy, y no existe un paquete como jnp.random. Podemos usar jax.random para generar valores aleatorios. Todas las funciones de generación de números aleatorios de JAX requieren un estado aleatorio explícito como primer parámetro.Este estado consta de dos enteros de 32 bits sin signo, denominados clave. El uso de una clave no la modifica, por lo que reutilizar la misma clave dará el mismo resultado. Si necesita un nuevo número aleatorio, puede usar jax.random.split() para generar una nueva subclave.

from jax import random
key = random.PRNGKey(0) # a random key
key, subkey = random.split(key) # split a key into two subkeys
a8 = random.uniform(subkey,shape=(10,10)) # a random number using subkey
print(a8)

Norma

Norma (Norma) es un concepto matemático que se utiliza para medir el "tamaño" de un vector en un espacio vectorial. La norma debe satisfacer las siguientes propiedades:

  • No negatividad: Todos los vectores tienen una norma mayor o igual a cero, excepto los vectores cero.
  • Homogeneidad: Para cualquier número real λ y cualquier vector v, existe ||λv|| = |λ| ||v||.
  • Desigualdad triangular: Para cualquier vector u y v, hay ||u + v|| ≤ ||u|| + ||v||.
    En aplicaciones prácticas, las normas se suelen utilizar para medir el tamaño de vectores o matrices. Por ejemplo, en el aprendizaje automático, las normas se utilizan a menudo para el cálculo de los términos de regularización.

Las normas comunes son:

  • Norma L0: el número de elementos distintos de cero en el vector.
  • Norma L1: La suma de los valores absolutos de cada elemento en el vector, también conocida como distancia de Manhattan.
  • Norma L2: la suma cuadrada de cada elemento en el vector y luego la raíz cuadrada, también conocida como distancia euclidiana.
  • Norma infinita: el valor máximo del valor absoluto de cada elemento en el vector.
    Cabe señalar que la norma L0 no es estrictamente una norma porque viola la homogeneidad. Pero en el aprendizaje automático, la norma L0 se usa a menudo para medir la cantidad de elementos distintos de cero en un vector, por lo que también se denomina "pseudo-norma".

Comencemos por calcular la norma L1 de un vector unidimensional, no se deje intimidar por el nombre de la norma L1, en realidad es la suma de valores absolutos:

norm10_1 = jnp.linalg.norm(a10,ord=1)
print(norm10_1)

Como era de esperar, el resultado es 6.

A continuación, veamos la norma L2, que es la distancia euclidiana, es decir, el cuadrado y la raíz cuadrada:

a10 = jnp.array([1, 2, 3])
norm10 = jnp.linalg.norm(a10)
print(norm10)

Según la definición de la norma L2, podemos calcularla manualmente: norm10 = jnp.sort(1 + 2 2 + 3 3) = 3.7416573.

Podemos ver que el valor de la norma 10 anterior es el mismo que nuestro cálculo manual.

Calculemos la norma infinita, que en realidad es el valor máximo:

norm10_inf = jnp.linalg.norm(a10, ord = jnp.inf)
print(norm10_inf)

El resultado es 3.

Hagamos una gran consolidación:

a10 = jnp.linspace(1,100,100) # 从1到100,生成100个数
n10 = jnp.linalg.norm(a10,ord=2)
print(n10)

Este resultado es 581.67865.

matriz inversa

Una matriz cuadrada con 1 en la diagonal y todos 0 en la otra se llama matriz identidad. En NumPy y JAX, usamos la función eye para generar la matriz de identidad.

Como es una matriz cuadrada, no se necesitan dos valores para filas y columnas, y solo se requiere un valor, este valor es el número de filas y columnas de la matriz. Al asignar este valor al primer parámetro de la función del ojo, se puede generar una matriz de identidad.

A continuación, revisemos cómo se calcula la multiplicación de matrices.

Para cada fila de la matriz A, necesitamos multiplicar con cada columna de la matriz B. "Multiplicar" aquí significa tomar una fila de A y una columna de B, multiplicar sus elementos correspondientes y sumar esos productos. Esta suma es el elemento en la posición correspondiente en la matriz resultante.

Como ejemplo, supongamos que tenemos dos matrices A y B de 2x2:

A = 1 2     B = 4 5
    3 4         6 7

Podemos calcular el producto de la matriz A y la matriz B así:

(1*4 + 2*6) (1*5 + 2*7)     16 19
(3*4 + 4*6) (3*5 + 4*7) =  34 43

Usemos JAX para calcular:

ma1 = jnp.array([[1,2],[3,4]])
ma2 = jnp.array([[4,5],[6,7]])
ma3 = jnp.dot(ma1,ma2)
print(ma1)
print(ma2)
print(ma3)

La salida es:

[[1 2]
 [3 4]]
[[4 5]
 [6 7]]
[[16 19]
 [36 43]]

Si A*B=I, I es la matriz identidad, entonces llamamos a B la matriz inversa de A.

Podemos usar la función inv para calcular la inversa de una matriz.

ma1 = jnp.array([[1,2],[3,4]])
inv1 = jnp.linalg.inv(ma1)
print(inv1)

El resultado de salida es:

[[-2.0000002   1.0000001 ]
 [ 1.5000001  -0.50000006]]

Derivadas y Gradientes

Una derivada es la tasa de cambio de una función en un punto y se usa para describir la tasa de cambio de la función en ese punto. La derivada puede representar la pendiente de la función en ese punto, es decir, qué tan inclinada es la función en ese punto.

Gradiente es un vector que indica que la derivada direccional de la función en ese punto toma su valor máximo a lo largo de esa dirección. El gradiente puede representar la dirección en la que la función cambia más rápido y con la mayor tasa de cambio en ese punto. En una función de valor real univariada, el gradiente puede entenderse simplemente como una derivada.

Como marco que admite el aprendizaje profundo, se da prioridad al soporte de JAX para gradientes. Podemos usar la función jax.grad para calcular el gradiente. Para una función de una variable, el gradiente es la derivada. Podemos usar el siguiente código para calcular el gradiente de la función sin en x=1.0:

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sin(x)

# 计算 f 在 x=1.0 处的梯度
grad_f = jax.grad(f)
print(grad_f(1.0))

Si vamos en la dirección del gradiente cada vez, entonces podemos encontrar el extremo de la función. Este método de avance en la dirección del gradiente es el método de descenso del gradiente. El método de descenso del gradiente es un algoritmo de optimización de uso común. Su idea central es: si el valor del gradiente de una función en un cierto punto es positivo, entonces la función disminuirá más rápido a lo largo de la dirección del gradiente en ese punto; si una función en un cierto punto El valor del gradiente de es negativo, entonces la función crece más rápido a lo largo de la dirección del gradiente en este punto. Por lo tanto, podemos encontrar el extremo de la función avanzando continuamente en la dirección del gradiente.

Entonces, ¿qué hace el descenso de gradiente? Podemos usar el descenso de gradiente para encontrar el valor mínimo de una función. Podemos usar el siguiente código para resolver la función f ( x ) = x 2 f(x)=x^2f ( x )=XValor mínimo de 2 :

import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

grad_f = jax.grad(f)

x = 2.0  # 初始点
learning_rate = 0.1  # 学习率
num_steps = 100  # 迭代步数

for i in range(num_steps):
    grad = grad_f(x)  # 计算梯度
    x = x - learning_rate * grad  # 按负梯度方向更新 x

print(x)  # 打印最终的 x 值,应接近 0(函数的最小值)

El resultado de mi ejecución esta vez es 4.0740736e-10 En otras palabras, usamos el método de descenso de gradiente para resolver la función f ( x ) = x 2 f(x)=x^2f ( x )=XEl valor mínimo de 2 , el valor x final está cerca de 0, que es el valor mínimo de la función.

Entre ellos, la tasa de aprendizaje (o tamaño de paso) es un número positivo que se usa para controlar la magnitud de cada actualización de paso. La tasa de aprendizaje debe seleccionarse cuidadosamente. Si es demasiado grande, es posible que el algoritmo no converja, y si es demasiado pequeña, la velocidad de convergencia puede ser demasiado lenta.

probabilidad

Después de despertar algunos recuerdos de álgebra lineal y matemáticas avanzadas, repasemos finalmente la teoría de la probabilidad.

Comencemos con lanzar una moneda. Sabemos que suponiendo que una moneda es par, el número de caras será cerca de la mitad del número total de lanzamientos si se lanzan suficientes veces.

A este tipo de experimento aleatorio con solo dos resultados posibles, le damos un nombre alto llamado ensayo de Bernoulli (ensayo de Bernoulli).

A continuación, utilizaremos la distribución de Bernoulli de JAX para simular el proceso de lanzar una moneda.

import jax
import time
from jax import random

# 生成一个形状为 (10, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(10, ))
print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)

La función media se utiliza para calcular el promedio, también conocida como expectativa matemática.

El resultado impreso puede ser 0,5, 0,3, 0,8, etc. Esto se debe a que solo lanzamos la moneda 10 veces, lo cual es tan poco frecuente que el número de caras que caían no era necesariamente cerca de la mitad del total.

Este es el resultado de uno de los 0.6:

[ True  True  True  True False False  True False False  True]
0.6

Después de ejecutar varias veces, no es raro que aparezcan 0.1 y 0.9:

[False False False False False False False False False  True]
0.1

Cuando cambiamos la forma a un número más grande como 100, 1000, 10000, el resultado se acerca cada vez más a 0,5.

Repasemos los dos valores que representan la desviación:

  • Varianza: la varianza es una forma de medir cuánto se desvía un punto de datos de la media. En otras palabras, describe el cuadrado de la distancia promedio entre los puntos de datos y la media.
  • Desviación estándar: La desviación estándar es la raíz cuadrada de la varianza. Debido a que la varianza se eleva al cuadrado sobre la base de la desviación media, su dimensión (unidad) es diferente de los datos originales. Para resolver este problema, introducimos el concepto de desviación estándar. La desviación estándar tiene la misma dimensión que los datos originales, que es más fácil de interpretar.

Ambas estadísticas reflejan el grado de dispersión de la distribución de datos. Cuanto mayor sea la varianza y la desviación estándar, más dispersos serán los puntos de datos; por el contrario, cuanto menor sea la varianza y la desviación estándar, más concentrados serán los puntos de datos.

Podemos usar la función var de JAX para calcular la varianza y la función std para calcular la desviación estándar.

import jax
import time
from jax import random

# 生成一个形状为 (1000, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(1000, ))
#print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)
var_x = jnp.var(rand_matrix)
print(var_x)
std_x = jnp.std(rand_matrix)
print(std_x)

Finalmente, repasemos la cantidad de información de la que hablamos anteriormente. Pensemos en una pregunta, ¿cómo maximizar la cantidad promedio de información en la distribución de Bernoulli?

Primero construimos dos casos especiales. Por ejemplo, si p=0, entonces nunca obtendremos un resultado de cabeza arriba. En este momento, sabemos el resultado y la cantidad de información es 0. Si p=1, entonces nunca obtendremos el resultado de cruz, en este momento, también sabemos el resultado, y la cantidad de información también es 0.

Si p = 0.01, la cantidad promedio de información que se nos puede traer todavía no es grande, porque básicamente podemos adivinar ciegamente que el resultado es cruz, y el resultado de frente ocasional, aunque trae una sola más grande La cantidad de información, pero la probabilidad de ocurrencia es demasiado baja, por lo que la cantidad promedio de información aún no es grande.

Y si p = 0.5, no podemos adivinar si el resultado es cabeza arriba o respaldo En este momento, la cantidad promedio de información que obtenemos es la más grande.

Por supuesto, esto es solo un análisis cualitativo, también necesitamos dar una fórmula cuantitativa:

H ( X ) = − ∑ x ∈ X pags ( x ) Iniciar sesión ⁡ 2 pags ( x ) H(X) = - \sum_{x \in X} p(x) \log_2 p(x)H ( X )=X Xpag ( x )iniciar sesión2pag ( x )

import jax.numpy as jnp

# 计算离散型随机变量 X 的平均信息量
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.3, 0.7])
avg_info = avg_information(p)
print(avg_info)

Probamos varios cálculos y podemos obtener que cuando p es 0.3, la cantidad promedio de información es 0.8812325; cuando p es 0.01, la cantidad promedio de información es 0.08079329; cuando p es 0.5, la cantidad promedio de información es 1.0, llegando a la máximo.

Si el cálculo con la función de Python es demasiado lento, podemos llamar a la función jit de JAX para acelerarlo. Solo necesitamos agregar @jit delante de la definición de la función.

import jax.numpy as jnp
from jax import jit

# 计算离散型随机变量 X 的平均信息量
@jit
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.01, 0.99])
avg_info = avg_information(p)
print(avg_info)

resumen

Arriba hemos seleccionado algunos puntos de conocimiento de álgebra lineal, matemáticas avanzadas y teoría de la probabilidad para despertar la memoria de todos. Al mismo tiempo, también presentamos su implementación y aceleración en JAX.
Aunque nuestros ejemplos no tienen nada especial, en realidad se ejecutan en la TPU.

Aunque el modelo grande proporciona una gran habilidad, aún necesitamos dedicar suficiente tiempo a las habilidades básicas. Tanto el hardware como el marco están en la parte creciente, pero la evolución del conocimiento básico de las matemáticas es muy lenta y la relación entrada-salida es muy alta. Después de tener habilidades básicas sólidas, el marco y el nuevo hardware se pueden aprender mientras se usan.

Supongo que te gusta

Origin blog.csdn.net/lusing/article/details/131113520
Recomendado
Clasificación