matplotlib----各种画图实现代码集锦

color_bar:

t-SNE  Embedding分布图:

#自己代码的版本
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib.ticker as ticker
import pickle

from sklearn import datasets
from sklearn.manifold import TSNE


import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

dataset='retailrocket'  #-------------------------------------------------------------------------------------------------------------------
# dataset='Tmall'

if dataset == 'Tmall':
    behaviors = ['pv','fav', 'cart', 'buy']
elif dataset == 'IJCAI_15':
    behaviors = ['click','fav', 'cart', 'buy']
elif dataset == 'JD':
    behaviors = ['review','browse', 'buy']
elif dataset == 'retailrocket':
    behaviors = ['view','cart', 'buy']


#'red', 'cyan', 'blue', 'green', 'black', 'magenta', 'pink', 'purple','chocolate', 'orange', 'steelblue', 'crimson', 'lightgreen', 'salmon','gold', 'darkred'
#'x', 's', '^', 'o', 'v', '^', '<', '>', 'D'
c = ['red', 'cyan', 'blue', 'black', 'green']
m = ['x', 's', '^', 'o', 'v', '^', '<', '>', 'D']

# #SSL Tmall
# Tmall_2D_NOSSL_NOBe
# loadPath = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/Tmall/for_t_SNE_2th_main_MB_GCN_multi_Tmall_2021_06_02__09_00_54_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"
# Tmall_2D_NOSSL
# loadPath_NOSSL = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/Tmall/for_t_SNE_2th_main_MB_GCN_multi_behavior_NOSSL_Tmall_2021_06_02__09_04_42_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"
# Tmall_2D
# loadPath_SSL = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/Tmall/for_t_SNE_3th_main_MB_GCN_multi_behavior_SSL_Tmall_2021_06_03__16_38_38_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"





#-------------------------------------------------------------------------------------------------------------------
#SSL retailrocket
# retailrocket_2D_NOSSL_NOBe
loadPath = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/retailrocket/for_t_SNE_2th_main_MB_GCN_multi_retailrocket_2021_06_02__09_01_43_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"
# retailrocket_2D_NOSSL
# loadPath_NOSSL = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/retailrocket/for_t_SNE_2th_main_MB_GCN_multi_behavior_NOSSL_retailrocket_2021_06_02__09_05_21_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"
# retailrocket_2D
# loadPath_SSL = "/home/ww/Code/work3/T_Meta_SSL_MBGCN/Model/retailrocket/for_t_SNE_3th_main_MB_GCN_multi_behavior_SSL_retailrocket_2021_06_03__16_39_44_lr_0.001_reg_0.001_batch_size_1024_time_slot_31104000_gnn_layer_[16,16,16].pth"



#-------------------------------------------------------------------------------------------------------------------
# loadpath_tst = "/home/ww/Code/DATASET/Tmall/tst_int"
loadpath_tst = "/home/ww/Code/DATASET/work3_dataset/retailrocket/tst_int"
test_data = pickle.load(open(loadpath_tst,'rb'))
test_user = np.array([idx for idx, i in enumerate(test_data) if i is not None])
test_item = np.array([i for idx, i in enumerate(test_data) if i is not None])
        
# test_user = np.random.choice(np.array([idx for idx, i in enumerate(test_data) if i is not None]), size=5000)
# test_item = np.random.choice(np.array([i for idx, i in enumerate(test_data) if i is not None]), size=5000)


checkpoint= torch.load(loadPath)
# checkpoint_NOSSL = torch.load(loadPath_NOSSL)
# checkpoint_SSL = torch.load(loadPath_SSL)
# model = checkpoint['model']
# model_NOSSL = checkpoint_NOSSL['model']
# model_SSL = checkpoint_SSL['model']
# params = model.state_dict()


def get_data(embedding, beh_index, test_index):

    data = embedding[test_index].detach().numpy()
    # data = np.array(embedding)                                                 #[1797, 64]
    label = np.empty(data.shape[0], dtype=np.int32)
    label.fill(beh_index)       #[1797]
    n_samples, n_features = embedding.shape                                        #1797, 64
    return data, label, n_samples, n_features



