This commit is contained in:
kalenforn 2022-07-01 17:35:22 +08:00
commit 9e6030dbf9
27 changed files with 2890 additions and 0 deletions

78
READEM.md Normal file
View File

@ -0,0 +1,78 @@
# Differentiable Cross Modal Hashing via Multimodal Transformers
## Framework
The main architecture of our method.
![framework](./data/structure.jpg)
We propose a selecting mechanism to generate hash code that will transfor the discrete space into a continuous space. Hash code will be encoded as a $2D$ vector.
![hash](./data/method.jpg)
## Dependencies
We use python to build our code, you need to install those package to run
- pytorch 1.9.1
- sklearn
- tqdm
- pillow
## Training
### Processing dataset
Before training, you need to download the oringal data from [coco](https://cocodataset.org/#download)(include 2017 train,val and annotations), [nuswide](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html)(include all), [mirflickr25k](https://www.kaggle.com/datasets/paulrohan2020/mirflickr25k)(include mirflickr25k and mirflickr25k_annotations_v080),
then use the "data/make_XXX.py" to generate .mat file
For example:
> cd COCO_DIR # include train val images and annotations files
>
> make mat
>
> cp DCMHT/data/make_coco.py mat
>
> python make_coco.py --coco-dir ../ --save-dir ./
After all mat file generated, the dir of `dataset` will like this:
~~~
dataset
├── base.py
├── __init__.py
├── dataloader.py
├── coco
│   ├── caption.mat
│   ├── index.mat
│   └── label.mat
├── flickr25k
│   ├── caption.mat
│   ├── index.mat
│   └── label.mat
└── nuswide
    ├── caption.txt # Notice! It is a txt file!
    ├── index.mat
    └── label.mat
~~~
### Download CLIP pretrained model
Pretrained model will be found in the 30 lines of [CLIP/clip/clip.py](https://github.com/openai/CLIP/blob/main/clip/clip.py). This code is based on the "ViT-B/32".
You should copy ViT-B-32.pt to this dir.
### Start
After the dataset has been prepared, we could run the follow command to train.
> python main.py --is-train --hash-layer select --dataset coco --caption-file caption.mat --index-file index.mat --label-file label.mat --similarity-function euclidean --loss-type l2 --vartheta 0.75 --lr 0.0001 --output-dim 64 --save-dir ./result/coco/64 --clip-path ./ViT-B-32.pt --batch-size 256
## Result
![result](./data/result.png)
## Acknowledegements
[CLIP](https://github.com/openai/CLIP)
[SSAH](https://github.com/lelan-li/SSAH)
[GCH](https://github.com/DeXie0808/GCH)
[AGAH](https://github.com/WendellGul/AGAH)
[DADH](https://github.com/Zjut-MultimediaPlus/DADH)
[deep-cross-modal-hashing](https://github.com/WangGodder/deep-cross-modal-hashing)

BIN
data/method.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

BIN
data/result.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

BIN
data/structure.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

0
dataset/__init__.py Normal file
View File

102
dataset/base.py Normal file
View File

@ -0,0 +1,102 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from torch.utils.data import Dataset
import torch
import random
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
class BaseDataset(Dataset):
def __init__(self,
captions: dict,
indexs: dict,
labels: dict,
is_train=True,
tokenizer=Tokenizer(),
maxWords=32,
imageResolution=224,
npy=False):
self.captions = captions
self.indexs = indexs
self.labels = labels
self.npy = npy
self.maxWords = maxWords
self.tokenizer = tokenizer
self.transform = Compose([
Resize(imageResolution, interpolation=Image.BICUBIC),
CenterCrop(imageResolution),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]) if is_train else Compose([
Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
self.__length = len(self.indexs)
def __len__(self):
return self.__length
def _load_image(self, index: int) -> torch.Tensor:
if not self.npy:
image_path = self.indexs[index].strip()
# print(image_path)
image = Image.open(image_path).convert("RGB")
else:
image = Image.fromarray(self.indexs[index]).convert("RGB")
image = self.transform(image)
return image
def _load_text(self, index: int):
captions = self.captions[index]
use_cap = captions[random.randint(0, len(captions) - 1)]
words = self.tokenizer.tokenize(use_cap)
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
total_length_with_CLS = self.maxWords - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
caption = self.tokenizer.convert_tokens_to_ids(words)
while len(caption) < self.maxWords:
caption.append(0)
caption = torch.tensor(caption)
return caption
def _load_label(self, index: int) -> torch.Tensor:
label = self.labels[index]
label = torch.from_numpy(label)
return label
def get_all_label(self):
labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.int64)
for i, item in enumerate(self.labels):
labels[i] = torch.from_numpy(item)
return labels
def __getitem__(self, index):
image = self._load_image(index)
caption = self._load_text(index)
label = self._load_label(index)
return image, caption, label, index

68
dataset/dataloader.py Normal file
View File

@ -0,0 +1,68 @@
from .base import BaseDataset
import os
import numpy as np
import scipy.io as scio
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
np.random.seed(seed=seed)
random_index = np.random.permutation(range(len(indexs)))
query_index = random_index[: query_num]
train_index = random_index[query_num: query_num + train_num]
retrieval_index = random_index[query_num:]
query_indexs = indexs[query_index]
query_captions = captions[query_index]
query_labels = labels[query_index]
train_indexs = indexs[train_index]
train_captions = captions[train_index]
train_labels = labels[train_index]
retrieval_indexs = indexs[retrieval_index]
retrieval_captions = captions[retrieval_index]
retrieval_labels = labels[retrieval_index]
split_indexs = (query_indexs, train_indexs, retrieval_indexs)
split_captions = (query_captions, train_captions, retrieval_captions)
split_labels = (query_labels, train_labels, retrieval_labels)
return split_indexs, split_captions, split_labels
def dataloader(captionFile: str,
indexFile: str,
labelFile: str,
maxWords=32,
imageResolution=224,
query_num=5000,
train_num=10000,
seed=None,
npy=False):
if captionFile.endswith("mat"):
captions = scio.loadmat(captionFile)["caption"]
captions = captions[0] if captions.shape[0] == 1 else captions
elif captionFile.endswith("txt"):
with open(captionFile, "r") as f:
captions = f.readlines()
captions = np.asarray([[item.strip()] for item in captions])
else:
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.")
if not npy:
indexs = scio.loadmat(indexFile)["index"]
else:
indexs = np.load(indexFile, allow_pickle=True)
labels = scio.loadmat(labelFile)["category"]
# for item in ['__version__', '__globals__', '__header__']:
# captions.pop(item)
# indexs.pop(item)
# labels.pop(item)
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
return train_data, query_data, retrieval_data

178
dataset/make_coco.py Normal file
View File

@ -0,0 +1,178 @@
import os
import numpy as np
def make_index(jsonData: dict, indexDict: dict):
"""
use coco dict data as orignial data.
indexDict: {jsonData's key: [index_key, index_value]}
"""
result = []
for name in indexDict:
data = jsonData[name]
middle_dict = {}
for item in data:
if item[indexDict[name][0]] not in middle_dict:
middle_dict.update({item[indexDict[name][0]]: [item[indexDict[name][1]]]})
else:
middle_dict[item[indexDict[name][0]]].append(item[indexDict[name][1]])
result.append(middle_dict)
return result
def check_file_exist(indexDict: dict, file_path: str):
keys = list(indexDict.keys())
for item in keys:
# print(indexDict[item])
if not os.path.exists(os.path.join(file_path, indexDict[item][0])):
print(item, indexDict[item])
indexDict.pop(item)
indexDict[item] = os.path.join(file_path, indexDict[item][0])
return indexDict
def chage_categories2numpy(category_ids: dict, data: dict):
for item in data:
class_item = [0] * len(category_ids)
for class_id in data[item]:
class_item[category_ids[class_id]] = 1
data[item] = np.asarray(class_item)
return data
def get_all_use_key(categoryDict: dict):
return list(categoryDict.keys())
def remove_not_use(data: dict, used_key: list):
keys = list(data.keys())
for item in keys:
if item not in used_key:
# print("remove:", item, indexDict[item])
data.pop(item)
# print(len(category_list))
return data
def merge_to_list(data: dict):
result = []
key_sort = list(data.keys())
key_sort.sort()
# print(key_sort)
# print(key_sort.index(91654))
for item in key_sort:
result.append(data[item])
return result
if __name__ == "__main__":
import json
import scipy.io as scio
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--coco-dir", default="./", type=str, help="the coco dataset dir")
parser.add_argument("--save-dir", default="./", type=str, help="mat file saved dir")
args = parser.parse_args()
PATH = args.coco_dir
jsonFile = os.path.join(PATH, "annotations", "captions_train2017.json")
with open(jsonFile, "r") as f:
jsonData = json.load(f)
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
result = make_index(jsonData, indexDict)
indexDict_, captionDict = result
indexDict_ = check_file_exist(indexDict_, os.path.join(PATH, "train2017"))
print("caption:", len(indexDict_), len(captionDict))
# print_result = list(indexDict.keys())
# print_result.sort()
# print(print_result)
# indexList = merge_to_list(indexDict_)
# captionList = merge_to_list(captionDict)
# print(indexDict[565962], indexList[4864])
# print(captionDict[565962], captionList[4864])
# print(result)
jsonFile = os.path.join(PATH, "annotations", "instances_train2017.json")
with open(jsonFile, "r") as f:
jsonData = json.load(f)
categroy_ids = {}
for i, item in enumerate(jsonData['categories']):
categroy_ids.update({item['id']: i})
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
result = make_index(jsonData, indexDict)
categoryDict = result[0]
cateIndexDict = result[1]
# cateIndexList = merge_to_list(cateIndexDict)
# print(categoryDict[91654])
categoryDict = chage_categories2numpy(categroy_ids, categoryDict)
# print(categoryDict[91654])
# categoryList = merge_to_list(categoryDict)
# print(categoryDict[91654], categoryList[780])
# print(indexList[100], cateIndexList[100])
# print("category:", len(categoryDict), len(cateIndexList))
used_key = get_all_use_key(categoryDict)
# 统一index
indexDict_ = remove_not_use(indexDict_, used_key)
captionDict = remove_not_use(captionDict, used_key)
categoryIndexDict = remove_not_use(cateIndexDict, used_key)
categoryDict = remove_not_use(categoryDict, used_key)
# 转变为list
indexList = merge_to_list(indexDict_)
captionList = merge_to_list(captionDict)
categoryIndexList = merge_to_list(categoryIndexDict)
categoryList = merge_to_list(categoryDict)
print("result", len(indexDict_), len(categoryDict))
print("category:", len(categoryDict), len(categoryIndexList))
for i in range(len(indexList)):
if indexList[i] != categoryIndexList[i]:
print("Not the same:", i, indexList[i], categoryIndexList[i])
val_jsonFile = os.path.join(PATH, "annotations", "captions_val2017.json")
with open(val_jsonFile, "r") as f:
jsonData = json.load(f)
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
result = make_index(jsonData, indexDict)
val_indexDict = result[0]
val_captionDict = result[1]
val_indexDict = check_file_exist(val_indexDict, os.path.join(PATH, "val2017"))
jsonFile = os.path.join(PATH, "annotations", "instances_val2017.json")
with open(jsonFile, "r") as f:
jsonData = json.load(f)
categroy_ids = {}
for i, item in enumerate(jsonData['categories']):
categroy_ids.update({item['id']: i})
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
result = make_index(jsonData, indexDict)
val_categoryDict = result[0]
val_categoryIndexDict = result[1]
val_categoryDict = chage_categories2numpy(categroy_ids, val_categoryDict)
used_key = get_all_use_key(val_categoryDict)
val_indexDict = remove_not_use(val_indexDict, used_key)
val_captionDict = remove_not_use(val_captionDict, used_key)
val_categoryIndexDict = remove_not_use(val_categoryIndexDict, used_key)
val_categoryDict = remove_not_use(val_categoryDict, used_key)
val_indexList = merge_to_list(val_indexDict)
val_captionList = merge_to_list(val_captionDict)
val_categoryIndexList = merge_to_list(val_categoryIndexDict)
val_categoryList = merge_to_list(val_categoryDict)
indexList.extend(val_indexList)
captionList.extend(val_captionList)
categoryIndexList.extend(val_categoryIndexList)
categoryList.extend(val_categoryList)
print(len(indexList), len(captionList), len(categoryIndexList))
indexs = {"index": indexList}
captions = {"caption": captionList}
categorys = {"category": categoryList}
scio.savemat(os.path.join(args.save_dir, "index.mat"), indexs)
scio.savemat(os.path.join(args.save_dir, "caption.mat"), captions)
scio.savemat(os.path.join(args.save_dir, "label.mat"), categorys)

View File

@ -0,0 +1,81 @@
import os
import scipy.io as scio
import numpy as np
# mirflickr25k_annotations_v080 and mirflickr
# mkdir mat
# mv make_mirflickr25k.py mat
# python make_mirflickr25k.py
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
file_path = os.path.join(root_dir, "/mirflickr25k_annotations_v080")
file_list = os.listdir(file_path)
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
print("class num:", len(file_list))
class_index = {}
for i, item in enumerate(file_list):
class_index.update({item: i})
label_dict = {}
for path_id in file_list:
path = os.path.join(file_path, path_id)
with open(path, "r") as f:
for item in f:
item = item.strip()
if item not in label_dict:
label = np.zeros(len(file_list))
label[class_index[path_id]] = 1
label_dict.update({item: label})
else:
# print()
label_dict[item][class_index[path_id]] = 1
# print(label_dict)
print("create label:", len(label_dict))
keys = list(label_dict.keys())
keys.sort()
labels = []
for key in keys:
labels.append(label_dict[key])
print("labels created:", len(labels))
labels = {"category": labels}
PATH = os.path.join(root_dir, "mirflickr")
index = [os.path.join(PATH, "im" + item + ".jpg") for item in keys]
print("index created:", len(index))
index= {"index": index}
captions_path = os.path.join(root_dir, "/mirflickr/meta/tags")
captions_list = os.listdir(captions_path)
captions_dict = {}
for item in captions_list:
id_ = item.split(".")[0].replace("tags", "")
caption = ""
with open(os.path.join(captions_path, item), "r") as f:
for word in f.readlines():
caption += word.strip() + " "
caption = caption.strip()
captions_dict.update({id_: caption})
captions = []
for item in keys:
captions.append([captions_dict[item]])
print("captions created:", len(captions))
captions = {"caption": captions}
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), index)
scio.savemat(os.path.join(root_dir, "/mat/caption.mat"), captions)
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)

107
dataset/make_nuswide.py Normal file
View File

@ -0,0 +1,107 @@
import os
import scipy.io as scio
import numpy as np
# mkdir mat
# mv make_nuswide.py mat
# python make_nuswide.py
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt")
# you can use the image urls to download images
imagePath = os.path.join(root_dir, "nuswide/Flickr")
with open(imageListFile, "r") as f:
indexs = f.readlines()
indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs]
print("indexs length:", len(indexs))
#class_index = {}
#with open(classIndexFile, "r") as f:
# data = f.readlines()
#
#for i, item in enumerate(data):
# class_index.update({item.strip(): i})
captions = []
with open(textFile, "r") as f:
for line in f:
if len(line.strip()) == 0:
print("some line empty!")
continue
caption = line.split()[1:]
caption = " ".join(caption).strip()
if len(caption) == 0:
caption = "123456"
captions.append(caption)
print("captions length:", len(captions))
#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
# label_lists = os.listdir(labelPath)
with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f:
label_lists = f.readlines()
label_lists = [item.strip() for item in label_lists]
class_index = {}
for i, item in enumerate(label_lists):
class_index.update({item: i})
labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
for item in label_lists:
path = os.path.join(labelPath, item)
class_label = item# .split(".")[0].split("_")[-1]
with open(path, "r") as f:
data = f.readlines()
for i, val in enumerate(data):
labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0
print("labels sum:", labels.sum())
not_used_id = []
with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f:
not_used_id = f.readlines()
not_used_id = [int(item.strip()) for item in not_used_id]
# for item in not_used_id:
# indexs.pop(item)
# captions.pop(item)
# labels = np.delete(labels, item, 0)
ind = list(range(len(indexs)))
for item in not_used_id:
ind.remove(item)
indexs[item] = ""
captions[item] = ""
indexs = [item for item in indexs if item != ""]
captions = [item for item in captions if item != ""]
ind = np.asarray(ind)
labels = labels[ind]
# ind = range(len(indexs))
print("indexs length:", len(indexs))
print("captions length:", len(captions))
print("labels shape:", labels.shape)
indexs = {"index": indexs}
captions = {"caption": captions}
labels = {"category": labels}
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs)
# scio.savemat("caption.mat", captions)
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)
captions = [item + "\n" for item in captions["caption"]]
with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f:
f.writelines(captions)
print("finished!")

8
main.py Normal file
View File

@ -0,0 +1,8 @@
from train.hash_train import Trainer
if __name__ == "__main__":
Trainer()

1
model/__init__.py Executable file
View File

@ -0,0 +1 @@
from .clip import *

Binary file not shown.

224
model/clip.py Executable file
View File

@ -0,0 +1,224 @@
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result

161
model/hash_model.py Normal file
View File

@ -0,0 +1,161 @@
import os
import torch
import logging
import torch.nn as nn
import numpy as np
from typing import Union
from model.model import build_model
from utils import get_logger, get_summary_writer
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_uniform_(m.weight, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
class LinearHash(nn.Module):
def __init__(self, inputDim=2048, outputDim=64):
super(LinearHash, self).__init__()
self.fc = nn.Linear(inputDim, outputDim)
self.fc.apply(weights_init_kaiming)
self.drop_out = nn.Dropout(p=0.2)
def forward(self, data):
result = self.fc(data)
return torch.tanh(self.drop_out(result))
class HashLayer(nn.Module):
LINEAR_EMBED = 128
SIGMOID_ALPH = 10
def __init__(self, inputDim=2048, outputDim=64):
super(HashLayer, self).__init__()
self.fc = nn.Linear(inputDim, self.LINEAR_EMBED)
self.fc.apply(weights_init_kaiming)
self.hash_list = nn.ModuleList([nn.Linear(self.LINEAR_EMBED, 2) for _ in range(outputDim)])
for item in self.hash_list:
item.apply(weights_init_kaiming)
def forward(self, data):
embed = self.fc(data)
embed = torch.relu(embed)
softmax_list = [torch.softmax(item(embed), dim=-1) for item in self.hash_list]
return softmax_list
class HashLayer_easy_logic(nn.Module):
LINEAR_EMBED = 128
SIGMOID_ALPH = 10
def __init__(self, inputDim=2048, outputDim=64):
super(HashLayer, self).__init__()
self.bit = outputDim
self.fc = nn.Linear(inputDim, outputDim * 2)
self.fc.apply(weights_init_kaiming)
for item in self.hash_list:
item.apply(weights_init_kaiming)
def forward(self, data):
embed = self.fc(data)
softmax_list = embed.view(embed.shape[0], self.bit, 2)
softmax_list = torch.softmax(softmax_list, dim=-1)
return softmax_list
class DCMHT(nn.Module):
def __init__(self,
outputDim=64,
clipPath="./ViT-B-32.pt",
writer=None,
saveDir="./result/log",
logger: logging.Logger=None,
is_train=True,
linear=False):
super(DCMHT, self).__init__()
os.makedirs(saveDir, exist_ok=True)
self.logger = logger if logger is not None else get_logger(os.path.join(saveDir, "train.log" if is_train else "test.log"))
self.writer = writer if writer is not None and is_train else get_summary_writer(os.path.join(saveDir, "tensorboard"))
embedDim, self.clip = self.load_clip(clipPath)
# if is_train:
# self.clip.eval()
# print("start freezen")
# self.freezen()
self.image_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim)
self.text_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim)
# print(self.image_hash)
# print(self.text_hash)
def freezen(self):
for name, param in self.clip.named_parameters():
# print(name)
if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
# print("1")
continue
elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
if layer_num >= 12:
# print("2")
continue
if name.find("conv2.") == 0:
# print("3")
continue
else:
# paramenters which < freeze_layer_num will be freezed
param.requires_grad = False
def load_clip(self, clipPath: str) -> tuple:
try:
model = torch.jit.load(clipPath, map_location="cpu").eval()
state_dict = model.state_dict()
except RuntimeError:
state_dict = torch.load(clipPath, map_location="cpu")
return state_dict["text_projection"].shape[1], build_model(state_dict)
def encode_image(self, image):
image_embed = self.clip.encode_image(image)
image_embed = self.image_hash(image_embed)
return image_embed
def eval(self):
self.image_hash.eval()
self.text_hash.eval()
# self.clip.eval()
def train(self):
self.image_hash.train()
self.text_hash.train()
def encode_text(self, text):
text_embed = self.clip.encode_text(text)
text_embed = self.text_hash(text_embed)
return text_embed
def forward(self, image, text):
return self.encode_image(image), self.encode_text(text)

450
model/model.py Executable file
View File

@ -0,0 +1,450 @@
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
# self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
# return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
attn_mask_ = self.attn_mask
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
attn_mask_ = self.attn_mask(x.size(0)) # LND
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, x: torch.Tensor):
# x, video_frame = x_tuple
# print(x.shape)
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
# print(x.shape)
# print(x.shape)
x = self.conv1(x) # shape = [*, width, grid, grid]
# print("image feature map:", x.shape)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
# print(x.shape)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# print(x.shape)
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
# print(x.shape)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
# print(x.shape)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self, context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(context_length, context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding[:x.size(1), :].type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
# print("vision width:", vision_width)
# print("vision patch size", vision_patch_size)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
# return model.eval()
return model

168
model/optimization.py Normal file
View File

@ -0,0 +1,168 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + math.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine': warmup_cosine,
'warmup_constant': warmup_constant,
'warmup_linear': warmup_linear,
}
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
max_grad_norm=1.0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
# next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
# next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
else:
lr_scheduled = group['lr']
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
return loss

