1. Proceso de implementación
1. Preparar datos
La diferencia con el método de PyTorch para implementar la regresión logística para la entrada de características multidimensionales es que este documento usa el método DataLoader y hereda la clase abstracta DataSet para implementar la optimización de descenso de gradiente mini_batch en el conjunto de datos. el código se muestra a continuación:
import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader
class DiabetesDataSet(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:,[-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataSet('G:/datasets/diabetes/diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)
2. Modelo de diseño
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.activate = torch.nn.Sigmoid()
def forward(self, x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.activate(self.linear3(x))
return x
model = Model()
3. Construya la función de pérdida y el optimizador
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
4. Proceso de formación
Cada vez que saca muestras de mini_batch para entrenamiento, el código es el siguiente:
epoch_list = []
loss_list = []
for epoch in range(100):
count = 0
loss1 = 0
for i, data in enumerate(train_loader,0):
# 1.Prepare data
inputs, labels = data
# 2.Forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,i,loss.item())
count += 1
loss1 += loss.item()
# 3.Backward
optimizer.zero_grad()
loss.backward()
# 4.Update
optimizer.step()
epoch_list.append(epoch)
loss_list.append(loss1/count)
5. Visualización de resultados
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
plt.show()