torch yolov3 minibatch 失败


minibatch失败了,原因是内存溢出了,代码:

loss += model(sub_imgs, sub_targets)

# -*- coding:utf-8 -*-
from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from logger import Logger
import os
import sys
import time
import datetime
import argparse

import torch
from torch.utils.data import DataLoader

from torch.autograd import Variable
import torch.optim as optim

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=2001, help='number of epochs')
parser.add_argument('--image_folder', type=str, default='data/samples', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=20, help='size of each image batch')
parser.add_argument('--learning_rate', type=float, default=0.01, help='learning_rate')
parser.add_argument('--train_dir', type=str, default=r'D:\data/',help='train_dir')
parser.add_argument('--model_config_path', type=str, default='config/yolov3_2cls.cfg', help='path to model config file')
parser.add_argument('--data_config_path', type=str, default='config/coco.data', help='path to data config file')
parser.add_argument('--weights_path', type=str, default='weights/yolov3.weights', help='path to weights file')
# parser.add_argument('--weights_path', type=str, default='checkpoints/40.weights', help='path to weights file')
parser.add_argument('--class_path', type=str, default='data/coco_2cls.names', help='path to class label file')
parser.add_argument('--conf_thres', type=float, default=0.8, help='object confidence threshold')
parser.add_argument('--nms_thres', type=float, default=0.4, help='iou thresshold for non-maximum suppression')
parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_size', type=int, default=416, help='size of each image dimension')
parser.add_argument('--checkpoint_interval', type=int, default=2, help='interval between saving model weights')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='directory where model checkpoints are saved')
opt = parser.parse_args()
print(opt)

os.makedirs('output', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
def adjust_learning_rate(optimizer, decay_rate=0.5):
    for param_group in optimizer.param_groups:
        if(param_group['lr']>1e-8):
            param_group['lr'] = param_group['lr'] * decay_rate
    print(optimizer)
cuda = True if torch.cuda.is_available else False

classes = load_classes(opt.class_path)

# Get data configuration
data_config     = parse_data_config(opt.data_config_path)
train_path      = data_config['train']

# Get hyper parameters

module_defs=parse_model_config(opt.model_config_path)
hyperparams     = module_defs[0]
anchors=hyperparams["anchors"]
anchors = [int(x) for x in anchors.split(",")]
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
module_defs[83]["anchors"]=anchors
module_defs[95]["anchors"]=anchors
module_defs[107]["anchors"]=anchors
batch_size      = opt.batch_size# int(hyperparams['batch'])
subdivisions    = int(hyperparams['subdivisions'])
sub_batch       = batch_size // subdivisions
learning_rate   = opt.learning_rate
momentum        = float(hyperparams['momentum'])
decay           = float(hyperparams['decay'])
burn_in         = int(hyperparams['burn_in'])
hyperparams['height']=hyperparams['width']=opt.img_size

if __name__ == '__main__':
    dataloader = torch.utils.data.DataLoader(
        ListDataset(opt.train_dir,img_size=opt.img_size,is_training = 1,data_size=400),
        batch_size=batch_size, shuffle=1, num_workers=opt.n_cpu)

    model = Darknet(module_defs,img_size=opt.img_size)
    model.load_weights(opt.weights_path,is_training=True)
    #model.apply(weights_init_normal)

    if cuda:
        model = model.cuda()

    model.train()
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
    optimizer = optim.Adamax(model.parameters(), lr=learning_rate/batch_size, weight_decay=decay*batch_size)

    print("subdivisions",subdivisions)
    logger = Logger('./logs')
    mean_loss=0
    for epoch in range(opt.epochs):
        last_time = datetime.datetime.now()
        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            imgs = Variable(imgs.type(Tensor))
            targets = Variable(targets.type(Tensor), requires_grad=False)
            # print("prepa_img", (datetime.datetime.now()-last_time).microseconds)
            # last_time = datetime.datetime.now()
            loss = 0
            loss_x =0
            loss_y =0
            loss_w =0
            loss_h =0
            loss_conf=0
            loss_cls=0
            m_recall=0

            for i in range(subdivisions):
                optimizer.zero_grad()
                sub_imgs = imgs[i*sub_batch: (i+1)*sub_batch]
                sub_targets = targets[i*sub_batch: (i+1)*sub_batch]
                loss += model(sub_imgs, sub_targets)
                loss_x+= model.losses['x']
                loss_y += model.losses['y']
                loss_w += model.losses['w']
                loss_h+= model.losses['h']
                loss_conf += model.losses['conf']
                loss_cls+= model.losses['cls']
                m_recall+=model.losses['recall']
            # print("train_img", (datetime.datetime.now() - last_time).microseconds)
            # last_time = datetime.datetime.now()
            loss.backward()
            optimizer.step()
            # print("backw_img", (datetime.datetime.now() - last_time).microseconds)
            # last_time = datetime.datetime.now()
            if epoch > 0 and batch_i == 0:
                if loss.item() > mean_loss / batch_size :
                    print("mean_loss", mean_loss)
                    adjust_learning_rate(optimizer)
                mean_loss = loss.item()
            else:
                mean_loss += loss.item()
            # info = {'loss': loss.item(), 'cls': model.losses['cls'], 'conf': model.losses['conf']}
            #
            # for tag, value in info.items():
            #     logger.scalar_summary(tag, value, epoch)

            now = datetime.datetime.now()
            strftime = now.strftime("%H:%M:%S")
            print('%s [Epoch %d/%d, Batch %d/%d Losses: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f, recall: %.5f]' %
                  (strftime, epoch, opt.epochs, batch_i, len(dataloader),
                   loss_x, loss_y , loss_w , loss_h , loss_conf , loss_cls,
                   loss.item(), m_recall/sub_batch))

        if (epoch % opt.checkpoint_interval == 0 and model.losses['recall']>0.9) or model.losses['recall']>0.98:
            # model.save_weights('%s/%d.weights' % (opt.checkpoint_dir, epoch))
            model.save_weights('%s/%d.weights' % (opt.checkpoint_dir, epoch))

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/80782927
今日推荐