143
model/simple_tokenizer.py Executable file
View File

@ -0,0 +1,143 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
def tokenize(self, text):
tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return tokens
def convert_tokens_to_ids(self, tokens):
return [self.encoder[bpe_token] for bpe_token in tokens]

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
tqdm
scipy
numpy
pillow
matplotlib
sklearn
pytorch==1.9.1

0
train/__init__.py Normal file
View File

94
train/base.py Normal file
View File

@ -0,0 +1,94 @@
import os
from tqdm import tqdm
import torch
from torch import distributed as dist
from utils import get_logger, get_summary_writer
class TrainBase(object):
def __init__(self,
args,
rank=0):
self.args = args
os.makedirs(args.save_dir, exist_ok=True)
self._init_writer()
self.logger.info(self.args)
self.rank = rank
self._init_dataset()
self._init_model()
self.global_step = 0
# self.global_step_t = 0
self.max_mapi2t = 0
self.max_mapt2i = 0
self.best_epoch_i = 0
self.best_epoch_t = 0
def _init_dataset(self):
self.train_loader = None
self.query_loader = None
self.retrieval_loader = None
def _init_model(self):
self.model = None
self.model_ddp = None
def _init_writer(self):
self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log"))
self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard"))
def run(self):
if self.args.is_train:
self.train()
else:
self.test()
def change_state(self, mode):
if mode == "train":
self.model.train()
elif mode == "valid":
self.model.eval()
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash = self.model.encode_image(image)
text_hash = self.model.encode_text(text)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def hash_loss(self, a: torch.Tensor):
return torch.mean(torch.sqrt(torch.sum(torch.pow(torch.sign(a) - a, 2), dim=1))) * 0.5
def similarity_loss(self):
raise NotImplementedError("Function of 'similarity_loss' doesn't implement.")
def save_model(self, epoch):
torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))
self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")))
def train(self):
raise NotImplementedError("Function of 'train' doesn't implement.")
def valid(self):
raise NotImplementedError("Function of 'valid' doesn't implement.")
def test(self):
raise NotImplementedError("Function of 'test' doesn't implement.")
def compute_loss(self):
raise NotImplementedError("Function of 'compute_loss' doesn't implement.")