# 对样本进行预处理并画图
def plot_embedding(data, label, title):
    """
    :param data:数据集
    :param label:样本标签
    :param title:图像标题
    :return:图像
    """
    x_min, x_max = np.min(data, 0), np.max(data, 0)
    data = (data - x_min) / (x_max - x_min)     # 对数据进行归一化处理
    fig = plt.figure()      # 创建图形实例
    ax = plt.subplot(111)       # 创建子图
    # 遍历所有样本
    for i in range(data.shape[0]):
        # 在图中为每个数据点画出标签
        # plt.text(data[i, 0], data[i, 1], str(label[i]/10), color=plt.cm.Set1(label[i] / 10),
        plt.text(data[i, 0], data[i, 1], str(label[i]), color=c[label[i]],
                 fontdict={'weight': 'normal', 'size': 10})
    plt.xticks()        # 指定坐标的刻度
    plt.yticks()
    plt.title(title, fontsize=14)
    # 返回值
    return fig


def plot_embedding_3D(data, label, title):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    n = len(test_user)
    # plt.rcParams["le"]

    for i in range(data.shape[0]):
        # xs = np.arange(data.shape[1]) 
        # ys = np.arange(data.shape[1])
        xs = data[i].swapaxes(0,1)[0] 
        ys = data[i].swapaxes(0,1)[1]
        zs = data[i].swapaxes(0,1)[2]
        ax.scatter(xs, ys, zs, c=c[i], alpha=0.1, s=0.5)
    ax.set_xlabel('X Label')
    ax.set_xlabel('Y Label')
    ax.set_xlabel('Z Label')
    return fig

def get_input():
    input_data = None
    input_label = None
    for i in range(len(behaviors)):
        if i == (len(behaviors)-1):
            data, label , n_samples, n_features = get_data(checkpoint['user_embed'].cpu(), i, test_user)        #[1797, 64]    [1797]    1797   64
        else:
            data, label , n_samples, n_features = get_data(checkpoint['user_embeds'][i].cpu(), i, test_user)  #---------------------------------

        if i == 0:
            input_data = data
            input_label = label
        else:
            input_data = np.vstack((input_data, data))
            input_label = np.hstack((input_label, label))
        print(f"data.shape: {data.shape}")
        print(f"input_data.shape: {input_data.shape}")
        print(f"label.shape: {label.shape}")
        print(f"input_label.shape: {input_label.shape}")
    return input_data, input_label

def get_input_3D():
    input_data = [None]*len(behaviors)
    input_label = [None]*len(behaviors)
    
    for i in range(len(behaviors)):
        if i == (len(behaviors)-1):
            input_data[i], input_label[i] , n_samples, n_features = get_data(checkpoint_SSL['user_embed'].cpu(), i, test_user)        #[1797, 64]    [1797]    1797   64
        else:
            input_data[i], input_label[i] , n_samples, n_features = get_data(checkpoint_SSL['user_embeds'][i].cpu(), i, test_user)

        # print(f"data.shape: {data.shape}")
        print(f"input_data.shape: {input_data[i].shape}")
        # print(f"label.shape: {input_data[i].shape}")
        print(f"input_label.shape: {input_label[i].shape}")
    input_data = np.array(input_data)
    input_label = np.array(input_label)
    return input_data, input_label 




#--------------------------------------------------------------------------------------------------------
input_data, input_label = get_input()
print('Starting compute t-SNE Embedding...')
ts = TSNE(n_components=3, init='pca', random_state=0)
# t-SNE降维
result = ts.fit_transform(input_data)  #[1797, 2]
pickle.dump(result, open('/home/ww/Code/work3/T_Meta_SSL_MBGCN/TSNE/'+dataset+'_2D_NOSSL_NOBe', 'wb'))
#--------------------------------------------------------------------------------------------------------


# #--------------------------------------------------------------------------------------------------------
# ts = [None]*len(behaviors)
# result = [None]*len(behaviors)

# input_data, input_label = get_input_3D()
# print('Starting compute t-SNE Embedding...')
# for i in range(input_data.shape[0]):
#     ts[i]= TSNE(n_components=3, init='pca', random_state=0)
#     # t-SNE降维
#     result[i] = ts[i].fit_transform(input_data[i])  #[1797, 2]
# pickle.dump(result, open('/home/ww/Code/work3/T_Meta_SSL_MBGCN/TSNE/'+dataset,'wb'))
# print('Starting compute t-SNE Embedding...')
# #--------------------------------------------------------------------------------------------------------

