179 lines
6.6 KiB
Python
179 lines
6.6 KiB
Python
import os
|
|
import numpy as np
|
|
|
|
|
|
def make_index(jsonData: dict, indexDict: dict):
|
|
"""
|
|
use coco dict data as orignial data.
|
|
indexDict: {jsonData's key: [index_key, index_value]}
|
|
"""
|
|
result = []
|
|
for name in indexDict:
|
|
data = jsonData[name]
|
|
middle_dict = {}
|
|
for item in data:
|
|
if item[indexDict[name][0]] not in middle_dict:
|
|
middle_dict.update({item[indexDict[name][0]]: [item[indexDict[name][1]]]})
|
|
else:
|
|
middle_dict[item[indexDict[name][0]]].append(item[indexDict[name][1]])
|
|
result.append(middle_dict)
|
|
|
|
return result
|
|
|
|
def check_file_exist(indexDict: dict, file_path: str):
|
|
keys = list(indexDict.keys())
|
|
for item in keys:
|
|
# print(indexDict[item])
|
|
if not os.path.exists(os.path.join(file_path, indexDict[item][0])):
|
|
print(item, indexDict[item])
|
|
indexDict.pop(item)
|
|
indexDict[item] = os.path.join(file_path, indexDict[item][0])
|
|
return indexDict
|
|
|
|
def chage_categories2numpy(category_ids: dict, data: dict):
|
|
|
|
for item in data:
|
|
class_item = [0] * len(category_ids)
|
|
for class_id in data[item]:
|
|
class_item[category_ids[class_id]] = 1
|
|
data[item] = np.asarray(class_item)
|
|
|
|
return data
|
|
|
|
def get_all_use_key(categoryDict: dict):
|
|
return list(categoryDict.keys())
|
|
|
|
def remove_not_use(data: dict, used_key: list):
|
|
|
|
keys = list(data.keys())
|
|
for item in keys:
|
|
if item not in used_key:
|
|
# print("remove:", item, indexDict[item])
|
|
data.pop(item)
|
|
# print(len(category_list))
|
|
return data
|
|
|
|
def merge_to_list(data: dict):
|
|
|
|
result = []
|
|
key_sort = list(data.keys())
|
|
key_sort.sort()
|
|
# print(key_sort)
|
|
# print(key_sort.index(91654))
|
|
|
|
for item in key_sort:
|
|
result.append(data[item])
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import json
|
|
import scipy.io as scio
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--coco-dir", default="./", type=str, help="the coco dataset dir")
|
|
parser.add_argument("--save-dir", default="./", type=str, help="mat file saved dir")
|
|
args = parser.parse_args()
|
|
|
|
|
|
PATH = args.coco_dir
|
|
jsonFile = os.path.join(PATH, "annotations", "captions_train2017.json")
|
|
with open(jsonFile, "r") as f:
|
|
jsonData = json.load(f)
|
|
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
|
|
result = make_index(jsonData, indexDict)
|
|
indexDict_, captionDict = result
|
|
indexDict_ = check_file_exist(indexDict_, os.path.join(PATH, "train2017"))
|
|
print("caption:", len(indexDict_), len(captionDict))
|
|
# print_result = list(indexDict.keys())
|
|
# print_result.sort()
|
|
# print(print_result)
|
|
# indexList = merge_to_list(indexDict_)
|
|
# captionList = merge_to_list(captionDict)
|
|
# print(indexDict[565962], indexList[4864])
|
|
# print(captionDict[565962], captionList[4864])
|
|
# print(result)
|
|
jsonFile = os.path.join(PATH, "annotations", "instances_train2017.json")
|
|
with open(jsonFile, "r") as f:
|
|
jsonData = json.load(f)
|
|
categroy_ids = {}
|
|
for i, item in enumerate(jsonData['categories']):
|
|
categroy_ids.update({item['id']: i})
|
|
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
|
|
result = make_index(jsonData, indexDict)
|
|
categoryDict = result[0]
|
|
cateIndexDict = result[1]
|
|
# cateIndexList = merge_to_list(cateIndexDict)
|
|
# print(categoryDict[91654])
|
|
categoryDict = chage_categories2numpy(categroy_ids, categoryDict)
|
|
# print(categoryDict[91654])
|
|
# categoryList = merge_to_list(categoryDict)
|
|
# print(categoryDict[91654], categoryList[780])
|
|
# print(indexList[100], cateIndexList[100])
|
|
# print("category:", len(categoryDict), len(cateIndexList))
|
|
used_key = get_all_use_key(categoryDict)
|
|
# 统一index
|
|
indexDict_ = remove_not_use(indexDict_, used_key)
|
|
captionDict = remove_not_use(captionDict, used_key)
|
|
categoryIndexDict = remove_not_use(cateIndexDict, used_key)
|
|
categoryDict = remove_not_use(categoryDict, used_key)
|
|
# 转变为list
|
|
indexList = merge_to_list(indexDict_)
|
|
captionList = merge_to_list(captionDict)
|
|
categoryIndexList = merge_to_list(categoryIndexDict)
|
|
categoryList = merge_to_list(categoryDict)
|
|
print("result", len(indexDict_), len(categoryDict))
|
|
print("category:", len(categoryDict), len(categoryIndexList))
|
|
for i in range(len(indexList)):
|
|
if indexList[i] != categoryIndexList[i]:
|
|
print("Not the same:", i, indexList[i], categoryIndexList[i])
|
|
|
|
val_jsonFile = os.path.join(PATH, "annotations", "captions_val2017.json")
|
|
with open(val_jsonFile, "r") as f:
|
|
jsonData = json.load(f)
|
|
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
|
|
result = make_index(jsonData, indexDict)
|
|
val_indexDict = result[0]
|
|
val_captionDict = result[1]
|
|
val_indexDict = check_file_exist(val_indexDict, os.path.join(PATH, "val2017"))
|
|
jsonFile = os.path.join(PATH, "annotations", "instances_val2017.json")
|
|
with open(jsonFile, "r") as f:
|
|
jsonData = json.load(f)
|
|
categroy_ids = {}
|
|
for i, item in enumerate(jsonData['categories']):
|
|
categroy_ids.update({item['id']: i})
|
|
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
|
|
result = make_index(jsonData, indexDict)
|
|
val_categoryDict = result[0]
|
|
val_categoryIndexDict = result[1]
|
|
val_categoryDict = chage_categories2numpy(categroy_ids, val_categoryDict)
|
|
used_key = get_all_use_key(val_categoryDict)
|
|
val_indexDict = remove_not_use(val_indexDict, used_key)
|
|
val_captionDict = remove_not_use(val_captionDict, used_key)
|
|
val_categoryIndexDict = remove_not_use(val_categoryIndexDict, used_key)
|
|
val_categoryDict = remove_not_use(val_categoryDict, used_key)
|
|
|
|
val_indexList = merge_to_list(val_indexDict)
|
|
val_captionList = merge_to_list(val_captionDict)
|
|
val_categoryIndexList = merge_to_list(val_categoryIndexDict)
|
|
val_categoryList = merge_to_list(val_categoryDict)
|
|
|
|
indexList.extend(val_indexList)
|
|
captionList.extend(val_captionList)
|
|
categoryIndexList.extend(val_categoryIndexList)
|
|
categoryList.extend(val_categoryList)
|
|
|
|
print(len(indexList), len(captionList), len(categoryIndexList))
|
|
indexs = {"index": indexList}
|
|
captions = {"caption": captionList}
|
|
categorys = {"category": categoryList}
|
|
|
|
scio.savemat(os.path.join(args.save_dir, "index.mat"), indexs)
|
|
scio.savemat(os.path.join(args.save_dir, "caption.mat"), captions)
|
|
scio.savemat(os.path.join(args.save_dir, "label.mat"), categorys)
|
|
|
|
|
|
|