Neural Differential Equation Resnet Variants Achieving Memory Drop and Preserving Accuracy

Text content:

1. Study notes on neural differential equations, mainly to exercise your ability to learn new knowledge and the ability to read papers with many mathematical principles;

2. Neural differential equations can be used for time series data modeling, dynamic modeling, etc., but this article focuses on classification problems-resnet variant <easy to understand>;


Personal understanding:

The code implementation of joint sensitivity is more complicated, and the code logic and algorithm steps are the same. It is easy to understand by comparison. In fact, the gradient calculation is essentially reduced to the problem of solving the differential:

The way to implement OdeintAdjointMethod in engineering is to inherit the torch.autograd.Function class, implement the forward and backward methods, and replace the forward and backward with the ODE solver, instead of using the original chain rule of torch.autograd.Function to solve the gradient.


Fundamental:

Gradient backpropagation is a method for training neural networks that avoids the scalability issues encountered when using backpropagation to train derivative functions. This approach involves forward propagation using an ordinary differential equation (ODE) solver, followed by backpropagation using a joint sensitivity method , allowing backpropagation to be performed again using an ODE solver. In order to update the parameters of the derivative function, the gradient of the loss function with respect to the parameters of the dynamic function needs to be obtained using the joint sensitivity method. The final algorithm involves setting certain variables and packing information into them, then calling the ODE solver to backpropagate to get theta and update the CNN encoder and derivative function parameters.

The algorithm flow of the gradient backpropagation method is as follows:

insert image description here

A more complete version:

insert image description here


The code implementation of joint sensitivity is more complicated, and the code logic and algorithm steps are the same. It is easy to understand by comparison. In fact, the gradient calculation is essentially reduced to the problem of solving the differential:

**The way to implement OdeintAdjointMethod in engineering is to inherit the torch.autograd.Function class, implement the forward and backward methods, and replace forward and backward with the ODE solver, instead of using the original chain rule of torch.autograd.Function for gradient solve. **odeint is the ODE solver used in this article.


The types of solvers provided by the code warehouse are as follows:

SOLVERS = {
    
    
    'dopri8': Dopri8Solver,
    'dopri5': Dopri5Solver,
    'bosh3': Bosh3Solver,
    'fehlberg2': Fehlberg2,
    'adaptive_heun': AdaptiveHeunSolver,
    'euler': Euler,
    'midpoint': Midpoint,
    'rk4': RK4,
    'explicit_adams': AdamsBashforth,
    'implicit_adams': AdamsBashforthMoulton,
    # Backward compatibility: use the same name as before
    'fixed_adams': AdamsBashforthMoulton,
    # ~Backwards compatibility
    'scipy_solver': ScipyWrapperODESolver,
}

The complete source code is at: rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com) , here are the core parts and comments, forward and backpropagation parts;

