From f3dc5d0d85119b44cd4acffa7c57529cac4fd875 Mon Sep 17 00:00:00 2001 From: weixin_43297441 Date: Mon, 1 Dec 2025 21:45:59 +0800 Subject: [PATCH] up --- steer_vector.py | 14 +++++++------- train_utils.py | 24 +++++++++++++++++++++--- utils.py | 16 ---------------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/steer_vector.py b/steer_vector.py index 1396aec..cb0f5bb 100644 --- a/steer_vector.py +++ b/steer_vector.py @@ -5,14 +5,14 @@ from datasets import load_dataset from tqdm import tqdm import numpy as np import argparse -from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn +from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn ,split_indices from transformers import AutoTokenizer, AutoModelForCausalLM from llm_layers import add_sv_layers from sklearn.metrics import roc_auc_score -from torch.cuda.amp import autocast, GradScaler +from torch.amp import autocast, GradScaler import torch.nn.functional as F import logging -from utils import load_npy_shapes,split_indices + def seed_everything(seed: int): @@ -72,7 +72,7 @@ def train_model(model, optimizer, device, prompts, labels, args): losses = [] best_test_auroc = -1 - scaler = GradScaler() # 混合精度的梯度缩放器 + scaler = GradScaler('cuda') # 混合精度的梯度缩放器 num_trains = args.num_train @@ -122,7 +122,7 @@ def train_model(model, optimizer, device, prompts, labels, args): attention_mask = attention_mask.to(batch_prompts.device) # ======= 前向传播:取最后一层最后非 padding token 的表示 r_v ======= - with autocast(dtype=torch.float16): + with autocast('cuda', dtype=torch.float16): output = model( batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), @@ -226,7 +226,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size, num_val_samples = len(test_prompts) with torch.no_grad(): - with autocast(dtype=torch.float16): + with autocast('cuda', dtype=torch.float16): for batch_start in range(0, num_val_samples, batch_size): batch_prompts = test_prompts[batch_start:batch_start + batch_size] batch_labels = test_labels[batch_start:batch_start + batch_size] @@ -253,7 +253,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size, last_token_rep = F.normalize(last_token_rep, p=2, dim=-1) centroids = F.normalize(centroids, p=2, dim=-1) - with autocast(dtype=torch.float16): + with autocast('cuda', dtype=torch.float16): similarities = torch.matmul(last_token_rep, centroids.T) # [B, 2] # 相似度 / 温度 -> softmax -> 取“真”的那一维作为概率 diff --git a/train_utils.py b/train_utils.py index 210435a..fbcdc68 100644 --- a/train_utils.py +++ b/train_utils.py @@ -1,7 +1,7 @@ import torch from tqdm import tqdm -from torch.cuda.amp import autocast +from torch.amp import autocast import torch.nn.functional as F @@ -49,7 +49,7 @@ def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_sel num_samples = len(prompts) with torch.no_grad(): - with autocast(dtype=torch.float16): + with autocast('cuda', dtype=torch.float16): for batch_start in tqdm(range(0, num_samples, batch_size)): batch_prompts = prompts[batch_start: batch_start + batch_size] batch_labels = labels[batch_start: batch_start + batch_size] @@ -180,4 +180,22 @@ def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args): # EMA update for centroids updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1) - return updated_centroids \ No newline at end of file + return updated_centroids + + +def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42): + + + + rng = np.random.default_rng(seed) + indices = np.arange(n) + rng.shuffle(indices) + + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_index = indices[:n_train] + val_index = indices[n_train:n_train + n_val] + test_index = indices[n_train + n_val:] + + return train_index, val_index, test_index \ No newline at end of file diff --git a/utils.py b/utils.py index bf9e580..dc73335 100644 --- a/utils.py +++ b/utils.py @@ -106,19 +106,3 @@ def load_npy_shapes(directory_path, steer_place="layer"): -def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42): - - - - rng = np.random.default_rng(seed) - indices = np.arange(n) - rng.shuffle(indices) - - n_train = int(n * train_ratio) - n_val = int(n * val_ratio) - - train_index = indices[:n_train] - val_index = indices[n_train:n_train + n_val] - test_index = indices[n_train + n_val:] - - return train_index, val_index, test_index \ No newline at end of file