【写给自己】搭建cnn识别IQ路调制信号

qq_44880660的完善下,代码终于可以成功跑通。对训练函数进行了完善,之后可以套用这个模板。

def train(model, dataloader, itr, device):
    running_loss = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.CrossEntropyLoss()
    model.train()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_func(y_pred,y.argmax())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    ep_loss = running_loss / len(dataloader.dataset)
    print("Epoch: ", itr + 1,
          "Loss: ", round(ep_loss,3))
    return ep_loss

同时发现数据可以包装起来

train_dataset = Data.TensorDataset(x_train, y_train)
train_dataloader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True)

这样子可以批训练,训练会快很多

猜你喜欢

转载自blog.csdn.net/weixin_45121008/article/details/129188546
今日推荐