import torch
import numpy as np
import torchvision
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])# -1 表示最后一列不要
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loder = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.relu = torch.nn.ReLU()#torch.nn.Sigmoid()
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#使用 ReLU + Sigmoid 的结合, 最后一层嵌套 Sigmoidreturn x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)if __name__ =='__main__':for epoch in range(100):
for i,data in enumerate(train_loder,0):
#1. prepare data
inputs,labels = data
#2. Forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,loss.item())#3. Backward
optimizer.zero_grad()
loss.backward()#4. Update
optimizer.step()
#outs:960.6488460302352905960.6837149858474731960.6773790717124939960.5895138382911682960.6454289555549622960.6540895700454712960.634219229221344960.5071176886558533960.6521698832511902960.5890663862228394960.6260353326797485960.6036486625671387960.6821258068084717960.6444405913352966960.7117087244987488960.6727374792098999960.5942217707633972960.572909951210022960.5683532953262329960.5906638503074646960.5729745626449585960.6592338681221008960.5666840672492981960.5628921389579773
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))970.5504413843154907970.5592724680900574970.6298409700393677970.5115721821784973970.5389221906661987970.5870144963264465970.5902682542800903970.6780911684036255970.6389861702919006970.7095761895179749970.6540379524230957970.6255027055740356970.6726484298706055970.6435151100158691970.6405020952224731970.6269117593765259970.6627392172813416970.6168067455291748970.6124289035797119970.713973879814148970.5994608402252197970.5881710648536682970.6652423143386841970.5875065922737122
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))980.6116299033164978980.6460859179496765980.6689874529838562980.5503941774368286980.6141761541366577980.5878812074661255980.692798376083374980.5875721573829651980.5414928793907166980.6070823073387146980.6634365916252136980.6647754311561584980.5649587512016296980.5715025663375854980.6386815905570984980.5531622767448425980.6154085993766785980.6604638695716858980.6641201972961426980.5941271781921387980.6185081005096436980.6462123990058899980.6671874523162842980.685947060585022
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))990.5591648817062378990.7278904914855957990.5913693904876709990.5824277400970459990.6379620432853699990.6897075176239014990.5347273349761963990.6482306718826294990.7141227722167969990.5258570313453674990.6128413081169128990.6075040698051453990.521778404712677990.6718929409980774990.661381721496582990.6632771492004395990.6090944409370422990.5503813624382019990.6275145411491394990.698398232460022990.6591424345970154990.5420821309089661990.6173321604728699990.634967565536499
Process finished with exit code 0