From d828c6de9aac94088225e031b7a4b7e078b72eba Mon Sep 17 00:00:00 2001 From: weixin_43297441 Date: Mon, 31 Mar 2025 18:02:46 +0800 Subject: [PATCH] new mat --- hal_det_llama.py | 21 +++++++++++---------- tqa_score.mat | Bin 0 -> 2032 bytes 2 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 tqa_score.mat diff --git a/hal_det_llama.py b/hal_det_llama.py index 5baa091..5ad1882 100644 --- a/hal_det_llama.py +++ b/hal_det_llama.py @@ -10,6 +10,7 @@ import numpy as np import pickle # from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q from utils import mahalanobis_distance +from scipy.io import savemat import llama_iti import pickle import argparse @@ -411,6 +412,10 @@ def main(): inv_C_t= torch.linalg.pinv(C_t) + torch.eye(C_t.shape[0], dtype=int).cuda() * epsilon inv_C_h= torch.linalg.pinv(C_h) + torch.eye(C_h.shape[0], dtype=int).cuda() * epsilon + test_t=torch.from_numpy(embed_generated[:, layer, :]).cuda()-centered_t + test_h=torch.from_numpy(embed_generated[:, layer, :]).cuda()-centered_h + # scores= torch.sqrt(torch.clamp(test_t @ inv_C_t @ test_t.T, min=0.0)) + # - torch.sqrt(torch.clamp(test_h @ inv_C_h @ test_h.T, min=0.0)) scores= torch.sqrt(torch.clamp(centered @ inv_C_t @ centered.T, min=0.0)) - torch.sqrt(torch.clamp(centered @ inv_C_h @ centered.T, min=0.0)) # scores= mahalanobis_distance(torch.from_numpy(embed_generated[:, layer, :]).cuda(), torch.from_numpy(mean_recorded).cuda(), C_) torch.clamp(centered @ inv_C_t @ centered.T, min=0.0) @@ -499,11 +504,6 @@ def main(): allow_pickle=True) else: assert "Not supported!" - # embed_generated = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_layer_wise.npy', - # allow_pickle=True) - # embed_generated = np.load( - # f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_head_wise.npy', - # allow_pickle=True) feat_indices_wild = [] feat_indices_eval = [] @@ -525,9 +525,6 @@ def main(): else: embed_generated_wild = embed_generated[feat_indices_wild] embed_generated_eval = embed_generated[feat_indices_eval] - # print(embed_generated.shape) - # print(embed_generated_h.shape) - # print(embed_generated_t.shape) embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_eval], embed_generated_t[feat_indices_eval] @@ -566,7 +563,10 @@ def main(): assert test_scores.shape[1] == 1 test_scores = np.sqrt(np.sum(np.square(test_scores), axis=1)) - + mdic = {"gt_1": test_scores[gt_label_test == 1], "gt_0": test_scores[gt_label_test == 0], "scale_gt_1":returned_results['best_sign'] * test_scores[gt_label_test == 1], + "scale_gt_0": returned_results['best_sign'] *test_scores[gt_label_test == 0] + } + savemat("tqa_score.mat", mdic) measures = get_measures(returned_results['best_sign'] * test_scores[gt_label_test == 1], returned_results['best_sign'] *test_scores[gt_label_test == 0], plot=False) print_measures(measures[0], measures[1], measures[2], 'direct-projection') @@ -612,6 +612,7 @@ def main(): clf.eval() + torch.save(clf.state_dict(), f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info+ '_{layer}_model_weights.pth') output = clf(torch.from_numpy( embed_generated_test[:, layer, :]).cuda()) pca_wild_score_binary_cls = torch.sigmoid(output) @@ -623,7 +624,7 @@ def main(): breakpoint() measures = get_measures(pca_wild_score_binary_cls[gt_label_test == 1], pca_wild_score_binary_cls[gt_label_test == 0], plot=False) - + # print_measures(measures[0], measures[1], measures[2], 'class-acc') if measures[0] > best_auroc: best_auroc = measures[0] best_result = [100 * measures[0]] diff --git a/tqa_score.mat b/tqa_score.mat new file mode 100644 index 0000000000000000000000000000000000000000..c8be060d040c42faedd14534293886b587cda958 GIT binary patch literal 2032 zcmbW%dra0<90u?gMJ__Gmzh8~d1vXAkX}RuiTC$CNRzx^F3F;1$mHUMn}C;)#8A_D z!Gto>ypHAlQVy_^-*YT+5Q0gf#z31>VZ@x8m?p6vjIFIdn(J(5XXpCIv+w6bN5o8v zm}napXtza2#EedwpOR>skrJPsl%AOyX3I#=n!hZ-7MYnCpPiUsOHT{4MW?6PqT@4d z!9lj5kg(w3upql_jD5^F%l~LjlcTM?HqXgoapCC3?%MNApPo6lEYsPYEY6nX?6{!r z{ZKAOI_;1<&z8uQTNRQ$p};7<;$_U6WK*qCQ&d&XD#2J!>DS__wpUgg-?*=l@_<9) z<>V*t+*m07YLNuqJ0R9W8>P!VMD0&{Bmd}|ApnU|5S}k8gN3U zHoB;FrmIT)YKg#or&wnnmCS`Maz53RNp5j+qh*=YU*BvL7Oxf8+Bn%=?;*pzoK@3< zHhF76ir6D=i09x+N72emHTd1N^7g#l#*a+*R4`kzg-rbsgq&%ij3H3PxWf-HSzPPluO54)v>E5q-#%x3~3CJ znZF;AX}gX{Nw}Yy8t0)Nm5-F0`^seB$a0A}R4j>271Euzz1M%|o%c`P_U^npnRnTu zfokmdRMol3Nj>y=QHq=QN~LwPB;;wV3^^@Vjoh1Ea&E4hD#+R(mvp=uK2!$@G3KCBdJGZ(_0(L=8|? ztjG@hEB_?Q#>}uO<*kbDB&l!*9 z{A~}lrp8;jNsW~K)+!@vFNh_#T^y4;WcaWu+3+w~qCa`b_{GOpZMpBO+M?QJ8^5Sx+{CbMMp-~LGeG(OW9)EB+lATtytG84R^Ju)Str( zgWs3FNhd{p5$kAPP$CnHHp|&t9mbTrbMi_3%Ui}f^I># zCf$N=LARh=O}C(1&@Je;f^Ie4f^JQ^1>J&fLANH|f^Ie4f^I>#Cf$N=LARh=(5*?g znr=b2pj%D1pj*%_=oWMfy47^6=@xVgx&_^WZqc2;dMda7%HRGZ{wM!xk7G^8nvOLc zgN{MRpkvUnrejUVpkvT+1s#KqLC2tD&@t%v6di+(O*+