目录
8、加载数据集
windows需要在下面的代码前面加上
if __name__ == '__main__':
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
============================================
时间:2021.8.
作者:手可摘星辰不去高声语
文件名:.py
功能:
1、Ctrl + Enter 在下方新建行但不移动光标;
2、Shift + Enter 在下方新建行并移到新行行首;
3、Shift + Enter 任意位置换行
4、Ctrl + D 向下复制当前行
5、Ctrl + Y 删除当前行
6、Ctrl + Shift + V 打开剪切板
7、Ctrl + / 注释(取消注释)选择的行;
8、Ctrl + E 可打开最近访问过的文件
9、Double Shift + / 万能搜索
============================================
"""
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
self.xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = self.xy.shape[0]
self.x_data = torch.from_numpy(self.xy[:, :-1])
self.y_data = torch.from_numpy(self.xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.Tanh()
def forward(self, x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
dataset = DiabetesDataset("E:/AI学习/PyTouch/刘普洪:《PyTorch深度学习实践》完结合集/Code/DataSet/diabetes.csv")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=1)
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimzer = torch.optim.SGD(model.parameters(), lr=0.1)
if __name__ == '__main__':
for epoch in range(100):
for i, (x, y) in enumerate(train_loader, 0):
y_pred = model(x)
loss = criterion(y_pred, y)
print("Epoch:{}\tBCE:{}".format(epoch, loss))
optimzer.zero_grad()
loss.backward()
optimzer.step()
# y_t = model(x_test_data)
# y = y_t.data.numpy()
# result = np.where(y > 0.5, 1, 0) # 满足大于0.5的值保留,不满足的设为0
# result = (result - y_test_data.numpy()) # 同真实值相减,如果一致,则为零,寻找0的个数即为测试正确的个数
# result = np.count_nonzero(result < 0.1) # 寻找矩阵中小于0.1的数的个数,就是寻找0个个数
# result = result / 59
# print("Test Result:", result)