tsne 绘图(CC2)

tsne

# coding: utf-8
import collections
import numpy as np
import os
import pickle
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.manifold import TSNE
    # .......
    X = X+black_verify+white_verify+unknown_verify+bd_verify
    print black_verify_labels+white_verify_labels+unknown_verify_labels+bd_verify_labels
    y = y+black_verify_labels+white_verify_labels+unknown_verify_labels+bd_verify_labels
    print("ALL data check:")
    print("len of X:", len(X))
    print("len of y:", len(y))
    # print(unknown_verify)

    X_embedded = TSNE(n_components=2).fit_transform(X)

    with open("tsne_data_X.pkl", "wb") as f:
        pickle.dump([X_embedded, y], f)
import pickle
from collections import Counter
import numpy as np
import matplotlib.pyplot as Plot

def main():
    with open("tsne_data_X.pkl", "rb") as f:
        [X_embedded, y] = pickle.load(f, encoding='iso-8859-1')

    print(len(X_embedded))
    print(len(y))
    print(X_embedded[:3])
    print(y[:3])
    i = 0
    for l in y:
        if type(l) == type([]):
            raise Exception(str([i,y]))
        i+=1
    print(Counter(y))
    Y, labels = np.array(X_embedded), np.array(y)
    titles = ("white","black","black_verify_labels","white_verify_labels","unknown_verify_labels","bd_verify_labels")
    colors=['b', 'c', 'y', 'm', 'r', 'g', 'peru']
    for i in range(0, 6):
       idx_1 = [i1 for i1 in range(len(labels)) if labels[i1]==i]
       flg1=Plot.scatter(Y[idx_1,0], Y[idx_1,1], 20,color=colors[i],label=titles[i]);
    Plot.legend()
    Plot.savefig('tsne.pdf')
    Plot.show()
main()

 

猜你喜欢

转载自www.cnblogs.com/bonelee/p/9116202.html