use clip decoder
This commit is contained in:
parent
312671bdc7
commit
cb449df1c5
|
|
@ -8,7 +8,8 @@ import torch
|
|||
import random
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from model.bert_tokenizer import BertTokenizer
|
||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
|
|
@ -18,8 +19,8 @@ class BaseDataset(Dataset):
|
|||
captions: dict,
|
||||
indexs: dict,
|
||||
labels: dict,
|
||||
tokenizer: any,
|
||||
is_train=True,
|
||||
tokenizer=Tokenizer(),
|
||||
maxWords=32,
|
||||
imageResolution=224,
|
||||
npy=False):
|
||||
|
|
@ -65,14 +66,14 @@ class BaseDataset(Dataset):
|
|||
captions = self.captions[index]
|
||||
use_cap = captions[random.randint(0, len(captions) - 1)]
|
||||
|
||||
words = self.tokenizer._tokenize(use_cap)
|
||||
words = self.tokenizer.tokenize(use_cap)
|
||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
|
||||
total_length_with_CLS = self.maxWords - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
|
||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
|
||||
caption = self.tokenizer._convert_tokens_to_ids(words)
|
||||
caption = self.tokenizer.convert_tokens_to_ids(words)
|
||||
|
||||
while len(caption) < self.maxWords:
|
||||
caption.append(0)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
|
|||
def dataloader(captionFile: str,
|
||||
indexFile: str,
|
||||
labelFile: str,
|
||||
tokenizer: any,
|
||||
maxWords=77,
|
||||
imageResolution=224,
|
||||
query_num=5000,
|
||||
|
|
@ -60,9 +59,9 @@ def dataloader(captionFile: str,
|
|||
|
||||
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
|
||||
|
||||
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, npy=npy)
|
||||
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
|
||||
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
|
||||
return train_data, query_data, retrieval_data
|
||||
|
||||
|
|
|
|||
|
|
@ -233,9 +233,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return [self._convert_token_to_id(token) for token in tokens]
|
||||
# def _convert_tokens_to_ids(self, tokens):
|
||||
# """ Converts a token (str) in an id using the vocab. """
|
||||
# return [self._convert_token_to_id(token) for token in tokens]
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari
|
|||
from utils.calc_utils import cal_map, cal_pr
|
||||
from dataset.dataloader import dataloader
|
||||
import clip
|
||||
from model.bert_tokenizer import BertTokenizer
|
||||
|
||||
# from transformers import BertModel
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
|
@ -39,8 +37,6 @@ class Trainer(TrainBase):
|
|||
self.text_mean=text_mean
|
||||
self.text_var=text_var
|
||||
self.device=rank
|
||||
self.tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
|
||||
|
||||
# self.run()
|
||||
|
||||
def _init_model(self):
|
||||
|
|
@ -58,8 +54,7 @@ class Trainer(TrainBase):
|
|||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
||||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
||||
indexFile=self.args.index_file,
|
||||
labelFile=self.args.label_file,
|
||||
tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder),
|
||||
labelFile=self.args.label_file,
|
||||
maxWords=self.args.max_words,
|
||||
imageResolution=self.args.resolution,
|
||||
query_num=self.args.query_num,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from dataset.dataloader import dataloader
|
|||
import clip
|
||||
import copy
|
||||
from model.bert_tokenizer import BertTokenizer
|
||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||
from transformers import BertForMaskedLM
|
||||
# from transformers import BertModel
|
||||
|
||||
|
|
@ -123,6 +124,7 @@ class Trainer(TrainBase):
|
|||
self.image_mean=image_mean
|
||||
self.image_var=image_var
|
||||
self.device=rank
|
||||
self.clip_tokenizer=Tokenizer()
|
||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
|
||||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
||||
# self.run()
|
||||
|
|
|
|||
Loading…
Reference in New Issue