push to remote
This commit is contained in:
parent
9b8eca3c38
commit
5f390c1bb6
189
steer_vector.py
189
steer_vector.py
|
|
@ -11,7 +11,6 @@ from llm_layers import add_sv_layers
|
||||||
from sklearn.metrics import roc_auc_score
|
from sklearn.metrics import roc_auc_score
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from sinkhorn_knopp import SinkhornKnopp_imb
|
|
||||||
import logging
|
import logging
|
||||||
from utils import load_npy_shapes,split_indices
|
from utils import load_npy_shapes,split_indices
|
||||||
|
|
||||||
|
|
@ -81,13 +80,13 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
|
|
||||||
# 用 exemplar 的标签估计类分布 w(truthful/hallu 的先验)
|
# 用 exemplar 的标签估计类分布 w(truthful/hallu 的先验)
|
||||||
# ex_hallu = P(hallucinated), ex_true = P(truthful)
|
# ex_hallu = P(hallucinated), ex_true = P(truthful)
|
||||||
ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars
|
# ex_hallu = (num_exemplars - exemplar_labels[:num_exemplars].sum()) / num_exemplars
|
||||||
ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars
|
# ex_true = (exemplar_labels[:num_exemplars].sum()) / num_exemplars
|
||||||
cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda()
|
# cls_dist = torch.tensor([ex_hallu, ex_true]).float().cuda()
|
||||||
cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1]
|
# cls_dist = cls_dist.view(-1, 1) # 形状 [2, 1]
|
||||||
|
|
||||||
# 实例化带类别边缘约束的 Sinkhorn(实现论文里的 OT 问题)
|
# 实例化带类别边缘约束的 Sinkhorn(实现论文里的 OT 问题)
|
||||||
sinkhorn = SinkhornKnopp_imb(args, cls_dist)
|
# sinkhorn = SinkhornKnopp_imb(args, cls_dist)
|
||||||
|
|
||||||
# ========= 初始化两个类别的“原型向量” μ_c =========
|
# ========= 初始化两个类别的“原型向量” μ_c =========
|
||||||
# 这里 centroids 就是论文中球面上的 μ_truthful, μ_hallucinated
|
# 这里 centroids 就是论文中球面上的 μ_truthful, μ_hallucinated
|
||||||
|
|
@ -160,7 +159,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
centroids, last_token_rep, batch_labels_oh, args
|
centroids, last_token_rep, batch_labels_oh, args
|
||||||
)
|
)
|
||||||
|
|
||||||
# ====== 只对 TSV 反传,LLM 本体是冻结的 ======
|
# ====== 只对 SV 反传,LLM 本体是冻结的 ======
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
@ -174,12 +173,12 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
# ====== 在 test 上评估,记录 AUROC ======
|
# ====== 在 test 上评估,记录 AUROC ======
|
||||||
if (epoch + 1) % 1 == 0:
|
if (epoch + 1) % 1 == 0:
|
||||||
test_labels_ = test_labels
|
test_labels_ = test_labels
|
||||||
test_predictions, test_labels_combined = test_model(
|
test_predictions, test_labels= test_model(
|
||||||
model, centroids, test_prompts, test_labels_, device, batch_size, layer_number
|
model, centroids, test_prompts, test_labels_, device, batch_size, layer_number
|
||||||
)
|
)
|
||||||
|
|
||||||
test_auroc = roc_auc_score(
|
test_auroc = roc_auc_score(
|
||||||
test_labels_combined.cpu().numpy(), test_predictions.cpu().numpy()
|
test_labels.cpu().numpy(), test_predictions.cpu().numpy()
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
||||||
|
|
@ -193,6 +192,12 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
||||||
)
|
)
|
||||||
|
# 保存最佳AUROC时的centroids到npy文件
|
||||||
|
centroids_np = centroids.cpu().numpy()
|
||||||
|
centroids_file = os.path.join(dir_name, f"best_centroids_epoch_{epoch}.npy")
|
||||||
|
np.save(centroids_file, centroids_np)
|
||||||
|
print(f"Saved best centroids to {centroids_file}")
|
||||||
|
logging.info(f"Saved best centroids to {centroids_file}")
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
||||||
|
|
@ -200,143 +205,7 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
logging.info(f"Test AUROC: {test_auroc:.4f}")
|
logging.info(f"Test AUROC: {test_auroc:.4f}")
|
||||||
print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}")
|
print(f"Epoch [{epoch+1}/{num_epochs}],Test AUROC: {test_auroc:.4f}")
|
||||||
|
|
||||||
# ========= 进入 Phase 2:伪标签 + 自训练 =========
|
|
||||||
logging.info(f"SS Learning Starts")
|
|
||||||
|
|
||||||
# 1) get_ex_data: 在 wild train 上做 TSV 推断 + Sinkhorn OT,选出高置信伪标签样本
|
|
||||||
with torch.no_grad():
|
|
||||||
selected_indices, selected_labels_soft = get_ex_data(
|
|
||||||
model,
|
|
||||||
train_prompts, # wild 区间的样本
|
|
||||||
train_labels, # 这里一般是无标签 or dummy,这里暂时没用
|
|
||||||
batch_size,
|
|
||||||
centroids, # 当前原型
|
|
||||||
sinkhorn, # OT + Sinkhorn 对象,用来求 Q
|
|
||||||
args.num_selected_data, # 选多少个伪标签样本
|
|
||||||
cls_dist,
|
|
||||||
args,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_samples = len(selected_indices) + args.num_exemplars
|
|
||||||
|
|
||||||
num_epochs = args.aug_num_epochs # 自训练的 epoch 数
|
|
||||||
|
|
||||||
exemplar_label = torch.tensor(exemplar_labels).cuda()
|
|
||||||
|
|
||||||
# 2) 构造“增强后的训练集”:伪标签样本 + 原来的 exemplar
|
|
||||||
selected_prompts = [train_prompts[i] for i in selected_indices]
|
|
||||||
selected_labels = selected_labels_soft # 这里已经是 soft label(OT 输出的 q)
|
|
||||||
|
|
||||||
augmented_prompts = selected_prompts + exemplar_prompts_
|
|
||||||
exemplar_labels = torch.nn.functional.one_hot(
|
|
||||||
exemplar_label.to(torch.int64), num_classes=2
|
|
||||||
)
|
|
||||||
augmented_labels = torch.concat(
|
|
||||||
(selected_labels, torch.tensor(exemplar_labels).clone().cuda())
|
|
||||||
)
|
|
||||||
|
|
||||||
augmented_prompts_train = augmented_prompts
|
|
||||||
augmented_labels_label = augmented_labels
|
|
||||||
|
|
||||||
num_samples = len(augmented_prompts_train)
|
|
||||||
|
|
||||||
# 3) 在“伪标签 + exemplar”上再训练一轮 TSV(自训练)
|
|
||||||
with autocast(dtype=torch.float16):
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
running_loss = 0.0
|
|
||||||
total = 0
|
|
||||||
all_labels = []
|
|
||||||
|
|
||||||
for batch_start in tqdm(
|
|
||||||
range(0, num_samples, batch_size),
|
|
||||||
desc=f"Epoch {epoch+1}/{num_epochs} Batches",
|
|
||||||
leave=False,
|
|
||||||
):
|
|
||||||
batch_prompts = augmented_prompts_train[batch_start: batch_start + batch_size]
|
|
||||||
batch_labels = augmented_labels_label[batch_start: batch_start + batch_size]
|
|
||||||
|
|
||||||
batch_prompts, batch_labels = collate_fn(batch_prompts, batch_labels)
|
|
||||||
|
|
||||||
attention_mask = (batch_prompts != 0).half()
|
|
||||||
|
|
||||||
batch_prompts = batch_prompts.to(device)
|
|
||||||
batch_labels = batch_labels.to(device)
|
|
||||||
attention_mask = attention_mask.to(batch_prompts.device)
|
|
||||||
|
|
||||||
output = model(
|
|
||||||
batch_prompts.squeeze(),
|
|
||||||
attention_mask=attention_mask.squeeze(),
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = output.hidden_states
|
|
||||||
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
|
|
||||||
last_layer_hidden_state = hidden_states[layer_number]
|
|
||||||
|
|
||||||
last_token_rep = get_last_non_padded_token_rep(
|
|
||||||
last_layer_hidden_state, attention_mask.squeeze()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 这里 compute_ot_loss_cos 接收的是 soft label (OT 得到的 q)
|
|
||||||
ot_loss, similarities = compute_ot_loss_cos(
|
|
||||||
last_token_rep, centroids, batch_labels, batch_size, args
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = ot_loss
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# 自训练阶段,原型更新用的是 soft label 版的 EMA
|
|
||||||
centroids = update_centroids_ema(
|
|
||||||
centroids, last_token_rep, batch_labels.half(), args
|
|
||||||
)
|
|
||||||
all_labels.append(batch_labels.cpu())
|
|
||||||
total += batch_labels.size(0)
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
running_loss += loss.item() * batch_labels.size(0)
|
|
||||||
|
|
||||||
epoch_loss = running_loss / total
|
|
||||||
|
|
||||||
# ====== 用 test set 评估 AUROC ======
|
|
||||||
with torch.no_grad():
|
|
||||||
all_labels = torch.cat(all_labels).numpy()
|
|
||||||
test_labels_ = test_labels
|
|
||||||
|
|
||||||
if epoch % 1 == 0:
|
|
||||||
test_predictions, test_labels_combined = test_model(
|
|
||||||
model,
|
|
||||||
centroids,
|
|
||||||
test_prompts,
|
|
||||||
test_labels_,
|
|
||||||
device,
|
|
||||||
batch_size,
|
|
||||||
layer_number,
|
|
||||||
)
|
|
||||||
test_auroc = roc_auc_score(test_labels_combined, test_predictions)
|
|
||||||
|
|
||||||
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
|
||||||
losses.append(epoch_loss)
|
|
||||||
|
|
||||||
if test_auroc > best_test_auroc:
|
|
||||||
best_test_auroc = test_auroc
|
|
||||||
best_test_epoch = epoch + args.init_num_epochs
|
|
||||||
print(
|
|
||||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, "
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"Best test AUROC: {best_test_auroc:.4f}, at epoch: {best_test_epoch}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return best_test_auroc
|
return best_test_auroc
|
||||||
|
|
||||||
|
|
@ -345,14 +214,14 @@ def train_model(model, optimizer, device, prompts, labels, args):
|
||||||
def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number):
|
def test_model(model, centroids, test_prompts, test_labels, device, batch_size, layer_number):
|
||||||
"""
|
"""
|
||||||
在论文里相当于:
|
在论文里相当于:
|
||||||
- 对 test 集算当前 TSV steer 后的 r_v
|
- 对 test 集算当前 TV steer 后的 r_v
|
||||||
- 计算与两个原型 μ_c 的余弦相似度
|
- 计算与两个原型 μ_c 的余弦相似度
|
||||||
- 用 softmax 得到属于 truthful 的概率 p(c=truthful | r_v)
|
- 用 softmax 得到属于 truthful 的概率 p(c=truthful | r_v)
|
||||||
- 用这个概率做 AUROC 评估
|
- 用这个概率做 AUROC 评估
|
||||||
"""
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
val_predictions = [] # 保存 p_truthful
|
val_predictions = [] # 保存 p
|
||||||
val_labels_combined = [] # 对应的 gt(二分类标签)
|
val_labels = [] # 对应的(二分类标签)
|
||||||
|
|
||||||
num_val_samples = len(test_prompts)
|
num_val_samples = len(test_prompts)
|
||||||
|
|
||||||
|
|
@ -392,11 +261,11 @@ def test_model(model, centroids, test_prompts, test_labels, device, batch_size,
|
||||||
similarity_scores = similarity_scores[:, 1] # 假设 index 1 = truthful
|
similarity_scores = similarity_scores[:, 1] # 假设 index 1 = truthful
|
||||||
|
|
||||||
val_predictions.append(similarity_scores.cpu())
|
val_predictions.append(similarity_scores.cpu())
|
||||||
val_labels_combined.append(batch_labels.cpu())
|
val_labels.append(batch_labels.cpu())
|
||||||
|
|
||||||
val_predictions = torch.cat(val_predictions)
|
val_predictions = torch.cat(val_predictions)
|
||||||
val_labels_combined = torch.cat(val_labels_combined)
|
val_labels = torch.cat(val_labels)
|
||||||
return val_predictions, val_labels_combined
|
return val_predictions, val_labels
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -477,7 +346,7 @@ def main():
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '')
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token = '')
|
||||||
|
|
||||||
prompts = []
|
prompts_ = []
|
||||||
qa_pairs = []
|
qa_pairs = []
|
||||||
categories = []
|
categories = []
|
||||||
|
|
||||||
|
|
@ -491,9 +360,9 @@ def main():
|
||||||
return_tensors='pt'
|
return_tensors='pt'
|
||||||
).input_ids.cuda()
|
).input_ids.cuda()
|
||||||
|
|
||||||
prompts.append(prompt)
|
prompts_.append(prompt)
|
||||||
qa_pairs.append({'Question': question, 'Answer': adversary})
|
qa_pairs.append({'Question': question, 'Answer': adversary})
|
||||||
categories.append(1) # 1 = adverse
|
categories.append(0) # 0 = adverse
|
||||||
|
|
||||||
for i in tqdm(range(length)):
|
for i in tqdm(range(length)):
|
||||||
question = dataset[i]['query']
|
question = dataset[i]['query']
|
||||||
|
|
@ -503,12 +372,14 @@ def main():
|
||||||
return_tensors='pt'
|
return_tensors='pt'
|
||||||
).input_ids.cuda()
|
).input_ids.cuda()
|
||||||
|
|
||||||
prompts.append(prompt)
|
prompts_.append(prompt)
|
||||||
qa_pairs.append({'Question': question, 'Answer': clean})
|
qa_pairs.append({'Question': question, 'Answer': clean})
|
||||||
categories.append(0) # 0 = benign
|
categories.append(1) # 1 = benign
|
||||||
|
|
||||||
train_index, val_index, test_index=split_indices(len(prompts), args.train_ratio, args.val_ratio)
|
train_index, val_index, test_index=split_indices(len(prompts_), args.train_ratio, args.val_ratio)
|
||||||
|
|
||||||
|
labels = [categories[test_index], categories[train_index]]
|
||||||
|
prompts = [prompts_[test_index], prompts_[train_index]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -532,7 +403,7 @@ def main():
|
||||||
optimizer = torch.optim.AdamW(list(sv.parameters()), lr=args.lr)
|
optimizer = torch.optim.AdamW(list(sv.parameters()), lr=args.lr)
|
||||||
|
|
||||||
# ====== 6.5 调用 train_model,进入两阶段训练流程 ======
|
# ====== 6.5 调用 train_model,进入两阶段训练流程 ======
|
||||||
train_model(model, optimizer, device, train_data, train_labels, args=args)
|
train_model(model, optimizer, device, prompts, labels, args=args)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Skip training steer vectors.")
|
print("Skip training steer vectors.")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,183 @@
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(prompts, labels):
|
||||||
|
|
||||||
|
# Find the maximum sequence length in the batch
|
||||||
|
max_seq_len = max(prompt.size(1) for prompt in prompts)
|
||||||
|
|
||||||
|
# Initialize a tensor to hold the batched prompts
|
||||||
|
batch_size = len(prompts)
|
||||||
|
dtype = prompts[0].dtype
|
||||||
|
device = prompts[0].device # Assuming all prompts are on the same device
|
||||||
|
prompts_padded = torch.zeros(batch_size, 1, max_seq_len, dtype=dtype)
|
||||||
|
|
||||||
|
# Pad each prompt to the maximum sequence length
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
seq_len = prompt.size(1)
|
||||||
|
prompts_padded[i, :, :seq_len] = prompt
|
||||||
|
|
||||||
|
# Stack labels into a tensor
|
||||||
|
labels = torch.tensor(labels, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
return prompts_padded, labels
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_non_padded_token_rep(hidden_states, attention_mask):
|
||||||
|
"""
|
||||||
|
Get the last non-padded token's representation for each sequence in the batch.
|
||||||
|
"""
|
||||||
|
# Find the length of each sequence by summing the attention mask (1 for real tokens, 0 for padding)
|
||||||
|
lengths = attention_mask.squeeze().sum(dim=1).long()
|
||||||
|
|
||||||
|
# Index the last non-padded token for each sequence
|
||||||
|
batch_size, max_seq_len, hidden_size = hidden_states.size()
|
||||||
|
last_token_reps = torch.stack([hidden_states[i, lengths[i]-1, :] for i in range(batch_size)])
|
||||||
|
|
||||||
|
return last_token_reps
|
||||||
|
|
||||||
|
|
||||||
|
def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_selected_data, cls_dist, args):
|
||||||
|
|
||||||
|
all_embeddings = []
|
||||||
|
all_labels = []
|
||||||
|
num_samples = len(prompts)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(dtype=torch.float16):
|
||||||
|
for batch_start in tqdm(range(0, num_samples, batch_size)):
|
||||||
|
batch_prompts = prompts[batch_start: batch_start + batch_size]
|
||||||
|
batch_labels = labels[batch_start: batch_start + batch_size]
|
||||||
|
batch_prompts, batch_labels = collate_fn(batch_prompts,batch_labels)
|
||||||
|
attention_mask = (batch_prompts != 0).half()
|
||||||
|
batch_prompts = batch_prompts.cuda()
|
||||||
|
batch_labels = batch_labels.cuda()
|
||||||
|
attention_mask = attention_mask.to(batch_prompts.device)
|
||||||
|
all_labels.append(batch_labels.cpu().numpy())
|
||||||
|
|
||||||
|
output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True)
|
||||||
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
|
||||||
|
last_layer_hidden_state = hidden_states[-1]
|
||||||
|
|
||||||
|
last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze())
|
||||||
|
all_embeddings.append(last_token_rep)
|
||||||
|
|
||||||
|
all_embeddings = F.normalize(torch.concat(all_embeddings),p=2,dim=-1)
|
||||||
|
|
||||||
|
pseudo_label = sinkhorn(all_embeddings, centroids)
|
||||||
|
|
||||||
|
selected_indices = compute_entropy(all_embeddings, centroids, pseudo_label, num_selected_data, cls_dist, args)
|
||||||
|
|
||||||
|
selected_labels_soft = pseudo_label[selected_indices]
|
||||||
|
|
||||||
|
|
||||||
|
return selected_indices, selected_labels_soft
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ot_loss_cos(last_token_rep, centroids, pseudo_label, batch_size, args):
|
||||||
|
|
||||||
|
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
|
||||||
|
|
||||||
|
centroids = F.normalize(centroids, p=2, dim=-1)
|
||||||
|
|
||||||
|
similarities = torch.matmul(last_token_rep, centroids.T)
|
||||||
|
|
||||||
|
similarities = similarities / args.cos_temp
|
||||||
|
|
||||||
|
pt = F.softmax(similarities, dim=-1)
|
||||||
|
|
||||||
|
ot_loss = -torch.sum(pseudo_label * torch.log(pt + 1e-8)) / pseudo_label.shape[0]
|
||||||
|
|
||||||
|
return ot_loss, similarities
|
||||||
|
|
||||||
|
|
||||||
|
def compute_entropy(last_token_rep, centroids, pseudo_label, k, cls_dist, args):
|
||||||
|
|
||||||
|
|
||||||
|
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
|
||||||
|
|
||||||
|
centroids = F.normalize(centroids, p=2, dim=-1)
|
||||||
|
|
||||||
|
similarities = torch.matmul(last_token_rep, centroids.T)
|
||||||
|
|
||||||
|
similarities = similarities / args.cos_temp
|
||||||
|
|
||||||
|
pt = F.softmax(similarities, dim=-1)
|
||||||
|
|
||||||
|
ce = - (pseudo_label * torch.log(pt + 1e-8))
|
||||||
|
|
||||||
|
pseudo_label_hard = torch.argmax(pt,dim=1)
|
||||||
|
|
||||||
|
# * Added for preventing severe cases
|
||||||
|
# Class-wise data selection: Select pseudo-labeled unlabeled data in proportion to the class distribution of the exemplar set.
|
||||||
|
|
||||||
|
cls0_num = k*cls_dist[0]
|
||||||
|
cls1_num = k*cls_dist[1]
|
||||||
|
|
||||||
|
cls_0_indices = (pseudo_label_hard == 0).nonzero(as_tuple=True)[0]
|
||||||
|
cls_1_indices = (pseudo_label_hard == 1).nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
ce = torch.sum(ce, dim=1)
|
||||||
|
|
||||||
|
ce_class_0 = ce[cls_0_indices]
|
||||||
|
ce_class_1 = ce[cls_1_indices]
|
||||||
|
|
||||||
|
if len(ce_class_0) < cls0_num or len(ce_class_1) < cls1_num: # Fallback to top-k across all classes
|
||||||
|
|
||||||
|
_, top_k_indices = torch.topk(ce, k, largest=False, sorted=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
top_0_indices = cls_0_indices[torch.topk(ce_class_0, int(cls0_num), largest=False, sorted=True).indices]
|
||||||
|
top_1_indices = cls_1_indices[torch.topk(ce_class_1, int(cls1_num), largest=False, sorted=True).indices]
|
||||||
|
top_k_indices = torch.cat((top_0_indices, top_1_indices))
|
||||||
|
|
||||||
|
return top_k_indices
|
||||||
|
|
||||||
|
|
||||||
|
def update_centroids_ema(centroids, last_token_rep, pseudo_label, args):
|
||||||
|
|
||||||
|
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
|
||||||
|
|
||||||
|
centroids= F.normalize(centroids, p=2, dim=1)
|
||||||
|
|
||||||
|
weighted_sum = torch.matmul(pseudo_label.T, last_token_rep_norm)
|
||||||
|
|
||||||
|
# Normalize the weighted sums to get the new centroids
|
||||||
|
pseudo_label_sum = pseudo_label.sum(dim=0).unsqueeze(1) + 1e-8
|
||||||
|
new_centroids_batch = weighted_sum / pseudo_label_sum
|
||||||
|
|
||||||
|
# EMA update for centroids
|
||||||
|
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=1)
|
||||||
|
|
||||||
|
return updated_centroids
|
||||||
|
|
||||||
|
def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args):
|
||||||
|
|
||||||
|
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
|
||||||
|
|
||||||
|
centroids = F.normalize(centroids, p=2, dim=1)
|
||||||
|
|
||||||
|
max_indices = torch.argmax(pseudo_label, dim=1)
|
||||||
|
|
||||||
|
discrete_labels = torch.zeros_like(pseudo_label)
|
||||||
|
|
||||||
|
discrete_labels[torch.arange(pseudo_label.size(0)), max_indices] = 1
|
||||||
|
|
||||||
|
weighted_sum = torch.matmul(discrete_labels.T.float(), last_token_rep_norm)
|
||||||
|
|
||||||
|
pseudo_label_sum = discrete_labels.sum(dim=0).unsqueeze(1) + 1e-8
|
||||||
|
|
||||||
|
new_centroids_batch = weighted_sum / pseudo_label_sum
|
||||||
|
|
||||||
|
# EMA update for centroids
|
||||||
|
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
|
||||||
|
|
||||||
|
return updated_centroids
|
||||||
Loading…
Reference in New Issue