commit 9e6030dbf9d20b1311a92697d3d399d81287067c Author: kalenforn <1564920382@qq.com> Date: Fri Jul 1 17:35:22 2022 +0800 DCMHT diff --git a/READEM.md b/READEM.md new file mode 100644 index 0000000..dcd709b --- /dev/null +++ b/READEM.md @@ -0,0 +1,78 @@ +# Differentiable Cross Modal Hashing via Multimodal Transformers + +## Framework +The main architecture of our method. +![framework](./data/structure.jpg) + +We propose a selecting mechanism to generate hash code that will transfor the discrete space into a continuous space. Hash code will be encoded as a $2D$ vector. +![hash](./data/method.jpg) + +## Dependencies +We use python to build our code, you need to install those package to run + +- pytorch 1.9.1 +- sklearn +- tqdm +- pillow + +## Training + +### Processing dataset +Before training, you need to download the oringal data from [coco](https://cocodataset.org/#download)(include 2017 train,val and annotations), [nuswide](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html)(include all), [mirflickr25k](https://www.kaggle.com/datasets/paulrohan2020/mirflickr25k)(include mirflickr25k and mirflickr25k_annotations_v080), +then use the "data/make_XXX.py" to generate .mat file + +For example: +> cd COCO_DIR # include train val images and annotations files +> +> make mat +> +> cp DCMHT/data/make_coco.py mat +> +> python make_coco.py --coco-dir ../ --save-dir ./ + +After all mat file generated, the dir of `dataset` will like this: +~~~ +dataset +├── base.py +├── __init__.py +├── dataloader.py +├── coco +│   ├── caption.mat +│   ├── index.mat +│   └── label.mat +├── flickr25k +│   ├── caption.mat +│   ├── index.mat +│   └── label.mat +└── nuswide +    ├── caption.txt # Notice! It is a txt file! +    ├── index.mat +    └── label.mat +~~~ + +### Download CLIP pretrained model +Pretrained model will be found in the 30 lines of [CLIP/clip/clip.py](https://github.com/openai/CLIP/blob/main/clip/clip.py). This code is based on the "ViT-B/32". + +You should copy ViT-B-32.pt to this dir. + +### Start + +After the dataset has been prepared, we could run the follow command to train. +> python main.py --is-train --hash-layer select --dataset coco --caption-file caption.mat --index-file index.mat --label-file label.mat --similarity-function euclidean --loss-type l2 --vartheta 0.75 --lr 0.0001 --output-dim 64 --save-dir ./result/coco/64 --clip-path ./ViT-B-32.pt --batch-size 256 + + +## Result +![result](./data/result.png) + +## Acknowledegements +[CLIP](https://github.com/openai/CLIP) + +[SSAH](https://github.com/lelan-li/SSAH) + +[GCH](https://github.com/DeXie0808/GCH) + +[AGAH](https://github.com/WendellGul/AGAH) + +[DADH](https://github.com/Zjut-MultimediaPlus/DADH) + +[deep-cross-modal-hashing](https://github.com/WangGodder/deep-cross-modal-hashing) diff --git a/data/method.jpg b/data/method.jpg new file mode 100644 index 0000000..99ff4bf Binary files /dev/null and b/data/method.jpg differ diff --git a/data/result.png b/data/result.png new file mode 100644 index 0000000..f34ebef Binary files /dev/null and b/data/result.png differ diff --git a/data/structure.jpg b/data/structure.jpg new file mode 100644 index 0000000..4565789 Binary files /dev/null and b/data/structure.jpg differ diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/base.py b/dataset/base.py new file mode 100644 index 0000000..cbf32d5 --- /dev/null +++ b/dataset/base.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +from torch.utils.data import Dataset +import torch +import random +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from model.simple_tokenizer import SimpleTokenizer as Tokenizer + + +class BaseDataset(Dataset): + + def __init__(self, + + captions: dict, + indexs: dict, + labels: dict, + is_train=True, + tokenizer=Tokenizer(), + maxWords=32, + imageResolution=224, + npy=False): + + self.captions = captions + self.indexs = indexs + self.labels = labels + self.npy = npy + + self.maxWords = maxWords + self.tokenizer = tokenizer + + self.transform = Compose([ + Resize(imageResolution, interpolation=Image.BICUBIC), + CenterCrop(imageResolution), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) if is_train else Compose([ + Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + self.__length = len(self.indexs) + + def __len__(self): + return self.__length + + def _load_image(self, index: int) -> torch.Tensor: + if not self.npy: + image_path = self.indexs[index].strip() + # print(image_path) + image = Image.open(image_path).convert("RGB") + else: + image = Image.fromarray(self.indexs[index]).convert("RGB") + image = self.transform(image) + + return image + + def _load_text(self, index: int): + captions = self.captions[index] + use_cap = captions[random.randint(0, len(captions) - 1)] + + words = self.tokenizer.tokenize(use_cap) + words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + total_length_with_CLS = self.maxWords - 1 + if len(words) > total_length_with_CLS: + words = words[:total_length_with_CLS] + + words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + caption = self.tokenizer.convert_tokens_to_ids(words) + + while len(caption) < self.maxWords: + caption.append(0) + caption = torch.tensor(caption) + + return caption + + def _load_label(self, index: int) -> torch.Tensor: + label = self.labels[index] + label = torch.from_numpy(label) + + return label + + def get_all_label(self): + labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.int64) + for i, item in enumerate(self.labels): + + labels[i] = torch.from_numpy(item) + return labels + + def __getitem__(self, index): + image = self._load_image(index) + caption = self._load_text(index) + label = self._load_label(index) + + return image, caption, label, index + diff --git a/dataset/dataloader.py b/dataset/dataloader.py new file mode 100644 index 0000000..d04c1bb --- /dev/null +++ b/dataset/dataloader.py @@ -0,0 +1,68 @@ + +from .base import BaseDataset +import os +import numpy as np +import scipy.io as scio + + +def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None): + np.random.seed(seed=seed) + random_index = np.random.permutation(range(len(indexs))) + query_index = random_index[: query_num] + train_index = random_index[query_num: query_num + train_num] + retrieval_index = random_index[query_num:] + + query_indexs = indexs[query_index] + query_captions = captions[query_index] + query_labels = labels[query_index] + + train_indexs = indexs[train_index] + train_captions = captions[train_index] + train_labels = labels[train_index] + + retrieval_indexs = indexs[retrieval_index] + retrieval_captions = captions[retrieval_index] + retrieval_labels = labels[retrieval_index] + + split_indexs = (query_indexs, train_indexs, retrieval_indexs) + split_captions = (query_captions, train_captions, retrieval_captions) + split_labels = (query_labels, train_labels, retrieval_labels) + return split_indexs, split_captions, split_labels + +def dataloader(captionFile: str, + indexFile: str, + labelFile: str, + maxWords=32, + imageResolution=224, + query_num=5000, + train_num=10000, + seed=None, + npy=False): + if captionFile.endswith("mat"): + captions = scio.loadmat(captionFile)["caption"] + captions = captions[0] if captions.shape[0] == 1 else captions + elif captionFile.endswith("txt"): + with open(captionFile, "r") as f: + captions = f.readlines() + captions = np.asarray([[item.strip()] for item in captions]) + else: + raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.") + if not npy: + indexs = scio.loadmat(indexFile)["index"] + else: + indexs = np.load(indexFile, allow_pickle=True) + labels = scio.loadmat(labelFile)["category"] + # for item in ['__version__', '__globals__', '__header__']: + # captions.pop(item) + # indexs.pop(item) + # labels.pop(item) + + split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed) + + train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy) + query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) + retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) + + return train_data, query_data, retrieval_data + + diff --git a/dataset/make_coco.py b/dataset/make_coco.py new file mode 100644 index 0000000..782aa22 --- /dev/null +++ b/dataset/make_coco.py @@ -0,0 +1,178 @@ +import os +import numpy as np + + +def make_index(jsonData: dict, indexDict: dict): + """ + use coco dict data as orignial data. + indexDict: {jsonData's key: [index_key, index_value]} + """ + result = [] + for name in indexDict: + data = jsonData[name] + middle_dict = {} + for item in data: + if item[indexDict[name][0]] not in middle_dict: + middle_dict.update({item[indexDict[name][0]]: [item[indexDict[name][1]]]}) + else: + middle_dict[item[indexDict[name][0]]].append(item[indexDict[name][1]]) + result.append(middle_dict) + + return result + +def check_file_exist(indexDict: dict, file_path: str): + keys = list(indexDict.keys()) + for item in keys: + # print(indexDict[item]) + if not os.path.exists(os.path.join(file_path, indexDict[item][0])): + print(item, indexDict[item]) + indexDict.pop(item) + indexDict[item] = os.path.join(file_path, indexDict[item][0]) + return indexDict + +def chage_categories2numpy(category_ids: dict, data: dict): + + for item in data: + class_item = [0] * len(category_ids) + for class_id in data[item]: + class_item[category_ids[class_id]] = 1 + data[item] = np.asarray(class_item) + + return data + +def get_all_use_key(categoryDict: dict): + return list(categoryDict.keys()) + +def remove_not_use(data: dict, used_key: list): + + keys = list(data.keys()) + for item in keys: + if item not in used_key: + # print("remove:", item, indexDict[item]) + data.pop(item) + # print(len(category_list)) + return data + +def merge_to_list(data: dict): + + result = [] + key_sort = list(data.keys()) + key_sort.sort() + # print(key_sort) + # print(key_sort.index(91654)) + + for item in key_sort: + result.append(data[item]) + + return result + + +if __name__ == "__main__": + import json + import scipy.io as scio + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--coco-dir", default="./", type=str, help="the coco dataset dir") + parser.add_argument("--save-dir", default="./", type=str, help="mat file saved dir") + args = parser.parse_args() + + + PATH = args.coco_dir + jsonFile = os.path.join(PATH, "annotations", "captions_train2017.json") + with open(jsonFile, "r") as f: + jsonData = json.load(f) + indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]} + result = make_index(jsonData, indexDict) + indexDict_, captionDict = result + indexDict_ = check_file_exist(indexDict_, os.path.join(PATH, "train2017")) + print("caption:", len(indexDict_), len(captionDict)) + # print_result = list(indexDict.keys()) + # print_result.sort() + # print(print_result) + # indexList = merge_to_list(indexDict_) + # captionList = merge_to_list(captionDict) + # print(indexDict[565962], indexList[4864]) + # print(captionDict[565962], captionList[4864]) + # print(result) + jsonFile = os.path.join(PATH, "annotations", "instances_train2017.json") + with open(jsonFile, "r") as f: + jsonData = json.load(f) + categroy_ids = {} + for i, item in enumerate(jsonData['categories']): + categroy_ids.update({item['id']: i}) + indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]} + result = make_index(jsonData, indexDict) + categoryDict = result[0] + cateIndexDict = result[1] + # cateIndexList = merge_to_list(cateIndexDict) + # print(categoryDict[91654]) + categoryDict = chage_categories2numpy(categroy_ids, categoryDict) + # print(categoryDict[91654]) + # categoryList = merge_to_list(categoryDict) + # print(categoryDict[91654], categoryList[780]) + # print(indexList[100], cateIndexList[100]) + # print("category:", len(categoryDict), len(cateIndexList)) + used_key = get_all_use_key(categoryDict) + # 统一index + indexDict_ = remove_not_use(indexDict_, used_key) + captionDict = remove_not_use(captionDict, used_key) + categoryIndexDict = remove_not_use(cateIndexDict, used_key) + categoryDict = remove_not_use(categoryDict, used_key) + # 转变为list + indexList = merge_to_list(indexDict_) + captionList = merge_to_list(captionDict) + categoryIndexList = merge_to_list(categoryIndexDict) + categoryList = merge_to_list(categoryDict) + print("result", len(indexDict_), len(categoryDict)) + print("category:", len(categoryDict), len(categoryIndexList)) + for i in range(len(indexList)): + if indexList[i] != categoryIndexList[i]: + print("Not the same:", i, indexList[i], categoryIndexList[i]) + + val_jsonFile = os.path.join(PATH, "annotations", "captions_val2017.json") + with open(val_jsonFile, "r") as f: + jsonData = json.load(f) + indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]} + result = make_index(jsonData, indexDict) + val_indexDict = result[0] + val_captionDict = result[1] + val_indexDict = check_file_exist(val_indexDict, os.path.join(PATH, "val2017")) + jsonFile = os.path.join(PATH, "annotations", "instances_val2017.json") + with open(jsonFile, "r") as f: + jsonData = json.load(f) + categroy_ids = {} + for i, item in enumerate(jsonData['categories']): + categroy_ids.update({item['id']: i}) + indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]} + result = make_index(jsonData, indexDict) + val_categoryDict = result[0] + val_categoryIndexDict = result[1] + val_categoryDict = chage_categories2numpy(categroy_ids, val_categoryDict) + used_key = get_all_use_key(val_categoryDict) + val_indexDict = remove_not_use(val_indexDict, used_key) + val_captionDict = remove_not_use(val_captionDict, used_key) + val_categoryIndexDict = remove_not_use(val_categoryIndexDict, used_key) + val_categoryDict = remove_not_use(val_categoryDict, used_key) + + val_indexList = merge_to_list(val_indexDict) + val_captionList = merge_to_list(val_captionDict) + val_categoryIndexList = merge_to_list(val_categoryIndexDict) + val_categoryList = merge_to_list(val_categoryDict) + + indexList.extend(val_indexList) + captionList.extend(val_captionList) + categoryIndexList.extend(val_categoryIndexList) + categoryList.extend(val_categoryList) + + print(len(indexList), len(captionList), len(categoryIndexList)) + indexs = {"index": indexList} + captions = {"caption": captionList} + categorys = {"category": categoryList} + + scio.savemat(os.path.join(args.save_dir, "index.mat"), indexs) + scio.savemat(os.path.join(args.save_dir, "caption.mat"), captions) + scio.savemat(os.path.join(args.save_dir, "label.mat"), categorys) + + + diff --git a/dataset/make_mirflickr25k.py b/dataset/make_mirflickr25k.py new file mode 100644 index 0000000..fe5d302 --- /dev/null +++ b/dataset/make_mirflickr25k.py @@ -0,0 +1,81 @@ +import os +import scipy.io as scio +import numpy as np + +# mirflickr25k_annotations_v080 and mirflickr +# mkdir mat +# mv make_mirflickr25k.py mat +# python make_mirflickr25k.py +root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/" + +file_path = os.path.join(root_dir, "/mirflickr25k_annotations_v080") + +file_list = os.listdir(file_path) + +file_list = [item for item in file_list if "_r1" not in item and "README" not in item] + +print("class num:", len(file_list)) + +class_index = {} +for i, item in enumerate(file_list): + class_index.update({item: i}) + +label_dict = {} +for path_id in file_list: + path = os.path.join(file_path, path_id) + with open(path, "r") as f: + for item in f: + item = item.strip() + if item not in label_dict: + label = np.zeros(len(file_list)) + label[class_index[path_id]] = 1 + label_dict.update({item: label}) + else: + # print() + label_dict[item][class_index[path_id]] = 1 + +# print(label_dict) +print("create label:", len(label_dict)) +keys = list(label_dict.keys()) +keys.sort() + +labels = [] +for key in keys: + labels.append(label_dict[key]) +print("labels created:", len(labels)) +labels = {"category": labels} + + +PATH = os.path.join(root_dir, "mirflickr") +index = [os.path.join(PATH, "im" + item + ".jpg") for item in keys] +print("index created:", len(index)) +index= {"index": index} + + +captions_path = os.path.join(root_dir, "/mirflickr/meta/tags") +captions_list = os.listdir(captions_path) +captions_dict = {} +for item in captions_list: + id_ = item.split(".")[0].replace("tags", "") + caption = "" + with open(os.path.join(captions_path, item), "r") as f: + for word in f.readlines(): + caption += word.strip() + " " + caption = caption.strip() + captions_dict.update({id_: caption}) + +captions = [] + +for item in keys: + captions.append([captions_dict[item]]) + +print("captions created:", len(captions)) +captions = {"caption": captions} + +scio.savemat(os.path.join(root_dir, "/mat/index.mat"), index) +scio.savemat(os.path.join(root_dir, "/mat/caption.mat"), captions) +scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels) + + + + diff --git a/dataset/make_nuswide.py b/dataset/make_nuswide.py new file mode 100644 index 0000000..aa6ec32 --- /dev/null +++ b/dataset/make_nuswide.py @@ -0,0 +1,107 @@ +import os +import scipy.io as scio +import numpy as np + +# mkdir mat +# mv make_nuswide.py mat +# python make_nuswide.py +root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/" + + +imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt") +labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels") +textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt") +classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt") + +# you can use the image urls to download images +imagePath = os.path.join(root_dir, "nuswide/Flickr") + +with open(imageListFile, "r") as f: + indexs = f.readlines() + +indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs] +print("indexs length:", len(indexs)) + +#class_index = {} +#with open(classIndexFile, "r") as f: +# data = f.readlines() +# +#for i, item in enumerate(data): +# class_index.update({item.strip(): i}) + +captions = [] +with open(textFile, "r") as f: + for line in f: + if len(line.strip()) == 0: + print("some line empty!") + continue + caption = line.split()[1:] + caption = " ".join(caption).strip() + if len(caption) == 0: + caption = "123456" + captions.append(caption) + +print("captions length:", len(captions)) + +#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) +# label_lists = os.listdir(labelPath) +with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f: + label_lists = f.readlines() +label_lists = [item.strip() for item in label_lists] + +class_index = {} +for i, item in enumerate(label_lists): + class_index.update({item: i}) + +labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8) + +for item in label_lists: + path = os.path.join(labelPath, item) + class_label = item# .split(".")[0].split("_")[-1] + + with open(path, "r") as f: + data = f.readlines() + for i, val in enumerate(data): + labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0 +print("labels sum:", labels.sum()) + +not_used_id = [] +with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f: + not_used_id = f.readlines() +not_used_id = [int(item.strip()) for item in not_used_id] + +# for item in not_used_id: +# indexs.pop(item) +# captions.pop(item) +# labels = np.delete(labels, item, 0) +ind = list(range(len(indexs))) +for item in not_used_id: + ind.remove(item) + indexs[item] = "" + captions[item] = "" +indexs = [item for item in indexs if item != ""] +captions = [item for item in captions if item != ""] +ind = np.asarray(ind) +labels = labels[ind] +# ind = range(len(indexs)) + +print("indexs length:", len(indexs)) +print("captions length:", len(captions)) +print("labels shape:", labels.shape) + +indexs = {"index": indexs} +captions = {"caption": captions} +labels = {"category": labels} + +scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs) +# scio.savemat("caption.mat", captions) +scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels) + + +captions = [item + "\n" for item in captions["caption"]] + +with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f: + f.writelines(captions) + +print("finished!") + diff --git a/main.py b/main.py new file mode 100644 index 0000000..4de8332 --- /dev/null +++ b/main.py @@ -0,0 +1,8 @@ +from train.hash_train import Trainer + + +if __name__ == "__main__": + + Trainer() + + diff --git a/model/__init__.py b/model/__init__.py new file mode 100755 index 0000000..dcc5619 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/model/bpe_simple_vocab_16e6.txt.gz b/model/bpe_simple_vocab_16e6.txt.gz new file mode 100755 index 0000000..7b5088a Binary files /dev/null and b/model/bpe_simple_vocab_16e6.txt.gz differ diff --git a/model/clip.py b/model/clip.py new file mode 100755 index 0000000..6ce5565 --- /dev/null +++ b/model/clip.py @@ -0,0 +1,224 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if torch.__version__.split(".") < ["1", "7", "1"]: + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/model/hash_model.py b/model/hash_model.py new file mode 100644 index 0000000..2d9c851 --- /dev/null +++ b/model/hash_model.py @@ -0,0 +1,161 @@ +import os +import torch +import logging +import torch.nn as nn +import numpy as np +from typing import Union + +from model.model import build_model +from utils import get_logger, get_summary_writer + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.kaiming_uniform_(m.weight, mode='fan_out') + nn.init.constant_(m.bias, 0.0) + elif classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find('BatchNorm') != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +class LinearHash(nn.Module): + + def __init__(self, inputDim=2048, outputDim=64): + super(LinearHash, self).__init__() + self.fc = nn.Linear(inputDim, outputDim) + self.fc.apply(weights_init_kaiming) + self.drop_out = nn.Dropout(p=0.2) + + def forward(self, data): + result = self.fc(data) + return torch.tanh(self.drop_out(result)) + + +class HashLayer(nn.Module): + + LINEAR_EMBED = 128 + SIGMOID_ALPH = 10 + def __init__(self, inputDim=2048, outputDim=64): + + super(HashLayer, self).__init__() + self.fc = nn.Linear(inputDim, self.LINEAR_EMBED) + self.fc.apply(weights_init_kaiming) + self.hash_list = nn.ModuleList([nn.Linear(self.LINEAR_EMBED, 2) for _ in range(outputDim)]) + for item in self.hash_list: + item.apply(weights_init_kaiming) + + def forward(self, data): + + embed = self.fc(data) + embed = torch.relu(embed) + + softmax_list = [torch.softmax(item(embed), dim=-1) for item in self.hash_list] + + return softmax_list + +class HashLayer_easy_logic(nn.Module): + + LINEAR_EMBED = 128 + SIGMOID_ALPH = 10 + def __init__(self, inputDim=2048, outputDim=64): + + super(HashLayer, self).__init__() + self.bit = outputDim + self.fc = nn.Linear(inputDim, outputDim * 2) + self.fc.apply(weights_init_kaiming) + for item in self.hash_list: + item.apply(weights_init_kaiming) + + def forward(self, data): + + embed = self.fc(data) + softmax_list = embed.view(embed.shape[0], self.bit, 2) + + softmax_list = torch.softmax(softmax_list, dim=-1) + + return softmax_list + + +class DCMHT(nn.Module): + + def __init__(self, + outputDim=64, + clipPath="./ViT-B-32.pt", + writer=None, + saveDir="./result/log", + logger: logging.Logger=None, + is_train=True, + linear=False): + super(DCMHT, self).__init__() + os.makedirs(saveDir, exist_ok=True) + self.logger = logger if logger is not None else get_logger(os.path.join(saveDir, "train.log" if is_train else "test.log")) + self.writer = writer if writer is not None and is_train else get_summary_writer(os.path.join(saveDir, "tensorboard")) + embedDim, self.clip = self.load_clip(clipPath) + # if is_train: + # self.clip.eval() + # print("start freezen") + # self.freezen() + self.image_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim) + self.text_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim) + # print(self.image_hash) + # print(self.text_hash) + + def freezen(self): + for name, param in self.clip.named_parameters(): + # print(name) + if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \ + or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0: + # print("1") + continue + elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0: + layer_num = int(name.split(".resblocks.")[1].split(".")[0]) + if layer_num >= 12: + # print("2") + continue + if name.find("conv2.") == 0: + # print("3") + continue + else: + # paramenters which < freeze_layer_num will be freezed + param.requires_grad = False + + def load_clip(self, clipPath: str) -> tuple: + try: + model = torch.jit.load(clipPath, map_location="cpu").eval() + state_dict = model.state_dict() + except RuntimeError: + state_dict = torch.load(clipPath, map_location="cpu") + + return state_dict["text_projection"].shape[1], build_model(state_dict) + + def encode_image(self, image): + + image_embed = self.clip.encode_image(image) + image_embed = self.image_hash(image_embed) + + return image_embed + + def eval(self): + self.image_hash.eval() + self.text_hash.eval() + # self.clip.eval() + + def train(self): + self.image_hash.train() + self.text_hash.train() + + def encode_text(self, text): + + text_embed = self.clip.encode_text(text) + text_embed = self.text_hash(text_embed) + + return text_embed + + def forward(self, image, text): + return self.encode_image(image), self.encode_text(text) + diff --git a/model/model.py b/model/model.py new file mode 100755 index 0000000..aa8fc27 --- /dev/null +++ b/model/model.py @@ -0,0 +1,450 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + # self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + # return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + attn_mask_ = self.attn_mask + if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): + attn_mask_ = self.attn_mask(x.size(0)) # LND + + attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] + + def forward(self, x: torch.Tensor): + # x, video_frame = x_tuple + # print(x.shape) + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + # print(x.shape) + # print(x.shape) + x = self.conv1(x) # shape = [*, width, grid, grid] + # print("image feature map:", x.shape) + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + # print(x.shape) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # print(x.shape) + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + # print(x.shape) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + # print(x.shape) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self, context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding[:x.size(1), :].type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + # print("vision width:", vision_width) + # print("vision patch size", vision_patch_size) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + # return model.eval() + return model diff --git a/model/optimization.py b/model/optimization.py new file mode 100644 index 0000000..f040b83 --- /dev/null +++ b/model/optimization.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for BERT model.""" + +import math +import torch +from torch.optim import Optimizer +from torch.optim.optimizer import required +from torch.nn.utils import clip_grad_norm_ +import logging + +logger = logging.getLogger(__name__) + +def warmup_cosine(x, warmup=0.002): + if x < warmup: + return x/warmup + return 0.5 * (1.0 + math.cos(math.pi * x)) + +def warmup_constant(x, warmup=0.002): + """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. + Learning rate is 1. afterwards. """ + if x < warmup: + return x/warmup + return 1.0 + +def warmup_linear(x, warmup=0.002): + """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. + After `t_total`-th training step, learning rate is zero. """ + if x < warmup: + return x/warmup + return max((x-1.)/(warmup-1.), 0) + +SCHEDULES = { + 'warmup_cosine': warmup_cosine, + 'warmup_constant': warmup_constant, + 'warmup_linear': warmup_linear, +} + + +class BertAdam(Optimizer): + """Implements BERT version of Adam algorithm with weight decay fix. + Params: + lr: learning rate + warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 + t_total: total number of training steps for the learning + rate schedule, -1 means constant learning rate. Default: -1 + schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' + b1: Adams b1. Default: 0.9 + b2: Adams b2. Default: 0.999 + e: Adams epsilon. Default: 1e-6 + weight_decay: Weight decay. Default: 0.01 + max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 + """ + def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', + b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, + max_grad_norm=1.0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if schedule not in SCHEDULES: + raise ValueError("Invalid schedule parameter: {}".format(schedule)) + if not 0.0 <= warmup < 1.0 and not warmup == -1: + raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) + if not 0.0 <= b1 < 1.0: + raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) + if not 0.0 <= b2 < 1.0: + raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) + if not e >= 0.0: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) + defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, + b1=b1, b2=b2, e=e, weight_decay=weight_decay, + max_grad_norm=max_grad_norm) + super(BertAdam, self).__init__(params, defaults) + + def get_lr(self): + lr = [] + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + if len(state) == 0: + return [0] + if group['t_total'] != -1: + schedule_fct = SCHEDULES[group['schedule']] + lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) + else: + lr_scheduled = group['lr'] + lr.append(lr_scheduled) + return lr + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['next_m'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['next_v'] = torch.zeros_like(p.data) + + next_m, next_v = state['next_m'], state['next_v'] + beta1, beta2 = group['b1'], group['b2'] + + # Add grad clipping + if group['max_grad_norm'] > 0: + clip_grad_norm_(p, group['max_grad_norm']) + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 + next_m.mul_(beta1).add_(grad, alpha=1 - beta1) + # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 + next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + update = next_m / (next_v.sqrt() + group['e']) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if group['weight_decay'] > 0.0: + update += group['weight_decay'] * p.data + + if group['t_total'] != -1: + schedule_fct = SCHEDULES[group['schedule']] + progress = state['step']/group['t_total'] + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) + else: + lr_scheduled = group['lr'] + + update_with_lr = lr_scheduled * update + p.data.add_(-update_with_lr) + + state['step'] += 1 + + return loss \ No newline at end of file diff --git a/model/simple_tokenizer.py b/model/simple_tokenizer.py new file mode 100755 index 0000000..3eb73c2 --- /dev/null +++ b/model/simple_tokenizer.py @@ -0,0 +1,143 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def tokenize(self, text): + tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder[bpe_token] for bpe_token in tokens] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d315af2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +tqdm +scipy +numpy +pillow +matplotlib +sklearn +pytorch==1.9.1 + diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/base.py b/train/base.py new file mode 100644 index 0000000..eacc2b1 --- /dev/null +++ b/train/base.py @@ -0,0 +1,94 @@ +import os +from tqdm import tqdm +import torch + +from torch import distributed as dist +from utils import get_logger, get_summary_writer + + +class TrainBase(object): + + def __init__(self, + args, + rank=0): + + self.args = args + os.makedirs(args.save_dir, exist_ok=True) + self._init_writer() + self.logger.info(self.args) + self.rank = rank + + self._init_dataset() + self._init_model() + + self.global_step = 0 + # self.global_step_t = 0 + self.max_mapi2t = 0 + self.max_mapt2i = 0 + self.best_epoch_i = 0 + self.best_epoch_t = 0 + + def _init_dataset(self): + self.train_loader = None + self.query_loader = None + self.retrieval_loader = None + + def _init_model(self): + self.model = None + self.model_ddp = None + + def _init_writer(self): + self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log")) + self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard")) + + def run(self): + if self.args.is_train: + self.train() + else: + self.test() + + def change_state(self, mode): + + if mode == "train": + self.model.train() + elif mode == "valid": + self.model.eval() + + def get_code(self, data_loader, length: int): + + img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + + for image, text, label, index in tqdm(data_loader): + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + index = index.numpy() + image_hash = self.model.encode_image(image) + text_hash = self.model.encode_text(text) + + img_buffer[index, :] = image_hash.data + text_buffer[index, :] = text_hash.data + + return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) + + def hash_loss(self, a: torch.Tensor): + return torch.mean(torch.sqrt(torch.sum(torch.pow(torch.sign(a) - a, 2), dim=1))) * 0.5 + + def similarity_loss(self): + raise NotImplementedError("Function of 'similarity_loss' doesn't implement.") + + def save_model(self, epoch): + torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")) + self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))) + + def train(self): + raise NotImplementedError("Function of 'train' doesn't implement.") + + def valid(self): + raise NotImplementedError("Function of 'valid' doesn't implement.") + + def test(self): + raise NotImplementedError("Function of 'test' doesn't implement.") + + def compute_loss(self): + raise NotImplementedError("Function of 'compute_loss' doesn't implement.") diff --git a/train/hash_train.py b/train/hash_train.py new file mode 100644 index 0000000..ab495d2 --- /dev/null +++ b/train/hash_train.py @@ -0,0 +1,360 @@ +from torch.nn.modules import loss +from model.hash_model import DCMHT as DCMHT +import os +from tqdm import tqdm +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import scipy.io as scio + + +from .base import TrainBase +from model.optimization import BertAdam +from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity +from utils.calc_utils import calc_map_k_matrix as calc_map_k +from dataset.dataloader import dataloader + +class Trainer(TrainBase): + + def __init__(self, + rank=0): + args = get_args() + super(Trainer, self).__init__(args, rank) + self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) + self.run() + + def _init_model(self): + self.logger.info("init model.") + linear = False + if self.args.hash_layer == "linear": + linear = True + + self.logger.info("ViT+GPT!") + HashModel = DCMHT + self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, + writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) + if self.args.pretrained != "" and os.path.exists(self.args.pretrained): + self.logger.info("load pretrained model.") + self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) + + self.model.float() + self.optimizer = BertAdam([ + {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, + {'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, + {'params': self.model.text_hash.parameters(), 'lr': self.args.lr} + ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', + b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs, + weight_decay=self.args.weight_decay, max_grad_norm=1.0) + + print(self.model) + + def _init_dataset(self): + self.logger.info("init dataset.") + self.logger.info(f"Using {self.args.dataset} dataset.") + self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) + self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) + self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) + train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, + indexFile=self.args.index_file, + labelFile=self.args.label_file, + maxWords=self.args.max_words, + imageResolution=self.args.resolution, + query_num=self.args.query_num, + train_num=self.args.train_num, + seed=self.args.seed) + self.train_labels = train_data.get_all_label() + self.query_labels = query_data.get_all_label() + self.retrieval_labels = retrieval_data.get_all_label() + self.args.retrieval_num = len(self.retrieval_labels) + self.logger.info(f"query shape: {self.query_labels.shape}") + self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") + self.train_loader = DataLoader( + dataset=train_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.query_loader = DataLoader( + dataset=query_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.retrieval_loader = DataLoader( + dataset=retrieval_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + + def train_epoch(self, epoch): + self.change_state(mode="train") + self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) + all_loss = 0 + times = 0 + for image, text, label, index in self.train_loader: + self.global_step += 1 + times += 1 + image.float() + if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: + label = torch.ones([image.shape[0]], dtype=torch.int) + label = label.diag() + # print(text.dtype) + # text.float() + # label.float() + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + # print("text shape:", text.shape) + index = index.numpy() + # print(text.shape) + hash_img, hash_text = self.model(image, text) + if self.args.hash_layer == "select": + hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1) + hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1) + loss = self.compute_loss(hash_img, hash_text, label, epoch, times) + all_loss += loss + + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") + + def train(self): + self.logger.info("Start train.") + + for epoch in range(self.args.epochs): + self.train_epoch(epoch) + self.valid(epoch) + self.save_model(epoch) + + self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") + + def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): + + s = torch.matmul(a, b.t()) + b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s))) + + return b_loss + + def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): + """ + """ + kl_divergence = torch.mean(a * torch.log(a / (b + 0.001))) + print("mean", torch.mean(a - b)) + print("kl", kl_divergence) + return kl_divergence + + + def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05): + + # $\vartheta$ + vartheta = self.args.vartheta + if self.args.sim_threshold != 0: + threshold = self.args.sim_threshold + similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b) + + positive_similarity = similarity * label_sim + # 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。 + negative_similarity = similarity * (1 - label_sim) + + if self.args.similarity_function == "cosine": + positive_similarity = positive_similarity.clip(threshold) - threshold + negative_similarity = negative_similarity.clip(max=1.) + negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity + elif self.args.similarity_function == "euclidean": + # 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值 + # 人为认定的最大值是一半长度的hash码不同。 + max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5 + negative_similarity = negative_similarity.clip(max=max_value) + negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity + + if self.args.loss_type == "l1": + positive_loss = positive_similarity.mean() + negative_loss = negative_similarity.mean() + elif self.args.loss_type == "l2": + positive_loss = torch.pow(positive_similarity, 2).mean() + negative_loss = torch.pow(negative_similarity, 2).mean() + else: + raise ValueError("argument of loss_type is not support.") + + return similarity, positive_loss, negative_loss + + def make_hash_code(self, code: list) -> torch.Tensor: + + code = torch.stack(code) + # print(code.shape) + code = code.permute(1, 0, 2) + hash_code = torch.argmax(code, dim=-1) + hash_code[torch.where(hash_code == 0)] = -1 + hash_code = hash_code.float() + + return hash_code + + def get_code(self, data_loader, length: int): + + img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + + for image, text, label, index in tqdm(data_loader): + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + index = index.numpy() + image_hash = self.model.encode_image(image) + image_hash = self.make_hash_code(image_hash) + text_hash = self.model.encode_text(text) + text_hash = self.make_hash_code(text_hash) + # image_hash.to(self.rank) + # text_hash.to(self.rank) + img_buffer[index, :] = image_hash.data + text_buffer[index, :] = text_hash.data + + return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) + + def our_loss(self, image, text, label, epoch, times): + loss = 0 + + label_sim = calc_neighbor(label, label) + if image.is_cuda: + label_sim = label_sim.to(image.device) + intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim) + inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim) + inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim) + + intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss) + inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t + similarity_loss = inter_similarity_loss + intra_similarity_loss + + # if self.writer is not None: + # self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step) + # self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step) + # self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step) + # self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step) + + # self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step) + # self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step) + # self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step) + # self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step) + + # self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step) + # self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step) + # self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step) + # self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step) + + # self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step) + # self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step) + # self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step) + + if self.args.hash_layer != "select": + quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2 + loss = similarity_loss + quantization_loss + if self.global_step % self.args.display_step == 0: + self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\ + f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \ + f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\ + f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\ + f"QUATIZATION LOSS, {quantization_loss.data}, "\ + f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") + else: + loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss) + if self.global_step % self.args.display_step == 0: + self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\ + f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \ + f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\ + f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\ + # f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\ + f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") + + return loss + + def compute_loss(self, image, text, label, epoch, times): + + loss = self.our_loss(image, text, label, epoch, times) + + return loss + + def test(self, mode_name="i2t"): + if self.args.pretrained == "": + raise RuntimeError("test step must load a model! please set the --pretrained argument.") + self.change_state(mode="valid") + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) + mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # print("map map") + mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + self.max_mapt2i = max(self.max_mapt2i, mAPt2i) + self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}") + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + self.logger.info(">>>>>> save all data!") + + + def valid(self, epoch): + self.logger.info("Valid.") + self.change_state(mode="valid") + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) + # print("get all code") + mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # print("map map") + mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + if self.max_mapi2t < mAPi2t: + self.best_epoch_i = epoch + self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") + self.max_mapi2t = max(self.max_mapi2t, mAPi2t) + if self.max_mapt2i < mAPt2i: + self.best_epoch_t = epoch + self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") + self.max_mapt2i = max(self.max_mapt2i, mAPt2i) + self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ + MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") + + def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): + + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + self.logger.info(f">>>>>> save best {mode_name} data!") + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e98ec6d --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +from .utils import * +from .logger import get_logger, get_summary_writer +from .get_args import get_args \ No newline at end of file diff --git a/utils/calc_utils.py b/utils/calc_utils.py new file mode 100644 index 0000000..826839b --- /dev/null +++ b/utils/calc_utils.py @@ -0,0 +1,357 @@ +import torch +import numpy as np +from tqdm import tqdm + + +def calc_hammingDist(B1, B2): + q = B2.shape[1] + if len(B1.shape) < 2: + B1 = B1.unsqueeze(0) + distH = 0.5 * (q - B1.mm(B2.transpose(0, 1))) + return distH + + +def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0): + + num_query = query_L.shape[0] + if qB.is_cuda: + qB = qB.cpu() + rB = rB.cpu() + map = 0 + if k is None: + k = retrieval_L.shape[0] + gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + tsums = torch.sum(gnds, dim=-1, keepdim=True, dtype=torch.int32) + hamms = calc_hammingDist(qB, rB) + _, ind = torch.sort(hamms, dim=-1) + + totals = torch.min(tsums, torch.tensor([k], dtype=torch.int32).expand_as(tsums)) + for iter in range(num_query): + gnd = gnds[iter][ind[iter]] + total = totals[iter].squeeze() + count = torch.arange(1, total + 1).type(torch.float32) + tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0 + map = map + torch.mean(count / tindex) + map = map / num_query + + return map + + +def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0): + + num_query = query_L.shape[0] + qB = torch.sign(qB) + rB = torch.sign(rB) + map = 0 + if k is None: + k = retrieval_L.shape[0] + for iter in range(num_query): + q_L = query_L[iter] + if len(q_L.shape) < 2: + q_L = q_L.unsqueeze(0) # [1, hash length] + gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + tsum = torch.sum(gnd) + if tsum == 0: + continue + hamm = calc_hammingDist(qB[iter, :], rB) + _, ind = torch.sort(hamm) + ind.squeeze_() + gnd = gnd[ind] + total = min(k, int(tsum)) + count = torch.arange(1, total + 1).type(torch.float32) + tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0 + if tindex.is_cuda: + count = count.to(rank) + map = map + torch.mean(count / tindex) + map = map / num_query + return map + + +def calc_precisions_topn_matrix(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000): + if not isinstance(qB, torch.Tensor): + qB = torch.from_numpy(qB) + rB = torch.from_numpy(rB) + query_L = torch.from_numpy(query_L) + retrieval_L = torch.from_numpy(retrieval_L) + qB = qB.float() + rB = rB.float() + qB = torch.sign(qB - 0.5) + rB = torch.sign(rB - 0.5) + if qB.is_cuda: + qB = qB.cpu() + rB = rB.cpu() + num_query = query_L.shape[0] + # num_retrieval = retrieval_L.shape[0] + precisions = [0] * int(1 / recall_gas) + gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + hamms = calc_hammingDist(qB, rB) + _, inds = torch.sort(hamms, dim=-1) + for iter in range(num_query): + gnd = gnds[iter] + ind = inds[iter] + + gnd = gnd[ind] + for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)): + total = int(num_retrieval * recall) + right = torch.nonzero(gnd[: total]).squeeze().numpy() + + right_num = right.size + precisions[i] += (right_num/total) + for i in range(len(precisions)): + precisions[i] /= num_query + return precisions + + +def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000): + qB = qB.float() + rB = rB.float() + qB = torch.sign(qB - 0.5) + rB = torch.sign(rB - 0.5) + num_query = query_L.shape[0] + # num_retrieval = retrieval_L.shape[0] + precisions = [0] * int(1 / recall_gas) + for iter in range(num_query): + q_L = query_L[iter] + if len(q_L.shape) < 2: + q_L = q_L.unsqueeze(0) # [1, hash length] + gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + hamm = calc_hammingDist(qB[iter, :], rB) + _, ind = torch.sort(hamm) + ind.squeeze_() + gnd = gnd[ind] + for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)): + total = int(num_retrieval * recall) + right = torch.nonzero(gnd[: total]).squeeze().numpy() + # right_num = torch.nonzero(gnd[: total]).squeeze().shape[0] + right_num = right.size + precisions[i] += (right_num/total) + for i in range(len(precisions)): + precisions[i] /= num_query + return precisions + + +def calc_precisions_hash(qB, rB, query_L, retrieval_L): + qB = qB.float() + rB = rB.float() + qB = torch.sign(qB - 0.5) + rB = torch.sign(rB - 0.5) + num_query = query_L.shape[0] + num_retrieval = retrieval_L.shape[0] + bit = qB.shape[1] + hamm = calc_hammingDist(qB, rB) + hamm = hamm.type(torch.ByteTensor) + total_num = [0] * (bit + 1) + max_hamm = int(torch.max(hamm)) + gnd = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze() + total_right = torch.sum(torch.matmul(query_L, retrieval_L.t())>0) + precisions = np.zeros([max_hamm + 1]) + recalls = np.zeros([max_hamm + 1]) + # _, index = torch.sort(hamm) + # del _ + # for i in range(index.shape[0]): + # gnd[i, :] = gnd[i, index[i]] + # del index + right_num = 0 + recall_num = 0 + for i, radius in enumerate(range(0, max_hamm+1)): + recall = torch.nonzero(hamm == radius) + right = gnd[recall.split(1, dim=1)] + recall_num += recall.shape[0] + del recall + right_num += torch.nonzero(right).shape[0] + del right + precisions[i] += (right_num / (recall_num + 1e-8)) + # recalls[i] += (recall_num / num_retrieval / num_query) + recalls[i] += (recall_num / total_right) + return precisions, recalls + +def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval): + if not isinstance(qB, torch.Tensor): + qB = torch.from_numpy(qB) + if not isinstance(rB, torch.Tensor): + rB = torch.from_numpy(rB) + if not isinstance(Gnd, torch.Tensor): + Gnd = torch.from_numpy(Gnd) + + def CalcHammingDist_np(B1, B2): + q = B2.shape[1] + distH = 0.5 * (q - np.dot(B1, B2.transpose())) + return distH + bit = qB.shape[1] + # if isinstance(qB, np.ndarray): + # hamm = CalcHammingDist_np(qB, rB) + # else: + hamm = calc_hammingDist(qB, rB) + hamm = hamm.type(torch.ByteTensor) + total_num = [0] * (bit + 1) + max_hamm = int(torch.max(hamm)) + + gnd = Gnd + + total_right = torch.sum(gnd>0) + precisions = np.zeros([max_hamm + 1]) + recalls = np.zeros([max_hamm + 1]) + + right_num = 0 + recall_num = 0 + for i, radius in enumerate(range(0, max_hamm+1)): + recall = torch.nonzero(hamm == radius) + right = gnd[recall.split(1, dim=1)] + recall_num += recall.shape[0] + del recall + right_num += torch.nonzero(right).shape[0] + del right + precisions[i] += (right_num / (recall_num + 1e-8)) + recalls[i] += (recall_num / num_retrieval / num_query) + # recalls[i] += (recall_num / total_right) + p = precisions.round(2) + r = recalls.round(2) + # return p, r + + precisions = [] + recalls = [] + + precision_ = 0 + num = 1 + for i in range(len(r) - 1): + if r[i] == r[i + 1]: + precision_ += p[i] + num += 1 + else: + precision_ += p[i] + precisions.append(precision_ / num) + recalls.append(r[i]) + precision_ = 0 + num = 1 + + return np.asarray(precisions).round(2), np.asarray(recalls) + + +def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1): + num_query = query_L.shape[0] + bit = qB.shape[1] + precisions = [0] * int(bit / hamming_gas) + for iter in range(num_query): + q_L = query_L[iter] + if len(q_L.shape) < 2: + q_L = q_L.unsqueeze(0) # [1, hash length] + gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + hamm = calc_hammingDist(qB[iter, :], rB) + _, ind = torch.sort(hamm) + ind.squeeze_() + gnd = gnd[ind] + for i, recall in enumerate(np.arange(1, bit+1, hamming_gas)): + total = torch.nonzero(hamm <= recall).squeeze().shape[0] + if total == 0: + precisions[i] += 0 + continue + right = torch.nonzero(gnd[: total]).squeeze().numpy() + right_num = right.size + + precisions[i] += (right_num / total) + for i in range(len(precisions)): + precisions[i] /= num_query + return precisions + + +def calc_neighbor(label1, label2): + # calculate the similar matrix + Sim = label1.matmul(label2.transpose(0, 1)) > 0 + return Sim.float() + + +def norm_max_min(x: torch.Tensor, dim=None): + if dim is None: + max = torch.max(x) + min = torch.min(x) + if dim is not None: + max = torch.max(x, dim=dim)[0] + min = torch.min(x, dim=dim)[0] + if dim > 0: + max = max.unsqueeze(len(x.shape) - 1) + min = min.unsqueeze(len(x.shape) - 1) + norm = (x - min) / (max - min) + return norm + + +def norm_mean(x: torch.Tensor, dim=None): + if dim is None: + mean = torch.mean(x) + std = torch.std(x) + if dim is not None: + mean = torch.mean(x, dim=dim) + std = torch.std(x, dim=dim) + if dim > 0: + mean = mean.unsqueeze(len(x.shape) - 1) + std = std.unsqueeze(len(x.shape) - 1) + norm = (x - mean) / std + return norm + + +def norm_abs_mean(x: torch.Tensor, dim=None): + if dim is None: + mean = torch.mean(x) + std = torch.std(x) + if dim is not None: + mean = torch.mean(x, dim=dim) + std = torch.std(x, dim=dim) + if dim > 0: + mean = mean.unsqueeze(len(x.shape) - 1) + std = std.unsqueeze(len(x.shape) - 1) + norm = torch.abs(x - mean) / std + return norm + + +def factorial(n): + if n == 0: + return 1 + else: + return n * factorial(n - 1) + + +def calc_IF(all_bow): + word_num = torch.sum(all_bow, dim=0) + total_num = torch.sum(word_num) + IF = word_num / total_num + return IF + + +# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta): +# theta = torch.matmul(F, G.transpose(0, 1)) / 2 +# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta) +# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2 +# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f) +# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2 +# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g) +# intra_loss = gamma1 * intra_img + gamma2 * intra_txt +# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta +# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2)) +# # loss = term1 + gamma * term2 + eta * term3 +# loss = inter_loss + intra_loss + quan_loss +# return loss + + +# if __name__ == '__main__': +# qB = torch.Tensor([[1, -1, 1, 1], +# [-1, -1, -1, 1], +# [1, 1, -1, 1], +# [1, 1, 1, -1]]) +# rB = torch.Tensor([[1, -1, 1, -1], +# [-1, -1, 1, -1], +# [-1, -1, 1, -1], +# [1, 1, -1, -1], +# [-1, 1, -1, -1], +# [1, 1, -1, 1]]) +# query_L = torch.Tensor([[0, 1, 0, 0], +# [1, 1, 0, 0], +# [1, 0, 0, 1], +# [0, 1, 0, 1]]) +# retrieval_L = torch.Tensor([[1, 0, 0, 1], +# [1, 1, 0, 0], +# [0, 1, 1, 0], +# [0, 0, 1, 0], +# [1, 0, 0, 0], +# [0, 0, 1, 0]]) +# +# map = calc_map_k(qB, rB, query_L, retrieval_L) +# print(map) diff --git a/utils/get_args.py b/utils/get_args.py new file mode 100644 index 0000000..7ffa88b --- /dev/null +++ b/utils/get_args.py @@ -0,0 +1,48 @@ +import argparse + + +def get_args(): + + parser = argparse.ArgumentParser() + + parser.add_argument("--hash-layer", type=str, default="select", help="choice a hash layer [select, linear] to run. select: select mechaism, linear: sign function.") + parser.add_argument("--save-dir", type=str, default="./result/64-bit") + parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") + parser.add_argument("--pretrained", type=str, default="") + parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") + parser.add_argument("--index-file", type=str, default="index.mat") + parser.add_argument("--caption-file", type=str, default="caption.mat") + parser.add_argument("--label-file", type=str, default="label.mat") + parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]") + parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]") + # parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat") + # parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat") + # parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat") + + parser.add_argument("--output-dim", type=int, default=64) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--max-words", type=int, default=32) + parser.add_argument("--resolution", type=int, default=224) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--query-num", type=int, default=5120) + parser.add_argument("--train-num", type=int, default=10240) + parser.add_argument("--lr-decay-freq", type=int, default=5) + parser.add_argument("--display-step", type=int, default=50) + parser.add_argument("--seed", type=int, default=1814) + + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--lr-decay", type=float, default=0.9) + parser.add_argument("--clip-lr", type=float, default=0.00001) + parser.add_argument("--weight-decay", type=float, default=0.2) + parser.add_argument("--warmup-proportion", type=float, default=0.1, + help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") + parser.add_argument("--vartheta", type=float, default=0.5, help="the rate of error code.") + parser.add_argument("--sim-threshold", type=float, default=0.1) + + parser.add_argument("--is-train", action="store_true") + + args = parser.parse_args() + + return args + diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..bda3109 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,22 @@ +import os +import logging + +from torch.utils.tensorboard import SummaryWriter + +def get_logger(filename=None): + logger = logging.getLogger('logger') + logger.setLevel(logging.DEBUG) + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) + if filename is not None: + handler = logging.FileHandler(filename) + handler.setLevel(logging.DEBUG) + handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) + logging.getLogger().addHandler(handler) + return logger + +def get_summary_writer(dirname: str): + + os.makedirs(dirname, exist_ok=True) + return SummaryWriter(log_dir=dirname) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..771700f --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,229 @@ +import torch +import numpy as np +from typing import Union +import torch.nn as nn +from torch.nn import functional as F + +from sklearn.metrics.pairwise import euclidean_distances + + +def compute_metrics(x): + # 取复值的原因在于cosine的值越大说明越相似,但是需要取的是前N个值,所以取符号变为增函数s + sx = np.sort(-x, axis=1) + d = np.diag(-x) + d = d[:, np.newaxis] + ind = sx - d + ind = np.where(ind == 0) + ind = ind[1] + metrics = {} + metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) + metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) + metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) + metrics['MR'] = np.median(ind) + 1 + metrics["MedianR"] = metrics['MR'] + metrics["MeanR"] = np.mean(ind) + 1 + metrics["cols"] = [int(i) for i in list(ind)] + return metrics + +def encode_hash(a: Union[torch.Tensor, np.ndarray]): + if isinstance(a, torch.Tensor): + hash_a = torch.sign(a) + # where 是吧所有false值转为0 + # hash_a = torch.where(hash_a>0, hash_a, torch.tensor(0)) + return hash_a + else: + hash_a = np.sign(a) + # hash_a = np.where(hash_a > 0, hash_a, 0) + return hash_a + +def calc_neighbor(a: torch.Tensor, b: torch.Tensor): + # print(a.dtype, b.dtype) + return (a.matmul(b.transpose(0, 1)) > 0).float() + + +def euclidean_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): + + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + similarity = torch.cdist(a, b, p=2.0) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + similarity = euclidean_distances(a, b) + else: + raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b))) + return similarity + + +def euclidean_dist_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor): + """ + calculate euclidean distance as inner product + :param tensor1: a tensor with shape (a, c) + :param tensor2: a tensor with shape (b, c) + :return: the euclidean distance matrix which each point is the distance between a row in tensor1 and a row in tensor2. + """ + dim1 = tensor1.shape[0] + dim2 = tensor2.shape[0] + multi = torch.matmul(tensor1, tensor2.t()) + a2 = torch.sum(torch.pow(tensor1, 2), dim=1, keepdim=True).expand(dim1, dim2) + b2 = torch.sum(torch.pow(tensor2, 2), dim=1, keepdim=True).t().expand(dim1, dim2) + dist = torch.sqrt(a2 + b2 - 2 * multi) + return dist + + +def cosine_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): + + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + a = a / a.norm(dim=-1, keepdim=True) if len(torch.where(a != 0)[0]) > 0 else a + b = b / b.norm(dim=-1, keepdim=True) if len(torch.where(b != 0)[0]) > 0 else b + return torch.matmul(a, b.t()) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + a = a / np.linalg.norm(a, axis=-1, keepdims=True) if len(np.where(a != 0)[0]) > 0 else a + b = b / np.linalg.norm(b, axis=-1, keepdims=True) if len(np.where(b != 0)[0]) > 0 else b + return np.matmul(a, b.T) + else: + raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b))) + +def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0): + # qB: {-1,+1}^{mxq} + # rB: {-1,+1}^{nxq} + # query_L: {0,1}^{mxl} + # retrieval_L: {0,1}^{nxl} + num_query = query_L.shape[0] + qB = torch.sign(qB) + rB = torch.sign(rB) + map = 0 + if k is None: + k = retrieval_L.shape[0] + # print("query num:", num_query) + for iter in range(num_query): + q_L = query_L[iter] + if len(q_L.shape) < 2: + q_L = q_L.unsqueeze(0) # [1, hash length] + gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32) + tsum = torch.sum(gnd) + if tsum == 0: + continue + hamm = calcHammingDist(qB[iter, :], rB) + _, ind = torch.sort(hamm) + ind.squeeze_() + gnd = gnd[ind] + total = min(k, int(tsum)) + count = torch.arange(1, total + 1).type(torch.float32) + tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0 + if tindex.is_cuda: + count = count.to(rank) + map = map + torch.mean(count / tindex) + map = map / num_query + return map + +def softmax_hash(code: Union[torch.Tensor, np.ndarray], dim_alph=0.25): + + device = None + if isinstance(code, torch.Tensor): + device = code.device + if code.is_cuda: + code = code.detach.cpu().numpy() + else: + code = code.cpu().numpy() + + softmax_code = np.exp(code) / np.exp(code).sum(axis=-1, keepdims=True) + # print("max softmax_code:", softmax_code.max()) + # print("min softmax_code:", softmax_code.min()) + # print(1 / int(dim_alph * softmax_code.shape[-1])) + # print(np.sum(softmax_code[0])) + hash_code = np.where(softmax_code >= 1 / int(dim_alph * softmax_code.shape[-1]), softmax_code, -1) + hash_code = np.where(hash_code <= 1 / int(dim_alph * softmax_code.shape[-1]), hash_code, 1) + # print(hash_code[0]) + # print("hash code sum:", hash_code.sum()) + # print(len(np.where(hash_code == 1.0)[0])) + # print(len(np.where(hash_code == -1.0)[0])) + if device is not None: + hash_code = torch.from_numpy(hash_code).to(device) + return hash_code + +# def calcHammingDist(B1, B2): +# result = np.zeros((B1.shape[0], B2.shape[0])) +# for i, data in enumerate(B1): +# result[i] = np.sum(np.where((data + B2) != 2, data + B2, 0), axis=-1) +# return result + +def calcHammingDist(B1, B2): + + if len(B1.shape) < 2: + B1.view(1, -1) + if len(B2.shape) < 2: + B2.view(1, -1) + q = B2.shape[1] + if isinstance(B1, torch.Tensor): + distH = 0.5 * (q - torch.matmul(B1, B2.t())) + elif isinstance(B1, np.ndarray): + distH = 0.5 * (q - np.matmul(B1, B2.transpose())) + else: + raise ValueError("B1, B2 must in [torch.Tensor, np.ndarray]") + return distH + +def compute_hash_similarity(visual_embed, text_embed, use_softmax_hash=False, alph=0.25): + # hamming distance的值越大说明越不相似 + hash_visual = encode_hash(visual_embed) if not use_softmax_hash else softmax_hash(visual_embed, alph) + hash_text = encode_hash(text_embed) if not use_softmax_hash else softmax_hash(text_embed, alph) + vt_similarity = calcHammingDist(hash_visual, hash_text) + # print(vt_similarity[0]) + tv_similarity = calcHammingDist(hash_text, hash_visual) + return vt_similarity, tv_similarity + +class CrossEn(nn.Module): + def __init__(self, mode="cosine"): + super(CrossEn, self).__init__() + # if mode == "euclidean": + # self.compute_func = F.softmax + # else: + # self.compute_func = F.log_softmax + self.mode = mode + + def forward(self, sim_matrix): + # if self.mode == "cosine": + # logpt = F.log_softmax(sim_matrix, dim=-1) + # logpt = torch.diag(logpt) + # nce_loss = -logpt + # sim_loss = nce_loss.mean() + # elif self.mode == "euclidean": + # logpt = F.softmax(sim_matrix, dim=-1) + # logpt = torch.diag(sim_matrix) + # sim_loss = logpt.mean() + # else: + # raise ValueError("mode paramater is not support.[cosine, euclidean]") + if self.mode == "euclidean": + sim_matrix = -sim_matrix + logpt = F.log_softmax(sim_matrix, dim=-1) + logpt = torch.diag(logpt) + nce_loss = -logpt + sim_loss = nce_loss.mean() + return sim_loss + +class CrossEn_mean(nn.Module): + def __init__(self, mode="cosine"): + super(CrossEn_mean, self).__init__() + # if mode == "euclidean": + # self.compute_func = F.softmax + # else: + # self.compute_func = F.log_softmax + self.mode = mode + + def forward(self, sim_matrix): + # if self.mode == "cosine": + # logpt = F.log_softmax(sim_matrix, dim=-1) + # logpt = torch.diag(logpt) + # nce_loss = -logpt + # sim_loss = nce_loss.mean() + # elif self.mode == "euclidean": + # logpt = F.softmax(sim_matrix, dim=-1) + # logpt = torch.diag(sim_matrix) + # sim_loss = logpt.mean() + # else: + # raise ValueError("mode paramater is not support.[cosine, euclidean]") + # if self.mode == "euclidean": + # sim_matrix = -sim_matrix + # print(sim_matrix.max(), sim_matrix.min()) + # logpt = F.log_softmax(sim_matrix, dim=-1) + # logpt = torch.diag(logpt) + # print(logpt.max()) + sim_loss = sim_matrix.mean() + return sim_loss