tflearn saves the model at the end of each epoch

Key code: 
tflearn.DNN(net, checkpoint_path='model_resnet_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
我的demo:
def get_model(width, height, classes=40):
    # TODO, modify model
    network = input_data(shape=[None, width, height, 3])  # if RGB, 224,224,3
    # Residual blocks  
    # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
    n = 2
    net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
    net = tflearn.residual_block(net, n, 16)
    net = tflearn.residual_block(net, 1, 32, downsample=True)
    net = tflearn.residual_block(net, n-1, 32)
    net = tflearn.residual_block(net, 1, 64, downsample=True)
    net = tflearn.residual_block(net, n-1, 64)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, 'relu')
    net = tflearn.global_avg_pool(net)
    # Regression  
    net = tflearn.fully_connected(net, classes, activation='softmax')
    #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
    mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
    net = tflearn.regression(net, optimizer=mom,
                             loss='categorical_crossentropy')
    # Training  
    model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
    return model



def  main():
    trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
    testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
    #trainX = trainX.reshape([-1, width, height, 1])
    #testX = testX.reshape([-1, width, height, 1])
    print("sample data:")
    print(trainX[0])
    print (trainY [0])
    print(testX[-1])
    print(testY[-1])

    model = get_model(width, height, classes=3755)

    filename = 'tflearn_resnet/model.tflearn'
    # try to load model and resume training
    try:
        #model.load(filename)
        model.load("model_resnet_cifar10-195804")
        print("Model loaded OK. Resume training!")
    except:
        pass

    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
    try:      
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
    except StopIteration as e:
        print("OK, stop iterate!Good!")

    model.save(filename)

    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
    filename = 'tflearn_resnet/model-infer.tflearn'
    model.save(filename)

 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325812711&siteId=291194637