Text content:
1. Study notes on neural differential equations, mainly to exercise your ability to learn new knowledge and the ability to read papers with many mathematical principles;
2. Neural differential equations can be used for time series data modeling, dynamic modeling, etc., but this article focuses on classification problems-resnet variant <easy to understand>;
Personal understanding:
The code implementation of joint sensitivity is more complicated, and the code logic and algorithm steps are the same. It is easy to understand by comparison. In fact, the gradient calculation is essentially reduced to the problem of solving the differential:
The way to implement OdeintAdjointMethod in engineering is to inherit the torch.autograd.Function class, implement the forward and backward methods, and replace the forward and backward with the ODE solver, instead of using the original chain rule of torch.autograd.Function to solve the gradient.
Fundamental:
Gradient backpropagation is a method for training neural networks that avoids the scalability issues encountered when using backpropagation to train derivative functions. This approach involves forward propagation using an ordinary differential equation (ODE) solver, followed by backpropagation using a joint sensitivity method , allowing backpropagation to be performed again using an ODE solver. In order to update the parameters of the derivative function, the gradient of the loss function with respect to the parameters of the dynamic function needs to be obtained using the joint sensitivity method. The final algorithm involves setting certain variables and packing information into them, then calling the ODE solver to backpropagate to get theta and update the CNN encoder and derivative function parameters.
The algorithm flow of the gradient backpropagation method is as follows:
A more complete version:
The code implementation of joint sensitivity is more complicated, and the code logic and algorithm steps are the same. It is easy to understand by comparison. In fact, the gradient calculation is essentially reduced to the problem of solving the differential:
**The way to implement OdeintAdjointMethod in engineering is to inherit the torch.autograd.Function class, implement the forward and backward methods, and replace forward and backward with the ODE solver, instead of using the original chain rule of torch.autograd.Function for gradient solve. **odeint is the ODE solver used in this article.
The types of solvers provided by the code warehouse are as follows:
SOLVERS = {
'dopri8': Dopri8Solver,
'dopri5': Dopri5Solver,
'bosh3': Bosh3Solver,
'fehlberg2': Fehlberg2,
'adaptive_heun': AdaptiveHeunSolver,
'euler': Euler,
'midpoint': Midpoint,
'rk4': RK4,
'explicit_adams': AdamsBashforth,
'implicit_adams': AdamsBashforthMoulton,
# Backward compatibility: use the same name as before
'fixed_adams': AdamsBashforthMoulton,
# ~Backwards compatibility
'scipy_solver': ScipyWrapperODESolver,
}
The complete source code is at: rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com) , here are the core parts and comments, forward and backpropagation parts;
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
adjoint_options, t_requires_grad, *adjoint_params):
ctx.shapes = shapes
ctx.func = func
ctx.adjoint_rtol = adjoint_rtol
ctx.adjoint_atol = adjoint_atol
ctx.adjoint_method = adjoint_method
ctx.adjoint_options = adjoint_options
ctx.t_requires_grad = t_requires_grad
ctx.event_mode = event_fn is not None
with torch.no_grad():
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
if event_fn is None:
y = ans
ctx.save_for_backward(t, y, *adjoint_params)
else:
event_t, y = ans
ctx.save_for_backward(t, y, event_t, *adjoint_params)
return ans
@staticmethod
def backward(ctx, *grad_y):
with torch.no_grad():
func = ctx.func
adjoint_rtol = ctx.adjoint_rtol
adjoint_atol = ctx.adjoint_atol
adjoint_method = ctx.adjoint_method
adjoint_options = ctx.adjoint_options
t_requires_grad = ctx.t_requires_grad
# 反向传播如果积分到达时间,不会在事件时间内反向传播。
# Backprop as if integrating up to event time.
# Does NOT backpropagate through the event time.
event_mode = ctx.event_mode
if event_mode:
t, y, event_t, *adjoint_params = ctx.saved_tensors
_t = t
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
grad_y = grad_y[1]
else:
t, y, *adjoint_params = ctx.saved_tensors
grad_y = grad_y[0]
adjoint_params = tuple(adjoint_params)
##################################
# 创建初始化状态 #
##################################
# [-1] because y and grad_y are both of shape (len(t), *y0.shape)
aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
##################################
# 创建反向ODE函数 #
##################################
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# 动力学函数
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y = y_aug[1]
adj_y = y_aug[2]
# ignore gradients wrt time and parameters
with torch.enable_grad():
t_ = t.detach()
t = t_.requires_grad_(True)
y = y.detach().requires_grad_(True)
# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
#如果使用自适应求解器,不想浪费时间来求解dL/dt,除非我们需要它(如果有分段结构,它甚至不存在),所以关闭梯度
# wrt t here means we won't compute that if we don't need it.
func_eval = func(t if t_requires_grad else t_, y)
# Workaround for PyTorch bug #39784
_t = torch.as_strided(t, (), ()) # noqa
_y = torch.as_strided(y, (), ()) # noqa
_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
func_eval, (t, y) + adjoint_params, -adj_y,
allow_unused=True, retain_graph=True
)
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
for param, vjp_param in zip(adjoint_params, vjp_params)]
return (vjp_t, func_eval, vjp_y, *vjp_params)
##################################
# 求解联合ODE #
##################################
if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
time_vjps = None
for i in range(len(t) - 1, 0, -1):
if t_requires_grad:
# Compute the effect of moving the current time measurement point.
# We don't compute this unless we need to, to save some computation.
func_eval = func(t[i], y[i])
dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
aug_state[0] -= dLd_cur_t
time_vjps[i] = dLd_cur_t
# Run the augmented system backwards in time.
# 运行增强系统反向
aug_state = odeint(
augmented_dynamics, tuple(aug_state),
t[i - 1:i + 1].flip(0),
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
if t_requires_grad:
time_vjps[0] = aug_state[0]
# 计算梯度
# Only compute gradient wrt initial time when in event handling mode.
if event_mode and t_requires_grad:
time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
adj_y = aug_state[2]
adj_params = aug_state[3:]
return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
Model architecture:
Using the small residual net, after downsampling, 6 standard residual blocks are replaced by ODESolve to obtain ODE-Net.
Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
(1): GroupNorm(32, 64, eps=1e-05, affine=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): GroupNorm(32, 64, eps=1e-05, affine=True)
(5): ReLU(inplace=True)
(6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): ODEBlock(
(odefunc): ODEfunc(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(relu): ReLU(inplace=True)
(conv1): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv2): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
)
)
(8): GroupNorm(32, 64, eps=1e-05, affine=True)
(9): ReLU(inplace=True)
(10): AdaptiveAvgPool2d(output_size=(1, 1))
(11): Flatten()
(12): Linear(in_features=64, out_features=10, bias=True)
)
The result is as follows:
The amount of parameters and memory are better than Resnet, but the time is not clearly displayed, only the time complexity;
### Run the resnet demo:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
if args.adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim):
return nn.GroupNorm(min(32, dim), dim)
class ResBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.conv2 = conv3x3(planes, planes)
def forward(self, x):
shortcut = x
out = self.relu(self.norm1(x))
if self.downsample is not None:
shortcut = self.downsample(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(out)
return out + shortcut
class ConcatConv2d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class ODEfunc(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
if data_aug:
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
train_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=2, drop_last=True
)
train_eval_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
test_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
return train_loader, test_loader, train_eval_loader
def inf_generator(iterable):
"""Allows training with DataLoaders in a single infinite loop:
for i, (x, y) in enumerate(inf_generator(train_loader)):
"""
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()
def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
initial_learning_rate = args.lr * batch_size / batch_denom
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(itr):
lt = [itr < b for b in boundaries] + [True]
i = np.argmax(lt)
return vals[i]
return learning_rate_fn
def one_hot(x, K):
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
def accuracy(model, dataset_loader):
total_correct = 0
for x, y in dataset_loader:
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)
target_class = np.argmax(y, axis=1)
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
return total_correct / len(dataset_loader.dataset)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
logger = logging.getLogger()
if debug:
level = logging.DEBUG
else:
level = logging.INFO
logger.setLevel(level)
if saving:
info_file_handler = logging.FileHandler(logpath, mode="a")
info_file_handler.setLevel(level)
logger.addHandler(info_file_handler)
if displaying:
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
logger.addHandler(console_handler)
logger.info(filepath)
with open(filepath, "r") as f:
logger.info(f.read())
for f in package_files:
logger.info(f)
with open(f, "r") as package_f:
logger.info(package_f.read())
return logger
if __name__ == '__main__':
makedirs(args.save)
logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
logger.info(args)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
is_odenet = args.network == 'odenet'
if args.downsampling_method == 'conv':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]
elif args.downsampling_method == 'res':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
]
feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
logger.info(model)
logger.info('Number of parameters: {}'.format(count_parameters(model)))
criterion = nn.CrossEntropyLoss().to(device)
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
args.data_aug, args.batch_size, args.test_batch_size
)
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)
lr_fn = learning_rate_with_decay(
args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
decay_rates=[1, 0.1, 0.01, 0.001]
)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()
for itr in range(args.nepochs * batches_per_epoch):
for param_group in optimizer.param_groups:
param_group['lr'] = lr_fn(itr)
optimizer.zero_grad()
x, y = data_gen.__next__()
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
if is_odenet:
nfe_forward = feature_layers[0].nfe
feature_layers[0].nfe = 0
loss.backward()
optimizer.step()
if is_odenet:
nfe_backward = feature_layers[0].nfe
feature_layers[0].nfe = 0
batch_time_meter.update(time.time() - end)
if is_odenet:
f_nfe_meter.update(nfe_forward)
b_nfe_meter.update(nfe_backward)
end = time.time()
if itr % batches_per_epoch == 0:
with torch.no_grad():
train_acc = accuracy(model, train_eval_loader)
val_acc = accuracy(model, test_loader)
if val_acc > best_acc:
torch.save({
'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
best_acc = val_acc
logger.info(
"Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
"Train Acc {:.4f} | Test Acc {:.4f}".format(
itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
b_nfe_meter.avg, train_acc, val_acc
)
)
Training process:
Epoch 0000 | Time 3.425 (3.425) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0987 | Test Acc 0.0958
Epoch 0001 | Time 3.279 (0.840) | NFE-F 20.3 | NFE-B 0.0 | Train Acc 0.9755 | Test Acc 0.9779
Epoch 0002 | Time 3.500 (0.839) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9858 | Test Acc 0.9875
Epoch 0003 | Time 3.403 (0.828) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9884 | Test Acc 0.9879
Epoch 0004 | Time 3.303 (0.807) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9926 | Test Acc 0.9921
Epoch 0005 | Time 3.308 (0.801) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9940 | Test Acc 0.9930
Epoch 0006 | Time 3.255 (0.804) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9917 | Test Acc 0.9894
Epoch 0007 | Time 3.376 (0.808) | NFE-F 20.2 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9929
Epoch 0008 | Time 3.260 (0.806) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9935 | Test Acc 0.9934
Epoch 0009 | Time 3.248 (0.832) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9909
Epoch 0010 | Time 3.286 (0.817) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9959 | Test Acc 0.9947
Epoch 0011 | Time 3.281 (0.827) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9967 | Test Acc 0.9951
Epoch 0012 | Time 3.382 (0.825) | NFE-F 20.9 | NFE-B 0.0 | Train Acc 0.9949 | Test Acc 0.9929
Epoch 0013 | Time 3.299 (0.862) | NFE-F 22.0 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9949
Epoch 0014 | Time 3.326 (0.824) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9947 | Test Acc 0.9936
Epoch 0015 | Time 3.291 (0.839) | NFE-F 21.3 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9948
Epoch 0016 | Time 3.467 (0.935) | NFE-F 24.4 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9941
Epoch 0017 | Time 3.483 (0.900) | NFE-F 23.2 | NFE-B 0.0 | Train Acc 0.9970 | Test Acc 0.9939
Epoch 0018 | Time 3.309 (0.872) | NFE-F 22.2 | NFE-B 0.0 | Train Acc 0.9961 | Test Acc 0.9932
Epoch 0019 | Time 3.294 (0.913) | NFE-F 23.6 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9954
Epoch 0020 | Time 3.504 (0.984) | NFE-F 25.9 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9951
Epoch 0021 | Time 3.589 (0.966) | NFE-F 25.2 | NFE-B 0.0 | Train Acc 0.9966 | Test Acc 0.9929
Epoch 0022 | Time 3.503 (0.994) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9949
Epoch 0023 | Time 3.457 (0.995) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9978 | Test Acc 0.9939
Epoch 0024 | Time 3.529 (0.985) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9958
Epoch 0025 | Time 3.459 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9973 | Test Acc 0.9947
Epoch 0026 | Time 3.541 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9979 | Test Acc 0.9946
Epoch 0027 | Time 3.513 (0.993) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9959
Epoch 0028 | Time 3.505 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9982 | Test Acc 0.9953
Epoch 0029 | Time 3.501 (0.990) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9953
Epoch 0030 | Time 3.475 (0.992) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9954
Epoch 0031 | Time 3.506 (0.993) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9947
Epoch 0032 | Time 3.527 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9981 | Test Acc 0.9954
Epoch 0033 | Time 3.529 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9945
Epoch 0034 | Time 3.545 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9959
Epoch 0035 | Time 3.479 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9990 | Test Acc 0.9953
Epoch 0036 | Time 3.479 (0.997) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9989 | Test Acc 0.9963
Epoch 0037 | Time 3.540 (0.998) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9957