This commit is contained in:
weixin_43297441 2025-12-01 21:45:59 +08:00
parent 5f390c1bb6
commit f3dc5d0d85
3 changed files with 28 additions and 26 deletions

View File

@ -5,14 +5,14 @@ from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import argparse
from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn
from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn ,split_indices
from transformers import AutoTokenizer, AutoModelForCausalLM
from llm_layers import add_sv_layers
from sklearn.metrics import roc_auc_score
from torch.cuda.amp import autocast, GradScaler
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import logging
from utils import load_npy_shapes,split_indices
def seed_everything(seed: int):
@ -72,7 +72,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
losses = []
best_test_auroc = -1
scaler = GradScaler() # 混合精度的梯度缩放器
scaler = GradScaler('cuda') # 混合精度的梯度缩放器
num_trains = args.num_train
@ -122,7 +122,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
attention_mask = attention_mask.to(batch_prompts.device)
# ======= 前向传播:取最后一层最后非 padding token 的表示 r_v =======
with autocast(dtype=torch.float16):
with autocast('cuda', dtype=torch.float16):
output = model(
batch_prompts.squeeze(),
attention_mask=attention_mask.squeeze(),
@ -226,7 +226,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
num_val_samples = len(test_prompts)
with torch.no_grad():
with autocast(dtype=torch.float16):
with autocast('cuda', dtype=torch.float16):
for batch_start in range(0, num_val_samples, batch_size):
batch_prompts = test_prompts[batch_start:batch_start + batch_size]
batch_labels = test_labels[batch_start:batch_start + batch_size]
@ -253,7 +253,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
with autocast(dtype=torch.float16):
with autocast('cuda', dtype=torch.float16):
similarities = torch.matmul(last_token_rep, centroids.T) # [B, 2]
# 相似度 / 温度 -> softmax -> 取“真”的那一维作为概率

View File

@ -1,7 +1,7 @@
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.amp import autocast
import torch.nn.functional as F
@ -49,7 +49,7 @@ def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_sel
num_samples = len(prompts)
with torch.no_grad():
with autocast(dtype=torch.float16):
with autocast('cuda', dtype=torch.float16):
for batch_start in tqdm(range(0, num_samples, batch_size)):
batch_prompts = prompts[batch_start: batch_start + batch_size]
batch_labels = labels[batch_start: batch_start + batch_size]
@ -180,4 +180,22 @@ def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args):
# EMA update for centroids
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
return updated_centroids
return updated_centroids
def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42):
rng = np.random.default_rng(seed)
indices = np.arange(n)
rng.shuffle(indices)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_index = indices[:n_train]
val_index = indices[n_train:n_train + n_val]
test_index = indices[n_train + n_val:]
return train_index, val_index, test_index

View File

@ -106,19 +106,3 @@ def load_npy_shapes(directory_path, steer_place="layer"):
def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42):
rng = np.random.default_rng(seed)
indices = np.arange(n)
rng.shuffle(indices)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_index = indices[:n_train]
val_index = indices[n_train:n_train + n_val]
test_index = indices[n_train + n_val:]
return train_index, val_index, test_index