添加 heatmap.py
This commit is contained in:
parent
57f99677c4
commit
844d18af5e
|
|
@ -0,0 +1,33 @@
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 数据矩阵
|
||||||
|
data = np.array([
|
||||||
|
[78.64, 71.14, 67.81, 91.54],
|
||||||
|
[76.26, 77.4, 75.94, 91.09],
|
||||||
|
[77.36, 74.32, 76.42, 92.98],
|
||||||
|
[76.97, 75.48, 71.99, 94.04]
|
||||||
|
])
|
||||||
|
|
||||||
|
# 行列标签
|
||||||
|
row_labels = ['TruthfulQA (s)', 'TriviaQA (s)', 'CoQA (s)', 'TydiQA-GP (s)']
|
||||||
|
col_labels = ['TruthfulQA (t)', 'TriviaQA (t)', 'CoQA (t)', 'TydiQA-GP (t)']
|
||||||
|
|
||||||
|
# 设置画布大小
|
||||||
|
plt.figure(figsize=(6, 5))
|
||||||
|
|
||||||
|
# 绘制热图
|
||||||
|
ax = sns.heatmap(data, annot=True, fmt=".2f", cmap="Purples",
|
||||||
|
xticklabels=col_labels, yticklabels=row_labels,
|
||||||
|
linewidths=0.5, linecolor='gray', cbar=True)
|
||||||
|
|
||||||
|
# 调整刻度标签
|
||||||
|
plt.xticks(rotation=30, ha="right") # 旋转 X 轴标签
|
||||||
|
plt.yticks(rotation=0) # Y 轴标签保持水平
|
||||||
|
|
||||||
|
# 设置标题
|
||||||
|
plt.title("Transferability results across different datasets.")
|
||||||
|
|
||||||
|
# 显示图像
|
||||||
|
plt.show()
|
||||||
Loading…
Reference in New Issue