get clean embedding from bert
This commit is contained in:
parent
73c901a18c
commit
9170cbd6e9
|
|
@ -262,7 +262,7 @@ class Trainer(TrainBase):
|
|||
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
|
||||
|
||||
#clean state
|
||||
clean_embeds=self.model.encode_text(text_tokens)
|
||||
clean_embeds=self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask)
|
||||
final_adverse = []
|
||||
for i, text in enumerate(texts):
|
||||
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
|
||||
|
|
|
|||
Loading…
Reference in New Issue