This commit is contained in:
weixin_43297441 2025-12-02 11:27:03 +08:00
parent f3dc5d0d85
commit 67139e4279
6 changed files with 147 additions and 8 deletions

View File

@ -0,0 +1,16 @@
2025-12-02 11:13:28,208 - INFO - Starting training
2025-12-02 11:13:28,208 - INFO - component=res, str_layer=9
2025-12-02 11:13:48,802 - INFO - Epoch [1/20], Loss: 0.3692
2025-12-02 11:13:48,802 - INFO - Best test AUROC: 1.0000, at epoch: 0
2025-12-02 11:13:48,803 - INFO - Saved best centroids to SV_alpaca_7B_AdvBench/res/9/5/best_centroids_epoch_0.npy
2025-12-02 11:13:48,803 - INFO - Epoch [1/20], Train Loss: 0.3692,
2025-12-02 11:13:48,803 - INFO - Test AUROC: 1.0000
2025-12-02 11:14:08,725 - INFO - Epoch [2/20], Loss: 0.0742
2025-12-02 11:14:08,726 - INFO - Epoch [2/20], Train Loss: 0.0742,
2025-12-02 11:14:08,726 - INFO - Test AUROC: 1.0000
2025-12-02 11:14:28,751 - INFO - Epoch [3/20], Loss: 0.0150
2025-12-02 11:14:28,751 - INFO - Epoch [3/20], Train Loss: 0.0150,
2025-12-02 11:14:28,751 - INFO - Test AUROC: 1.0000
2025-12-02 11:14:48,784 - INFO - Epoch [4/20], Loss: 0.0036
2025-12-02 11:14:48,784 - INFO - Epoch [4/20], Train Loss: 0.0036,
2025-12-02 11:14:48,784 - INFO - Test AUROC: 1.0000

2
a.sh
View File

@ -1,2 +1,2 @@
export HF_ENDPOINT=https://hf-mirror.com
CUDA_VISIBLE_DEVICES=9 python hal_generate.py
CUDA_VISIBLE_DEVICES=9 python steer_vector.py

112
cache_utils.py Normal file
View File

@ -0,0 +1,112 @@
import copy
import importlib.metadata
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
# from .configuration_utils import PretrainedConfig
# from .utils import (
# is_hqq_available,
# is_optimum_quanto_available,
# is_torchdynamo_compiling,
# logging,
# )
# from .utils.deprecation import deprecate_kwarg
# if is_hqq_available():
# from hqq.core.quantize import Quantizer as HQQQuantizer
# logger = logging.get_logger(__name__)
class Cache(torch.nn.Module):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def __init__(self):
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
# we change naming to be more explicit
def get_max_length(self) -> Optional[int]:
# logger.warning_once(
# "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
# "Calling `get_max_cache()` will raise error from v4.48"
# )
return self.get_max_cache_shape()
def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_cache_shape()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] != []:
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []:
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
@property
def seen_tokens(self):
# logger.warning_once(
# "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
# "model input instead."
# )
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None

View File

@ -91,7 +91,7 @@ class SVLayer(nn.Module):
self.lam = lam
def forward(self, x):
if self.tv is not None:
if self.sv is not None:
x = x.half()
y = self.lam[0] * self.sv.repeat(1,x.shape[1],1)

View File

@ -48,8 +48,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
# ========= 日志 & 结果保存目录 =========
dir_name = f"SV_{args.model_name}_{args.dataset_name}/{args.component}/{args.str_layer}/{args.lam}"
log_dir = f"/{dir_name}/"
log_file = os.path.join(log_dir, f"log.txt")
log_file = os.path.join(dir_name, f"log.txt")
os.makedirs(dir_name, exist_ok=True)
logging.basicConfig(
@ -74,7 +73,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
scaler = GradScaler('cuda') # 混合精度的梯度缩放器
num_trains = args.num_train
num_trains = len(train_prompts)
@ -207,7 +206,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
return best_test_auroc
return best_test_auroc
@ -265,6 +264,12 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
val_predictions = torch.cat(val_predictions)
val_labels = torch.cat(val_labels)
# Debug: print predictions and labels distribution
print(f"[DEBUG] test_model: {len(val_predictions)} samples")
print(f"[DEBUG] Predictions min/max/mean: {val_predictions.min():.4f}/{val_predictions.max():.4f}/{val_predictions.mean():.4f}")
print(f"[DEBUG] Labels distribution: {torch.sum(val_labels == 0)} zeros, {torch.sum(val_labels == 1)} ones")
return val_predictions, val_labels
@ -287,6 +292,7 @@ def main():
# 1. 解析命令行参数
# ======================
parser = argparse.ArgumentParser()
parser.add_argument('--model_prefix', type=str, default='')
parser.add_argument('--model_name', type=str, default='alpaca_7B')
parser.add_argument('--num_gene', type=int, default=1) # 每个问题生成多少个答案
parser.add_argument('--train_sv', type=bool, default=True) # 是否执行“生成答案”阶段1=生成+保存答案0=不生成)
@ -378,8 +384,12 @@ def main():
train_index, val_index, test_index=split_indices(len(prompts_), args.train_ratio, args.val_ratio)
labels = [categories[test_index], categories[train_index]]
prompts = [prompts_[test_index], prompts_[train_index]]
# Convert numpy arrays to lists for Python list indexing
test_index_list = test_index.tolist()
train_index_list = train_index.tolist()
labels = [[categories[i] for i in test_index_list], [categories[i] for i in train_index_list]]
prompts = [[prompts_[i] for i in test_index_list], [prompts_[i] for i in train_index_list]]

View File

@ -3,6 +3,7 @@ import torch
from tqdm import tqdm
from torch.amp import autocast
import torch.nn.functional as F
import numpy as np