Pytorch重要代码段(持续更新)

1.构建简单网络

方法1:构建一个类

我们自己构建的网络,都要继承自torch的nn.Module模块,自己定义的类里,一定会有__init__ 方法和 forward 方法,前者相当于确定我们有哪些积木,后者决定这些积木怎么搭起来。此处是一个最简单的单隐层神经网络。

class Net(nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net, self).__init__()
        self.hidden=nn.Linear(n_feature,n_hidden)
        self.output=nn.Linear(n_hidden,n_output)

    def forward(self,x):
        x=nn.ReLU(self.hidden(x))
        x=self.output(x)
        return x

net=Net(2,10,2)
print(net)

输出结果:

Net(
  (hidden): Linear(in_features=2, out_features=10, bias=True)
  (output): Linear(in_features=10, out_features=2, bias=True)
)

方法2:利用nn.Sequential

使用nn.Sequential依次列出网络模块和激活函数,是一种更加快速简便的方法。对比可以看出,此处没有hidden,output等我们命名的名称,而是用数字来标识不同的层。注意辨析此处的nn.ReLU 和上文使用的nn.functional.relu .

net=nn.Sequential(
    nn.Linear(3,10),
    nn.ReLU(),
    nn.Linear(10,2)
)
print(net)

输出结果:

Sequential(
  (0): Linear(in_features=3, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=2, bias=True)
)

2.数据装载器

Step1:定义自己的DataSet

首先需要调用相关模块from torch.utils.data import Dataset,DataLoader,然后创建自己的dataset类,一般都要重写__init__,__ len__,__ getitem __ 这几个函数,len是返回数据量大小,getitem是按索引取数,具体的写法要根据自己数据的形式来。创建对象的时候把我们的训练数据装进去即可。

class dataset(Dataset):        # 继承自Dataset类
    def __init__(self,train_data_x,train_data_y):
        self.x=train_data_x
        self.y=train_data_y

    def __len__(self):
        pass

    def __getitem__(self, item):
        pass

Step2:Dataloader

使用dataloader时需要先创建一个Dataloader对象,并传入相应的参数:

  • 第一个参数即我们刚才定义的dataset对象
  • 第二,确定batch_size大小
  • shuffle表示是否要将数据打乱,默认为False,即不打乱
dataloader=DataLoader(dataset,batch_size=5,shuffle=True)

然后利用循环跑里面的数据即可。

3.优化器

定义优化器和损失函数

optim = torch.optim.AdamW(model.parameters())  # 选择优化器
func_loss = nn.CrossEntropyLoss()              # 选择损失函数

在迭代中,进行:

loss = loss_func(y_pred,y_true)                # 计算损失函数
optimizer.zero_grad()                          # 梯度清零
loss.backward()                                # 误差反向传播
optimizer.step()                               # 参数更新

猜你喜欢

转载自blog.csdn.net/codelady_g/article/details/127683055