advclip/dataset/dataloader.py

129 lines
5.8 KiB
Python

from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset
import os
import numpy as np
import scipy.io as scio
import torch
import torchvision.transforms as transforms
import json
from torch.utils.data import Dataset
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
np.random.seed(seed=seed)
random_index = np.random.permutation(range(len(indexs)))
query_index = random_index[: query_num]
train_index = random_index[query_num: query_num + train_num]
retrieval_index = random_index[query_num:]
query_indexs = indexs[query_index]
query_captions = captions[query_index]
query_labels = labels[query_index]
train_indexs = indexs[train_index]
train_captions = captions[train_index]
train_labels = labels[train_index]
retrieval_indexs = indexs[retrieval_index]
retrieval_captions = captions[retrieval_index]
retrieval_labels = labels[retrieval_index]
split_indexs = (query_indexs, train_indexs, retrieval_indexs)
split_captions = (query_captions, train_captions, retrieval_captions)
split_labels = (query_labels, train_labels, retrieval_labels)
return split_indexs, split_captions, split_labels
def dataloader(captionFile: str,
indexFile: str,
labelFile: str,
maxWords=77,
imageResolution=224,
query_num=5000,
train_num=1000,
seed=None,
npy=False):
if captionFile.endswith("mat"):
captions = scio.loadmat(captionFile)["caption"]
captions = captions[0] if captions.shape[0] == 1 else captions
elif captionFile.endswith("txt"):
with open(captionFile, "r") as f:
captions = f.readlines()
captions = np.asarray([[item.strip()] for item in captions])
elif captionFile.endswith("json"):
with open(captionFile, "r") as f:
data = json.load(f)
captions=data["caption"]
# captions = captions[0] if captions.shape[0] == 1 else captions
else:
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.")
if not npy:
indexs = scio.loadmat(indexFile)["index"]
else:
indexs = np.load(indexFile, allow_pickle=True)
labels = scio.loadmat(labelFile)["category"]
# for item in ['__version__', '__globals__', '__header__']:
# captions.pop(item)
# indexs.pop(item)
# labels.pop(item)
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
return train_data, query_data, retrieval_data
def get_dataset_filename(split):
filename = {
'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'),
'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'),
'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt')
}
return filename[split]
def cross_modal_dataset(args ):
transform_test = transforms.Compose([
transforms.Resize((args.resolution, args.resolution)),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
if args.dataset == 'flickr25k':
dataset=MirflickrDataset(root_dir=args.flickr25k_root)
elif args.dataset == 'coco':
img_name, text_name, label_name = get_dataset_filename('train')
dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root,
text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test)
elif args.dataset == 'iapr':
train_file = os.path.join(args.iapr_root, 'iapr_train')
test_file = os.path.join(args.root, 'iapr_test')
retrieval_file = os.path.join(args.root, 'iapr_retrieval')
train_set=IaprDataset(args,txt=train_file, transform=transform_test)
test_set =IaprDataset(args,txt=test_file, transform=transform_test)
db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test)
train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
else:
raise ValueError("Not support.")
# if args.dataset == 'iapr':
# elif args.dataset == 'coco':
# img_name, text_name, label_name = get_dataset_filename('train')
# pass
# else:
# test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75])
# train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35])
# train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
# test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
# db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
return train_loader, test_loader, db_loader , train_set, test_set, db_set