JAX Study Notes

1. Installation environment

2. I use the virtual environment of python3.7 here

conda create -n jax python=3.7
conda activate jax

2. Download and install related dependent packages such as jax and jaxlib , and then enter the virtual environment for installation

cd C:\Users\dz\Downloads\jax-0.2.9_and_jaxlib-0.1.61-cp37-win_amd64
pip install jaxlib-0.1.61-cp37-none-win_amd64.whl
pip install jax==0.2.9
pip install matplotlib

3. Write a py file to test whether the environment is installed successfully

import jax.numpy as jnp
import matplotlib.pyplot as plt
x_jnp=jnp.linspace(0,10,1000)
y_jnp=jnp.sin(x_jnp)*jnp.cos(x_jnp)
print(x_jnp,y_jnp)
plt.plot(x_jnp,y_jnp)
plt.show()

As follows, the environment installation is successful
insert image description here

2. Basic knowledge

2.1.random

import jax.numpy as jnp
from jax import random
key=random.PRNGKey(0)#随机种子
x=random.normal(key,(10,),dtype=jnp.float32)#生成1维10个数的数组
print(x)
print(type(x))#<class 'jax.interpreters.xla._DeviceArray'>

2.2.grad function and DeviceArray property

  1. JAX allows us to convert Python functions, jax.grad will find the gradient with respect to the first argument
import jax.numpy as jnp
from jax import random
import jax
def sum_of_squares(x):
    return jnp.sum(x**2)
sum_of_squares_dx=jax.grad(sum_of_squares)#它接受一个用 Python 编写的数值函数,并返回一个新的 Python 函数,该函数计算原始函数的梯度。
x=jnp.asarray([1.0,2.0,3.0,4.0])
print(sum_of_squares(x))#求平方和》30.0
print(sum_of_squares_dx(x))##求平方和函数对每个自变量x的导数》[2. 4. 6. 8.]

2. To find the gradient with respect to different parameters (or multiples), you can set the argnums

import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
    return jnp.sum((x-y)**2)
sum_squared_error_dx_dy=jax.value_and_grad(sum_squared_error,argnums=(0,1))
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#对x求导:2x-2y;对y求导:2y-2x
print(sum_squared_error_dx_dy(x,y))#[2., 2., 2., 2.];[-2., -2., -2., -2.]

3. Need to find the value and gradient of the function with jax.value_and_grad

import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
    return jnp.sum((x-y)**2)
sum_squared_error_dx_dy=jax.value_and_grad(sum_squared_error)
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#(DeviceArray(4., dtype=float32), DeviceArray([2., 2., 2., 2.], dtype=float32))
print(sum_squared_error_dx_dy(x,y))#jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) 

4. The grad function is not a function, but a set of tuples (intermediate functions), use has_aux=True

import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
    return jnp.sum((x-y)**2),x-y
"""jax.grad is only defined on scalar functions, 
and our new function returns a tuple.
But we need to return a tuple to return our intermediate results!
This is where has_aux comes in"""
sum_squared_error_dx_dy=jax.grad(sum_squared_error,has_aux=True)
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#(DeviceArray([2., 2., 2., 2.], dtype=float32), DeviceArray([1., 1., 1., 1.], dtype=float32))
print(sum_squared_error_dx_dy(x,y))

5. The modification of DeviceArray should be based on the index, and it is a soft modification

import jax.numpy as jnp
import numpy as np
#1.修改numpy数组
"""x=np.array([1,2,3])
def in_place_modify(x):
    x[0]=4
    return None
in_place_modify(x)
print(x)#[4 2 3]"""
#2.修改jnp数组:按索引进行就地修改,旧数组未受影响
y=jnp.array([1,2,3])
def jax_in_place_modify(x):
    return x.at[0].set(4)
print(jax_in_place_modify(y))#[4 2 3]
print(y)#[1 2 3]

2.3.vmap automatic vectorization

3. Small projects

3.1. XOR function

Input a 2-dimensional array, 3 neurons in the first layer, 1 neuron in the second layer, and output the XOR (exclusive OR) result of the 2-dimensional array, as follows
insert image description here

import random
import itertools
import jax
import jax.numpy as jnp
import numpy as np                   
learning_rate=1
inputs=jnp.array([[0,0],[0,1],[1,0],[1,1]])
def sigmoid(x):
    return 1/(1+jnp.exp(-x))