360
train/hash_train.py Normal file
View File

@ -0,0 +1,360 @@
from torch.nn.modules import loss
from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import scipy.io as scio
from .base import TrainBase
from model.optimization import BertAdam
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity
from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader
class Trainer(TrainBase):
def __init__(self,
rank=0):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
self.run()
def _init_model(self):
self.logger.info("init model.")
linear = False
if self.args.hash_layer == "linear":
linear = True
self.logger.info("ViT+GPT!")
HashModel = DCMHT
self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
self.logger.info("load pretrained model.")
self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
self.model.float()
self.optimizer = BertAdam([
{'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
{'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
{'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
weight_decay=self.args.weight_decay, max_grad_norm=1.0)
print(self.model)
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
def train_epoch(self, epoch):
self.change_state(mode="train")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0
times = 0
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
image.float()
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag()
# print(text.dtype)
# text.float()
# label.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
# print("text shape:", text.shape)
index = index.numpy()
# print(text.shape)
hash_img, hash_text = self.model(image, text)
if self.args.hash_layer == "select":
hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1)
hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1)
loss = self.compute_loss(hash_img, hash_text, label, epoch, times)
all_loss += loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
def train(self):
self.logger.info("Start train.")
for epoch in range(self.args.epochs):
self.train_epoch(epoch)
self.valid(epoch)
self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
s = torch.matmul(a, b.t())
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
return b_loss
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
"""
"""
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
print("mean", torch.mean(a - b))
print("kl", kl_divergence)
return kl_divergence
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
# $\vartheta$
vartheta = self.args.vartheta
if self.args.sim_threshold != 0:
threshold = self.args.sim_threshold
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
positive_similarity = similarity * label_sim
# 只要cosine为负值的全都算为计算正确了因为优化到2确实很难。
negative_similarity = similarity * (1 - label_sim)
if self.args.similarity_function == "cosine":
positive_similarity = positive_similarity.clip(threshold) - threshold
negative_similarity = negative_similarity.clip(max=1.)
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
elif self.args.similarity_function == "euclidean":
# 有euclidean距离可知当有一半长度的hash码不同时其negative_similarity距离应该是长度concat操作将outputdim翻倍所以这里clip掉认为认定的值
# 人为认定的最大值是一半长度的hash码不同。
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
negative_similarity = negative_similarity.clip(max=max_value)
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
if self.args.loss_type == "l1":
positive_loss = positive_similarity.mean()
negative_loss = negative_similarity.mean()
elif self.args.loss_type == "l2":
positive_loss = torch.pow(positive_similarity, 2).mean()
negative_loss = torch.pow(negative_similarity, 2).mean()
else:
raise ValueError("argument of loss_type is not support.")
return similarity, positive_loss, negative_loss
def make_hash_code(self, code: list) -> torch.Tensor:
code = torch.stack(code)
# print(code.shape)
code = code.permute(1, 0, 2)
hash_code = torch.argmax(code, dim=-1)
hash_code[torch.where(hash_code == 0)] = -1
hash_code = hash_code.float()
return hash_code
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash = self.model.encode_image(image)
image_hash = self.make_hash_code(image_hash)
text_hash = self.model.encode_text(text)
text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def our_loss(self, image, text, label, epoch, times):
loss = 0
label_sim = calc_neighbor(label, label)
if image.is_cuda:
label_sim = label_sim.to(image.device)
intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim)
inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim)
inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim)
intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss)
inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t
similarity_loss = inter_similarity_loss + intra_similarity_loss
# if self.writer is not None:
# self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step)
# self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step)
# self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step)
# self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step)
# self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step)
# self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step)
# self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step)
# self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step)
# self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step)
# self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step)
# self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step)
# self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step)
# self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step)
# self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step)
# self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step)
if self.args.hash_layer != "select":
quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2
loss = similarity_loss + quantization_loss
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
f"QUATIZATION LOSS, {quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
else:
loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss)
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
# f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return loss
def compute_loss(self, image, text, label, epoch, times):
loss = self.our_loss(image, text, label, epoch, times)
return loss
def test(self, mode_name="i2t"):
if self.args.pretrained == "":
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def valid(self, epoch):
self.logger.info("Valid.")
self.change_state(mode="valid")
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# print("get all code")
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
if self.max_mapi2t < mAPi2t:
self.best_epoch_i = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
if self.max_mapt2i < mAPt2i:
self.best_epoch_t = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")

