use clip decoder

This commit is contained in:
leewlving 2024-06-18 21:28:14 +08:00
parent 312671bdc7
commit cb449df1c5
5 changed files with 14 additions and 17 deletions

View File

@ -8,7 +8,8 @@ import torch
import random
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.bert_tokenizer import BertTokenizer
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
class BaseDataset(Dataset):
@ -18,8 +19,8 @@ class BaseDataset(Dataset):
captions: dict,
indexs: dict,
labels: dict,
tokenizer: any,
is_train=True,
tokenizer=Tokenizer(),
maxWords=32,
imageResolution=224,
npy=False):
@ -65,14 +66,14 @@ class BaseDataset(Dataset):
captions = self.captions[index]
use_cap = captions[random.randint(0, len(captions) - 1)]
words = self.tokenizer._tokenize(use_cap)
words = self.tokenizer.tokenize(use_cap)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.maxWords - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
caption = self.tokenizer._convert_tokens_to_ids(words)
caption = self.tokenizer.convert_tokens_to_ids(words)
while len(caption) < self.maxWords:
caption.append(0)

View File

@ -32,7 +32,6 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
def dataloader(captionFile: str,
indexFile: str,
labelFile: str,
tokenizer: any,
maxWords=77,
imageResolution=224,
query_num=5000,
@ -60,9 +59,9 @@ def dataloader(captionFile: str,
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, npy=npy)
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], tokenizer=tokenizer, maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
return train_data, query_data, retrieval_data

View File

@ -233,9 +233,9 @@ class BertTokenizer(PreTrainedTokenizer):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_tokens_to_ids(self, tokens):
""" Converts a token (str) in an id using the vocab. """
return [self._convert_token_to_id(token) for token in tokens]
# def _convert_tokens_to_ids(self, tokens):
# """ Converts a token (str) in an id using the vocab. """
# return [self._convert_token_to_id(token) for token in tokens]
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""

View File

@ -15,8 +15,6 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
import clip
from model.bert_tokenizer import BertTokenizer
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -39,8 +37,6 @@ class Trainer(TrainBase):
self.text_mean=text_mean
self.text_var=text_var
self.device=rank
self.tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
# self.run()
def _init_model(self):
@ -59,7 +55,6 @@ class Trainer(TrainBase):
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder),
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,

View File

@ -17,6 +17,7 @@ from dataset.dataloader import dataloader
import clip
import copy
from model.bert_tokenizer import BertTokenizer
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import BertForMaskedLM
# from transformers import BertModel
@ -123,6 +124,7 @@ class Trainer(TrainBase):
self.image_mean=image_mean
self.image_var=image_var
self.device=rank
self.clip_tokenizer=Tokenizer()
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
# self.run()