使用Chainer训练LeNet5

学习Chainer很长一段时间了,在学习caffe的时候,第一个Demo就是训练LeNet5,可是,在Chainer上,使用mnist训练集的Demo是MLP,不是LeNet5。于是自己尝试着改成用LeNet5训练,却是报错。详见:Invalid operation is performed in: Convolution2DFunction (Forward)
一直卡在这里,又一次,准备重新试试,却是又卡在这里。在重新看了Chainer的文档后,突发奇想,改了下面的代码
原来是这样的

train, test = chainer.datasets.get_mnist(withlabel=True)

稍等改动后,变成这样

train, test = chainer.datasets.get_mnist(ndim=3, withlabel=True)

不同点就是多传了一个参数ndim=3
于是就能跑了。

(venv) pikachu@pikachu-PC:~/PycharmProjects/LeNet$ optirun python lenet5_demo1.py -g 0 -r result/snapshot_iter_12000 -e 21
/home/pikachu/swcontest/LRCN/venv/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 21

(1, 28, 28)
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.21789     1.64665               0.614184       0.9141                    355.524       
2           0.237945    1.54542               0.933001       0.9519                    435.652       
3           0.143741    1.51993               0.957601       0.9671                    490.544       
4           0.108724    1.50635               0.9672         0.9735                    699.786       
5           0.0872876   1.4972                0.973549       0.9776                    3470.16       
6           0.0745747   1.49964               0.977166       0.9763                    3576.54       
7           0.0649679   1.49016               0.980049       0.9819                    3732.18       
8           0.058112    1.48777               0.982265       0.9831                    3806.33       
9           0.0519419   1.48632               0.983882       0.9836                    3826.74       
10          0.0472876   1.48512               0.985599       0.9837                    3835.67       
11          0.0431381   1.48331               0.986548       0.9848                    3842.56       
12          0.0392124   1.48257               0.987849       0.9852                    3848.72       
13          0.0371112   1.48026               0.988915       0.9858                    3855.14       
14          0.0341503   1.48032               0.989432       0.9859                    3861.11       
15          0.0313045   1.47911               0.990248       0.9876                    3867.27       
16          0.0290763   1.47942               0.990982       0.9867                    3873.29       
17          0.0275939   1.47704               0.991282       0.9874                    3879.52       
18          0.0253959   1.47738               0.991748       0.9877                    3885.51       
19          0.0239291   1.477                 0.992265       0.9875                    3891.59       
20          0.0223463   1.47749               0.992798       0.9871                    3897.57       
21          0.0208822   1.47678               0.993432       0.9874                    3908.42       
(venv) pikachu@pikachu-PC:~/PycharmProjects/LeNet$ 

全部代码

#!/usr/bin/env python
import argparse
import numpy as np
import chainer
from chainer.backends import cuda
from chainer import Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions


# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


class LeNet5(Chain):
    def __init__(self):
        super(LeNet5, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(in_channels=1, out_channels=6, ksize=5, stride=1)
            self.conv2 = L.Convolution2D(in_channels=6, out_channels=16, ksize=5, stride=1)
            self.conv3 = L.Convolution2D(in_channels=16, out_channels=120, ksize=4, stride=1)
            self.fc4 = L.Linear(None, 84)
            self.fc5 = L.Linear(84, 10)

    def __call__(self, x):
        h = F.sigmoid(self.conv1(x))
        h = F.max_pooling_2d(h, 2, 2)
        h = F.sigmoid(self.conv2(h))
        h = F.max_pooling_2d(h, 2, 2)
        h = F.sigmoid(self.conv3(h))
        h = F.sigmoid(self.fc4(h))
        if chainer.config.train:
            return self.fc5(h)
        return F.softmax(self.fc5(h))


def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=0,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(LeNet5())
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist(ndim=3, withlabel=True)
    print(train[0][0].shape)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()


if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/qq_32768743/article/details/80246080
今日推荐