def net(params,x):
    w1,b1,w2,b2=params
    hidden=jnp.tanh(jnp.dot(w1,x)+b1)
    return sigmoid(jnp.dot(w2,hidden)+b2)#输出0,1分类
def loss(params,x,y):
    out=net(params,x)
    cross_entropy=-y*jnp.log(out)-(1-y)*jnp.log(1-out)
    return cross_entropy
def test_all_inputs(inputs,params):
    predictions=[int(net(params,inp)>0.5) for inp in inputs]
    for inp,out in zip(inputs,predictions):
        print(inp,'->',out)
    return (predictions==[np.bitwise_xor(*inp) for inp in inputs])#网络输出结果进行异或运算
#1.jax.grad 接受一个函数并返回一个新函数,该函数计算原始函数的渐变。默认情况下,相对于第一个参数进行渐变;这可以通过 jgn.grad 的 argnums 参数来控制。
loss_grad=jax.grad(loss)
def initial_params():
    return [np.random.randn(3,2),np.random.randn(3),np.random.randn(3),np.random.randn()]
params=initial_params()#初始化参数
for n in itertools.count():#迭代
    x=inputs[np.random.choice(inputs.shape[0])]#四个数据中随机拿一个数据
    y=np.bitwise_xor(*x)#两个值的异或运算
    grads=loss_grad(params,x,y)
    params=[param-learning_rate*grad for param,grad in zip(params,grads)]#参数更新
    if not n%100:
        print('Iteration {}'.format(n))#每100次训练测试1次
        if test_all_inputs(inputs,params):#如果结果都正确了就结束循环
            break

3.2. Linear regression

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
#1.数据
xs=np.random.normal(size=(100,))
noise=np.random.normal(scale=0.1,size=(100,))
ys=xs*3-1+noise
plt.scatter(xs,ys)
# plt.show()
#2.模型\hat y(x; \theta) = wx + b
def model (theta,x):
    w,b=theta
    return w*x+b
def loss_fn(theta,x,y):
    prediction=model(theta,x)
    return jnp.mean((prediction-y)**2)#误差方J(x, y; \theta) = (\hat y - y)^2
#3.参数更新
def update(theta,x,y,lr=0.1):#参数更新\theta_{new} = \theta - 0.1 (\nabla_\theta J) (x, y; \theta)
    return theta -lr*jax.grad(loss_fn)(theta,x,y)
theta=jnp.array([1.,1.])
for _ in range(1000):
    theta=update(theta,xs,ys)
plt.plot(xs,model(theta,xs))
plt.show()  
w,b=theta
print(f"w:{
      
      w:<.2f},b:{
      
      b:<.2f}") #w:2.99,b:-1.00

insert image description here

3.3. MNIST Handwriting Recognition

pip install tensorflow_datasets -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorflow
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
from jax import jit,grad,random
from jax.experimental import optimizers,stax
num_classes= 10
input_shape=(-1,28*28)
step_size=0.001#学习率
batch_size=128
momentum_mass=0.9
rng=random.PRNGKey(0)
#1.数据
x_train=jnp.load(r"C:\Users\dz\Desktop\JAX\mnist_train_x.npy")
y_train=jnp.load(r"C:\Users\dz\Desktop\JAX\mnist_train_y.npy")
total_train_imgs=len(y_train)
def one_hot_nojit(x,k=10,dtype=jnp.float32):
    return jnp.array(x[:,None]==jnp.arange(k),dtype)
y_train=one_hot_nojit(y_train)
ds_train=tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train=tfds.as_numpy(ds_train)
#2.网络
init_random_params,predict=stax.serial(stax.Dense(1024),stax.Relu,stax.Dense(1024),stax.Relu,stax.Dense(10),stax.LogSoftmax)
def pred_check(params,batch):
    inputs,targets=batch
    predict_result=predict(params,inputs)
    predicted_class=jnp.argmax(predict_result,axis=1)
    targets=jnp.argmax(targets,axis=1)
    return jnp.sum(predicted_class==targets)
def loss(params,batch):
    inputs,targets=batch
    return jnp.mean(jnp.sum(-targets*predict(params,inputs),axis=1))
