This commit is contained in:
weixin_43297441 2025-03-31 18:02:46 +08:00
parent ba518178a1
commit d828c6de9a
2 changed files with 11 additions and 10 deletions

View File

@ -10,6 +10,7 @@ import numpy as np
import pickle
# from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q
from utils import mahalanobis_distance
from scipy.io import savemat
import llama_iti
import pickle
import argparse
@ -411,6 +412,10 @@ def main():
inv_C_t= torch.linalg.pinv(C_t) + torch.eye(C_t.shape[0], dtype=int).cuda() * epsilon
inv_C_h= torch.linalg.pinv(C_h) + torch.eye(C_h.shape[0], dtype=int).cuda() * epsilon
test_t=torch.from_numpy(embed_generated[:, layer, :]).cuda()-centered_t
test_h=torch.from_numpy(embed_generated[:, layer, :]).cuda()-centered_h
# scores= torch.sqrt(torch.clamp(test_t @ inv_C_t @ test_t.T, min=0.0))
# - torch.sqrt(torch.clamp(test_h @ inv_C_h @ test_h.T, min=0.0))
scores= torch.sqrt(torch.clamp(centered @ inv_C_t @ centered.T, min=0.0))
- torch.sqrt(torch.clamp(centered @ inv_C_h @ centered.T, min=0.0))
# scores= mahalanobis_distance(torch.from_numpy(embed_generated[:, layer, :]).cuda(), torch.from_numpy(mean_recorded).cuda(), C_) torch.clamp(centered @ inv_C_t @ centered.T, min=0.0)
@ -499,11 +504,6 @@ def main():
allow_pickle=True)
else:
assert "Not supported!"
# embed_generated = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_layer_wise.npy',
# allow_pickle=True)
# embed_generated = np.load(
# f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_head_wise.npy',
# allow_pickle=True)
feat_indices_wild = []
feat_indices_eval = []
@ -525,9 +525,6 @@ def main():
else:
embed_generated_wild = embed_generated[feat_indices_wild]
embed_generated_eval = embed_generated[feat_indices_eval]
# print(embed_generated.shape)
# print(embed_generated_h.shape)
# print(embed_generated_t.shape)
embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_eval], embed_generated_t[feat_indices_eval]
@ -566,7 +563,10 @@ def main():
assert test_scores.shape[1] == 1
test_scores = np.sqrt(np.sum(np.square(test_scores), axis=1))
mdic = {"gt_1": test_scores[gt_label_test == 1], "gt_0": test_scores[gt_label_test == 0], "scale_gt_1":returned_results['best_sign'] * test_scores[gt_label_test == 1],
"scale_gt_0": returned_results['best_sign'] *test_scores[gt_label_test == 0]
}
savemat("tqa_score.mat", mdic)
measures = get_measures(returned_results['best_sign'] * test_scores[gt_label_test == 1],
returned_results['best_sign'] *test_scores[gt_label_test == 0], plot=False)
print_measures(measures[0], measures[1], measures[2], 'direct-projection')
@ -612,6 +612,7 @@ def main():
clf.eval()
torch.save(clf.state_dict(), f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info+ '_{layer}_model_weights.pth')
output = clf(torch.from_numpy(
embed_generated_test[:, layer, :]).cuda())
pca_wild_score_binary_cls = torch.sigmoid(output)
@ -623,7 +624,7 @@ def main():
breakpoint()
measures = get_measures(pca_wild_score_binary_cls[gt_label_test == 1],
pca_wild_score_binary_cls[gt_label_test == 0], plot=False)
# print_measures(measures[0], measures[1], measures[2], 'class-acc')
if measures[0] > best_auroc:
best_auroc = measures[0]
best_result = [100 * measures[0]]

BIN
tqa_score.mat Normal file

Binary file not shown.