change text tokenizer to bert
This commit is contained in:
parent
d35f23c5ed
commit
312671bdc7
|
|
@ -63,21 +63,20 @@ class BaseDataset(Dataset):
|
|||
|
||||
def _load_text(self, index: int):
|
||||
captions = self.captions[index]
|
||||
# print(len(captions))
|
||||
caption=self.tokenizer(captions, padding='max_length', truncation=True, max_length=self.maxWords, return_tensors='pt')
|
||||
use_cap = captions[random.randint(0, len(captions) - 1)]
|
||||
|
||||
# words = self.tokenizer.tokenize(use_cap)
|
||||
# words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
|
||||
# total_length_with_CLS = self.maxWords - 1
|
||||
# if len(words) > total_length_with_CLS:
|
||||
# words = words[:total_length_with_CLS]
|
||||
words = self.tokenizer._tokenize(use_cap)
|
||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
|
||||
total_length_with_CLS = self.maxWords - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
|
||||
# words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
|
||||
# caption = self.tokenizer.convert_tokens_to_ids(words)
|
||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
|
||||
caption = self.tokenizer._convert_tokens_to_ids(words)
|
||||
|
||||
# while len(caption) < self.maxWords:
|
||||
# caption.append(0)
|
||||
# caption = torch.tensor(caption)
|
||||
while len(caption) < self.maxWords:
|
||||
caption.append(0)
|
||||
caption = torch.tensor(caption)
|
||||
|
||||
return caption
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,10 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return [self._convert_token_to_id(token) for token in tokens]
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class Trainer(TrainBase):
|
|||
self.text_var=text_var
|
||||
self.device=rank
|
||||
self.tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
|
||||
|
||||
# self.run()
|
||||
|
||||
def _init_model(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue