diff --git a/heatmap.py b/heatmap.py new file mode 100644 index 0000000..9103815 --- /dev/null +++ b/heatmap.py @@ -0,0 +1,33 @@ +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +# 数据矩阵 +data = np.array([ + [78.64, 71.14, 67.81, 91.54], + [76.26, 77.4, 75.94, 91.09], + [77.36, 74.32, 76.42, 92.98], + [76.97, 75.48, 71.99, 94.04] +]) + +# 行列标签 +row_labels = ['TruthfulQA (s)', 'TriviaQA (s)', 'CoQA (s)', 'TydiQA-GP (s)'] +col_labels = ['TruthfulQA (t)', 'TriviaQA (t)', 'CoQA (t)', 'TydiQA-GP (t)'] + +# 设置画布大小 +plt.figure(figsize=(6, 5)) + +# 绘制热图 +ax = sns.heatmap(data, annot=True, fmt=".2f", cmap="Purples", + xticklabels=col_labels, yticklabels=row_labels, + linewidths=0.5, linecolor='gray', cbar=True) + +# 调整刻度标签 +plt.xticks(rotation=30, ha="right") # 旋转 X 轴标签 +plt.yticks(rotation=0) # Y 轴标签保持水平 + +# 设置标题 +plt.title("Transferability results across different datasets.") + +# 显示图像 +plt.show()