push to remote

This commit is contained in:
weixin_43297441 2025-12-01 21:40:08 +08:00
parent 9b8eca3c38
commit 5f390c1bb6
2 changed files with 213 additions and 159 deletions

View File

@ -11,7 +11,6 @@ from llm_layers import add_sv_layers
from sklearn.metrics import roc_auc_score
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from sinkhorn_knopp import SinkhornKnopp_imb
import logging
from utils import load_npy_shapes,split_indices
@ -81,13 +80,13 @@ def train_model(model, optimizer, device, prompts, labels, args):
# 用 exemplar 的标签估计类分布 wtruthful/hallu 的先验)
# ex_hallu = P(hallucinated), ex_true = P(truthful)
ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars
ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars
cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda()
cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1]
# ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars
# ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars
# cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda()
# cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1]
# 实例化带类别边缘约束的 Sinkhorn实现论文里的 OT 问题)
sinkhorn = SinkhornKnopp_imb(args, cls_dist)
# sinkhorn = SinkhornKnopp_imb(args, cls_dist)
# ========= 初始化两个类别的“原型向量” μ_c =========
# 这里 centroids 就是论文中球面上的 μ_truthful, μ_hallucinated
@ -160,7 +159,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
centroids, last_token_rep, batch_labels_oh, args
)
# ====== 只对 TSV 反传LLM 本体是冻结的 ======
# ====== 只对 SV 反传LLM 本体是冻结的 ======
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
@ -174,12 +173,12 @@ def train_model(model, optimizer, device, prompts, labels, args):
# ====== 在 test 上评估,记录 AUROC ======
if (epoch + 1) % 1 == 0:
test_labels_ = test_labels
test_predictions, test_labels_combined = test_model(
test_predictions, test_labels= test_model(
model, centroids, test_prompts, test_labels_, device, batch_size, layer_number
)
test_auroc = roc_auc_score(
test_labels_combined.cpu().numpy(), test_predictions.cpu().numpy()
test_labels.cpu().numpy(), test_predictions.cpu().numpy()
)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
@ -193,6 +192,12 @@ def train_model(model, optimizer, device, prompts, labels, args):
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
# 保存最佳AUROC时的centroids到npy文件
centroids_np = centroids.cpu().numpy()
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
np.save(centroids_file, centroids_np)
print(f"Saved best centroids to {centroids_file}")
logging.info(f"Saved best centroids to {centroids_file}")
logging.info(
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
@ -200,143 +205,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
logging.info(f"Test AUROC: {test_auroc:.4f}")
print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}")
# ========= 进入 Phase 2伪标签 + 自训练 =========
logging.info(f"SS Learning Starts")
# 1) get_ex_data: 在 wild train 上做 TSV 推断 + Sinkhorn OT选出高置信伪标签样本
with torch.no_grad():
selected_indices, selected_labels_soft = get_ex_data(
model,
train_prompts, # wild 区间的样本
train_labels, # 这里一般是无标签 or dummy这里暂时没用
batch_size,
centroids, # 当前原型
sinkhorn, # OT + Sinkhorn 对象,用来求 Q
args.num_selected_data, # 选多少个伪标签样本
cls_dist,
args,
)
num_samples = len(selected_indices) + args.num_exemplars
num_epochs = args.aug_num_epochs # 自训练的 epoch 数
exemplar_label = torch.tensor(exemplar_labels).cuda()
# 2) 构造“增强后的训练集”:伪标签样本 + 原来的 exemplar
selected_prompts = [train_prompts[i] for i in selected_indices]
selected_labels = selected_labels_soft # 这里已经是 soft labelOT 输出的 q
augmented_prompts = selected_prompts + exemplar_prompts_
exemplar_labels = torch.nn.functional.one_hot(
exemplar_label.to(torch.int64), num_classes=2
)
augmented_labels = torch.concat(
(selected_labels, torch.tensor(exemplar_labels).clone().cuda())
)
augmented_prompts_train = augmented_prompts
augmented_labels_label = augmented_labels
num_samples = len(augmented_prompts_train)
# 3) 在“伪标签 + exemplar”上再训练一轮 TSV自训练
with autocast(dtype=torch.float16):
for epoch in range(num_epochs):
running_loss = 0.0
total = 0
all_labels = []
for batch_start in tqdm(
range(0, num_samples, batch_size),
desc=f"Epoch {epoch+1}/{num_epochs} Batches",
leave=False,
):
batch_prompts = augmented_prompts_train[batch_start: batch_start + batch_size]
batch_labels = augmented_labels_label[batch_start: batch_start + batch_size]
batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels)
attention_mask = (batch_prompts != 0).half()
batch_prompts = batch_prompts.to(device)
batch_labels = batch_labels.to(device)
attention_mask = attention_mask.to(batch_prompts.device)
output = model(
batch_prompts.squeeze(),
attention_mask=attention_mask.squeeze(),
output_hidden_states=True,
)
hidden_states = output.hidden_states
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[layer_number]
last_token_rep = get_last_non_padded_token_rep(
last_layer_hidden_state, attention_mask.squeeze()
)
# 这里 compute_ot_loss_cos 接收的是 soft label (OT 得到的 q)
ot_loss, similarities = compute_ot_loss_cos(
last_token_rep, centroids, batch_labels, batch_size, args
)
loss = ot_loss
with torch.no_grad():
# 自训练阶段,原型更新用的是 soft label 版的 EMA
centroids = update_centroids_ema(
centroids, last_token_rep, batch_labels.half(), args
)
all_labels.append(batch_labels.cpu())
total += batch_labels.size(0)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
running_loss += loss.item() * batch_labels.size(0)
epoch_loss = running_loss / total
# ====== 用 test set 评估 AUROC ======
with torch.no_grad():
all_labels = torch.cat(all_labels).numpy()
test_labels_ = test_labels
if epoch % 1 == 0:
test_predictions, test_labels_combined = test_model(
model,
centroids,
test_prompts,
test_labels_,
device,
batch_size,
layer_number,
)
test_auroc = roc_auc_score(test_labels_combined, test_predictions)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
losses.append(epoch_loss)
if test_auroc > best_test_auroc:
best_test_auroc = test_auroc
best_test_epoch = epoch + args.init_num_epochs
print(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
logging.info(
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
)
logging.info(
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
)
return best_test_auroc
@ -345,14 +214,14 @@ def train_model(model, optimizer, device, prompts, labels, args):
def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number):
"""
在论文里相当于
- test 集算当前 TSV steer 后的 r_v
- test 集算当前 TV steer 后的 r_v
- 计算与两个原型 μ_c 的余弦相似度
- softmax 得到属于 truthful 的概率 p(c=truthful | r_v)
- 用这个概率做 AUROC 评估
"""
model.eval()
val_predictions = [] # 保存 p_truthful
val_labels_combined = [] # 对应的 gt(二分类标签)
val_predictions = [] # 保存 p
val_labels = [] # 对应的(二分类标签)
num_val_samples = len(test_prompts)
@ -392,11 +261,11 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
similarity_scores = similarity_scores[:, 1] # 假设 index 1 = truthful
val_predictions.append(similarity_scores.cpu())
val_labels_combined.append(batch_labels.cpu())
val_labels.append(batch_labels.cpu())
val_predictions = torch.cat(val_predictions)
val_labels_combined = torch.cat(val_labels_combined)
return val_predictions, val_labels_combined
val_labels = torch.cat(val_labels)
return val_predictions, val_labels
@ -477,7 +346,7 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '')
prompts = []
prompts_ = []
qa_pairs = []
categories = []
@ -491,9 +360,9 @@ def main():
return_tensors='pt'
).input_ids.cuda()
prompts.append(prompt)
prompts_.append(prompt)
qa_pairs.append({'Question': question, 'Answer': adversary})
categories.append(1) # 1 = adverse
categories.append(0) # 0 = adverse
for i in tqdm(range(length)):
question = dataset[i]['query']
@ -503,12 +372,14 @@ def main():
return_tensors='pt'
).input_ids.cuda()
prompts.append(prompt)
prompts_.append(prompt)
qa_pairs.append({'Question': question, 'Answer': clean})
categories.append(0) # 0 = benign
categories.append(1) # 1 = benign
train_index, val_index, test_index=split_indices(len(prompts), args.train_ratio, args.val_ratio)
train_index, val_index, test_index=split_indices(len(prompts_), args.train_ratio, args.val_ratio)
labels = [categories[test_index], categories[train_index]]
prompts = [prompts_[test_index], prompts_[train_index]]
@ -532,7 +403,7 @@ def main():
optimizer = torch.optim.AdamW(list(sv.parameters()), lr=args.lr)
# ====== 6.5 调用 train_model进入两阶段训练流程 ======
train_model(model, optimizer, device, train_data, train_labels, args=args)
train_model(model, optimizer, device, prompts, labels, args=args)
else:
print("Skip training steer vectors.")

