haloscope/train_utils.py

201 lines
7.0 KiB
Python

import torch
from tqdm import tqdm
from torch.amp import autocast
import torch.nn.functional as F
def collate_fn(prompts, labels):
# Find the maximum sequence length in the batch
max_seq_len = max(prompt.size(1) for prompt in prompts)
# Initialize a tensor to hold the batched prompts
batch_size = len(prompts)
dtype = prompts[0].dtype
device = prompts[0].device # Assuming all prompts are on the same device
prompts_padded = torch.zeros(batch_size, 1, max_seq_len, dtype=dtype)
# Pad each prompt to the maximum sequence length
for i, prompt in enumerate(prompts):
seq_len = prompt.size(1)
prompts_padded[i, :, :seq_len] = prompt
# Stack labels into a tensor
labels = torch.tensor(labels, dtype=torch.long, device=device)
return prompts_padded, labels
def get_last_non_padded_token_rep(hidden_states, attention_mask):
"""
Get the last non-padded token's representation for each sequence in the batch.
"""
# Find the length of each sequence by summing the attention mask (1 for real tokens, 0 for padding)
lengths = attention_mask.squeeze().sum(dim=1).long()
# Index the last non-padded token for each sequence
batch_size, max_seq_len, hidden_size = hidden_states.size()
last_token_reps = torch.stack([hidden_states[i, lengths[i]-1, :] for i in range(batch_size)])
return last_token_reps
def get_ex_data(model, prompts, labels, batch_size, centroids, sinkhorn, num_selected_data, cls_dist, args):
all_embeddings = []
all_labels = []
num_samples = len(prompts)
with torch.no_grad():
with autocast('cuda', dtype=torch.float16):
for batch_start in tqdm(range(0, num_samples, batch_size)):
batch_prompts = prompts[batch_start: batch_start + batch_size]
batch_labels = labels[batch_start: batch_start + batch_size]
batch_prompts, batch_labels = collate_fn(batch_prompts,batch_labels)
attention_mask = (batch_prompts != 0).half()
batch_prompts = batch_prompts.cuda()
batch_labels = batch_labels.cuda()
attention_mask = attention_mask.to(batch_prompts.device)
all_labels.append(batch_labels.cpu().numpy())
output = model(batch_prompts.squeeze(), attention_mask=attention_mask.squeeze(), output_hidden_states=True)
hidden_states = output.hidden_states
hidden_states = torch.stack(hidden_states, dim=0).squeeze()
last_layer_hidden_state = hidden_states[-1]
last_token_rep = get_last_non_padded_token_rep(last_layer_hidden_state, attention_mask.squeeze())
all_embeddings.append(last_token_rep)
all_embeddings = F.normalize(torch.concat(all_embeddings),p=2,dim=-1)
pseudo_label = sinkhorn(all_embeddings, centroids)
selected_indices = compute_entropy(all_embeddings, centroids, pseudo_label, num_selected_data, cls_dist, args)
selected_labels_soft = pseudo_label[selected_indices]
return selected_indices, selected_labels_soft
def compute_ot_loss_cos(last_token_rep, centroids, pseudo_label, batch_size, args):
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
similarities = torch.matmul(last_token_rep, centroids.T)
similarities = similarities / args.cos_temp
pt = F.softmax(similarities, dim=-1)
ot_loss = -torch.sum(pseudo_label * torch.log(pt + 1e-8)) / pseudo_label.shape[0]
return ot_loss, similarities
def compute_entropy(last_token_rep, centroids, pseudo_label, k, cls_dist, args):
last_token_rep = F.normalize(last_token_rep, p=2, dim=-1)
centroids = F.normalize(centroids, p=2, dim=-1)
similarities = torch.matmul(last_token_rep, centroids.T)
similarities = similarities / args.cos_temp
pt = F.softmax(similarities, dim=-1)
ce = - (pseudo_label * torch.log(pt + 1e-8))
pseudo_label_hard = torch.argmax(pt,dim=1)
# * Added for preventing severe cases
# Class-wise data selection: Select pseudo-labeled unlabeled data in proportion to the class distribution of the exemplar set.
cls0_num = k*cls_dist[0]
cls1_num = k*cls_dist[1]
cls_0_indices = (pseudo_label_hard == 0).nonzero(as_tuple=True)[0]
cls_1_indices = (pseudo_label_hard == 1).nonzero(as_tuple=True)[0]
ce = torch.sum(ce, dim=1)
ce_class_0 = ce[cls_0_indices]
ce_class_1 = ce[cls_1_indices]
if len(ce_class_0) < cls0_num or len(ce_class_1) < cls1_num: # Fallback to top-k across all classes
_, top_k_indices = torch.topk(ce, k, largest=False, sorted=True)
else:
top_0_indices = cls_0_indices[torch.topk(ce_class_0, int(cls0_num), largest=False, sorted=True).indices]
top_1_indices = cls_1_indices[torch.topk(ce_class_1, int(cls1_num), largest=False, sorted=True).indices]
top_k_indices = torch.cat((top_0_indices, top_1_indices))
return top_k_indices
def update_centroids_ema(centroids, last_token_rep, pseudo_label, args):
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
centroids= F.normalize(centroids, p=2, dim=1)
weighted_sum = torch.matmul(pseudo_label.T, last_token_rep_norm)
# Normalize the weighted sums to get the new centroids
pseudo_label_sum = pseudo_label.sum(dim=0).unsqueeze(1) + 1e-8
new_centroids_batch = weighted_sum / pseudo_label_sum
# EMA update for centroids
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=1)
return updated_centroids
def update_centroids_ema_hard(centroids, last_token_rep, pseudo_label, args):
last_token_rep_norm = F.normalize(last_token_rep, p=2, dim=1)
centroids = F.normalize(centroids, p=2, dim=1)
max_indices = torch.argmax(pseudo_label, dim=1)
discrete_labels = torch.zeros_like(pseudo_label)
discrete_labels[torch.arange(pseudo_label.size(0)), max_indices] = 1
weighted_sum = torch.matmul(discrete_labels.T.float(), last_token_rep_norm)
pseudo_label_sum = discrete_labels.sum(dim=0).unsqueeze(1) + 1e-8
new_centroids_batch = weighted_sum / pseudo_label_sum
# EMA update for centroids
updated_centroids = F.normalize(args.ema_decay * centroids + (1 - args.ema_decay) * new_centroids_batch, p=2, dim=-1)
return updated_centroids
def split_indices(n, train_ratio=0.8, val_ratio=0.1, seed=42):
rng = np.random.default_rng(seed)
indices = np.arange(n)
rng.shuffle(indices)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_index = indices[:n_train]
val_index = indices[n_train:n_train + n_val]
test_index = indices[n_train + n_val:]
return train_index, val_index, test_index