opt_init,opt_update,get_params=optimizers.adam(step_size=2e-4)
_,init_params=init_random_params(rng,input_shape)
opt_state=opt_init(init_params)
def update(i,opt_state,batch):
    params=get_params(opt_state)
    return opt_update(i,grad(loss)(params,batch),opt_state)
#3.训练
for _ in range(17):
    itercount=0
    for batch_raw in ds_train:
        data=batch_raw[0].reshape((-1,28*28))
        targets=batch_raw[1].reshape((-1,10))
        opt_state=update((itercount),opt_state,(data,targets))
        itercount+=1
    params=get_params(opt_state)
    train_acc=[]
    correct_preds=0.0
    for batch_raw in ds_train:
        data=batch_raw[0].reshape((-1,28*28))
        targets=batch_raw[1]
        correct_preds+=pred_check(params,(data,targets))
    train_acc.append(correct_preds/float(total_train_imgs))
    print(f"training set accuracy:{
      
      train_acc}")

3.4. Iris classification

4 eigenvalues, 3 classification problems, using 2 layers of perceptrons for classification.

from cgitb import reset
from sklearn.datasets import load_iris
import jax.numpy as jnp
from jax import random,grad
import jax
#1.数据
data=load_iris()
iris_data=jnp.float32(data.data)#数据转化为float类型的list
iris_target=jnp.float32(data.target)
iris_data=jax.random.shuffle(random.PRNGKey(17),iris_data)#伪随机打乱数据
iris_target=jax.random.shuffle(random.PRNGKey(17),iris_target)
def one_hot_nojit(x,k=3,dtype=jnp.float32):
    return jnp.array(x[:,None]==jnp.arange(k),dtype)
iris_target=one_hot_nojit(iris_target)
#2.网络结构
def Dense(dense_shape=[1,1]):
    rng=random.PRNGKey(17)
    weight=random.normal(rng,shape=dense_shape)
    bias=random.normal(rng,shape=(dense_shape[-1],))
    params=[weight,bias]#参数结构
    def apply_fun(inputs,params=params):
        w,b=params
        return jnp.dot(inputs,w)+b#参数与输入数据点乘
    return apply_fun
def selu(x,alpha=1.67,lmbda=1.05):
    return lmbda*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)
def softmax(x,axis=-1):
    unnormalized=jnp.exp(x)
    return unnormalized/unnormalized.sum(axis,keepdims=True)
def cross_entropy(y_true,y_pred):
    y_true==jnp.array(y_true)
    y_pred=jnp.array(y_pred)
    red=-jnp.sum(y_true*jnp.log(y_pred+1e-7),axis=-1)
    return red
def mlp(x,params):
    a0,b0,a1,b1=params
    x=Dense()(x,[a0,b0])
    x=jax.nn.selu(x)
    x=Dense()(x,[a1,b1])
    x=softmax(x,axis=-1)
    return x
def loss_mlp(params,x,y):
    preds=mlp(x,params)
    loss_value=cross_entropy(y,preds)
    return jnp.mean(loss_value)
rng=random.PRNGKey(17)
a0=random.normal(rng,shape=(4,5))
b0=random.normal(rng,shape=(5,))
a1=random.normal(rng,shape=(5,3))
b1=random.normal(rng,shape=(3,))
params=[a0,b0,a1,b1]
learning_rate=2.17e-4
#3.训练
for i in range(20000):
    loss=loss_mlp(params,iris_data,iris_target)
    if i%1000==0:
        predict_result=mlp(iris_data,params)
        predicted_class=jnp.argmax(predict_result,axis=1)
        _iris_target=jnp.argmax(iris_target,axis=1)
        accuracy=jnp.sum(predicted_class==_iris_target)/len(_iris_target)
        print("i:",i,"loss:",loss,"accuracy:",accuracy)
    params_grad=grad(loss_mlp)(params,iris_data,iris_target)
    params=[(p-g*learning_rate) for p,g in zip(params,params_grad)]
predict_result=mlp(iris_data,params)
predicted_class=jnp.argmax(predict_result,axis=1)
iris_target=jnp.argmax(iris_target,axis=1)
accuracy=jnp.sum(predicted_class==iris_target)/len(iris_target)
print(accuracy)

4. Reference

[1]https://zhuanlan.zhihu.com/p/56468260
[2]https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
[3]https://github.com/google/jaxa

Guess you like

Origin blog.csdn.net/weixin_38226321/article/details/125322026