Fully connected neural network combat

Fully connected neural network combat

Refer to books and blogs, aim to understand the meaning of each sentence

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, losses

def load_dataset():
    dataset_path = keras.utils.get_file('auto-mpg.data','http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data')
    column_names = ['MPG','Cylinders','Displacement','Hoursepower','Weight','Acceleration', 'Model Year', 'Origin']
    raw_dataset = pd.read_csv(dataset_path, names=column_names,
                              na_values="?", comment='\t',
                              sep=" ", skipinitialspace=True)
    dataset = raw_dataset.copy()
    return dataset
dataset = load_dataset()
# 查看部分数据
dataset.head()

def preprocess_dataset(dataset):
    dataset = dataset.copy()
    #去除缺失值
    dataset = dataset.dropna()
    #将分类数字弹出
    origin = dataset.pop('Origin')
    #将分类数字化为虚拟变量
    dataset['USA'] = (origin==1)*1.0
    dataset['Europe'] = (origin==2)*1.0
    dataset['Japan'] = (origin == 3)*1.0
    #划分测试集与训练集
    train_dataset = dataset.sample(frac = 0.8,random_state = 0)
    test_dataset = dataset.drop(train_dataset.index)
    return train_dataset,test_dataset
train_dataset,test_dataset = preprocess_dataset(dataset)
train_dataset.head()
test_dataset.head()

#统计数据瞧瞧
sns_plot = sns.pairplot(train_dataset[['Cylinders','Displacement','Weight','MPG']],diag_kind ='kde')
#描述性统计分析
train_stats = train_dataset.describe()
train_stats.pop('MPG')
train_stats=train_stats.transpose()
train_stats

#数据格式化
def norm(x,train_stats):
    return (x-train_stats['mean'])/train_stats['std']

train_labels = train_dataset.pop('MPG')
test_labels = test_dataset.pop('MPG')
norm_train_dataset = norm(train_dataset,train_stats)
norm_test_dataset = norm(test_dataset,train_stats)
print(norm_train_dataset.shape,train_labels.shape)
print(norm_test_dataset.shape,test_labels.shape)

class Network(keras.Model):
    def __init__(self):
        super(Network,self).__init__()#调用父类(keras.Model的__init__())
        #创建三层网络
        self.fc1 = layers.Dense(64,activation='relu')
        self.fc2 = layers.Dense(64,activation='relu')
        self.fc3 = layers.Dense(1)
    def call(self,inputs):
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

def build_model():
    model = Network()
    model.build(input_shape=(4,9)) #不要忘记人家继承了父类
    model.summary()
    return model

model = build_model()
optimizer = tf.keras.optimizers.RMSprop(0.001)
train_db = tf.data.Dataset.from_tensor_slices((norm_train_dataset.values,train_labels.values))#打标签
train_db = train_db.shuffle(100).batch(32)

def train(model,train_db,optimizer,norm_test_data,test_labels):
    train_mae_losses = []
    test_mae_losses = []
    for epoch in range(200):
        for step,(x,y) in enumerate(train_db):
            with tf.GradientTape() as Tape:
                out = model(x)
                loss = tf.reduce_mean(losses.MSE(y,out))
                mae_loss = tf.reduce_mean(losses.MAE(y,out))
            if step % 10 ==0:
                print(epoch,step,float(loss))
            grads =Tape.gradient(loss,model.trainable_variables)
            optimizer.apply_gradients(zip(grads,model.trainable_variables))
        train_mae_losses.append(float(mae_loss))
        out = model(tf.constant(norm_test_dataset.values))
        test_mae_losses.append(tf.reduce_mean(losses.MAE(test_labels,out)))
    return train_mae_losses, test_mae_losses


def plot(train_mae_losses, test_mae_losses):
    plt.figure()
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.plot(train_mae_losses, label='Train')
    plt.plot(test_mae_losses, label='test')
    plt.legend()
    plt.legend()
    plt.show()

train_mae_losses,test_mae_losses = train(model,train_db,optimizer,norm_test_dataset,test_labels)
plot(train_mae_losses, test_mae_losses)

Insert picture description here
Insert picture description here

Guess you like

Origin blog.csdn.net/qq_42830971/article/details/112532616