This commit is contained in:
parent
5f390c1bb6
commit
f3dc5d0d85
|
|
@ -5,14 +5,14 @@ from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn
|
from train_utils import get_last_non_padded_token_rep, compute_ot_loss_cos, update_centroids_ema, update_centroids_ema_hard, get_ex_data, collate_fn ,split_indices
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from llm_layers import add_sv_layers
|
from llm_layers import add_sv_layers
|
||||||
from sklearn.metrics import roc_auc_score
|
from sklearn.metrics import roc_auc_score
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.amp import autocast, GradScaler
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import logging
|
import logging
|
||||||
from utils import load_npy_shapes,split_indices
|
|
||||||
|
|
||||||
|
|
||||||
def seed_everything(seed: int):
|
def seed_everything(seed: int):
|
||||||
|
|
@ -72,7 +72,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
losses = []
|
losses = []
|
||||||
best_test_auroc = -1
|
best_test_auroc = -1
|
||||||
|
|
||||||
scaler = GradScaler() # 混合精度的梯度缩放器
|
scaler = GradScaler('cuda') # 混合精度的梯度缩放器
|
||||||
|
|
||||||
num_trains = args.num_train
|
num_trains = args.num_train
|
||||||
|
|
||||||
|
|
@ -122,7 +122,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
attention_mask = attention_mask.to(batch_prompts.device)
|
attention_mask = attention_mask.to(batch_prompts.device)
|
||||||
|
|
||||||
# ======= 前向传播:取最后一层最后非 padding token 的表示 r_v =======
|
# ======= 前向传播:取最后一层最后非 padding token 的表示 r_v =======
|
||||||
with autocast(dtype=torch.float16):
|
with autocast('cuda', dtype=torch.float16):
|
||||||
output = model(
|
output = model(
|
||||||
batch_prompts.squeeze(),
|
batch_prompts.squeeze(),
|
||||||
attention_mask=attention_mask.squeeze(),
|
attention_mask=attention_mask.squeeze(),
|
||||||
|
|
@ -226,7 +226,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
|
||||||
num_val_samples = len(test_prompts)
|
num_val_samples = len(test_prompts)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast(dtype=torch.float16):
|
with autocast('cuda', dtype=torch.float16):
|
||||||
for batch_start in range(0, num_val_samples, batch_size):
|
for batch_start in range(0, num_val_samples, batch_size):
|
||||||
batch_prompts = test_prompts[batch_start:batch_start + batch_size]
|
batch_prompts = test_prompts[batch_start:batch_start + batch_size]
|
||||||
batch_labels = test_labels[batch_start:batch_start + batch_size]
|
batch_labels = test_labels[batch_start:batch_start + batch_size]
|
||||||
|
|
@ -253,7 +253,7 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
|
||||||
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
|
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
|
||||||
centroids = F.normalize(centroids, p=2, dim=-1)
|
centroids = F.normalize(centroids, p=2, dim=-1)
|
||||||
|
|
||||||
with autocast(dtype=torch.float16):
|
with autocast('cuda', dtype=torch.float16):
|
||||||
similarities = torch.matmul(last_token_rep, centroids.T) # [B, 2]
|
similarities = torch.matmul(last_token_rep, centroids.T) # [B, 2]
|
||||||
|
|
||||||
# 相似度 / 温度 -> softmax -> 取“真”的那一维作为概率
|
# 相似度 / 温度 -> softmax -> 取“真”的那一维作为概率
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.cuda.amp import autocast
|
from torch.amp import autocast
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_sel
|
||||||
num_samples = len(prompts)
|
num_samples = len(prompts)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast(dtype=torch.float16):
|
with autocast('cuda', dtype=torch.float16):
|
||||||
for batch_start in tqdm(range(0, num_samples, batch_size)):
|
for batch_start in tqdm(range(0, num_samples, batch_size)):
|
||||||
batch_prompts = prompts[batch_start: batch_start + batch_size]
|
batch_prompts = prompts[batch_start: batch_start + batch_size]
|
||||||
batch_labels = labels[batch_start: batch_start + batch_size]
|
batch_labels = labels[batch_start: batch_start + batch_size]
|
||||||
|
|
@ -180,4 +180,22 @@ def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args):
|
||||||
# EMA update for centroids
|
# EMA update for centroids
|
||||||
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
|
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
|
||||||
|
|
||||||
return updated_centroids
|
return updated_centroids
|
||||||
|
|
||||||
|
|
||||||
|
def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
indices = np.arange(n)
|
||||||
|
rng.shuffle(indices)
|
||||||
|
|
||||||
|
n_train = int(n * train_ratio)
|
||||||
|
n_val = int(n * val_ratio)
|
||||||
|
|
||||||
|
train_index = indices[:n_train]
|
||||||
|
val_index = indices[n_train:n_train + n_val]
|
||||||
|
test_index = indices[n_train + n_val:]
|
||||||
|
|
||||||
|
return train_index, val_index, test_index
|
||||||
16
utils.py
16
utils.py
|
|
@ -106,19 +106,3 @@ def load_npy_shapes(directory_path, steer_place="layer"):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42):
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
indices = np.arange(n)
|
|
||||||
rng.shuffle(indices)
|
|
||||||
|
|
||||||
n_train = int(n * train_ratio)
|
|
||||||
n_val = int(n * val_ratio)
|
|
||||||
|
|
||||||
train_index = indices[:n_train]
|
|
||||||
val_index = indices[n_train:n_train + n_val]
|
|
||||||
test_index = indices[n_train + n_val:]
|
|
||||||
|
|
||||||
return train_index, val_index, test_index
|
|
||||||
Loading…
Reference in New Issue