tSNE降维 样例代码
import numpy as np
from sklearn.manifold import TSNE
# For the UCI ML handwritten digits dataset
from sklearn.datasets import load_digits
# Import matplotlib for plotting graphs ans seaborn for attractive graphics.
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sns
def plot(x, colors):
# Choosing color palette
# https://seaborn.pydata.org/generated/seaborn.color_palette.html
palette = np.array(sns.color_palette("pastel", 10))
# pastel, husl, and so on
# Create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=palette[colors.astype(np.int8)])
# Add the labels for each digit.
txts = []
for i in range(10):
# Position of each label.
xtext, ytext = np.median(x[colors == i, :], axis=0)
txt = ax.text(xtext, ytext, str(i), fontsize=24)
txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])
txts.append(txt)
plt.savefig('./digits_tsne-pastel.png', dpi=120)
return f, ax, txts
digits = load_digits()
print(digits.data.shape)
# There are 10 classes (0 to 9) with alomst 180 images in each class
# The images are 8x8 and hence 64 pixels(dimensions)
# Place the arrays of data of each digit on top of each other and store in X
X = np.vstack([digits.data[digits.target==i] for i in range(10)])
# Place the arrays of data of each target digit by the side of each other continuosly and store in Y
Y = np.hstack([digits.target[digits.target==i] for i in range(10)])
# Implementing the TSNE Function - ah Scikit learn makes it so easy!
digits_final = TSNE(perplexity=30).fit_transform(X)
# Play around with varying the parameters like perplexity, random_state to get different plots
plot(digits_final, Y)
def plot2(data, x='x', y='y'):
sns.set_context("notebook", font_scale=1.1)
sns.set_style("ticks")
sns.lmplot(x=x,
y=y,
data=data,
fit_reg=False,
legend=True,
height=9,
hue='Label',
scatter_kws={
"s":200, "alpha":0.3})
plt.title('t-SNE Results: Digits', weight='bold').set_fontsize('14')
plt.xlabel(x, weight='bold').set_fontsize('10')
plt.ylabel(y, weight='bold').set_fontsize('10')
plt.savefig('./digits_tsne-plot2.png', dpi=120)
import pandas as pd
data = {
'x': digits_final[:, 0],
'y': digits_final[:, 1],
'Label': Y}
data = pd.DataFrame(data)
plot2(data)
结果如下所示: