Getting started with pytorch learn to load data sets
import torch
import numpy as np
import torchvision
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])# -1 表示最后一列不要
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loder = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.relu = torch.nn.ReLU()#torch.nn.Sigmoid()
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#使用 ReLU + Sigmoid 的结合, 最后一层嵌套 Sigmoidreturn x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)if __name__ =='__main__':for epoch in range(100):
for i,data in enumerate(train_loder,0):
#1. prepare data
inputs,labels = data
#2. Forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,loss.item())#3. Backward
optimizer.zero_grad()
loss.backward()#4. Update
optimizer.step()
#outs:960.6488460302352905960.6837149858474731960.6773790717124939960.5895138382911682960.6454289555549622960.6540895700454712960.634219229221344960.5071176886558533960.6521698832511902960.5890663862228394960.6260353326797485960.6036486625671387960.6821258068084717960.6444405913352966960.7117087244987488960.6727374792098999960.5942217707633972960.572909951210022960.5683532953262329960.5906638503074646960.5729745626449585960.6592338681221008960.5666840672492981960.5628921389579773
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))970.5504413843154907970.5592724680900574970.6298409700393677970.5115721821784973970.5389221906661987970.5870144963264465970.5902682542800903970.6780911684036255970.6389861702919006970.7095761895179749970.6540379524230957970.6255027055740356970.6726484298706055970.6435151100158691970.6405020952224731970.6269117593765259970.6627392172813416970.6168067455291748970.6124289035797119970.713973879814148970.5994608402252197970.5881710648536682970.6652423143386841970.5875065922737122
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))980.6116299033164978980.6460859179496765980.6689874529838562980.5503941774368286980.6141761541366577980.5878812074661255980.692798376083374980.5875721573829651980.5414928793907166980.6070823073387146980.6634365916252136980.6647754311561584980.5649587512016296980.5715025663375854980.6386815905570984980.5531622767448425980.6154085993766785980.6604638695716858980.6641201972961426980.5941271781921387980.6185081005096436980.6462123990058899980.6671874523162842980.685947060585022
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
D:\Anaconda3\envs\pcd\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average andreduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))990.5591648817062378990.7278904914855957990.5913693904876709990.5824277400970459990.6379620432853699990.6897075176239014990.5347273349761963990.6482306718826294990.7141227722167969990.5258570313453674990.6128413081169128990.6075040698051453990.521778404712677990.6718929409980774990.661381721496582990.6632771492004395990.6090944409370422990.5503813624382019990.6275145411491394990.698398232460022990.6591424345970154990.5420821309089661990.6173321604728699990.634967565536499
Process finished with exit code 0