class OdeintAdjointMethod(torch.autograd.Function):

    @staticmethod
    def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
                adjoint_options, t_requires_grad, *adjoint_params):

        ctx.shapes = shapes
        ctx.func = func
        ctx.adjoint_rtol = adjoint_rtol
        ctx.adjoint_atol = adjoint_atol
        ctx.adjoint_method = adjoint_method
        ctx.adjoint_options = adjoint_options
        ctx.t_requires_grad = t_requires_grad
        ctx.event_mode = event_fn is not None

        with torch.no_grad():
            ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)

            if event_fn is None:
                y = ans
                ctx.save_for_backward(t, y, *adjoint_params)
            else:
                event_t, y = ans
                ctx.save_for_backward(t, y, event_t, *adjoint_params)

        return ans

    @staticmethod
    def backward(ctx, *grad_y):
        with torch.no_grad():
            func = ctx.func
            adjoint_rtol = ctx.adjoint_rtol
            adjoint_atol = ctx.adjoint_atol
            adjoint_method = ctx.adjoint_method
            adjoint_options = ctx.adjoint_options
            t_requires_grad = ctx.t_requires_grad

            # 反向传播如果积分到达时间,不会在事件时间内反向传播。
            # Backprop as if integrating up to event time.
            # Does NOT backpropagate through the event time.
            event_mode = ctx.event_mode
            if event_mode:
                t, y, event_t, *adjoint_params = ctx.saved_tensors
                _t = t
                t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
                grad_y = grad_y[1]
            else:
                t, y, *adjoint_params = ctx.saved_tensors
                grad_y = grad_y[0]

            adjoint_params = tuple(adjoint_params)

            ##################################
            #      创建初始化状态      #
            ##################################

            # [-1] because y and grad_y are both of shape (len(t), *y0.shape)
            aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]]  # vjp_t, y, vjp_y
            aug_state.extend([torch.zeros_like(param) for param in adjoint_params])  # vjp_params

            ##################################
            #    创建反向ODE函数    #
            ##################################

            # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
            def augmented_dynamics(t, y_aug):
                # 动力学函数
                # Dynamics of the original system augmented with
                # the adjoint wrt y, and an integrator wrt t and args.
                y = y_aug[1]
                adj_y = y_aug[2]
                # ignore gradients wrt time and parameters

                with torch.enable_grad():
                    t_ = t.detach()
                    t = t_.requires_grad_(True)
                    y = y.detach().requires_grad_(True)

                    # If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
                    # doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
                    #如果使用自适应求解器,不想浪费时间来求解dL/dt,除非我们需要它(如果有分段结构,它甚至不存在),所以关闭梯度
                    # wrt t here means we won't compute that if we don't need it.
                    func_eval = func(t if t_requires_grad else t_, y)

                    # Workaround for PyTorch bug #39784
                    _t = torch.as_strided(t, (), ())  # noqa
                    _y = torch.as_strided(y, (), ())  # noqa
                    _params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params)  # noqa

                    vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
                        func_eval, (t, y) + adjoint_params, -adj_y,
                        allow_unused=True, retain_graph=True
                    )

                # autograd.grad returns None if no gradient, set to zero.
                vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
                vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
                vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, vjp_params)]

                return (vjp_t, func_eval, vjp_y, *vjp_params)

            ##################################
            #      求解联合ODE       #
            ##################################

            if t_requires_grad:
                time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
            else:
                time_vjps = None
            for i in range(len(t) - 1, 0, -1):
                if t_requires_grad:
                    # Compute the effect of moving the current time measurement point.
                    # We don't compute this unless we need to, to save some computation.
                    func_eval = func(t[i], y[i])
                    dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
                    aug_state[0] -= dLd_cur_t
                    time_vjps[i] = dLd_cur_t

                # Run the augmented system backwards in time.
                # 运行增强系统反向
                aug_state = odeint(
                    augmented_dynamics, tuple(aug_state),
                    t[i - 1:i + 1].flip(0),
                    rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
                )
                aug_state = [a[1] for a in aug_state]  # extract just the t[i - 1] value
                aug_state[1] = y[i - 1]  # update to use our forward-pass estimate of the state
                aug_state[2] += grad_y[i - 1]  # update any gradients wrt state at this time point

            if t_requires_grad:
                time_vjps[0] = aug_state[0]

            # 计算梯度
            # Only compute gradient wrt initial time when in event handling mode.
            if event_mode and t_requires_grad:
                time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])

            adj_y = aug_state[2]
            adj_params = aug_state[3:]

        return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)

Model architecture:

Using the small residual net, after downsampling, 6 standard residual blocks are replaced by ODESolve to obtain ODE-Net.

Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv1): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (conv2): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
    )
  )
  (8): GroupNorm(32, 64, eps=1e-05, affine=True)
  (9): ReLU(inplace=True)
  (10): AdaptiveAvgPool2d(output_size=(1, 1))
  (11): Flatten()
  (12): Linear(in_features=64, out_features=10, bias=True)
)

The result is as follows:

The amount of parameters and memory are better than Resnet, but the time is not clearly displayed, only the time complexity;

insert image description here


### Run the resnet demo:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


