【PyTorch】快速理解 torch.manual_seed()

【PyTorch】快速理解 torch.manual_seed()

一、pytorch中的seed是什么?

在神经网络中,参数默认是进行随机初始化的。不同的初始化参数往往会导致模型的训练结果会存在一定的差异。当得到比较好的结果时我们通常希望这个结果是可以复现的,就需要保证每一次初始化的参数都不变,这就引入了随机种子。在PyTorch中,通过设置全局随机数种子可以实现这个目的。本文总结了PyTorch中固定随机种子的方法。

**随机种子Demo(快速理解)——以torch.manual_seed(seed)为例

1、首先不使用seed,直接生成随机数

import torch
print(torch.rand(1))  #tensor([0.0448])
print(torch.rand(1))  #tensor([0.6716])
print(torch.rand(1))  #tensor([0.7460])

再次重复执行一次

import torch
print(torch.rand(1))  #tensor([0.3370])
print(torch.rand(1))  #tensor([0.9321])
print(torch.rand(1))  #tensor([0.5633])

2、尝试使用seed,生成随机数

import torch
seed = 1024
torch.manual_seed(seed)
print(torch.rand(1))  #tensor([0.8090])
print(torch.rand(1))  #tensor([0.7935])
print(torch.rand(1))  #tensor([0.2099])

注意我们说的seed会生成相同的随机数是指无论重复多少次程序,上面的三个print本质上是相互独立的,而不是重复的。

重复执行一次

import torch
seed = 1024
torch.manual_seed(seed)
print(torch.rand(1))  #tensor([0.8090])
print(torch.rand(1))  #tensor([0.7935])
print(torch.rand(1))  #tensor([0.2099])

这时我们发现seed真的控制了随机数,使之无论重复多少次都可以得到与上次相同的参数。

二、类似函数的功能

#seed取值范围为: [-9223372036854775808, 18446744073709551615],超出则 RuntimeError 报错
seed = 1024  

#为CPU中设置种子,生成随机数:
torch.manual_seed(seed) 

#为特定GPU设置种子,生成随机数:
torch.cuda.manual_seed(seed)

#为所有GPU设置种子,生成随机数:
torch.cuda.manual_seed_all(seed)#如果使用多个GPU,应该使用torch.cuda.manual_seed_all()
                                #为所有的GPU设置种子

另外,在Pytorch官方文档中说明了在Pytorch的不同提交、不同版本和不同平台上,不能保证完全可重现的结果。此外,即使使用相同的种子,因为存在不同的CPU和GPU,结果也不能重现。为了保证最大成都复现,在运行任何程序之前写入下面代码**(可以放在主代码的开头)**

def seed_torch(seed=1024,cuda_deterministic=True):
    torch.manual_seed(seed)  #为CPU中设置种子
    torch.cuda.manual_seed_all(seed)   #为所有GPU设置种子
    np.random.seed(seed)      
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)    # 为了禁止hash随机化,使得实验可复现
    if cuda_deterministic:  # slower, more reproducible
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:  # faster, less reproducible
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

(参考:https://blog.csdn.net/ytusdc/article/details/125529881

猜你喜欢

转载自blog.csdn.net/weixin_52527544/article/details/127206898