result = pickle.load(open('/home/ww/Code/work3/T_Meta_SSL_MBGCN/TSNE/'+dataset+'_2D_NOSSL_NOBe','rb'))
# # 调用函数,绘制图像
fig = plot_embedding(result, input_label, 't-SNE Embedding of digits')   #
# 调用函数,绘制图像
# fig = plot_embedding_3D(np.array(result), input_label, 't-SNE Embedding of digits')   #


# 显示图像
plt.show()



# fig.colorbar(cax)
# plt.savefig(f'D:\CODE\master_behavior_attention\Picture\multi_head_self_attentionself_attention{user_number}.pdf')
# plt.savefig('D:\CODE\master_behavior_attention\Picture\self_attentionself_attention6.jpg')
# plt.savefig('D:\CODE\master_behavior_attention\Pictureattentionself_attention.jpg')
# plt.show()
# PlotMats(self_attention_ndarray = params['self_attention_para'].cpu().numpy(), , show=False, savePath='visualization/legend.pdf', vrange=[0, 1])

超参实验画图代码:

from matplotlib import pyplot as plt
from matplotlib import font_manager
import matplotlib.ticker as ticker



#maker
lineColors = ['#0e72cc', '#6ca30f', '#f59311', '#16afcc', '#555555', '#fa4343']
lineMarkers = ['x', 'D', 's', '+', 'o']
# lineMarkers = ['v', '^', '<', '>', 'D', 'o']


# y_1 = [1,0,1,1,2,4,3,2,3,4,4,5,6,5,4,3,3,1,1,1]
# y_2= [1,0,3,1,2,2,3,3,2,1 ,2,1,1,1,1,1,1,1,1,1]
#time_slot
# y1_0 = [-2.110, 0.220, 0, -0.330]
# y1_1 = [-0.840, -0.290, 0, -0.420]
# y1_2 = [-1.060, -0.390, 0, 0.350]

y1_0 = [-0.0211, 0.0022, 0, -0.0023]
y1_1 = [-0.0084, -0.0019, 0, -0.0042]
y1_2 = [-0.0106, -0.0039, 0, 0.0035]
x1 = ['slot1','slot2', 'slot3', 'slot4']

#gnn_layer
y2_0 = [-0.0133, -0.0105, 0, -0.0041]
y2_1 = [-0.0023, 0.0034, 0, -0.0012]
y2_2 = [-0.0033, -0.0012, 0,-0.0043]

# y2_0 = [-1.33, -1.05, 0, -0.41]
# y2_1 = [-0.23, 0.34, 0, -0.12]
# y2_2 = [-0.33, -0.12, 0,-0.43]
x2 = ['1 layer','2 layer', '3 layer', '4 layer']

#hidden_dim
# y3_0 = [-2.21, -1.84, -0.62, 0 ,0.25]
# y3_1 = [-1.06, -0.71, -0.24, 0, -0.11]
# y3_2 = [-0.75, -0.46, -0.14, 0, -0.15]

y3_0 = [-0.0221, -0.0184, -0.0062, 0 ,0.0025]
y3_1 = [-0.056, -0.0071, -0.0024, 0, -0.0011]
y3_2 = [-0.0085, -0.0036, -0.0004, 0, -0.0035]
x3 = ['8', '16', '32', '64', '128']


# x=range(4)


# 设置图形大小
# plt.figure(figsize=(20,8),dpi=80)
width = 10
height = 2
plt.figure(1, figsize=(width, height))


#绘图,通过label指定图例内容,线条颜色,线条风格
plt.subplot(131)
plt.plot(x1,y1_0,label="TaoBao Data",color=lineColors[0],marker=lineMarkers[0])
plt.plot(x1,y1_1,label="IJCAI Contest",color=lineColors[1],marker=lineMarkers[1])
plt.plot(x1,y1_2,label="JD Data",color=lineColors[2],marker=lineMarkers[2])
plt.grid(True)
# plt.legend(loc="lower right")

plt.subplot(132)
plt.plot(x2,y2_0,label="TaoBao Data",color=lineColors[0],marker=lineMarkers[0])
plt.plot(x2,y2_1,label="IJCAI Contest",color=lineColors[1],marker=lineMarkers[1])
plt.plot(x2,y2_2,label="JD Data",color=lineColors[2],marker=lineMarkers[2])
plt.grid(True)
# plt.legend(loc="lower right")