if __name__ == '__main__':

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size
    )

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        loss.backward()
        optimizer.step()

        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
            f_nfe_meter.update(nfe_forward)
            b_nfe_meter.update(nfe_backward)
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader)
                val_acc = accuracy(model, test_loader)
                if val_acc > best_acc:
                    torch.save({
    
    'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
                    best_acc = val_acc
                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )

Training process:

Epoch 0000 | Time 3.425 (3.425) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0987 | Test Acc 0.0958
Epoch 0001 | Time 3.279 (0.840) | NFE-F 20.3 | NFE-B 0.0 | Train Acc 0.9755 | Test Acc 0.9779
Epoch 0002 | Time 3.500 (0.839) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9858 | Test Acc 0.9875
Epoch 0003 | Time 3.403 (0.828) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9884 | Test Acc 0.9879
Epoch 0004 | Time 3.303 (0.807) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9926 | Test Acc 0.9921
Epoch 0005 | Time 3.308 (0.801) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9940 | Test Acc 0.9930
Epoch 0006 | Time 3.255 (0.804) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9917 | Test Acc 0.9894
Epoch 0007 | Time 3.376 (0.808) | NFE-F 20.2 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9929
Epoch 0008 | Time 3.260 (0.806) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9935 | Test Acc 0.9934
Epoch 0009 | Time 3.248 (0.832) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9909
Epoch 0010 | Time 3.286 (0.817) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9959 | Test Acc 0.9947
Epoch 0011 | Time 3.281 (0.827) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9967 | Test Acc 0.9951
Epoch 0012 | Time 3.382 (0.825) | NFE-F 20.9 | NFE-B 0.0 | Train Acc 0.9949 | Test Acc 0.9929
Epoch 0013 | Time 3.299 (0.862) | NFE-F 22.0 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9949
Epoch 0014 | Time 3.326 (0.824) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9947 | Test Acc 0.9936
Epoch 0015 | Time 3.291 (0.839) | NFE-F 21.3 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9948
Epoch 0016 | Time 3.467 (0.935) | NFE-F 24.4 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9941
Epoch 0017 | Time 3.483 (0.900) | NFE-F 23.2 | NFE-B 0.0 | Train Acc 0.9970 | Test Acc 0.9939
Epoch 0018 | Time 3.309 (0.872) | NFE-F 22.2 | NFE-B 0.0 | Train Acc 0.9961 | Test Acc 0.9932
Epoch 0019 | Time 3.294 (0.913) | NFE-F 23.6 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9954
Epoch 0020 | Time 3.504 (0.984) | NFE-F 25.9 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9951
Epoch 0021 | Time 3.589 (0.966) | NFE-F 25.2 | NFE-B 0.0 | Train Acc 0.9966 | Test Acc 0.9929
Epoch 0022 | Time 3.503 (0.994) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9949
Epoch 0023 | Time 3.457 (0.995) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9978 | Test Acc 0.9939
Epoch 0024 | Time 3.529 (0.985) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9958
Epoch 0025 | Time 3.459 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9973 | Test Acc 0.9947
Epoch 0026 | Time 3.541 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9979 | Test Acc 0.9946
Epoch 0027 | Time 3.513 (0.993) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9959
Epoch 0028 | Time 3.505 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9982 | Test Acc 0.9953
Epoch 0029 | Time 3.501 (0.990) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9953
Epoch 0030 | Time 3.475 (0.992) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9954
Epoch 0031 | Time 3.506 (0.993) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9947
Epoch 0032 | Time 3.527 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9981 | Test Acc 0.9954
Epoch 0033 | Time 3.529 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9945
Epoch 0034 | Time 3.545 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9959
Epoch 0035 | Time 3.479 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9990 | Test Acc 0.9953
Epoch 0036 | Time 3.479 (0.997) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9989 | Test Acc 0.9963
Epoch 0037 | Time 3.540 (0.998) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9957

reference:

1. Neural ODE: Introductory Tutorial - Zhihu (zhihu.com)

2、rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com)

Supongo que te gusta

Origin blog.csdn.net/KPer_Yang/article/details/130186477
Recomendado
Clasificación