PyTorch out of memory 解决方案

如果模型在运行了一些时间后出现的outofmemory,那么有可能是因为无用的临时变量太多了,我们需要使用torch.cuda.empty_cache()进行清理就可以了。

try:
    loss, outputs = model(src, lengths, dec, targets)
except RuntimeError as e:
    if 'out of memory' in str(e):
        print('| WARNING: ran out of memory')
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
    else:
        raise e

如果这一批数据中有一个很长的句子容易导致oom那么就跳过这个句子

except RuntimeError as e:
    if 'out of memory' in str(e):
        print('| WARNING: ran out of memory, skipping batch')
        ooms += 1
        self.zero_grad()
    else:
        raise e

猜你喜欢

转载自blog.csdn.net/Dooonald/article/details/89875620