pytorch基本

pytorch主要分为以下几个模块来训练模型:

    tensor:tensor为基本结构,可以直接创建,从list创建以及由numpy数组得到,torch还提供一套运算以及shape变换方式。
    Variable:自动求导机制,利用Variable包装tensor后,便可以使用其求导的功能了,有点像个装饰器。
    nn:nn模块是整个pytorch的核心,自己设计的Net(),继承nn.Model后可以提取模型参数,进行前向forward()运算(自己设计),以及后向运算(自动),nn提供基本网络结构单元,例如nn.Linear(),nn.Conv2d()等,还提供基本损失函数nn.CrossEntropyLoss等。
    torch.optim:该模块提供自动求导更新参数等功能,用它封装模型参数nn.parameter()后,loss求导后,可以用.step来更新整个参数。
    torch.utils.data.DataSet:该模块提供加载数据初始化的方式,完善好getitem和len的接口后,便可以利用DataLoader多进程加载数据。

参考:

https://blog.csdn.net/qq_16949707/article/details/79067474

猜你喜欢

转载自blog.csdn.net/weixin_39752599/article/details/84726957