From 5f390c1bb6e3cba745e2d3641fcb00a87b68bfca Mon Sep 17 00:00:00 2001 From: weixin_43297441 Date: Mon, 1 Dec 2025 21:40:08 +0800 Subject: [PATCH] push to remote --- steer_vector.py | 189 ++++++++---------------------------------------- train_utils.py | 183 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+), 159 deletions(-) create mode 100644 train_utils.py diff --git a/steer_vector.py b/steer_vector.py index 6dca5ba..1396aec 100644 --- a/steer_vector.py +++ b/steer_vector.py @@ -11,7 +11,6 @@ from llm_layers import add_sv_layers from sklearn.metrics import roc_auc_score from torch.cuda.amp import autocast, GradScaler import torch.nn.functional as F -from sinkhorn_knopp import SinkhornKnopp_imb import logging from utils import load_npy_shapes,split_indices @@ -81,13 +80,13 @@ def train_model(model, optimizer, device, prompts, labels, args): # 用 exemplar 的标签估计类分布 w(truthful/hallu 的先验) # ex_hallu = P(hallucinated), ex_true = P(truthful) - ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars - ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars - cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda() - cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1] + # ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars + # ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars + # cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda() + # cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1] # 实例化带类别边缘约束的 Sinkhorn(实现论文里的 OT 问题) - sinkhorn = SinkhornKnopp_imb(args, cls_dist) + # sinkhorn = SinkhornKnopp_imb(args, cls_dist) # ========= 初始化两个类别的“原型向量” μ_c ========= # 这里 centroids 就是论文中球面上的 μ_truthful, μ_hallucinated @@ -160,7 +159,7 @@ def train_model(model, optimizer, device, prompts, labels, args): centroids, last_token_rep, batch_labels_oh, args ) - # ====== 只对 TSV 反传,LLM 本体是冻结的 ====== + # ====== 只对 SV 反传,LLM 本体是冻结的 ====== scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() @@ -174,12 +173,12 @@ def train_model(model, optimizer, device, prompts, labels, args): # ====== 在 test 上评估,记录 AUROC ====== if (epoch + 1) % 1 == 0: test_labels_ = test_labels - test_predictions, test_labels_combined = test_model( + test_predictions, test_labels= test_model( model, centroids, test_prompts, test_labels_, device, batch_size, layer_number ) test_auroc = roc_auc_score( - test_labels_combined.cpu().numpy(), test_predictions.cpu().numpy() + test_labels.cpu().numpy(), test_predictions.cpu().numpy() ) print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") @@ -193,6 +192,12 @@ def train_model(model, optimizer, device, prompts, labels, args): logging.info( f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}" ) + # 保存最佳AUROC时的centroids到npy文件 + centroids_np = centroids.cpu().numpy() + centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy") + np.save(centroids_file, centroids_np) + print(f"Saved best centroids to {centroids_file}") + logging.info(f"Saved best centroids to {centroids_file}") logging.info( f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, " @@ -200,143 +205,7 @@ def train_model(model, optimizer, device, prompts, labels, args): logging.info(f"Test AUROC: {test_auroc:.4f}") print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}") - # ========= 进入 Phase 2:伪标签 + 自训练 ========= - logging.info(f"SS Learning Starts") - - # 1) get_ex_data: 在 wild train 上做 TSV 推断 + Sinkhorn OT,选出高置信伪标签样本 - with torch.no_grad(): - selected_indices, selected_labels_soft = get_ex_data( - model, - train_prompts, # wild 区间的样本 - train_labels, # 这里一般是无标签 or dummy,这里暂时没用 - batch_size, - centroids, # 当前原型 - sinkhorn, # OT + Sinkhorn 对象,用来求 Q - args.num_selected_data, # 选多少个伪标签样本 - cls_dist, - args, - ) - - num_samples = len(selected_indices) + args.num_exemplars - - num_epochs = args.aug_num_epochs # 自训练的 epoch 数 - - exemplar_label = torch.tensor(exemplar_labels).cuda() - - # 2) 构造“增强后的训练集”:伪标签样本 + 原来的 exemplar - selected_prompts = [train_prompts[i] for i in selected_indices] - selected_labels = selected_labels_soft # 这里已经是 soft label(OT 输出的 q) - - augmented_prompts = selected_prompts + exemplar_prompts_ - exemplar_labels = torch.nn.functional.one_hot( - exemplar_label.to(torch.int64), num_classes=2 - ) - augmented_labels = torch.concat( - (selected_labels, torch.tensor(exemplar_labels).clone().cuda()) - ) - - augmented_prompts_train = augmented_prompts - augmented_labels_label = augmented_labels - - num_samples = len(augmented_prompts_train) - - # 3) 在“伪标签 + exemplar”上再训练一轮 TSV(自训练) - with autocast(dtype=torch.float16): - for epoch in range(num_epochs): - running_loss = 0.0 - total = 0 - all_labels = [] - - for batch_start in tqdm( - range(0, num_samples, batch_size), - desc=f"Epoch {epoch+1}/{num_epochs} Batches", - leave=False, - ): - batch_prompts = augmented_prompts_train[batch_start: batch_start + batch_size] - batch_labels = augmented_labels_label[batch_start: batch_start + batch_size] - - batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels) - - attention_mask = (batch_prompts != 0).half() - - batch_prompts = batch_prompts.to(device) - batch_labels = batch_labels.to(device) - attention_mask = attention_mask.to(batch_prompts.device) - - output = model( - batch_prompts.squeeze(), - attention_mask=attention_mask.squeeze(), - output_hidden_states=True, - ) - - hidden_states = output.hidden_states - hidden_states = torch.stack(hidden_states, dim=0).squeeze() - last_layer_hidden_state = hidden_states[layer_number] - - last_token_rep = get_last_non_padded_token_rep( - last_layer_hidden_state, attention_mask.squeeze() - ) - - # 这里 compute_ot_loss_cos 接收的是 soft label (OT 得到的 q) - ot_loss, similarities = compute_ot_loss_cos( - last_token_rep, centroids, batch_labels, batch_size, args - ) - - loss = ot_loss - - with torch.no_grad(): - # 自训练阶段,原型更新用的是 soft label 版的 EMA - centroids = update_centroids_ema( - centroids, last_token_rep, batch_labels.half(), args - ) - all_labels.append(batch_labels.cpu()) - total += batch_labels.size(0) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - running_loss += loss.item() * batch_labels.size(0) - - epoch_loss = running_loss / total - - # ====== 用 test set 评估 AUROC ====== - with torch.no_grad(): - all_labels = torch.cat(all_labels).numpy() - test_labels_ = test_labels - - if epoch % 1 == 0: - test_predictions, test_labels_combined = test_model( - model, - centroids, - test_prompts, - test_labels_, - device, - batch_size, - layer_number, - ) - test_auroc = roc_auc_score(test_labels_combined, test_predictions) - - print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") - losses.append(epoch_loss) - - if test_auroc > best_test_auroc: - best_test_auroc = test_auroc - best_test_epoch = epoch + args.init_num_epochs - print( - f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}" - ) - logging.info( - f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}" - ) - - logging.info( - f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, " - ) - logging.info( - f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}" - ) + return best_test_auroc @@ -345,14 +214,14 @@ def train_model(model, optimizer, device, prompts, labels, args): def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number): """ 在论文里相当于: - - 对 test 集算当前 TSV steer 后的 r_v + - 对 test 集算当前 TV steer 后的 r_v - 计算与两个原型 μ_c 的余弦相似度 - 用 softmax 得到属于 truthful 的概率 p(c=truthful | r_v) - 用这个概率做 AUROC 评估 """ model.eval() - val_predictions = [] # 保存 p_truthful - val_labels_combined = [] # 对应的 gt(二分类标签) + val_predictions = [] # 保存 p + val_labels = [] # 对应的(二分类标签) num_val_samples = len(test_prompts) @@ -392,11 +261,11 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size, similarity_scores = similarity_scores[:, 1] # 假设 index 1 = truthful val_predictions.append(similarity_scores.cpu()) - val_labels_combined.append(batch_labels.cpu()) + val_labels.append(batch_labels.cpu()) val_predictions = torch.cat(val_predictions) - val_labels_combined = torch.cat(val_labels_combined) - return val_predictions, val_labels_combined + val_labels = torch.cat(val_labels) + return val_predictions, val_labels @@ -477,7 +346,7 @@ def main(): ) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '') - prompts = [] + prompts_ = [] qa_pairs = [] categories = [] @@ -491,9 +360,9 @@ def main(): return_tensors='pt' ).input_ids.cuda() - prompts.append(prompt) + prompts_.append(prompt) qa_pairs.append({'Question': question, 'Answer': adversary}) - categories.append(1) # 1 = adverse + categories.append(0) # 0 = adverse for i in tqdm(range(length)): question = dataset[i]['query'] @@ -503,12 +372,14 @@ def main(): return_tensors='pt' ).input_ids.cuda() - prompts.append(prompt) + prompts_.append(prompt) qa_pairs.append({'Question': question, 'Answer': clean}) - categories.append(0) # 0 = benign + categories.append(1) # 1 = benign - train_index, val_index, test_index=split_indices(len(prompts), args.train_ratio, args.val_ratio) + train_index, val_index, test_index=split_indices(len(prompts_), args.train_ratio, args.val_ratio) + labels = [categories[test_index], categories[train_index]] + prompts = [prompts_[test_index], prompts_[train_index]] @@ -532,7 +403,7 @@ def main(): optimizer = torch.optim.AdamW(list(sv.parameters()), lr=args.lr) # ====== 6.5 调用 train_model,进入两阶段训练流程 ====== - train_model(model, optimizer, device, train_data, train_labels, args=args) + train_model(model, optimizer, device, prompts, labels, args=args) else: print("Skip training steer vectors.") diff --git a/train_utils.py b/train_utils.py new file mode 100644 index 0000000..210435a --- /dev/null +++ b/train_utils.py @@ -0,0 +1,183 @@ + +import torch +from tqdm import tqdm +from torch.cuda.amp import autocast +import torch.nn.functional as F + + + +def collate_fn(prompts, labels): + + # Find the maximum sequence length in the batch + max_seq_len = max(prompt.size(1) for prompt in prompts) + + # Initialize a tensor to hold the batched prompts + batch_size = len(prompts) + dtype = prompts[0].dtype + device = prompts[0].device # Assuming all prompts are on the same device + prompts_padded = torch.zeros(batch_size, 1, max_seq_len, dtype=dtype) + + # Pad each prompt to the maximum sequence length + for i, prompt in enumerate(prompts): + seq_len = prompt.size(1) + prompts_padded[i, :, :seq_len] = prompt + + # Stack labels into a tensor + labels = torch.tensor(labels, dtype=torch.long, device=device) + + return prompts_padded, labels + + +def get_last_non_padded_token_rep(hidden_states, attention_mask): + """ + Get the last non-padded token's representation for each sequence in the batch. + """ + # Find the length of each sequence by summing the attention mask (1 for real tokens, 0 for padding) + lengths = attention_mask.squeeze().sum(dim=1).long() + + # Index the last non-padded token for each sequence + batch_size, max_seq_len, hidden_size = hidden_states.size() + last_token_reps = torch.stack([hidden_states[i, lengths[i]-1, :] for i in range(batch_size)]) + + return last_token_reps + + +def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_selected_data, cls_dist, args): + + all_embeddings = [] + all_labels = [] + num_samples = len(prompts) + + with torch.no_grad(): + with autocast(dtype=torch.float16): + for batch_start in tqdm(range(0, num_samples, batch_size)): + batch_prompts = prompts[batch_start: batch_start + batch_size] + batch_labels = labels[batch_start: batch_start + batch_size] + batch_prompts, batch_labels = collate_fn(batch_prompts,batch_labels) + attention_mask = (batch_prompts != 0).half() + batch_prompts = batch_prompts.cuda() + batch_labels = batch_labels.cuda() + attention_mask = attention_mask.to(batch_prompts.device) + all_labels.append(batch_labels.cpu().numpy()) + + output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True) + hidden_states = output.hidden_states + + hidden_states = torch.stack(hidden_states, dim=0).squeeze() + last_layer_hidden_state = hidden_states[-1] + + last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze()) + all_embeddings.append(last_token_rep) + + all_embeddings = F.normalize(torch.concat(all_embeddings),p=2,dim=-1) + + pseudo_label = sinkhorn(all_embeddings, centroids) + + selected_indices = compute_entropy(all_embeddings, centroids, pseudo_label, num_selected_data, cls_dist, args) + + selected_labels_soft = pseudo_label[selected_indices] + + + return selected_indices, selected_labels_soft + + +def compute_ot_loss_cos(last_token_rep, centroids, pseudo_label, batch_size, args): + + last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) + + centroids = F.normalize(centroids, p=2, dim=-1) + + similarities = torch.matmul(last_token_rep, centroids.T) + + similarities = similarities / args.cos_temp + + pt = F.softmax(similarities, dim=-1) + + ot_loss = -torch.sum(pseudo_label * torch.log(pt + 1e-8)) / pseudo_label.shape[0] + + return ot_loss, similarities + + +def compute_entropy(last_token_rep, centroids, pseudo_label, k, cls_dist, args): + + + last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) + + centroids = F.normalize(centroids, p=2, dim=-1) + + similarities = torch.matmul(last_token_rep, centroids.T) + + similarities = similarities / args.cos_temp + + pt = F.softmax(similarities, dim=-1) + + ce = - (pseudo_label * torch.log(pt + 1e-8)) + + pseudo_label_hard = torch.argmax(pt,dim=1) + + # * Added for preventing severe cases + # Class-wise data selection: Select pseudo-labeled unlabeled data in proportion to the class distribution of the exemplar set. + + cls0_num = k*cls_dist[0] + cls1_num = k*cls_dist[1] + + cls_0_indices = (pseudo_label_hard == 0).nonzero(as_tuple=True)[0] + cls_1_indices = (pseudo_label_hard == 1).nonzero(as_tuple=True)[0] + + ce = torch.sum(ce, dim=1) + + ce_class_0 = ce[cls_0_indices] + ce_class_1 = ce[cls_1_indices] + + if len(ce_class_0) < cls0_num or len(ce_class_1) < cls1_num: # Fallback to top-k across all classes + + _, top_k_indices = torch.topk(ce, k, largest=False, sorted=True) + + else: + + top_0_indices = cls_0_indices[torch.topk(ce_class_0, int(cls0_num), largest=False, sorted=True).indices] + top_1_indices = cls_1_indices[torch.topk(ce_class_1, int(cls1_num), largest=False, sorted=True).indices] + top_k_indices = torch.cat((top_0_indices, top_1_indices)) + + return top_k_indices + + +def update_centroids_ema(centroids, last_token_rep, pseudo_label, args): + + last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1) + + centroids= F.normalize(centroids, p=2, dim=1) + + weighted_sum = torch.matmul(pseudo_label.T, last_token_rep_norm) + + # Normalize the weighted sums to get the new centroids + pseudo_label_sum = pseudo_label.sum(dim=0).unsqueeze(1) + 1e-8 + new_centroids_batch = weighted_sum / pseudo_label_sum + + # EMA update for centroids + updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=1) + + return updated_centroids + +def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args): + + last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1) + + centroids = F.normalize(centroids, p=2, dim=1) + + max_indices = torch.argmax(pseudo_label, dim=1) + + discrete_labels = torch.zeros_like(pseudo_label) + + discrete_labels[torch.arange(pseudo_label.size(0)), max_indices] = 1 + + weighted_sum = torch.matmul(discrete_labels.T.float(), last_token_rep_norm) + + pseudo_label_sum = discrete_labels.sum(dim=0).unsqueeze(1) + 1e-8 + + new_centroids_batch = weighted_sum / pseudo_label_sum + + # EMA update for centroids + updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1) + + return updated_centroids \ No newline at end of file