3
utils/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .utils import *
from .logger import get_logger, get_summary_writer
from .get_args import get_args

357
utils/calc_utils.py Normal file
View File

@ -0,0 +1,357 @@
import torch
import numpy as np
from tqdm import tqdm
def calc_hammingDist(B1, B2):
q = B2.shape[1]
if len(B1.shape) < 2:
B1 = B1.unsqueeze(0)
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
return distH
def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):
num_query = query_L.shape[0]
if qB.is_cuda:
qB = qB.cpu()
rB = rB.cpu()
map = 0
if k is None:
k = retrieval_L.shape[0]
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsums = torch.sum(gnds, dim=-1, keepdim=True, dtype=torch.int32)
hamms = calc_hammingDist(qB, rB)
_, ind = torch.sort(hamms, dim=-1)
totals = torch.min(tsums, torch.tensor([k], dtype=torch.int32).expand_as(tsums))
for iter in range(num_query):
gnd = gnds[iter][ind[iter]]
total = totals[iter].squeeze()
count = torch.arange(1, total + 1).type(torch.float32)
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
map = map + torch.mean(count / tindex)
map = map / num_query
return map
def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
num_query = query_L.shape[0]
qB = torch.sign(qB)
rB = torch.sign(rB)
map = 0
if k is None:
k = retrieval_L.shape[0]
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
total = min(k, int(tsum))
count = torch.arange(1, total + 1).type(torch.float32)
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
if tindex.is_cuda:
count = count.to(rank)
map = map + torch.mean(count / tindex)
map = map / num_query
return map
def calc_precisions_topn_matrix(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000):
if not isinstance(qB, torch.Tensor):
qB = torch.from_numpy(qB)
rB = torch.from_numpy(rB)
query_L = torch.from_numpy(query_L)
retrieval_L = torch.from_numpy(retrieval_L)
qB = qB.float()
rB = rB.float()
qB = torch.sign(qB - 0.5)
rB = torch.sign(rB - 0.5)
if qB.is_cuda:
qB = qB.cpu()
rB = rB.cpu()
num_query = query_L.shape[0]
# num_retrieval = retrieval_L.shape[0]
precisions = [0] * int(1 / recall_gas)
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamms = calc_hammingDist(qB, rB)
_, inds = torch.sort(hamms, dim=-1)
for iter in range(num_query):
gnd = gnds[iter]
ind = inds[iter]
gnd = gnd[ind]
for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)):
total = int(num_retrieval * recall)
right = torch.nonzero(gnd[: total]).squeeze().numpy()
right_num = right.size
precisions[i] += (right_num/total)
for i in range(len(precisions)):
precisions[i] /= num_query
return precisions
def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000):
qB = qB.float()
rB = rB.float()
qB = torch.sign(qB - 0.5)
rB = torch.sign(rB - 0.5)
num_query = query_L.shape[0]
# num_retrieval = retrieval_L.shape[0]
precisions = [0] * int(1 / recall_gas)
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)):
total = int(num_retrieval * recall)
right = torch.nonzero(gnd[: total]).squeeze().numpy()
# right_num = torch.nonzero(gnd[: total]).squeeze().shape[0]
right_num = right.size
precisions[i] += (right_num/total)
for i in range(len(precisions)):
precisions[i] /= num_query
return precisions
def calc_precisions_hash(qB, rB, query_L, retrieval_L):
qB = qB.float()
rB = rB.float()
qB = torch.sign(qB - 0.5)
rB = torch.sign(rB - 0.5)
num_query = query_L.shape[0]
num_retrieval = retrieval_L.shape[0]
bit = qB.shape[1]
hamm = calc_hammingDist(qB, rB)
hamm = hamm.type(torch.ByteTensor)
total_num = [0] * (bit + 1)
max_hamm = int(torch.max(hamm))
gnd = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze()
total_right = torch.sum(torch.matmul(query_L, retrieval_L.t())>0)
precisions = np.zeros([max_hamm + 1])
recalls = np.zeros([max_hamm + 1])
# _, index = torch.sort(hamm)
# del _
# for i in range(index.shape[0]):
# gnd[i, :] = gnd[i, index[i]]
# del index
right_num = 0
recall_num = 0
for i, radius in enumerate(range(0, max_hamm+1)):
recall = torch.nonzero(hamm == radius)
right = gnd[recall.split(1, dim=1)]
recall_num += recall.shape[0]
del recall
right_num += torch.nonzero(right).shape[0]
del right
precisions[i] += (right_num / (recall_num + 1e-8))
# recalls[i] += (recall_num / num_retrieval / num_query)
recalls[i] += (recall_num / total_right)
return precisions, recalls
def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval):
if not isinstance(qB, torch.Tensor):
qB = torch.from_numpy(qB)
if not isinstance(rB, torch.Tensor):
rB = torch.from_numpy(rB)
if not isinstance(Gnd, torch.Tensor):
Gnd = torch.from_numpy(Gnd)
def CalcHammingDist_np(B1, B2):
q = B2.shape[1]
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
return distH
bit = qB.shape[1]
# if isinstance(qB, np.ndarray):
# hamm = CalcHammingDist_np(qB, rB)
# else:
hamm = calc_hammingDist(qB, rB)
hamm = hamm.type(torch.ByteTensor)
total_num = [0] * (bit + 1)
max_hamm = int(torch.max(hamm))
gnd = Gnd
total_right = torch.sum(gnd>0)
precisions = np.zeros([max_hamm + 1])
recalls = np.zeros([max_hamm + 1])
right_num = 0
recall_num = 0
for i, radius in enumerate(range(0, max_hamm+1)):
recall = torch.nonzero(hamm == radius)
right = gnd[recall.split(1, dim=1)]
recall_num += recall.shape[0]
del recall
right_num += torch.nonzero(right).shape[0]
del right
precisions[i] += (right_num / (recall_num + 1e-8))
recalls[i] += (recall_num / num_retrieval / num_query)
# recalls[i] += (recall_num / total_right)
p = precisions.round(2)
r = recalls.round(2)
# return p, r
precisions = []
recalls = []
precision_ = 0
num = 1
for i in range(len(r) - 1):
if r[i] == r[i + 1]:
precision_ += p[i]
num += 1
else:
precision_ += p[i]
precisions.append(precision_ / num)
recalls.append(r[i])
precision_ = 0
num = 1
return np.asarray(precisions).round(2), np.asarray(recalls)
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
num_query = query_L.shape[0]
bit = qB.shape[1]
precisions = [0] * int(bit / hamming_gas)
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
for i, recall in enumerate(np.arange(1, bit+1, hamming_gas)):
total = torch.nonzero(hamm <= recall).squeeze().shape[0]
if total == 0:
precisions[i] += 0
continue
right = torch.nonzero(gnd[: total]).squeeze().numpy()
right_num = right.size
precisions[i] += (right_num / total)
for i in range(len(precisions)):
precisions[i] /= num_query
return precisions
def calc_neighbor(label1, label2):
# calculate the similar matrix
Sim = label1.matmul(label2.transpose(0, 1)) > 0
return Sim.float()
def norm_max_min(x: torch.Tensor, dim=None):
if dim is None:
max = torch.max(x)
min = torch.min(x)
if dim is not None:
max = torch.max(x, dim=dim)[0]
min = torch.min(x, dim=dim)[0]
if dim > 0:
max = max.unsqueeze(len(x.shape) - 1)
min = min.unsqueeze(len(x.shape) - 1)
norm = (x - min) / (max - min)
return norm
def norm_mean(x: torch.Tensor, dim=None):
if dim is None:
mean = torch.mean(x)
std = torch.std(x)
if dim is not None:
mean = torch.mean(x, dim=dim)
std = torch.std(x, dim=dim)
if dim > 0:
mean = mean.unsqueeze(len(x.shape) - 1)
std = std.unsqueeze(len(x.shape) - 1)
norm = (x - mean) / std
return norm
def norm_abs_mean(x: torch.Tensor, dim=None):
if dim is None:
mean = torch.mean(x)
std = torch.std(x)
if dim is not None:
mean = torch.mean(x, dim=dim)
std = torch.std(x, dim=dim)
if dim > 0:
mean = mean.unsqueeze(len(x.shape) - 1)
std = std.unsqueeze(len(x.shape) - 1)
norm = torch.abs(x - mean) / std
return norm
def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n - 1)
def calc_IF(all_bow):
word_num = torch.sum(all_bow, dim=0)
total_num = torch.sum(word_num)
IF = word_num / total_num
return IF
# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta):
# theta = torch.matmul(F, G.transpose(0, 1)) / 2
# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta)
# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2
# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f)
# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2
# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g)
# intra_loss = gamma1 * intra_img + gamma2 * intra_txt
# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta
# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2))
# # loss = term1 + gamma * term2 + eta * term3
# loss = inter_loss + intra_loss + quan_loss
# return loss
# if __name__ == '__main__':
# qB = torch.Tensor([[1, -1, 1, 1],
# [-1, -1, -1, 1],
# [1, 1, -1, 1],
# [1, 1, 1, -1]])
# rB = torch.Tensor([[1, -1, 1, -1],
# [-1, -1, 1, -1],
# [-1, -1, 1, -1],
# [1, 1, -1, -1],
# [-1, 1, -1, -1],
# [1, 1, -1, 1]])
# query_L = torch.Tensor([[0, 1, 0, 0],
# [1, 1, 0, 0],
# [1, 0, 0, 1],
# [0, 1, 0, 1]])
# retrieval_L = torch.Tensor([[1, 0, 0, 1],
# [1, 1, 0, 0],
# [0, 1, 1, 0],
# [0, 0, 1, 0],
# [1, 0, 0, 0],
# [0, 0, 1, 0]])
#
# map = calc_map_k(qB, rB, query_L, retrieval_L)
# print(map)

