300 lines
10 KiB
Python
300 lines
10 KiB
Python
from __future__ import absolute_import
|
||
from __future__ import division
|
||
from __future__ import unicode_literals
|
||
from __future__ import print_function
|
||
|
||
from torch.utils.data import Dataset
|
||
import torch
|
||
import random
|
||
from PIL import Image
|
||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||
import re
|
||
import os
|
||
import numpy as np
|
||
from torchvision import transforms
|
||
|
||
|
||
|
||
class BaseDataset(Dataset):
|
||
|
||
def __init__(self,
|
||
|
||
captions: dict,
|
||
indexs: dict,
|
||
labels: dict,
|
||
is_train=True,
|
||
tokenizer=Tokenizer(),
|
||
maxWords=32,
|
||
imageResolution=224,
|
||
npy=False):
|
||
|
||
self.captions = captions
|
||
self.indexs = indexs
|
||
self.labels = labels
|
||
self.npy = npy
|
||
|
||
self.maxWords = maxWords
|
||
self.tokenizer = tokenizer
|
||
|
||
self.transform = Compose([
|
||
Resize(imageResolution, interpolation=Image.BICUBIC),
|
||
CenterCrop(imageResolution),
|
||
ToTensor(),
|
||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||
]) if is_train else Compose([
|
||
Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC),
|
||
ToTensor(),
|
||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||
])
|
||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
|
||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
|
||
|
||
self.__length = len(self.indexs)
|
||
|
||
def __len__(self):
|
||
return self.__length
|
||
|
||
def _load_image(self, index: int) -> torch.Tensor:
|
||
if not self.npy:
|
||
image_path = self.indexs[index].strip()
|
||
# print(image_path)
|
||
image = Image.open(image_path).convert("RGB")
|
||
else:
|
||
image = Image.fromarray(self.indexs[index]).convert("RGB")
|
||
image = self.transform(image)
|
||
|
||
return image
|
||
|
||
def pre_caption(self, caption):
|
||
caption = re.sub(
|
||
r"([,.'!?\"()*#:;~])",
|
||
'',
|
||
caption.lower(),
|
||
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
||
caption = re.sub(
|
||
r"\s{2,}",
|
||
' ',
|
||
caption,
|
||
)
|
||
caption = caption.rstrip('\n')
|
||
caption = caption.strip(' ')
|
||
caption_words = caption.split(' ')
|
||
if len(caption_words)>self.maxWords:
|
||
caption = ' '.join(caption_words[:self.maxWords])
|
||
return caption
|
||
|
||
def _load_text(self, index: int):
|
||
# print(self.captions[index])
|
||
# captions = self.pre_caption(self.captions[index])
|
||
# captions =whitespace_clean(self.captions[index]).lower()
|
||
captions = self.captions[index]
|
||
use_cap = captions[random.randint(0, len(captions) - 1)]
|
||
|
||
words = self.tokenizer.tokenize(use_cap)
|
||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
|
||
total_length_with_CLS = self.maxWords - 1
|
||
if len(words) > total_length_with_CLS:
|
||
words = words[:total_length_with_CLS]
|
||
|
||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
|
||
caption = self.tokenizer.convert_tokens_to_ids(words)
|
||
|
||
while len(caption) < self.maxWords:
|
||
caption.append(0)
|
||
caption = torch.tensor(caption)
|
||
|
||
return captions
|
||
|
||
def _load_label(self, index: int) -> torch.Tensor:
|
||
label = self.labels[index]
|
||
label = torch.from_numpy(label)
|
||
|
||
return label
|
||
|
||
def get_all_label(self):
|
||
labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.int64)
|
||
for i, item in enumerate(self.labels):
|
||
|
||
labels[i] = torch.from_numpy(item)
|
||
return labels
|
||
|
||
def __getitem__(self, index):
|
||
image = self._load_image(index)
|
||
caption = self._load_text(index)
|
||
label = self._load_label(index)
|
||
|
||
return image, caption, label, index
|
||
|
||
# def default_loader(path):
|
||
# return Image.open(path).convert('RGB')
|
||
|
||
# class IaprDataset(Dataset):
|
||
# def __init__(self, args,txt,transform=None, loader=default_loader):
|
||
# self.transform = transform
|
||
# self.loader = loader
|
||
|
||
# name_label = []
|
||
# for line in open(txt):
|
||
# line = line.strip('\n').split()
|
||
# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的,前2912个是单词在词袋中的二进制编码
|
||
# tem = re.split('[/.]', line[0])
|
||
# file_name, sample_name = tem[0], tem[1]
|
||
# name_label.append([file_name, sample_name, label])
|
||
# # # print('label = ', label)
|
||
# # print('file_name = %s, sample_name = %s' %(file_name, sample_name))
|
||
# # label_list = np.where(label=='1')
|
||
# # print('label_list = ', label_list)
|
||
# self.name_label = name_label
|
||
# self.image_dir=args.image_dir
|
||
# self.text_dir = args.text_dir
|
||
|
||
|
||
# def __getitem__(self, index):
|
||
# words = self.name_label[index] # words = [file_name, sample_name, label]
|
||
# # print('words = ', words[0:2])
|
||
|
||
# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg')
|
||
# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt')
|
||
# # img
|
||
# img = self.loader(img_path)
|
||
# if self.transform is not None:
|
||
# img = self.transform(img)
|
||
# # text
|
||
# text = 'None'
|
||
# for line in open(text_path):
|
||
# text = '[CLS]' + line + '[SEP]'
|
||
|
||
# # label
|
||
# label = torch.LongTensor(words[2])
|
||
# # image, caption, label, index
|
||
# return img, text, label, index
|
||
|
||
|
||
# def __len__(self):
|
||
# return len(self.name_label)
|
||
|
||
default_transform = transforms.Compose([
|
||
transforms.Resize(224),
|
||
transforms.CenterCrop(224),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
])
|
||
|
||
class MScocoDataset(Dataset):
|
||
def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None):
|
||
self.data_path = data_path
|
||
|
||
if transform is None:
|
||
self.transform = default_transform
|
||
else:
|
||
self.transform = transform
|
||
|
||
img_filepath = os.path.join(data_path, img_filename)
|
||
with open(img_filepath, 'r') as f:
|
||
self.imgs = [x.strip() for x in f]
|
||
|
||
text_filepath = os.path.join(data_path, text_filename)
|
||
with open(text_filepath, 'r') as f:
|
||
self.texts = f.readlines()
|
||
self.texts = [i.replace('\n', '') for i in self.texts]
|
||
|
||
label_filepath = os.path.join(data_path, label_filename)
|
||
self.labels = np.genfromtxt(label_filepath, dtype=np.int32)
|
||
|
||
def __getitem__(self, index):
|
||
img = Image.open(os.path.join(self.data_path, self.imgs[index]))
|
||
img = img.convert('RGB')
|
||
|
||
if self.transform is not None:
|
||
img = self.transform(img)
|
||
|
||
label = torch.from_numpy(self.labels[index]).float()
|
||
text = self.texts[index]
|
||
# image, caption, label, index
|
||
|
||
return img, text, label, index
|
||
|
||
def __len__(self):
|
||
return len(self.imgs)
|
||
|
||
class MirflickrDataset(Dataset):
|
||
def __init__(self, root_dir, transform=None):
|
||
self.root_dir = root_dir
|
||
file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080")
|
||
file_list = os.listdir(file_path)
|
||
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
|
||
self.class_index = {}
|
||
for i, item in enumerate(file_list):
|
||
self.class_index.update({item: i})
|
||
self.label_dict = {}
|
||
for path_id in file_list:
|
||
path = os.path.join(file_path, path_id)
|
||
with open(path, "r") as f:
|
||
for item in f:
|
||
item = item.strip()
|
||
if item not in self.label_dict:
|
||
label = np.zeros(len(file_list))
|
||
label[self.class_index[path_id]] = 1
|
||
self.label_dict.update({item: label})
|
||
else:
|
||
# print()
|
||
self.label_dict[item][self.class_index[path_id]] = 1
|
||
self.captions_dict = {}
|
||
captions_path = os.path.join(root_dir, "mirflickr/meta/tags")
|
||
captions_list = os.listdir(captions_path)
|
||
for item in captions_list:
|
||
id_ = item.split(".")[0].replace("tags", "")
|
||
caption = ""
|
||
with open(os.path.join(captions_path, item), "r") as f:
|
||
for word in f.readlines():
|
||
caption += word.strip() + " "
|
||
caption = caption.strip()
|
||
self.captions_dict.update({id_: caption})
|
||
if transform is None:
|
||
self.transform = default_transform
|
||
else:
|
||
self.transform = transform
|
||
|
||
def __getitem__(self, index):
|
||
label=self.label_dict[index]
|
||
PATH = os.path.join(self.root_dir, "mirflickr")
|
||
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
|
||
img = img.convert('RGB')
|
||
if self.transform is not None:
|
||
img = self.transform(img)
|
||
|
||
# label = torch.from_numpy(self.labels[index]).float()
|
||
text = self.captions_dict[index]
|
||
# image, caption, label, index
|
||
|
||
return img, text, label, index
|
||
|
||
def __len__(self):
|
||
return len(list(self.label_dict.keys()))
|
||
|
||
|
||
class NusWideDataset(Dataset):
|
||
def __init__(self, root_dir, transform=None):
|
||
self.root_dir = root_dir
|
||
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
|
||
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
|
||
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
|
||
|
||
|
||
def __getitem__(self, index):
|
||
label=self.label_dict[index]
|
||
PATH = os.path.join(self.root_dir, "mirflickr")
|
||
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
|
||
img = img.convert('RGB')
|
||
if self.transform is not None:
|
||
img = self.transform(img)
|
||
|
||
# label = torch.from_numpy(self.labels[index]).float()
|
||
text = self.captions_dict[index]
|
||
# image, caption, label, index
|
||
|
||
return img, text, label, index
|
||
|
||
def __len__(self):
|
||
return len(list(self.label_dict.keys())) |