129 lines
5.8 KiB
Python
129 lines
5.8 KiB
Python
|
|
from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset
|
|
import os
|
|
import numpy as np
|
|
import scipy.io as scio
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
import json
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
|
|
np.random.seed(seed=seed)
|
|
random_index = np.random.permutation(range(len(indexs)))
|
|
query_index = random_index[: query_num]
|
|
train_index = random_index[query_num: query_num + train_num]
|
|
retrieval_index = random_index[query_num:]
|
|
|
|
query_indexs = indexs[query_index]
|
|
query_captions = captions[query_index]
|
|
query_labels = labels[query_index]
|
|
|
|
train_indexs = indexs[train_index]
|
|
train_captions = captions[train_index]
|
|
train_labels = labels[train_index]
|
|
|
|
retrieval_indexs = indexs[retrieval_index]
|
|
retrieval_captions = captions[retrieval_index]
|
|
retrieval_labels = labels[retrieval_index]
|
|
|
|
split_indexs = (query_indexs, train_indexs, retrieval_indexs)
|
|
split_captions = (query_captions, train_captions, retrieval_captions)
|
|
split_labels = (query_labels, train_labels, retrieval_labels)
|
|
return split_indexs, split_captions, split_labels
|
|
|
|
def dataloader(captionFile: str,
|
|
indexFile: str,
|
|
labelFile: str,
|
|
maxWords=77,
|
|
imageResolution=224,
|
|
query_num=5000,
|
|
train_num=1000,
|
|
seed=None,
|
|
npy=False):
|
|
if captionFile.endswith("mat"):
|
|
captions = scio.loadmat(captionFile)["caption"]
|
|
captions = captions[0] if captions.shape[0] == 1 else captions
|
|
elif captionFile.endswith("txt"):
|
|
with open(captionFile, "r") as f:
|
|
captions = f.readlines()
|
|
captions = np.asarray([[item.strip()] for item in captions])
|
|
elif captionFile.endswith("json"):
|
|
with open(captionFile, "r") as f:
|
|
data = json.load(f)
|
|
captions=data["caption"]
|
|
# captions = captions[0] if captions.shape[0] == 1 else captions
|
|
else:
|
|
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.")
|
|
if not npy:
|
|
indexs = scio.loadmat(indexFile)["index"]
|
|
else:
|
|
indexs = np.load(indexFile, allow_pickle=True)
|
|
labels = scio.loadmat(labelFile)["category"]
|
|
# for item in ['__version__', '__globals__', '__header__']:
|
|
# captions.pop(item)
|
|
# indexs.pop(item)
|
|
# labels.pop(item)
|
|
|
|
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
|
|
|
|
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
|
|
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
|
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
|
|
|
return train_data, query_data, retrieval_data
|
|
|
|
def get_dataset_filename(split):
|
|
filename = {
|
|
'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'),
|
|
'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'),
|
|
'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt')
|
|
}
|
|
|
|
return filename[split]
|
|
|
|
def cross_modal_dataset(args ):
|
|
transform_test = transforms.Compose([
|
|
transforms.Resize((args.resolution, args.resolution)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
|
])
|
|
if args.dataset == 'flickr25k':
|
|
dataset=MirflickrDataset(root_dir=args.flickr25k_root)
|
|
elif args.dataset == 'coco':
|
|
img_name, text_name, label_name = get_dataset_filename('train')
|
|
|
|
dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root,
|
|
text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test)
|
|
elif args.dataset == 'iapr':
|
|
train_file = os.path.join(args.iapr_root, 'iapr_train')
|
|
test_file = os.path.join(args.root, 'iapr_test')
|
|
retrieval_file = os.path.join(args.root, 'iapr_retrieval')
|
|
train_set=IaprDataset(args,txt=train_file, transform=transform_test)
|
|
test_set =IaprDataset(args,txt=test_file, transform=transform_test)
|
|
db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test)
|
|
train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
else:
|
|
raise ValueError("Not support.")
|
|
|
|
# if args.dataset == 'iapr':
|
|
|
|
# elif args.dataset == 'coco':
|
|
# img_name, text_name, label_name = get_dataset_filename('train')
|
|
|
|
# pass
|
|
# else:
|
|
# test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75])
|
|
# train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35])
|
|
# train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
# test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
# db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
|
|
return train_loader, test_loader, db_loader , train_set, test_set, db_set
|
|
|
|
|
|
|