Experiment very effective snippet

1. greatly enhance the training speed of Pytorch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

2. The original log file to add a suffix becomes .bak files, avoid direct coverage

# from co-teaching train code
txtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer) ## good job! nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if os.path.exists(txtfile): the os.system ( ' Music Videos S% S% ' % (txtfile, txtfile + " .bak-% S " % nowtime)) # backup file bakeup

3. Accuracy calculated return list, the function is called, the value extracted directly, rather than extraction list

# from co-teaching code but MixMatch_pytorch code also has it
def accuracy(logit, target, topk=(1,)): """Computes the precision@k for the specified values of k""" output = F.softmax(logit, dim=1) maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) front = pred.t () correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. good use logger files to record each epoch a Found

# from Pytorch_MixMatch code
class Logger(object):
    '''Save training process to log file with simple plot function.'''
    def __init__(self, fpath, title=None, resume=False): 
        self.file = None
        self.resume = resume
        self.title = '' if title == None else title
        if fpath is not None:
            if resume: 
                self.file = open(fpath, 'r') 
                name = self.file.readline()
                self.names = name.rstrip().split('\t')
                self.numbers = {}
                for _, name in enumerate(self.names):
                    self.numbers[name] = []

                for numbers in self.file:
                    numbers = numbers.rstrip().split('\t')
                    for i in range(0, len(numbers)):
                        self.numbers[self.names[i]].append(numbers[i])
                self.file.close()
                self.file = open(fpath, 'a')  
            else:
                self.file = open(fpath, 'w')

    def set_names(self, names):
        if self.resume: 
            pass
        # initialize numbers as empty list
        self.numbers = {}
        self.names = names
        for _, name in enumerate(self.names):
            self.file.write(name)
            self.file.write('\t')
            self.numbers[name] = []
        self.file.write('\n')
        self.file.flush()


    def append(self, numbers):
        assert len(self.names) == len(numbers), 'Numbers do not match names'
        for index, num in enumerate(numbers):
            self.file.write("{0:.4f}".format(num))
            self.file.write('\t')
            self.numbers[self.names[index]].append(num)
        self.file.write('\n')
        self.file.flush()

    def plot(self, names=None):   
        names = self.names if names == None else names
        numbers = self.numbers
        for _, name in enumerate(names):
            x = np.arange(len(numbers[name]))
            plt.plot(x, np.asarray(numbers[name]))
        plt.legend([self.title + '(' + name + ')' for name in names])
        plt.grid(True)

    def close(self):
        if self.file is not None:
            self.file.close()

# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for  epoch in range(100):
    logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

The use command-line tools to argparser code reconstruction, the different data sets different parameter adaptation, optimized differently, different Setting, to avoid a plurality of highly redundant codes repeated 

 

 

# to be continued

Guess you like

Origin www.cnblogs.com/Gelthin2017/p/12148011.html