机器学习笔记 - Google 神经网络库JAX/FLAX入门

1、概述

        JAX是由Google Research开发的用于高性能数值计算和机器学习研究的框架。它允许您使用 NumPy 一致的 API 构建 Python 应用程序,该 API 专门用于微分、矢量化、并行化和编译为 GPU/TPU Just-In-Time。JAX 在设计时将性能和速度作为第一要务,并且原生兼容常见的机器学习加速器,例如GPU和TPU。大型 ML 模型的训练可能需要很长时间——您可能会对将 JAX 用于速度和性能特别重要的应用程序感兴趣!

GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and moreComposable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and morehttps://github.com/google/jax

2、JAX 与 TensorFlow

        TensorFlow是一个成熟的产品,具有丰富且功能齐全的生态系统,能够支持机器学习从业者可能拥有的大多数用例(例如,用于设备推理计算的TFLite、用于共享预训练模型的 TFHub,以及许多其他专业的应用程序)。这种类型的广泛授权既对比又补充了 JAX 的哲学,后者更狭隘地关注速度和性能。我们建议在您确实希望最大限度地提高速度和性能但不需要任何只有TensorFlow 生态系统才能提供的长尾特性和附加功能的情况下使用 JAX。

3、FLAX 简介

        就像JAX专注于速度一样,也鼓励 JAX 生态系统的其他成员进行专业化。例如,Flax专注于神经网络,jgraph专注于图网络。

        Flax是一个基于 JAX 的神经网络库,最初由 Google Research 的 Brain 团队(与 JAX 团队密切合作)开发,但现在是开源的。如果您想在 GPU 和 TPU 上以更快的速度训练机器学习模型,或者如果您的 ML 项目可能受益于Autograd和XLA的结合,请考虑将Flax用于您的下一个项目! Flax特别适合使用大型语言模型的项目,是前沿机器学习研究的热门选择。

4、示例:使用JAX/FLAX的数字识别器

# Importing all the libraries necessary for the project
import jax
import flax
import numpy as np
import jax.numpy as jnp
import tensorflow as tf
import pandas as pd
import os
from flax import linen as nn # the Linen API
from flax.training import train_state 
import optax
import matplotlib.pyplot as plt
%matplotlib inline

# List all the available devices
jax.local_devices()

mnist_train = pd.read_csv('../input/digit-recognizer/train.csv')

labels = mnist_train["label"]
features = mnist_train.drop(labels = ["label"],axis = 1)
features = features/255.0
features = features.values.reshape(-1,28,28,1)
labels = labels.to_numpy()

train_ds = {
        'images': features,
        'labels': labels
       }


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)   
        x = nn.log_softmax(x)
        return x


def cross_entropy_loss(*, logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics


@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = CNN().apply({'params': params}, batch['images'])
        loss = cross_entropy_loss(logits=logits, labels=batch['labels'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['labels'])
    return state, metrics


per_device_batch_size = 32

total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)


def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['images'])
    steps_per_epoch = train_ds_size // batch_size 

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    training_batch_metrics = jax.device_get(batch_metrics)
    training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

    print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

    return state, training_epoch_metrics

def create_train_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28,28,1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 2e-5
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng


num_epochs = 10
training_accuracy = []
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_metrics = train_epoch(state, train_ds, total_batch_size, epoch, input_rng)
    training_accuracy.append(train_metrics["accuracy"])


# Plot the Accuracy 
plt.plot(training_accuracy)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

Dog Breed classification using JAX and FLAX | KaggleExplore and run machine learning code with Kaggle Notebooks | Using data from Dog Breed Identificationhttps://www.kaggle.com/code/nilaychauhan/dog-breed-classification-using-jax-and-flax/notebook

猜你喜欢

转载自blog.csdn.net/bashendixie5/article/details/125327206