import csv
import numpy as np
import matplotlib.pyplot as plt
def readData(filename):
"""
读取数据
:param filename: csv格式数据集
:return: X:list with shape[N,d], X1:shape[8,2], X2:shape[9,2]
y:list with shape[N,], y1:shape[8,1], y2:shape[9,1]
"""
X1, X2, y1, y2 = [], [], [], []
density, sugar = [], []
with open(filename) as f:
reader = csv.reader(f)
head_row = next(reader)
for line in reader:
if line[9] == '是':
X1.append([float(line[7]), float(line[8])])
y1.append([float(line[10])])
density.append(float(line[7]))
sugar.append(float(line[8]))
if line[9] == "否":
X2.append([float(line[7]), float(line[8])])
y2.append([float(line[10])])
density.append(float(line[7]))
sugar.append(float(line[8]))
return X1, X2, y1, y2, density, sugar
def LDA(X1, X2):
"""
线性判别分析
:param X1: np.array with shape[8,2]
:param X2: np.array with shape[9,2]
:return: omega: np.array with shape[2,1], LDA最优化参数
"""
mean1 = np.mean(X1, axis=0, keepdims=True) # shape[1,d]
mean2 = np.mean(X2, axis=0, keepdims=True)
Sw = (X1 - mean1).T.dot(X1 - mean1) + (X2 - mean2).T.dot(X2 - mean2) # shape[d,d]
omega = np.linalg.inv(Sw).dot((mean1 - mean2).T) # shape[d,1]
return omega
if __name__ == '__main__':
dataset = "C:\\Users\\14399\\Desktop\\西瓜3.0.csv"
X1, X2, y1, y2, density, sugar = readData(dataset)
# 可视化
plt.plot(density[:8], sugar[:8], 'r+')
plt.plot(density[8:], sugar[8:], 'bo')
# LDA
X1 = np.array(X1)
X2 = np.array(X2)
y1 = np.array(y1)
y2 = np.array(y2)
omega = LDA(X1, X2)
# 画图
lda_left = 0
lda_right = -(omega[0, 0] * 0.9) / omega[1, 0]
plt.plot([0, 0.9], [lda_left, lda_right], 'g')
plt.xlabel('density')
plt.ylabel('sugar')
plt.title('LDA')
plt.show()
结果:
西瓜3.0数据集:链接:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA 密码:3h6n
参考资料:https://blog.csdn.net/victoriaw/article/details/77989610