import os import torch import torch.nn.functional as F import evaluate from datasets import load_metric from datasets import load_dataset import datasets from tqdm import tqdm import numpy as np import pickle # from utils import get_llama_activations_bau, tokenized_tqa, tokenized_tqa_gen, tokenized_tqa_gen_end_q from utils import mahalanobis_distance import llama_iti import pickle import argparse import matplotlib.pyplot as plt from pprint import pprint from baukit import Trace, TraceDict from metric_utils import get_measures, print_measures import re from torch.autograd import Variable from scipy.spatial import distance from sklearn.linear_model import Perceptron from sklearn.ensemble import GradientBoostingRegressor def seed_everything(seed: int): import random, os import numpy as np import torch random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True HF_NAMES = { 'llama_7B': 'baffo32/decapoda-research-llama-7B-hf', 'honest_llama_7B': 'validation/results_dump/llama_7B_seed_42_top_48_heads_alpha_15', 'alpaca_7B': 'circulus/alpaca-7b', 'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 'llama2_chat_7B': 'models/Llama-2-7b-chat-hf', 'llama2_chat_13B': 'models/Llama-2-13b-chat-hf', 'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', } def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='llama2_chat_7B') parser.add_argument('--model_name', type=str, default='step-1-8k') parser.add_argument('--dataset_name', type=str, default='tqa') parser.add_argument('--num_gene', type=int, default=1) parser.add_argument('--use_rouge', type=bool, default= False) parser.add_argument('--weighted_svd', type=int, default=0) parser.add_argument('--feat_loc_svd', type=int, default=1) parser.add_argument('--wild_ratio', type=float, default=0.75) parser.add_argument('--thres_gt', type=float, default=0.5) parser.add_argument('--most_likely', type=bool, default=True) parser.add_argument("--model_dir", type=str, default=None, help='local directory with model data') args = parser.parse_args() MODEL = HF_NAMES[args.model] if not args.model_dir else args.model_dir if args.dataset_name == "tqa": dataset = load_dataset("truthful_qa", 'generation')['validation'] elif args.dataset_name == 'triviaqa': dataset = load_dataset("trivia_qa", "rc.nocontext", split="validation") id_mem = set() def remove_dups(batch): if batch['question_id'][0] in id_mem: return {_: [] for _ in batch.keys()} id_mem.add(batch['question_id'][0]) return batch dataset = dataset.map(remove_dups, batch_size=1, batched=True, load_from_cache_file=False) elif args.dataset_name == 'tydiqa': dataset = datasets.load_dataset("tydiqa", "secondary_task", split="train") used_indices = [] for i in range(len(dataset)): if 'english' in dataset[i]['id']: used_indices.append(i) elif args.dataset_name == 'coqa': import json import pandas as pd from datasets import Dataset def _save_dataset(): # https://github.com/lorenzkuhn/semantic_uncertainty/blob/main/code/parse_coqa.py save_path = f'./coqa_dataset' if not os.path.exists(save_path): # https://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json with open(f'./coqa-dev-v1.0.json', 'r') as infile: data = json.load(infile)['data'] dataset = {} dataset['story'] = [] dataset['question'] = [] dataset['answer'] = [] dataset['additional_answers'] = [] dataset['id'] = [] for sample_id, sample in enumerate(data): story = sample['story'] questions = sample['questions'] answers = sample['answers'] additional_answers = sample['additional_answers'] for question_index, question in enumerate(questions): dataset['story'].append(story) dataset['question'].append(question['input_text']) dataset['answer'].append({ 'text': answers[question_index]['input_text'], 'answer_start': answers[question_index]['span_start'] }) dataset['id'].append(sample['id'] + '_' + str(question_index)) additional_answers_list = [] for i in range(3): additional_answers_list.append(additional_answers[str(i)][question_index]['input_text']) dataset['additional_answers'].append(additional_answers_list) story = story + ' Q: ' + question['input_text'] + ' A: ' + answers[question_index]['input_text'] if not story[-1] == '.': story = story + '.' dataset_df = pd.DataFrame.from_dict(dataset) dataset = Dataset.from_pandas(dataset_df) dataset.save_to_disk(save_path) return save_path # dataset = datasets.load_from_disk(_save_dataset()) def get_dataset(tokenizer, split='validation'): # from https://github.com/lorenzkuhn/semantic_uncertainty/blob/main/code/parse_coqa.py dataset = datasets.load_from_disk(_save_dataset()) id_to_question_mapping = dict(zip(dataset['id'], dataset['question'])) def encode_coqa(example): example['answer'] = [example['answer']['text']] + example['additional_answers'] example['prompt'] = prompt = example['story'] + ' Q: ' + example['question'] + ' A:' return tokenizer(prompt, truncation=False, padding=False) dataset = dataset.map(encode_coqa, batched=False, load_from_cache_file=False) dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'], output_all_columns=True) return dataset dataset = get_dataset(llama_iti.LlamaTokenizer.from_pretrained(MODEL, trust_remote_code=True)) else: raise ValueError("Invalid dataset name") tokenizer = llama_iti.LlamaTokenizer.from_pretrained(MODEL, trust_remote_code=True) model = llama_iti.LlamaForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto").cuda() HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(model.config.num_hidden_layers)] MLPS = [f"model.layers.{i}.mlp" for i in range(model.config.num_hidden_layers)] # firstly get the embeddings of the generated question and answers. embed_generated = [] embed_generated_h =[] embed_generated_t=[] if args.dataset_name == 'tydiqa': length = len(used_indices) else: length = len(dataset) for i in tqdm(range(length)): if args.dataset_name == 'tydiqa': question = dataset[int(used_indices[i])]['question'] else: question = dataset[i]['question'] if args.most_likely: info = 'most_likely_' else: info = 'batch_generations_' answers = np.load( f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/answers/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') for anw in answers: if args.dataset_name == 'tydiqa': prompt = tokenizer( "Concisely answer the following question based on the information in the given passage: \n" + \ " Passage: " + dataset[int(used_indices[i])]['context'] + " \n Q: " + question + " \n A:" + anw, return_tensors='pt').input_ids.cuda() elif args.dataset_name == 'coqa': prompt = tokenizer(dataset[i]['prompt'] + anw, return_tensors='pt').input_ids.cuda() else: prompt = tokenizer( f"Answer the question concisely. Q: {question}" + " A:" + anw, return_tensors='pt').input_ids.cuda() with torch.no_grad(): hidden_states = model(prompt, output_hidden_states=True).hidden_states hidden_states = torch.stack(hidden_states, dim=0).squeeze() hidden_states = hidden_states.detach().cpu().numpy()[:, -1, :] embed_generated.append(hidden_states) embed_generated = np.asarray(np.stack(embed_generated), dtype=np.float32) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_layer_wise.npy', embed_generated) embed_generated_t_loc2 = [] embed_generated_t_loc1 = [] embed_generated_h_loc2 = [] embed_generated_h_loc1 = [] embed_generated_loc2 = [] embed_generated_loc1 = [] for i in tqdm(range(length)): if args.dataset_name == 'tydiqa': question = dataset[int(used_indices[i])]['question'] else: question = dataset[i]['question'] answers = np.load( f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/answers/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy') truths= np.load( f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/truths/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_truths_index_{i}.npy') hallucinations= np.load( f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/hallucinations/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_hallucinations_index_{i}.npy') for anw in answers: if args.dataset_name == 'tydiqa': prompt = tokenizer( "Concisely answer the following question based on the information in the given passage: \n" + \ " Passage: " + dataset[int(used_indices[i])]['context'] + " \n Q: " + question + " \n A:" + anw, return_tensors='pt').input_ids.cuda() elif args.dataset_name == 'coqa': prompt = tokenizer(dataset[i]['prompt'] + anw, return_tensors='pt').input_ids.cuda() else: prompt = tokenizer( f"Answer the question concisely. Q: {question}" + " A:" + anw, return_tensors='pt').input_ids.cuda() with torch.no_grad(): with TraceDict(model, HEADS + MLPS) as ret: output = model(prompt, output_hidden_states=True) head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS] head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim=0).squeeze().numpy() mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS] mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim=0).squeeze().numpy() embed_generated_loc2.append(mlp_wise_hidden_states[:, -1, :]) embed_generated_loc1.append(head_wise_hidden_states[:, -1, :]) for hal in hallucinations: if args.dataset_name == 'tydiqa': prompt = tokenizer( "Concisely answer the following question based on the information in the given passage: \n" + \ " Passage: " + dataset[int(used_indices[i])]['context'] + " \n Q: " + question + " \n A:" + hal, return_tensors='pt').input_ids.cuda() elif args.dataset_name == 'coqa': prompt = tokenizer(dataset[i]['prompt'] + hal, return_tensors='pt').input_ids.cuda() else: prompt = tokenizer( f"Answer the question concisely. Q: {question}" + " A:" + hal, return_tensors='pt').input_ids.cuda() with torch.no_grad(): with TraceDict(model, HEADS + MLPS) as ret: output = model(prompt, output_hidden_states=True) head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS] head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim=0).squeeze().numpy() mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS] mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim=0).squeeze().numpy() embed_generated_h_loc2.append(mlp_wise_hidden_states[:, -1, :]) embed_generated_h_loc1.append(head_wise_hidden_states[:, -1, :]) for tru in truths: if args.dataset_name == 'tydiqa': prompt = tokenizer( "Concisely answer the following question based on the information in the given passage: \n" + \ " Passage: " + dataset[int(used_indices[i])]['context'] + " \n Q: " + question + " \n A:" + tru, return_tensors='pt').input_ids.cuda() elif args.dataset_name == 'coqa': prompt = tokenizer(dataset[i]['prompt'] + tru, return_tensors='pt').input_ids.cuda() else: prompt = tokenizer( f"Answer the question concisely. Q: {question}" + " A:" + tru, return_tensors='pt').input_ids.cuda() with torch.no_grad(): with TraceDict(model, HEADS + MLPS) as ret: output = model(prompt, output_hidden_states=True) head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS] head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim=0).squeeze().numpy() mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS] mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim=0).squeeze().numpy() embed_generated_t_loc2.append(mlp_wise_hidden_states[:, -1, :]) embed_generated_t_loc1.append(head_wise_hidden_states[:, -1, :]) embed_generated_loc2 = np.asarray(np.stack(embed_generated_loc2), dtype=np.float32) embed_generated_loc1 = np.asarray(np.stack(embed_generated_loc1), dtype=np.float32) embed_generated_h_loc2 = np.asarray(np.stack(embed_generated_h_loc2), dtype=np.float32) embed_generated_h_loc1 = np.asarray(np.stack(embed_generated_h_loc1), dtype=np.float32) embed_generated_t_loc2 = np.asarray(np.stack(embed_generated_t_loc2), dtype=np.float32) embed_generated_t_loc1 = np.asarray(np.stack(embed_generated_t_loc1), dtype=np.float32) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_head_wise.npy', embed_generated_loc1) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_embeddings_mlp_wise.npy', embed_generated_loc2) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_h_head_wise.npy', embed_generated_h_loc2) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_h_mlp_wise.npy', embed_generated_h_loc1) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_t_head_wise.npy', embed_generated_t_loc2) np.save(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_t_mlp_wise.npy', embed_generated_t_loc1) # get the split and label (true or false) of the unlabeled data and the test data. if args.use_rouge: gts = np.load(f'./ml_{args.dataset_name}_{args.model_name}_rouge_score.npy') gts_bg = np.load(f'./bg_{args.dataset_name}_{args.model_name}_rouge_score.npy') else: gts = np.load(f'./ml_{args.dataset_name}_{args.model_name}_bleurt_score.npy') gts_bg = np.load(f'./bg_{args.dataset_name}_{args.model_name}_bleurt_score.npy') thres = args.thres_gt gt_label = np.asarray(gts> thres, dtype=np.int32) gt_label_bg = np.asarray(gts_bg > thres, dtype=np.int32) if args.dataset_name == 'tydiqa': length = len(used_indices) else: length = len(dataset) permuted_index = np.random.permutation(length) wild_q_indices = permuted_index[:int(args.wild_ratio * length)] # exclude validation samples. wild_q_indices1 = wild_q_indices[:len(wild_q_indices) - 100] wild_q_indices2 = wild_q_indices[len(wild_q_indices) - 100:] gt_label_test = [] gt_label_wild = [] gt_label_val = [] for i in range(length): if i not in wild_q_indices: gt_label_test.extend(gt_label[i: i+1]) elif i in wild_q_indices1: gt_label_wild.extend(gt_label[i: i+1]) else: gt_label_val.extend(gt_label[i: i+1]) gt_label_test = np.asarray(gt_label_test) gt_label_wild = np.asarray(gt_label_wild) gt_label_val = np.asarray(gt_label_val) def svd_embed_score(embed_generated_wild, gt_label,embed_generated_h,embed_generated_t, begin_k, k_span, mean=1, svd=10, epsilon=1e-20): embed_generated = embed_generated_wild # embed_hallucination= embed_generated_h best_auroc_over_k = 0 best_layer_over_k = 0 best_scores_over_k = None best_projection_over_k = None for k in tqdm(range(begin_k, k_span)): best_auroc = 0 best_layer = 0 # best_scores = None mean_recorded = None # best_projection = None for layer in range(len(embed_generated_wild[0])): # print(len(embed_generated_wild[0])) if mean: mean_recorded = embed_generated[:, layer, :].mean(0) centered = embed_generated[:, layer, :] - mean_recorded mean_h=embed_generated_h[:, layer, :].mean(0) centered_h=embed_generated_h[:, layer, :]-mean_h mean_t=embed_generated_t[:, layer, :].mean(0) centered_t=embed_generated_t[:, layer, :]-mean_t else: centered = embed_generated[:, layer, :] # if not svd: # assert "Not implemented!" # else: centered=torch.from_numpy(centered).cuda() centered_h=torch.from_numpy(centered_h).cuda() centered_t=torch.from_numpy(centered_t).cuda() _, sin_value, V_p = torch.linalg.svd(centered, full_matrices=False) sin_value_squared = torch.diag(sin_value[:k]) ** 2 V_p = V_p[:k, :] C=(1 / centered.shape[0])* V_p.T @ sin_value_squared @ V_p _, sin_value_h, V_p_h = torch.linalg.svd(centered_h, full_matrices=False) sin_value_h_squared = torch.diag(sin_value_h[:k]) ** 2 V_p_h = V_p_h[:k, :] C_h=(1 / centered_h.shape[0])* V_p_h.T @ sin_value_h_squared @ V_p_h # print(centered_t.shape) _, sin_value_t, V_p_t = torch.linalg.svd(centered_t, full_matrices=False) sin_value_t_squared = torch.diag(sin_value_t[:k]) ** 2 V_p_t = V_p_t[:k, :] C_t=(1 / centered_t.shape[0])* V_p_t.T @ sin_value_t_squared @ V_p_t inv_C_t= torch.linalg.pinv(C_t) + torch.eye(C_t.shape[0], dtype=int).cuda() * epsilon inv_C_h= torch.linalg.pinv(C_h) + torch.eye(C_h.shape[0], dtype=int).cuda() * epsilon scores= torch.sqrt(torch.clamp(centered @ inv_C_t @ centered.T, min=0.0)) - torch.sqrt(torch.clamp(centered @ inv_C_h @ centered.T, min=0.0)) # scores= mahalanobis_distance(torch.from_numpy(embed_generated[:, layer, :]).cuda(), torch.from_numpy(mean_recorded).cuda(), C_) torch.clamp(centered @ inv_C_t @ centered.T, min=0.0) # - mahalanobis_distance(torch.from_numpy(embed_generated[:, layer, :]).cuda(), torch.from_numpy(mean_t).cuda(), C_t) # + mahalanobis_distance(torch.from_numpy(embed_generated[:, layer, :]).cuda(), torch.from_numpy(mean_h).cuda(), C_h) centered @ inv_C_h @ centered.T scores = torch.mean(scores, -1, keepdim=True) scores = torch.sqrt(torch.sum(torch.square(scores), dim=1)) # projection=V_p[:k, :].T # scores1 = torch.mean(centered @ projection, -1, keepdim=True) # scores1 = torch.sqrt(torch.sum(torch.square(scores1), dim=1)) # not sure about whether true and false data the direction will point to, # so we test both. similar practices are in the representation engineering paper # https://arxiv.org/abs/2310.01405 scores=scores.data.cpu().numpy() # scores1=scores1.data.cpu().numpy() measures1 = get_measures(scores[gt_label == 1], scores[gt_label == 0], plot=False) measures2 = get_measures(-scores[gt_label == 1], -scores[gt_label == 0], plot=False) if measures1[0] > measures2[0]: measures = measures1 sign_layer = 1 else: measures = measures2 sign_layer = -1 if measures[0] > best_auroc: best_auroc = measures[0] best_result = [100 * measures[2], 100 * measures[0]] best_layer = layer best_scores = sign_layer * scores best_mean = mean_recorded best_sign = sign_layer print('k: ', k, 'best result: ', best_result, 'layer: ', best_layer, 'best_auroc: ', best_auroc) if best_auroc > best_auroc_over_k: best_auroc_over_k = best_auroc best_result_over_k = best_result best_layer_over_k = best_layer best_k = k best_sign_over_k = best_sign best_scores_over_k = best_scores # best_projection_over_k = best_projection best_mean_over_k = best_mean return {'k': best_k, 'best_layer':best_layer_over_k, 'best_auroc':best_auroc_over_k, 'best_result':best_result_over_k, 'best_scores':best_scores_over_k, 'best_mean': best_mean_over_k, 'best_sign':best_sign_over_k, # 'best_projection':best_projection_over_k } from sklearn.decomposition import PCA feat_loc = args.feat_loc_svd if args.most_likely: if feat_loc == 1: embed_generated = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_head_wise.npy', allow_pickle=True) embed_generated_h = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_h_head_wise.npy', allow_pickle=True) embed_generated_t = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_t_head_wise.npy', allow_pickle=True) elif feat_loc == 2: embed_generated = np.load( f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_mlp_wise.npy', allow_pickle=True) embed_generated_h = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_h_mlp_wise.npy', allow_pickle=True) embed_generated_t = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_t_mlp_wise.npy', allow_pickle=True) else: assert "Not supported!" # embed_generated = np.load(f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_layer_wise.npy', # allow_pickle=True) # embed_generated = np.load( # f'save_for_eval/{args.dataset_name}/{args.model_name}_hal_det/' + info + f'{args.model_name}_gene_embeddings_head_wise.npy', # allow_pickle=True) feat_indices_wild = [] feat_indices_eval = [] if args.dataset_name == 'tydiqa': length = len(used_indices) else: length = len(dataset) for i in range(length): if i in wild_q_indices1: feat_indices_wild.extend(np.arange(i, i+1).tolist()) elif i in wild_q_indices2: feat_indices_eval.extend(np.arange(i, i + 1).tolist()) if feat_loc == 3: embed_generated_wild = embed_generated[feat_indices_wild][:,1:,:] embed_generated_eval = embed_generated[feat_indices_eval][:, 1:, :] embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_wild][:,1:,:], embed_generated_t[feat_indices_wild][:,1:,:] else: embed_generated_wild = embed_generated[feat_indices_wild] embed_generated_eval = embed_generated[feat_indices_eval] # print(embed_generated.shape) # print(embed_generated_h.shape) # print(embed_generated_t.shape) embed_generated_hal,embed_generated_tru=embed_generated_h[feat_indices_eval], embed_generated_t[feat_indices_eval] # returned_results = svd_embed_score(embed_generated_wild, gt_label_wild, # 1, 11, mean=0, svd=0, weight=args.weighted_svd) # get the best hyper-parameters on validation set returned_results = svd_embed_score(embed_generated_eval, gt_label_val, embed_generated_hal,embed_generated_tru, 1, 15, mean=1, svd=10) pca_model = PCA(n_components=returned_results['k'], whiten=False).fit(embed_generated_wild[:,returned_results['best_layer'],:]) projection = pca_model.components_.T if args.weighted_svd: projection = pca_model.singular_values_ * projection scores = np.mean(np.matmul(embed_generated_wild[:,returned_results['best_layer'],:], projection), -1, keepdims=True) assert scores.shape[1] == 1 best_scores = np.sqrt(np.sum(np.square(scores), axis=1)) * returned_results['best_sign'] # direct projection feat_indices_test = [] for i in range(length): if i not in wild_q_indices: feat_indices_test.extend(np.arange(1 * i, 1 * i + 1).tolist()) if feat_loc == 3: embed_generated_test = embed_generated[feat_indices_test][:, 1:, :] else: embed_generated_test = embed_generated[feat_indices_test] test_scores = np.mean(np.matmul(embed_generated_test[:,returned_results['best_layer'],:], projection), -1, keepdims=True) assert test_scores.shape[1] == 1 test_scores = np.sqrt(np.sum(np.square(test_scores), axis=1)) measures = get_measures(returned_results['best_sign'] * test_scores[gt_label_test == 1], returned_results['best_sign'] *test_scores[gt_label_test == 0], plot=False) print_measures(measures[0], measures[1], measures[2], 'direct-projection') thresholds = np.linspace(0,1, num=40)[1:-1] normalizer = lambda x: x / (np.linalg.norm(x, ord=2, axis=-1, keepdims=True) + 1e-10) auroc_over_thres = [] for thres_wild in thresholds: best_auroc = 0 for layer in range(len(embed_generated_wild[0])): thres_wild_score = np.sort(best_scores)[int(len(best_scores) * thres_wild)] true_wild = embed_generated_wild[:,layer,:][best_scores > thres_wild_score] false_wild = embed_generated_wild[:,layer,:][best_scores <= thres_wild_score] embed_train = np.concatenate([true_wild,false_wild],0) label_train = np.concatenate([np.ones(len(true_wild)), np.zeros(len(false_wild))], 0) ## gt training, saplma # embed_train = embed_generated_wild[:,layer,:] # label_train = gt_label_wild ## gt training, saplma from linear_probe import get_linear_acc best_acc, final_acc, ( clf, best_state, best_preds, preds, labels_val), losses_train = get_linear_acc( embed_train, label_train, embed_train, label_train, 2, epochs = 50, print_ret = True, batch_size=512, cosine=True, nonlinear = True, learning_rate = 0.05, weight_decay = 0.0003) clf.eval() output = clf(torch.from_numpy( embed_generated_test[:, layer, :]).cuda()) pca_wild_score_binary_cls = torch.sigmoid(output) pca_wild_score_binary_cls = pca_wild_score_binary_cls.cpu().data.numpy() if np.isnan(pca_wild_score_binary_cls).sum() > 0: breakpoint() measures = get_measures(pca_wild_score_binary_cls[gt_label_test == 1], pca_wild_score_binary_cls[gt_label_test == 0], plot=False) if measures[0] > best_auroc: best_auroc = measures[0] best_result = [100 * measures[0]] best_layer = layer auroc_over_thres.append(best_auroc) print('thres: ', thres_wild, 'best result: ', best_result, 'best_layer: ', best_layer) if __name__ == '__main__': seed_everything(42) main()