change text tokenizer to bert

This commit is contained in:
leewlving 2024-06-18 20:27:29 +08:00
parent d35f23c5ed
commit 312671bdc7
3 changed files with 16 additions and 12 deletions

View File

@ -63,21 +63,20 @@ class BaseDataset(Dataset):
def _load_text(self, index: int):
captions = self.captions[index]
# print(len(captions))
caption=self.tokenizer(captions, padding='max_length', truncation=True, max_length=self.maxWords, return_tensors='pt')
use_cap = captions[random.randint(0, len(captions) - 1)]
# words = self.tokenizer.tokenize(use_cap)
# words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
# total_length_with_CLS = self.maxWords - 1
# if len(words) > total_length_with_CLS:
# words = words[:total_length_with_CLS]
words = self.tokenizer._tokenize(use_cap)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.maxWords - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
# words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
# caption = self.tokenizer.convert_tokens_to_ids(words)
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
caption = self.tokenizer._convert_tokens_to_ids(words)
# while len(caption) < self.maxWords:
# caption.append(0)
# caption = torch.tensor(caption)
while len(caption) < self.maxWords:
caption.append(0)
caption = torch.tensor(caption)
return caption

View File

@ -232,6 +232,10 @@ class BertTokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_tokens_to_ids(self, tokens):
""" Converts a token (str) in an id using the vocab. """
return [self._convert_token_to_id(token) for token in tokens]
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""

View File

@ -40,6 +40,7 @@ class Trainer(TrainBase):
self.text_var=text_var
self.device=rank
self.tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
# self.run()
def _init_model(self):