advclip/dataset/dataloader.py

69 lines
2.8 KiB
Python

from .base import BaseDataset
import os
import numpy as np
import scipy.io as scio
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
np.random.seed(seed=seed)
random_index = np.random.permutation(range(len(indexs)))
query_index = random_index[: query_num]
train_index = random_index[query_num: query_num + train_num]
retrieval_index = random_index[query_num:]
query_indexs = indexs[query_index]
query_captions = captions[query_index]
query_labels = labels[query_index]
train_indexs = indexs[train_index]
train_captions = captions[train_index]
train_labels = labels[train_index]
retrieval_indexs = indexs[retrieval_index]
retrieval_captions = captions[retrieval_index]
retrieval_labels = labels[retrieval_index]
split_indexs = (query_indexs, train_indexs, retrieval_indexs)
split_captions = (query_captions, train_captions, retrieval_captions)
split_labels = (query_labels, train_labels, retrieval_labels)
return split_indexs, split_captions, split_labels
def dataloader(captionFile: str,
indexFile: str,
labelFile: str,
maxWords=77,
imageResolution=224,
query_num=5000,
train_num=10000,
seed=None,
npy=False):
if captionFile.endswith("mat"):
captions = scio.loadmat(captionFile)["caption"]
captions = captions[0] if captions.shape[0] == 1 else captions
elif captionFile.endswith("txt"):
with open(captionFile, "r") as f:
captions = f.readlines()
captions = np.asarray([[item.strip()] for item in captions])
else:
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.")
if not npy:
indexs = scio.loadmat(indexFile)["index"]
else:
indexs = np.load(indexFile, allow_pickle=True)
labels = scio.loadmat(labelFile)["category"]
# for item in ['__version__', '__globals__', '__header__']:
# captions.pop(item)
# indexs.pop(item)
# labels.pop(item)
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
return train_data, query_data, retrieval_data