haloscope/steer_vector.py

767 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import argparse
from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn
from transformers import AutoTokenizer, AutoModelForCausalLM
from llm_layers import add_sv_layers
from sklearn.metrics import roc_auc_score
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from sinkhorn_knopp import SinkhornKnopp_imb
import logging
def seed_everything(seed: int):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def train_model(model, optimizer, device, prompts, labels, args):
"""
对应论文里的两阶段训练:
- Phase 1: 使用人工标注的 exemplar 做 initial training式 (3)(4)(5)(6)
- Phase 2: 使用 OT + Sinkhorn 选出来的伪标签数据做 self-trainingSec.3.3 pseudo-labeling
参数说明:
- model: 冻结好的 LLM + 已经插入 TSV 层的模型(只训练 TSV
- optimizer: 只优化 TSV 的 AdamW
- device: cuda
- prompts: [test_prompts, train_prompts, exemplar_prompts]
- labels: 对应三块的标签 [test_labels, train_labels, exemplar_labels]
- args: 超参数学习率、epoch 数、Sinkhorn 参数等)
"""
layer_number = -1 # 使用最后一层 hidden state这里 -1 当索引)
# ========= 日志 & 结果保存目录 =========
dir_name = f"TSV_{args.model_name}_{args.dataset_name}/exemplar_num_{args.num_exemplars}_num_selected_data_{args.num_selected_data}/{args.component}/{args.str_layer}/{args.lam}"
log_dir = f"/{dir_name}/"
log_file = os.path.join(log_dir, f"log.txt")
os.makedirs(dir_name, exist_ok=True)
logging.basicConfig(
filename=log_file,
filemode="w",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logging.info("Starting training")
logging.info(
f"Training parameters: few_shot_size={args.num_exemplars}, "
f"num_selected_data={args.num_selected_data}, "
f"component={args.component}, str_layer={args.str_layer}"
)
# 解包数据test / wild train / exemplar
test_prompts, train_prompts, exemplar_prompts = prompts[0], prompts[1], prompts[2]
test_labels, train_labels, exemplar_labels = labels[0], labels[1], labels[2]
batch_size = args.batch_size
losses = []
best_test_auroc = -1
scaler = GradScaler() # 混合精度的梯度缩放器
num_exemplars = args.num_exemplars
# ========= Sinkhorn OT 相关初始化(伪标签阶段要用)=========
args.num_iters_sk = 3 # Sinkhorn 迭代次数(论文 Eq.(8)(9)
args.epsilon_sk = 0.05 # OT 熵正则 ε(论文中也设为 0.05
# 用 exemplar 的标签估计类分布 wtruthful/hallu 的先验)
# ex_hallu = P(hallucinated), ex_true = P(truthful)
ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars
ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars
cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda()
cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1]
# 实例化带类别边缘约束的 Sinkhorn实现论文里的 OT 问题)
sinkhorn = SinkhornKnopp_imb(args, cls_dist)
# ========= 初始化两个类别的“原型向量” μ_c =========
# 这里 centroids 就是论文中球面上的 μ_truthful, μ_hallucinated
centroids = torch.randn((2, model.config.hidden_size)).half().cuda()
centroids = F.normalize(centroids, p=2, dim=1)
# 把 exemplar 的 prompt/label 整理成一个大 batchpadding
exemplar_prompts_, exemplar_labels_ = exemplar_prompts, exemplar_labels
exemplar_prompts, exemplar_labels = collate_fn(exemplar_prompts, exemplar_labels)
# ========= Phase 1: Initial training on exemplars =========
num_epochs = args.init_num_epochs
for epoch in range(num_epochs):
running_loss = 0.0
total = 0
num_samples = num_exemplars
# 在 exemplar 上分 batch 训练(对应论文中 D_E 只含人工标注数据)
for batch_start in tqdm(
range(0, num_samples, batch_size),
desc=f"Epoch {epoch+1}/{num_epochs} Batches",
leave=False,
):
batch_prompts = exemplar_prompts[batch_start: batch_start + batch_size]
batch_labels = exemplar_labels[batch_start: batch_start + batch_size]
# attention_mask: 1 = valid token, 0 = padding
attention_mask = (batch_prompts != 0).half()
batch_prompts = batch_prompts.to(device)
batch_labels = batch_labels.to(device)
attention_mask = attention_mask.to(batch_prompts.device)
# ======= 前向传播:取最后一层最后非 padding token 的表示 r_v =======
with autocast(dtype=torch.float16):
output = model(
batch_prompts.squeeze(),
attention_mask=attention_mask.squeeze(),
output_hidden_states=True,
)
hidden_states = output.hidden_states # tuple(len=L+1)
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[layer_number] # [B, T, H]
# 对应论文里 r_v最后非 padding token 的 hidden包含 TSV steer
last_token_rep = get_last_non_padded_token_rep(
last_layer_hidden_state, attention_mask.squeeze()
)
# 把标签变成 one-hot2 维)
batch_labels_oh = torch.nn.functional.one_hot(
batch_labels, num_classes=-1 # 这里 -1 应该在 utils 里特殊处理成 2
)
# compute_ot_loss_cos对应 Eq.(5),不过是用 OT 的版本:
# - 通过与 centroids 的余弦距离计算类似 softmax 的分类损失
# - 里面会用 args.cos_temp 做温度缩放
ot_loss, similarities = compute_ot_loss_cos(
last_token_rep, centroids, batch_labels_oh, batch_size, args
)
loss = ot_loss
total += batch_labels.size(0)
# ====== 更新类别原型 μ_cEMA对应论文 Eq.(6) ======
with torch.no_grad():
centroids = update_centroids_ema_hard(
centroids, last_token_rep, batch_labels_oh, args
)
# ====== 只对 TSV 反传LLM 本体是冻结的 ======
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
running_loss += loss.item() * batch_labels.size(0)
# 一个 epoch 的平均 loss
epoch_loss = running_loss / total
# ====== 在 test 上评估,记录 AUROC ======
if (epoch + 1) % 1 == 0:
test_labels_ = test_labels
test_predictions, test_labels_combined = test_model(
model, centroids, test_prompts, test_labels_, device, batch_size, layer_number
)
test_auroc = roc_auc_score(
test_labels_combined.cpu().numpy(), test_predictions.cpu().numpy()
)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
losses.append(epoch_loss)
if test_auroc > best_test_auroc:
best_test_auroc = test_auroc
best_test_epoch = epoch
print(f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}")
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
logging.info(
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
)
logging.info(f"Test AUROC: {test_auroc:.4f}")
print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}")
# ========= 进入 Phase 2伪标签 + 自训练 =========
logging.info(f"SS Learning Starts")
# 1) get_ex_data: 在 wild train 上做 TSV 推断 + Sinkhorn OT选出高置信伪标签样本
with torch.no_grad():
selected_indices, selected_labels_soft = get_ex_data(
model,
train_prompts, # wild 区间的样本
train_labels, # 这里一般是无标签 or dummy这里暂时没用
batch_size,
centroids, # 当前原型
sinkhorn, # OT + Sinkhorn 对象,用来求 Q
args.num_selected_data, # 选多少个伪标签样本
cls_dist,
args,
)
num_samples = len(selected_indices) + args.num_exemplars
num_epochs = args.aug_num_epochs # 自训练的 epoch 数
exemplar_label = torch.tensor(exemplar_labels).cuda()
# 2) 构造“增强后的训练集”:伪标签样本 + 原来的 exemplar
selected_prompts = [train_prompts[i] for i in selected_indices]
selected_labels = selected_labels_soft # 这里已经是 soft labelOT 输出的 q
augmented_prompts = selected_prompts + exemplar_prompts_
exemplar_labels = torch.nn.functional.one_hot(
exemplar_label.to(torch.int64), num_classes=2
)
augmented_labels = torch.concat(
(selected_labels, torch.tensor(exemplar_labels).clone().cuda())
)
augmented_prompts_train = augmented_prompts
augmented_labels_label = augmented_labels
num_samples = len(augmented_prompts_train)
# 3) 在“伪标签 + exemplar”上再训练一轮 TSV自训练
with autocast(dtype=torch.float16):
for epoch in range(num_epochs):
running_loss = 0.0
total = 0
all_labels = []
for batch_start in tqdm(
range(0, num_samples, batch_size),
desc=f"Epoch {epoch+1}/{num_epochs} Batches",
leave=False,
):
batch_prompts = augmented_prompts_train[batch_start: batch_start + batch_size]
batch_labels = augmented_labels_label[batch_start: batch_start + batch_size]
batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels)
attention_mask = (batch_prompts != 0).half()
batch_prompts = batch_prompts.to(device)
batch_labels = batch_labels.to(device)
attention_mask = attention_mask.to(batch_prompts.device)
output = model(
batch_prompts.squeeze(),
attention_mask=attention_mask.squeeze(),
output_hidden_states=True,
)
hidden_states = output.hidden_states
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[layer_number]
last_token_rep = get_last_non_padded_token_rep(
last_layer_hidden_state, attention_mask.squeeze()
)
# 这里 compute_ot_loss_cos 接收的是 soft label (OT 得到的 q)
ot_loss, similarities = compute_ot_loss_cos(
last_token_rep, centroids, batch_labels, batch_size, args
)
loss = ot_loss
with torch.no_grad():
# 自训练阶段,原型更新用的是 soft label 版的 EMA
centroids = update_centroids_ema(
centroids, last_token_rep, batch_labels.half(), args
)
all_labels.append(batch_labels.cpu())
total += batch_labels.size(0)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
running_loss += loss.item() * batch_labels.size(0)
epoch_loss = running_loss / total
# ====== 用 test set 评估 AUROC ======
with torch.no_grad():
all_labels = torch.cat(all_labels).numpy()
test_labels_ = test_labels
if epoch % 1 == 0:
test_predictions, test_labels_combined = test_model(
model,
centroids,
test_prompts,
test_labels_,
device,
batch_size,
layer_number,
)
test_auroc = roc_auc_score(test_labels_combined, test_predictions)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
losses.append(epoch_loss)
if test_auroc > best_test_auroc:
best_test_auroc = test_auroc
best_test_epoch = epoch + args.init_num_epochs
print(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
logging.info(
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
)
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
return best_test_auroc
def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number):
"""
在论文里相当于:
- 对 test 集算当前 TSV steer 后的 r_v
- 计算与两个原型 μ_c 的余弦相似度
- 用 softmax 得到属于 truthful 的概率 p(c=truthful | r_v)
- 用这个概率做 AUROC 评估
"""
model.eval()
val_predictions = [] # 保存 p_truthful
val_labels_combined = [] # 对应的 gt二分类标签
num_val_samples = len(test_prompts)
with torch.no_grad():
with autocast(dtype=torch.float16):
for batch_start in range(0, num_val_samples, batch_size):
batch_prompts = test_prompts[batch_start:batch_start + batch_size]
batch_labels = test_labels[batch_start:batch_start + batch_size]
batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels)
attention_mask = (batch_prompts != 0).half().to(device)
batch_prompts = batch_prompts.to(device)
batch_labels = batch_labels.to(device)
# 直接 forward + 拿最后一个非 padding token 的 hidden 表示作为 r_v
output = model(
batch_prompts.squeeze(),
attention_mask=attention_mask.squeeze(),
output_hidden_states=True,
)
hidden_states = output.hidden_states
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[layer_number]
last_token_rep = get_last_non_padded_token_rep(
last_layer_hidden_state, attention_mask.squeeze()
)
# 归一化到球面 & 原型也归一化
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
with autocast(dtype=torch.float16):
similarities = torch.matmul(last_token_rep, centroids.T) # [B, 2]
# 相似度 / 温度 -> softmax -> 取“真”的那一维作为概率
similarity_scores = torch.softmax(similarities / 0.1, dim=-1)
similarity_scores = similarity_scores[:, 1] # 假设 index 1 = truthful
val_predictions.append(similarity_scores.cpu())
val_labels_combined.append(batch_labels.cpu())
val_predictions = torch.cat(val_predictions)
val_labels_combined = torch.cat(val_labels_combined)
return val_predictions, val_labels_combined
HF_NAMES = {
'llama3.1-8B': 'meta-llama/Meta-Llama-3.1-8B',
'qwen2.5-7B': 'Qwen/Qwen2.5-7B'
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='llama3.1-8B')
parser.add_argument('--model_prefix', type=str, default='', help='prefix of model name')
parser.add_argument('--num_gene', type=int, default=1)
parser.add_argument('--gene', type=int, default=0)
parser.add_argument('--generate_gt', type=int, default=0)
parser.add_argument('--dataset_name', type=str, default='tqa')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--wild_ratio', type=float, default=0.75)
parser.add_argument('--thres_gt', type=float, default=0.5)
parser.add_argument('--most_likely', type=int, default=0)
parser.add_argument("--model_dir", type=str, default=None, help='local directory with model data')
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--cos_temp", type=float, default=0.1)
parser.add_argument("--ema_decay", type=float, default=0.99)
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--str_layer", type=int, default=9)
parser.add_argument("--component", type=str, default='res')
parser.add_argument("--lam", type=float, default=5)
parser.add_argument("--init_num_epochs", type=int, default=20)
parser.add_argument("--aug_num_epochs", type=int, default=20)
parser.add_argument("--num_exemplars", type=int, default=32)
parser.add_argument("--num_selected_data", type=int, default=128)
parser.add_argument("--cls_dist", type=str, default='proxy')
parser.add_argument("--optimizer", type=str, default='AdamW')
parser.add_argument("--num_iters_sk", type=int, default=3)
parser.add_argument("--epsilon_sk", type=float, default=0.05)
args = parser.parse_args()
model_name_or_path = HF_NAMES[args.model_prefix + args.model_name]
if args.dataset_name == "tqa":
dataset = load_dataset("truthful_qa", 'generation')['validation']
elif args.dataset_name == 'triviaqa':
dataset = load_dataset("trivia_qa", "rc.nocontext", split="validation")
id_mem = set()
def remove_dups(batch):
if batch['question_id'][0] in id_mem:
return {_: [] for _ in batch.keys()}
id_mem.add(batch['question_id'][0])
return batch
dataset = dataset.map(remove_dups, batch_size=1, batched=True, load_from_cache_file=False)
elif args.dataset_name == 'sciq':
dataset = load_dataset("allenai/sciq", split="validation")
elif args.dataset_name == 'nq_open':
dataset = load_dataset("google-research-datasets/nq_open", split="validation")
else:
raise ValueError("Invalid dataset name")
if args.gene:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '')
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto", token = '')
device = torch.device("cuda")
all_decoded_answers = []
begin_index = 0
end_index = len(dataset)
if not os.path.exists(f'./save_for_eval/{args.dataset_name}_hal_det/'):
os.mkdir(f'./save_for_eval/{args.dataset_name}_hal_det/')
if not os.path.exists(f'./save_for_eval/{args.dataset_name}_hal_det/answers'):
os.mkdir(f'./save_for_eval/{args.dataset_name}_hal_det/answers')
period_token_id = [tokenizer(_)['input_ids'][-1] for _ in ['\n']]
period_token_id += [tokenizer.eos_token_id]
for i in range(begin_index, end_index):
answers = [None] * args.num_gene
answers_ = [None] * args.num_gene
question = dataset[i]['question']
prompt = tokenizer(f"Answer the question concisely. Q: {question}" + " A:", return_tensors='pt').input_ids.cuda()
for gen_iter in range(args.num_gene):
if args.most_likely:
generated = model.generate(prompt,
num_beams=5,
num_return_sequences=1,
do_sample=False,
max_new_tokens=64,
)
else:
generated = model.generate(prompt,
do_sample=True,
num_return_sequences=1,
num_beams=1,
max_new_tokens=64,
temperature=0.5,
top_p=1.0)
decoded = tokenizer.decode(generated[0, prompt.shape[-1]:],
skip_special_tokens=True)
# answers[gen_iter] = decoded
# Cleaning
if '\nAnswer the question concisely.' in decoded:
print('#####error')
print(decoded.split('\nAnswer the question concisely.')[1])
print('#####error')
decoded = decoded.split('\nAnswer the question concisely.')[0]
if 'Answer the question concisely' in decoded:
print('#####error')
print(decoded.split('Answer the question concisely')[1])
print('#####error')
decoded = decoded.split('Answer the question concisely')[0]
if 'The answer to the question' in decoded:
print('#####error')
print(decoded.split('The answer to the question')[1])
print('#####error')
decoded = decoded.split('The answer to the question')[0]
if 'How to Write a Concise Statement' in decoded:
print('#####error')
print(decoded.split('How to Write a Concise Statement')[1])
print('#####error')
decoded = decoded.split('How to Write a Concise Statement')[0]
if 'Q:' in decoded:
print('#####error')
print(decoded.split('Q:')[1])
print('#####error')
decoded = decoded.split('Q:')[0]
if '\nYou are an AI assistant' in decoded:
print('#####error')
print(decoded.split('\nYou are an AI assistant')[1])
print('#####error')
decoded = decoded.split('\nYou are an AI assistant')[0]
if 'You are an AI assistant' in decoded:
print('#####error')
print(decoded.split('You are an AI assistant')[1])
print('#####error')
decoded = decoded.split('You are an AI assistant')[0]
if 'A:' in decoded:
print('#####error')
print(decoded.split('A:')[1])
print('#####error')
decoded = decoded.split('A:')[0]
if 'B:' in decoded:
print('#####error')
print(decoded.split('B:')[1])
print('#####error')
decoded = decoded.split('B:')[0]
if 'C:' in decoded:
print('#####error')
print(decoded.split('C:')[1])
print('#####error')
decoded = decoded.split('C:')[0]
if 'D:' in decoded:
print('#####error')
print(decoded.split('D:')[1])
print('#####error')
decoded = decoded.split('D:')[0]
print(f'Cleaned Answer: {decoded}')
answers[gen_iter] = decoded
print('sample: ', i)
if args.most_likely:
info = 'most_likely_'
else:
info = 'batch_generations_'
print("Saving answers")
print(decoded)
np.save(f'./save_for_eval/{args.dataset_name}_hal_det/answers/' + info + f'hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy',
answers)
elif args.generate_gt:
from bleurt_pytorch import BleurtForSequenceClassification, BleurtTokenizer
model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20').cuda()
tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20')
model.eval()
gts = np.zeros(0)
length = len(dataset)
for i in range(length):
if args.dataset_name == 'tqa':
best_answer = dataset[i]['best_answer']
correct_answer = dataset[i]['correct_answers']
all_answers = [best_answer] + correct_answer
question = dataset[i]['question']
elif args.dataset_name == 'triviaqa':
all_answers = dataset[i]['answer']['aliases']
if args.most_likely:
# answers = np.load(
# f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy')
answers = np.load(
f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy')
else:
answers = np.load(
f'./save_for_eval/{args.dataset_name}_hal_det/answers/batch_generations_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy')
# get the gt.
predictions = answers
all_results = np.zeros((len(all_answers), len(predictions)))
with torch.no_grad():
for anw in range(len(all_answers)):
inputs = tokenizer(predictions.tolist(), [all_answers[anw]] * len(predictions),
padding='longest', return_tensors='pt')
for key in list(inputs.keys()):
inputs[key] = inputs[key].cuda()
res = np.asarray(model(**inputs).logits.flatten().tolist())
all_results[anw] = res
gts = np.concatenate([gts, np.max(all_results, axis=0)], 0)
if i % 10 == 0:
print("samples passed: ", i)
if args.most_likely:
# np.save(f'./ml_{args.dataset_name}_bleurt_score.npy', gts)
np.save(f'./ml_{args.dataset_name}_bleurt_score.npy', gts)
else:
np.save(f'./bg_{args.dataset_name}_bleurt_score.npy', gts)
else:
device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto", token = '')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '')
prompts = []
qa_pairs = []
categories = []
length = len(dataset)
for i in tqdm(range(length)):
question = dataset[i]['question']
if args.dataset_name == 'tqa':
categories.append(dataset[i]['category'])
answers = np.load(
f'./save_for_eval/{args.dataset_name}_hal_det/answers/most_likely_hal_det_{args.model_name}_{args.dataset_name}_answers_index_{i}.npy')
for anw in answers:
prompt = tokenizer(
f"Answer the question concisely. Q: {question}" + " A:" + anw,
return_tensors='pt').input_ids.cuda()
prompts.append(prompt)
qa_pairs.append({'Question': question, 'Answer': anw})
gts = np.load(f'./ml_{args.dataset_name}_bleurt_score.npy')
length = len(dataset)
if args.dataset_name == 'tqa' or args.dataset_name == 'triviaqa':
args.thres_gt = 0.5
else:
args.thres_gt = 0.2
gt_label = np.asarray(gts> args.thres_gt, dtype=np.int32)
# index = np.random.permutation(length)
# exemplar_index = index[:args.num_exemplars]
# wild_q_indices = index[:int(args.wild_ratio * length)]
index = np.load(f'data_indices/data_index_{args.dataset_name}.npy')
exemplar_index = np.load(f'data_indices/exemplar_idx_{args.dataset_name}.npy')
wild_q_indices = index[:int(args.wild_ratio * length)]
wild_q_indices1 = wild_q_indices[:len(wild_q_indices) - 100]
args.num_exemplars = len(exemplar_index)
gt_label_test = []
gt_label_wild = []
gt_label_exemplar = []
test_prompts = []
train_prompts = []
exemplar_prompts = []
for i in range(length):
if i not in wild_q_indices:
gt_label_test.extend(gt_label[i: i+1])
test_prompts.extend(prompts[i:i+1])
elif i in exemplar_index:
gt_label_exemplar.extend(gt_label[i: i+1])
exemplar_prompts.extend(prompts[i:i+1])
elif i in wild_q_indices1:
gt_label_wild.extend(gt_label[i: i+1])
train_prompts.extend(prompts[i:i+1])
gt_label_test = np.asarray(gt_label_test)
gt_label_exemplar = np.asarray(gt_label_exemplar)
gt_label_wild = np.asarray(gt_label_wild)
labels = [ gt_label_test, gt_label_wild, gt_label_exemplar]
prompts = [ test_prompts, train_prompts, exemplar_prompts]
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
for param in model.parameters():
param.requires_grad = False
tsv = nn.ParameterList(
[nn.Parameter(torch.zeros(hidden_size), requires_grad=True) for _ in range(num_layers)])
tsv.to(device)
add_tsv_layers(model, tsv, [args.lam], args)
optimizer = torch.optim.AdamW(list(tsv.parameters()), lr=args.lr)
train_model(model, optimizer, device, prompts, labels, args=args)
if __name__ == '__main__':
seed_everything(42)
main()