《TensorFlow机器学习项目实战》人工数据集的最近邻聚类(K-nn)

import tensorflow as tf
import numpy as np
import time

import matplotlib
import matplotlib.pyplot as plt

from sklearn.datasets.samples_generator import make_circles

N=210
K=2
# Maximum number of iterations, if the conditions are not met
MAX_ITERS = 1000
cut=int(N*0.7)

start = time.time()

data, features = make_circles(n_samples=N, shuffle=True, noise= 0.12, factor=0.4)
tr_data, tr_features= data[:cut], features[:cut]
te_data,te_features=data[cut:], features[cut:]

fig, ax = plt.subplots()
ax.scatter(tr_data.transpose()[0], tr_data.transpose()[1], marker = 'o', s = 100, c = tr_features, cmap=plt.cm.coolwarm )
plt.plot()

points=tf.Variable(data)
cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))

sess = tf.Session()
sess.run(tf.initialize_all_variables())

test=[]

for i, j in zip(te_data, te_features):
    distances = tf.reduce_sum(tf.square(tf.subtract(i , tr_data)),reduction_indices=1)# 欧氏距离
    neighbor = tf.arg_min(distances,0)# 最接近点的索引
    
    test.append(tr_features[sess.run(neighbor)])
print(test)
fig, ax = plt.subplots()
ax.scatter(te_data.transpose()[0], te_data.transpose()[1], marker = 'o', s = 100, c = test, cmap=plt.cm.coolwarm )
plt.plot()

end = time.time()
print("Found in %.2f seconds" % (end-start))
print("Cluster assignments:", test)
plt.show()

在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/lly1122334/article/details/87387475