头歌--机器学习之感知机

目录

第1关:感知机 - 西瓜好坏自动识别

第2关:scikit-learn感知机实践 - 癌细胞精准识别


第1关:感知机 - 西瓜好坏自动识别

#encoding=utf8
import numpy as np
#构建感知机算法
class Perceptron(object):
    def __init__(self, learning_rate = 0.01, max_iter = 200):
        self.lr = learning_rate
        self.max_iter = max_iter
    def fit(self, data, label):
        '''
        input:data(ndarray):训练数据特征
              label(ndarray):训练数据标签
        output:w(ndarray):训练好的权重
               b(ndarry):训练好的偏置
        '''
        #编写感知机训练方法,w为权重,b为偏置
        self.w = np.array([1.]*data.shape[1])
        self.b = np.array([1.])
        #********* Begin *********#
        i = 0
        while i < self.max_iter:
            flag = True
            for j in range(len(label)):
                if label[j] * (np.inner(self.w, data[j]) + self.b) <= 0:
                    flag = False
                    self.w += self.lr * (label[j] * data[j])
                    self.b += self.lr * label[j]
            if flag: 
                break
            i+=1
        #********* End *********#
    def predict(self, data):
        '''
        input:data(ndarray):测试数据特征
        output:predict(ndarray):预测标签
        '''
        #********* Begin *********#
        y = np.inner(data, self.w) + self.b
        for i in range(len(y)):
            if y[i] >= 0:
                y[i] = 1
            else:
                y[i] = -1
        predict = y
        #********* End *********#
        return predict


第2关:scikit-learn感知机实践 - 癌细胞精准识别

#encoding=utf8
import os

if os.path.exists('./step2/result.csv'):
    os.remove('./step2/result.csv')

#********* Begin *********#
import pandas as pd
train_data = pd.read_csv('./step2/train_data.csv')
train_label = pd.read_csv('./step2/train_label.csv')
train_label = train_label['target']
test_data = pd.read_csv('./step2/test_data.csv')
from sklearn.linear_model.perceptron import Perceptron
clf = Perceptron(eta0 = 0.01,max_iter = 200)
clf.fit(train_data, train_label)
result = clf.predict(test_data)
frameResult = pd.DataFrame({'result':result})
frameResult.to_csv('./step2/result.csv', index = False)
#********* End *********#

猜你喜欢

转载自blog.csdn.net/m0_61059796/article/details/130369811