get clean embedding from bert

This commit is contained in:
leewlving 2024-06-18 22:55:03 +08:00
parent 73c901a18c
commit 9170cbd6e9
1 changed files with 1 additions and 1 deletions

View File

@ -262,7 +262,7 @@ class Trainer(TrainBase):
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
#clean state
clean_embeds=self.model.encode_text(text_tokens)
clean_embeds=self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask)
final_adverse = []
for i, text in enumerate(texts):
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)