advclip/dataset/base.py

300 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from torch.utils.data import Dataset
import torch
import random
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
import re
import os
import numpy as np
from torchvision import transforms
class BaseDataset(Dataset):
def __init__(self,
captions: dict,
indexs: dict,
labels: dict,
is_train=True,
tokenizer=Tokenizer(),
maxWords=32,
imageResolution=224,
npy=False):
self.captions = captions
self.indexs = indexs
self.labels = labels
self.npy = npy
self.maxWords = maxWords
self.tokenizer = tokenizer
self.transform = Compose([
Resize(imageResolution, interpolation=Image.BICUBIC),
CenterCrop(imageResolution),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]) if is_train else Compose([
Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
self.__length = len(self.indexs)
def __len__(self):
return self.__length
def _load_image(self, index: int) -> torch.Tensor:
if not self.npy:
image_path = self.indexs[index].strip()
# print(image_path)
image = Image.open(image_path).convert("RGB")
else:
image = Image.fromarray(self.indexs[index]).convert("RGB")
image = self.transform(image)
return image
def pre_caption(self, caption):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
caption_words = caption.split(' ')
if len(caption_words)>self.maxWords:
caption = ' '.join(caption_words[:self.maxWords])
return caption
def _load_text(self, index: int):
# print(self.captions[index])
# captions = self.pre_caption(self.captions[index])
# captions =whitespace_clean(self.captions[index]).lower()
captions = self.captions[index]
use_cap = captions[random.randint(0, len(captions) - 1)]
words = self.tokenizer.tokenize(use_cap)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.maxWords - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
caption = self.tokenizer.convert_tokens_to_ids(words)
while len(caption) < self.maxWords:
caption.append(0)
caption = torch.tensor(caption)
return captions
def _load_label(self, index: int) -> torch.Tensor:
label = self.labels[index]
label = torch.from_numpy(label)
return label
def get_all_label(self):
labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.int64)
for i, item in enumerate(self.labels):
labels[i] = torch.from_numpy(item)
return labels
def __getitem__(self, index):
image = self._load_image(index)
caption = self._load_text(index)
label = self._load_label(index)
return image, caption, label, index
# def default_loader(path):
# return Image.open(path).convert('RGB')
# class IaprDataset(Dataset):
# def __init__(self, args,txt,transform=None, loader=default_loader):
# self.transform = transform
# self.loader = loader
# name_label = []
# for line in open(txt):
# line = line.strip('\n').split()
# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的前2912个是单词在词袋中的二进制编码
# tem = re.split('[/.]', line[0])
# file_name, sample_name = tem[0], tem[1]
# name_label.append([file_name, sample_name, label])
# # # print('label = ', label)
# # print('file_name = %s, sample_name = %s' %(file_name, sample_name))
# # label_list = np.where(label=='1')
# # print('label_list = ', label_list)
# self.name_label = name_label
# self.image_dir=args.image_dir
# self.text_dir = args.text_dir
# def __getitem__(self, index):
# words = self.name_label[index] # words = [file_name, sample_name, label]
# # print('words = ', words[0:2])
# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg')
# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt')
# # img
# img = self.loader(img_path)
# if self.transform is not None:
# img = self.transform(img)
# # text
# text = 'None'
# for line in open(text_path):
# text = '[CLS]' + line + '[SEP]'
# # label
# label = torch.LongTensor(words[2])
# # image, caption, label, index
# return img, text, label, index
# def __len__(self):
# return len(self.name_label)
default_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class MScocoDataset(Dataset):
def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None):
self.data_path = data_path
if transform is None:
self.transform = default_transform
else:
self.transform = transform
img_filepath = os.path.join(data_path, img_filename)
with open(img_filepath, 'r') as f:
self.imgs = [x.strip() for x in f]
text_filepath = os.path.join(data_path, text_filename)
with open(text_filepath, 'r') as f:
self.texts = f.readlines()
self.texts = [i.replace('\n', '') for i in self.texts]
label_filepath = os.path.join(data_path, label_filename)
self.labels = np.genfromtxt(label_filepath, dtype=np.int32)
def __getitem__(self, index):
img = Image.open(os.path.join(self.data_path, self.imgs[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.from_numpy(self.labels[index]).float()
text = self.texts[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(self.imgs)
class MirflickrDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080")
file_list = os.listdir(file_path)
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
self.class_index = {}
for i, item in enumerate(file_list):
self.class_index.update({item: i})
self.label_dict = {}
for path_id in file_list:
path = os.path.join(file_path, path_id)
with open(path, "r") as f:
for item in f:
item = item.strip()
if item not in self.label_dict:
label = np.zeros(len(file_list))
label[self.class_index[path_id]] = 1
self.label_dict.update({item: label})
else:
# print()
self.label_dict[item][self.class_index[path_id]] = 1
self.captions_dict = {}
captions_path = os.path.join(root_dir, "mirflickr/meta/tags")
captions_list = os.listdir(captions_path)
for item in captions_list:
id_ = item.split(".")[0].replace("tags", "")
caption = ""
with open(os.path.join(captions_path, item), "r") as f:
for word in f.readlines():
caption += word.strip() + " "
caption = caption.strip()
self.captions_dict.update({id_: caption})
if transform is None:
self.transform = default_transform
else:
self.transform = transform
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))
class NusWideDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))