48
utils/get_args.py Normal file
View File

@ -0,0 +1,48 @@
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--hash-layer", type=str, default="select", help="choice a hash layer [select, linear] to run. select: select mechaism, linear: sign function.")
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
parser.add_argument("--pretrained", type=str, default="")
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]")
parser.add_argument("--index-file", type=str, default="index.mat")
parser.add_argument("--caption-file", type=str, default="caption.mat")
parser.add_argument("--label-file", type=str, default="label.mat")
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
parser.add_argument("--output-dim", type=int, default=64)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=32)
parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--query-num", type=int, default=5120)
parser.add_argument("--train-num", type=int, default=10240)
parser.add_argument("--lr-decay-freq", type=int, default=5)
parser.add_argument("--display-step", type=int, default=50)
parser.add_argument("--seed", type=int, default=1814)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--lr-decay", type=float, default=0.9)
parser.add_argument("--clip-lr", type=float, default=0.00001)
parser.add_argument("--weight-decay", type=float, default=0.2)
parser.add_argument("--warmup-proportion", type=float, default=0.1,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument("--vartheta", type=float, default=0.5, help="the rate of error code.")
parser.add_argument("--sim-threshold", type=float, default=0.1)
parser.add_argument("--is-train", action="store_true")
args = parser.parse_args()
return args

22
utils/logger.py Normal file
View File

@ -0,0 +1,22 @@
import os
import logging
from torch.utils.tensorboard import SummaryWriter
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
def get_summary_writer(dirname: str):
os.makedirs(dirname, exist_ok=True)
return SummaryWriter(log_dir=dirname)

229
utils/utils.py Normal file
View File

@ -0,0 +1,229 @@
import torch
import numpy as np
from typing import Union
import torch.nn as nn
from torch.nn import functional as F
from sklearn.metrics.pairwise import euclidean_distances
def compute_metrics(x):
# 取复值的原因在于cosine的值越大说明越相似但是需要取的是前N个值所以取符号变为增函数s
sx = np.sort(-x, axis=1)
d = np.diag(-x)
d = d[:, np.newaxis]
ind = sx - d
ind = np.where(ind == 0)
ind = ind[1]
metrics = {}
metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
metrics['MR'] = np.median(ind) + 1
metrics["MedianR"] = metrics['MR']
metrics["MeanR"] = np.mean(ind) + 1
metrics["cols"] = [int(i) for i in list(ind)]
return metrics
def encode_hash(a: Union[torch.Tensor, np.ndarray]):
if isinstance(a, torch.Tensor):
hash_a = torch.sign(a)
# where 是吧所有false值转为0
# hash_a = torch.where(hash_a>0, hash_a, torch.tensor(0))
return hash_a
else:
hash_a = np.sign(a)
# hash_a = np.where(hash_a > 0, hash_a, 0)
return hash_a
def calc_neighbor(a: torch.Tensor, b: torch.Tensor):
# print(a.dtype, b.dtype)
return (a.matmul(b.transpose(0, 1)) > 0).float()
def euclidean_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
similarity = torch.cdist(a, b, p=2.0)
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
similarity = euclidean_distances(a, b)
else:
raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b)))
return similarity
def euclidean_dist_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor):
"""
calculate euclidean distance as inner product
:param tensor1: a tensor with shape (a, c)
:param tensor2: a tensor with shape (b, c)
:return: the euclidean distance matrix which each point is the distance between a row in tensor1 and a row in tensor2.
"""
dim1 = tensor1.shape[0]
dim2 = tensor2.shape[0]
multi = torch.matmul(tensor1, tensor2.t())
a2 = torch.sum(torch.pow(tensor1, 2), dim=1, keepdim=True).expand(dim1, dim2)
b2 = torch.sum(torch.pow(tensor2, 2), dim=1, keepdim=True).t().expand(dim1, dim2)
dist = torch.sqrt(a2 + b2 - 2 * multi)
return dist
def cosine_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
a = a / a.norm(dim=-1, keepdim=True) if len(torch.where(a != 0)[0]) > 0 else a
b = b / b.norm(dim=-1, keepdim=True) if len(torch.where(b != 0)[0]) > 0 else b
return torch.matmul(a, b.t())
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
a = a / np.linalg.norm(a, axis=-1, keepdims=True) if len(np.where(a != 0)[0]) > 0 else a
b = b / np.linalg.norm(b, axis=-1, keepdims=True) if len(np.where(b != 0)[0]) > 0 else b
return np.matmul(a, b.T)
else:
raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b)))
def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# query_L: {0,1}^{mxl}
# retrieval_L: {0,1}^{nxl}
num_query = query_L.shape[0]
qB = torch.sign(qB)
rB = torch.sign(rB)
map = 0
if k is None:
k = retrieval_L.shape[0]
# print("query num:", num_query)
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calcHammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
total = min(k, int(tsum))
count = torch.arange(1, total + 1).type(torch.float32)
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
if tindex.is_cuda:
count = count.to(rank)
map = map + torch.mean(count / tindex)
map = map / num_query
return map
def softmax_hash(code: Union[torch.Tensor, np.ndarray], dim_alph=0.25):
device = None
if isinstance(code, torch.Tensor):
device = code.device
if code.is_cuda:
code = code.detach.cpu().numpy()
else:
code = code.cpu().numpy()
softmax_code = np.exp(code) / np.exp(code).sum(axis=-1, keepdims=True)
# print("max softmax_code:", softmax_code.max())
# print("min softmax_code:", softmax_code.min())
# print(1 / int(dim_alph * softmax_code.shape[-1]))
# print(np.sum(softmax_code[0]))
hash_code = np.where(softmax_code >= 1 / int(dim_alph * softmax_code.shape[-1]), softmax_code, -1)
hash_code = np.where(hash_code <= 1 / int(dim_alph * softmax_code.shape[-1]), hash_code, 1)
# print(hash_code[0])
# print("hash code sum:", hash_code.sum())
# print(len(np.where(hash_code == 1.0)[0]))
# print(len(np.where(hash_code == -1.0)[0]))
if device is not None:
hash_code = torch.from_numpy(hash_code).to(device)
return hash_code
# def calcHammingDist(B1, B2):
# result = np.zeros((B1.shape[0], B2.shape[0]))
# for i, data in enumerate(B1):
# result[i] = np.sum(np.where((data + B2) != 2, data + B2, 0), axis=-1)
# return result
def calcHammingDist(B1, B2):
if len(B1.shape) < 2:
B1.view(1, -1)
if len(B2.shape) < 2:
B2.view(1, -1)
q = B2.shape[1]
if isinstance(B1, torch.Tensor):
distH = 0.5 * (q - torch.matmul(B1, B2.t()))
elif isinstance(B1, np.ndarray):
distH = 0.5 * (q - np.matmul(B1, B2.transpose()))
else:
raise ValueError("B1, B2 must in [torch.Tensor, np.ndarray]")
return distH
def compute_hash_similarity(visual_embed, text_embed, use_softmax_hash=False, alph=0.25):
# hamming distance的值越大说明越不相似
hash_visual = encode_hash(visual_embed) if not use_softmax_hash else softmax_hash(visual_embed, alph)
hash_text = encode_hash(text_embed) if not use_softmax_hash else softmax_hash(text_embed, alph)
vt_similarity = calcHammingDist(hash_visual, hash_text)
# print(vt_similarity[0])
tv_similarity = calcHammingDist(hash_text, hash_visual)
return vt_similarity, tv_similarity
class CrossEn(nn.Module):
def __init__(self, mode="cosine"):
super(CrossEn, self).__init__()
# if mode == "euclidean":
# self.compute_func = F.softmax
# else:
# self.compute_func = F.log_softmax
self.mode = mode
def forward(self, sim_matrix):
# if self.mode == "cosine":
# logpt = F.log_softmax(sim_matrix, dim=-1)
# logpt = torch.diag(logpt)
# nce_loss = -logpt
# sim_loss = nce_loss.mean()
# elif self.mode == "euclidean":
# logpt = F.softmax(sim_matrix, dim=-1)
# logpt = torch.diag(sim_matrix)
# sim_loss = logpt.mean()
# else:
# raise ValueError("mode paramater is not support.[cosine, euclidean]")
if self.mode == "euclidean":
sim_matrix = -sim_matrix
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class CrossEn_mean(nn.Module):
def __init__(self, mode="cosine"):
super(CrossEn_mean, self).__init__()
# if mode == "euclidean":
# self.compute_func = F.softmax
# else:
# self.compute_func = F.log_softmax
self.mode = mode
def forward(self, sim_matrix):
# if self.mode == "cosine":
# logpt = F.log_softmax(sim_matrix, dim=-1)
# logpt = torch.diag(logpt)
# nce_loss = -logpt
# sim_loss = nce_loss.mean()
# elif self.mode == "euclidean":
# logpt = F.softmax(sim_matrix, dim=-1)
# logpt = torch.diag(sim_matrix)
# sim_loss = logpt.mean()
# else:
# raise ValueError("mode paramater is not support.[cosine, euclidean]")
# if self.mode == "euclidean":
# sim_matrix = -sim_matrix
# print(sim_matrix.max(), sim_matrix.min())
# logpt = F.log_softmax(sim_matrix, dim=-1)
# logpt = torch.diag(logpt)
# print(logpt.max())
sim_loss = sim_matrix.mean()
return sim_loss