183
train_utils.py Normal file
View File

@ -0,0 +1,183 @@
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast
import torch.nn.functional as F
def collate_fn(prompts, labels):
# Find the maximum sequence length in the batch
max_seq_len = max(prompt.size(1) for prompt in prompts)
# Initialize a tensor to hold the batched prompts
batch_size = len(prompts)
dtype = prompts[0].dtype
device = prompts[0].device # Assuming all prompts are on the same device
prompts_padded = torch.zeros(batch_size, 1, max_seq_len, dtype=dtype)
# Pad each prompt to the maximum sequence length
for i, prompt in enumerate(prompts):
seq_len = prompt.size(1)
prompts_padded[i, :, :seq_len] = prompt
# Stack labels into a tensor
labels = torch.tensor(labels, dtype=torch.long, device=device)
return prompts_padded, labels
def get_last_non_padded_token_rep(hidden_states, attention_mask):
"""
Get the last non-padded token's representation for each sequence in the batch.
"""
# Find the length of each sequence by summing the attention mask (1 for real tokens, 0 for padding)
lengths = attention_mask.squeeze().sum(dim=1).long()
# Index the last non-padded token for each sequence
batch_size, max_seq_len, hidden_size = hidden_states.size()
last_token_reps = torch.stack([hidden_states[i, lengths[i]-1, :] for i in range(batch_size)])
return last_token_reps
def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_selected_data, cls_dist, args):
all_embeddings = []
all_labels = []
num_samples = len(prompts)
with torch.no_grad():
with autocast(dtype=torch.float16):
for batch_start in tqdm(range(0, num_samples, batch_size)):
batch_prompts = prompts[batch_start: batch_start + batch_size]
batch_labels = labels[batch_start: batch_start + batch_size]
batch_prompts, batch_labels = collate_fn(batch_prompts,batch_labels)
attention_mask = (batch_prompts != 0).half()
batch_prompts = batch_prompts.cuda()
batch_labels = batch_labels.cuda()
attention_mask = attention_mask.to(batch_prompts.device)
all_labels.append(batch_labels.cpu().numpy())
output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True)
hidden_states = output.hidden_states
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[-1]
last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze())
all_embeddings.append(last_token_rep)
all_embeddings = F.normalize(torch.concat(all_embeddings),p=2,dim=-1)
pseudo_label = sinkhorn(all_embeddings, centroids)
selected_indices = compute_entropy(all_embeddings, centroids, pseudo_label, num_selected_data, cls_dist, args)
selected_labels_soft = pseudo_label[selected_indices]
return selected_indices, selected_labels_soft
def compute_ot_loss_cos(last_token_rep, centroids, pseudo_label, batch_size, args):
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
similarities = torch.matmul(last_token_rep, centroids.T)
similarities = similarities / args.cos_temp
pt = F.softmax(similarities, dim=-1)
ot_loss = -torch.sum(pseudo_label * torch.log(pt + 1e-8)) / pseudo_label.shape[0]
return ot_loss, similarities
def compute_entropy(last_token_rep, centroids, pseudo_label, k, cls_dist, args):
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
similarities = torch.matmul(last_token_rep, centroids.T)
similarities = similarities / args.cos_temp
pt = F.softmax(similarities, dim=-1)
ce = - (pseudo_label * torch.log(pt + 1e-8))
pseudo_label_hard = torch.argmax(pt,dim=1)
# * Added for preventing severe cases
# Class-wise data selection: Select pseudo-labeled unlabeled data in proportion to the class distribution of the exemplar set.
cls0_num = k*cls_dist[0]
cls1_num = k*cls_dist[1]
cls_0_indices = (pseudo_label_hard == 0).nonzero(as_tuple=True)[0]
cls_1_indices = (pseudo_label_hard == 1).nonzero(as_tuple=True)[0]
ce = torch.sum(ce, dim=1)
ce_class_0 = ce[cls_0_indices]
ce_class_1 = ce[cls_1_indices]
if len(ce_class_0) < cls0_num or len(ce_class_1) < cls1_num: # Fallback to top-k across all classes
_, top_k_indices = torch.topk(ce, k, largest=False, sorted=True)
else:
top_0_indices = cls_0_indices[torch.topk(ce_class_0, int(cls0_num), largest=False, sorted=True).indices]
top_1_indices = cls_1_indices[torch.topk(ce_class_1, int(cls1_num), largest=False, sorted=True).indices]
top_k_indices = torch.cat((top_0_indices, top_1_indices))
return top_k_indices
def update_centroids_ema(centroids, last_token_rep, pseudo_label, args):
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
centroids= F.normalize(centroids, p=2, dim=1)
weighted_sum = torch.matmul(pseudo_label.T, last_token_rep_norm)
# Normalize the weighted sums to get the new centroids
pseudo_label_sum = pseudo_label.sum(dim=0).unsqueeze(1) + 1e-8
new_centroids_batch = weighted_sum / pseudo_label_sum
# EMA update for centroids
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=1)
return updated_centroids
def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args):
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
centroids = F.normalize(centroids, p=2, dim=1)
max_indices = torch.argmax(pseudo_label, dim=1)
discrete_labels = torch.zeros_like(pseudo_label)
discrete_labels[torch.arange(pseudo_label.size(0)), max_indices] = 1
weighted_sum = torch.matmul(discrete_labels.T.float(), last_token_rep_norm)
pseudo_label_sum = discrete_labels.sum(dim=0).unsqueeze(1) + 1e-8
new_centroids_batch = weighted_sum / pseudo_label_sum
# EMA update for centroids
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
return updated_centroids