plt.subplot(133)
plt.plot(x3,y3_0,label="TaoBao Data",color=lineColors[0],marker=lineMarkers[0])
plt.plot(x3,y3_1,label="IJCAI Contest",color=lineColors[1],marker=lineMarkers[1])
plt.plot(x3,y3_2,label="JD Data",color=lineColors[2],marker=lineMarkers[2])
plt.grid(True)
plt.legend(loc="lower right")
# 控制坐标轴保留的小数位数
# plt.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e')
# plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.6f'))  #是数值的问题不是坐标轴的问题
                              
#设置x轴和y轴的刻度
# _xtick_labels=["{}岁".format(i) for i in x]
# plt.xticks(x,_xtick_labels)
# plt.yticks(range(0,0.5))

#设置网格线,透明度
# plt.grid(alpha=1,linestyle="--")
# plt.grid(True)



#添加图例加以说明,这里注意中文显示 需要设置prop 只有legend中文是prop=my_font
# plt.legend(prop=my_font,loc="upper left")
# plt.legend(loc="lower right")
                        # loc="upper left"调整图例的位置,默认是右上角
#调整子图之间的关系
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=None)

#显示图形
plt.show()
plt.savefig("")

"""
c–cyan–青色
r–red–红色
m–magente–品红
g–green–绿色
b–blue–蓝色
y–yellow–黄色
k–black–黑色
w–white–白色
"""

"""
– 虚线
-. 形式即为-.
: 细小的虚线
"""

"""
s–方形
h–六角形
H–六角形
*–*形
±-加号
x–x形
d–菱形
D–菱形
p–五角形

"""

一张图里面画好几个折线图的代码:

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

names = ['0.2-0.4', '0.4-0.6', '0.6-0.8', '0.8-1.0']

lines = {
	'SVM': [0.1738, 0.3996, 0.6976, 0.8942],
	'STGCN': [0.0226, 0.4558, 0.8159, 0.9317],
	'GMAN': [0.2318, 0.5361, 0.8269, 0.9353],
	'DeepCrime': [0.2271, 0.5811, 0.7930, 0.9296],
	'ST-MetaNet': [0.0653, 0.3146, 0.9509, 0.9999],
	'ST-SHN': [0.3815, 0.6204, 0.8272, 0.9353],
}

l_names = [
    'SVM',
	'STGCN',
	'GMAN',
	'DeepCrime',
	'ST-MetaNet',
	'ST-SHN',
]

colors = [
    'red', 'gold', 'darkred', 'black', 'navy'
    ]

lineColors = ['#0e72cc', '#6ca30f', '#f59311', '#16afcc', '#555555', '#fa4343']
# lineMarkers = ['x', 'D', 's', '+', 'o']
lineMarkers = ['v', '^', '<', '>', 'D', 'o']

x = [i*0.3 for i in range(len(names))]

i = 0
for lineName in lines:
    
    plt.figure(1)  #, figsize=(width,height))

    plt.subplot(111)    
    plt.plot(x, lines[lineName], color=lineColors[i], label=l_names[i], marker='.')
    # plt.legend()
    plt.grid(True)

    plt.grid(True)

    plt.legend(list(lines.keys()), loc='lower right', ncol=2, framealpha=1, fancybox=False, handlelength=1.2, handleheight=1.2, handletextpad=0.4, labelspacing=0.4, columnspacing=0.2, fontsize=10, borderaxespad=-0.25, borderpad=0.2)

    fig.tight_layout()    
    plt.subplots_adjust(top=0.958, bottom=0.076, left=0.106, right=0.85, hspace=0.2, wspace=0.2)

    i += 1

plt.show()
    # plt.savefig('figures/sparsity_%s.pdf' % title)

画attention图的原始代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.ticker as ticker

a = torch.randn(4, 2)
b = a.softmax(dim=1)
c = a.softmax(dim=0).transpose(0, 1)
print(a, '\n',  b, '\n', c)
d = b.matmul(c)
print(d)

d = d.numpy()
variables = ['A','B','C','X']
labels = ['ID_0','ID_1','ID_2','ID_3']

df = pd.DataFrame(d, columns=variables, index=labels)

fig = plt.figure()

ax = fig.add_subplot(111)

cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
fig.colorbar(cax)

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

ax.set_xticklabels([''] + list(df.columns))
ax.set_yticklabels([''] + list(df.index))

plt.show()

参照博客: https://blog.csdn.net/m0_38133212/article/details/86664569

おすすめ

転載: blog.csdn.net/weiwei935707936/article/details/113795344