DCMHT
This commit is contained in:
commit
9e6030dbf9
|
|
@ -0,0 +1,78 @@
|
|||
# Differentiable Cross Modal Hashing via Multimodal Transformers
|
||||
|
||||
## Framework
|
||||
The main architecture of our method.
|
||||

|
||||
|
||||
We propose a selecting mechanism to generate hash code that will transfor the discrete space into a continuous space. Hash code will be encoded as a $2D$ vector.
|
||||

|
||||
|
||||
## Dependencies
|
||||
We use python to build our code, you need to install those package to run
|
||||
|
||||
- pytorch 1.9.1
|
||||
- sklearn
|
||||
- tqdm
|
||||
- pillow
|
||||
|
||||
## Training
|
||||
|
||||
### Processing dataset
|
||||
Before training, you need to download the oringal data from [coco](https://cocodataset.org/#download)(include 2017 train,val and annotations), [nuswide](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html)(include all), [mirflickr25k](https://www.kaggle.com/datasets/paulrohan2020/mirflickr25k)(include mirflickr25k and mirflickr25k_annotations_v080),
|
||||
then use the "data/make_XXX.py" to generate .mat file
|
||||
|
||||
For example:
|
||||
> cd COCO_DIR # include train val images and annotations files
|
||||
>
|
||||
> make mat
|
||||
>
|
||||
> cp DCMHT/data/make_coco.py mat
|
||||
>
|
||||
> python make_coco.py --coco-dir ../ --save-dir ./
|
||||
|
||||
After all mat file generated, the dir of `dataset` will like this:
|
||||
~~~
|
||||
dataset
|
||||
├── base.py
|
||||
├── __init__.py
|
||||
├── dataloader.py
|
||||
├── coco
|
||||
│ ├── caption.mat
|
||||
│ ├── index.mat
|
||||
│ └── label.mat
|
||||
├── flickr25k
|
||||
│ ├── caption.mat
|
||||
│ ├── index.mat
|
||||
│ └── label.mat
|
||||
└── nuswide
|
||||
├── caption.txt # Notice! It is a txt file!
|
||||
├── index.mat
|
||||
└── label.mat
|
||||
~~~
|
||||
|
||||
### Download CLIP pretrained model
|
||||
Pretrained model will be found in the 30 lines of [CLIP/clip/clip.py](https://github.com/openai/CLIP/blob/main/clip/clip.py). This code is based on the "ViT-B/32".
|
||||
|
||||
You should copy ViT-B-32.pt to this dir.
|
||||
|
||||
### Start
|
||||
|
||||
After the dataset has been prepared, we could run the follow command to train.
|
||||
> python main.py --is-train --hash-layer select --dataset coco --caption-file caption.mat --index-file index.mat --label-file label.mat --similarity-function euclidean --loss-type l2 --vartheta 0.75 --lr 0.0001 --output-dim 64 --save-dir ./result/coco/64 --clip-path ./ViT-B-32.pt --batch-size 256
|
||||
|
||||
|
||||
## Result
|
||||

|
||||
|
||||
## Acknowledegements
|
||||
[CLIP](https://github.com/openai/CLIP)
|
||||
|
||||
[SSAH](https://github.com/lelan-li/SSAH)
|
||||
|
||||
[GCH](https://github.com/DeXie0808/GCH)
|
||||
|
||||
[AGAH](https://github.com/WendellGul/AGAH)
|
||||
|
||||
[DADH](https://github.com/Zjut-MultimediaPlus/DADH)
|
||||
|
||||
[deep-cross-modal-hashing](https://github.com/WangGodder/deep-cross-modal-hashing)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 98 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 271 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 190 KiB |
|
|
@ -0,0 +1,102 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
captions: dict,
|
||||
indexs: dict,
|
||||
labels: dict,
|
||||
is_train=True,
|
||||
tokenizer=Tokenizer(),
|
||||
maxWords=32,
|
||||
imageResolution=224,
|
||||
npy=False):
|
||||
|
||||
self.captions = captions
|
||||
self.indexs = indexs
|
||||
self.labels = labels
|
||||
self.npy = npy
|
||||
|
||||
self.maxWords = maxWords
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.transform = Compose([
|
||||
Resize(imageResolution, interpolation=Image.BICUBIC),
|
||||
CenterCrop(imageResolution),
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]) if is_train else Compose([
|
||||
Resize((imageResolution, imageResolution), interpolation=Image.BICUBIC),
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
|
||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
|
||||
|
||||
self.__length = len(self.indexs)
|
||||
|
||||
def __len__(self):
|
||||
return self.__length
|
||||
|
||||
def _load_image(self, index: int) -> torch.Tensor:
|
||||
if not self.npy:
|
||||
image_path = self.indexs[index].strip()
|
||||
# print(image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
image = Image.fromarray(self.indexs[index]).convert("RGB")
|
||||
image = self.transform(image)
|
||||
|
||||
return image
|
||||
|
||||
def _load_text(self, index: int):
|
||||
captions = self.captions[index]
|
||||
use_cap = captions[random.randint(0, len(captions) - 1)]
|
||||
|
||||
words = self.tokenizer.tokenize(use_cap)
|
||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
|
||||
total_length_with_CLS = self.maxWords - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
|
||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
|
||||
caption = self.tokenizer.convert_tokens_to_ids(words)
|
||||
|
||||
while len(caption) < self.maxWords:
|
||||
caption.append(0)
|
||||
caption = torch.tensor(caption)
|
||||
|
||||
return caption
|
||||
|
||||
def _load_label(self, index: int) -> torch.Tensor:
|
||||
label = self.labels[index]
|
||||
label = torch.from_numpy(label)
|
||||
|
||||
return label
|
||||
|
||||
def get_all_label(self):
|
||||
labels = torch.zeros([self.__length, len(self.labels[0])], dtype=torch.int64)
|
||||
for i, item in enumerate(self.labels):
|
||||
|
||||
labels[i] = torch.from_numpy(item)
|
||||
return labels
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = self._load_image(index)
|
||||
caption = self._load_text(index)
|
||||
label = self._load_label(index)
|
||||
|
||||
return image, caption, label, index
|
||||
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
|
||||
from .base import BaseDataset
|
||||
import os
|
||||
import numpy as np
|
||||
import scipy.io as scio
|
||||
|
||||
|
||||
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
|
||||
np.random.seed(seed=seed)
|
||||
random_index = np.random.permutation(range(len(indexs)))
|
||||
query_index = random_index[: query_num]
|
||||
train_index = random_index[query_num: query_num + train_num]
|
||||
retrieval_index = random_index[query_num:]
|
||||
|
||||
query_indexs = indexs[query_index]
|
||||
query_captions = captions[query_index]
|
||||
query_labels = labels[query_index]
|
||||
|
||||
train_indexs = indexs[train_index]
|
||||
train_captions = captions[train_index]
|
||||
train_labels = labels[train_index]
|
||||
|
||||
retrieval_indexs = indexs[retrieval_index]
|
||||
retrieval_captions = captions[retrieval_index]
|
||||
retrieval_labels = labels[retrieval_index]
|
||||
|
||||
split_indexs = (query_indexs, train_indexs, retrieval_indexs)
|
||||
split_captions = (query_captions, train_captions, retrieval_captions)
|
||||
split_labels = (query_labels, train_labels, retrieval_labels)
|
||||
return split_indexs, split_captions, split_labels
|
||||
|
||||
def dataloader(captionFile: str,
|
||||
indexFile: str,
|
||||
labelFile: str,
|
||||
maxWords=32,
|
||||
imageResolution=224,
|
||||
query_num=5000,
|
||||
train_num=10000,
|
||||
seed=None,
|
||||
npy=False):
|
||||
if captionFile.endswith("mat"):
|
||||
captions = scio.loadmat(captionFile)["caption"]
|
||||
captions = captions[0] if captions.shape[0] == 1 else captions
|
||||
elif captionFile.endswith("txt"):
|
||||
with open(captionFile, "r") as f:
|
||||
captions = f.readlines()
|
||||
captions = np.asarray([[item.strip()] for item in captions])
|
||||
else:
|
||||
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.")
|
||||
if not npy:
|
||||
indexs = scio.loadmat(indexFile)["index"]
|
||||
else:
|
||||
indexs = np.load(indexFile, allow_pickle=True)
|
||||
labels = scio.loadmat(labelFile)["category"]
|
||||
# for item in ['__version__', '__globals__', '__header__']:
|
||||
# captions.pop(item)
|
||||
# indexs.pop(item)
|
||||
# labels.pop(item)
|
||||
|
||||
split_indexs, split_captions, split_labels = split_data(captions, indexs, labels, query_num=query_num, train_num=train_num, seed=seed)
|
||||
|
||||
train_data = BaseDataset(captions=split_captions[1], indexs=split_indexs[1], labels=split_labels[1], maxWords=maxWords, imageResolution=imageResolution, npy=npy)
|
||||
query_data = BaseDataset(captions=split_captions[0], indexs=split_indexs[0], labels=split_labels[0], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||
|
||||
return train_data, query_data, retrieval_data
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_index(jsonData: dict, indexDict: dict):
|
||||
"""
|
||||
use coco dict data as orignial data.
|
||||
indexDict: {jsonData's key: [index_key, index_value]}
|
||||
"""
|
||||
result = []
|
||||
for name in indexDict:
|
||||
data = jsonData[name]
|
||||
middle_dict = {}
|
||||
for item in data:
|
||||
if item[indexDict[name][0]] not in middle_dict:
|
||||
middle_dict.update({item[indexDict[name][0]]: [item[indexDict[name][1]]]})
|
||||
else:
|
||||
middle_dict[item[indexDict[name][0]]].append(item[indexDict[name][1]])
|
||||
result.append(middle_dict)
|
||||
|
||||
return result
|
||||
|
||||
def check_file_exist(indexDict: dict, file_path: str):
|
||||
keys = list(indexDict.keys())
|
||||
for item in keys:
|
||||
# print(indexDict[item])
|
||||
if not os.path.exists(os.path.join(file_path, indexDict[item][0])):
|
||||
print(item, indexDict[item])
|
||||
indexDict.pop(item)
|
||||
indexDict[item] = os.path.join(file_path, indexDict[item][0])
|
||||
return indexDict
|
||||
|
||||
def chage_categories2numpy(category_ids: dict, data: dict):
|
||||
|
||||
for item in data:
|
||||
class_item = [0] * len(category_ids)
|
||||
for class_id in data[item]:
|
||||
class_item[category_ids[class_id]] = 1
|
||||
data[item] = np.asarray(class_item)
|
||||
|
||||
return data
|
||||
|
||||
def get_all_use_key(categoryDict: dict):
|
||||
return list(categoryDict.keys())
|
||||
|
||||
def remove_not_use(data: dict, used_key: list):
|
||||
|
||||
keys = list(data.keys())
|
||||
for item in keys:
|
||||
if item not in used_key:
|
||||
# print("remove:", item, indexDict[item])
|
||||
data.pop(item)
|
||||
# print(len(category_list))
|
||||
return data
|
||||
|
||||
def merge_to_list(data: dict):
|
||||
|
||||
result = []
|
||||
key_sort = list(data.keys())
|
||||
key_sort.sort()
|
||||
# print(key_sort)
|
||||
# print(key_sort.index(91654))
|
||||
|
||||
for item in key_sort:
|
||||
result.append(data[item])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import scipy.io as scio
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--coco-dir", default="./", type=str, help="the coco dataset dir")
|
||||
parser.add_argument("--save-dir", default="./", type=str, help="mat file saved dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
PATH = args.coco_dir
|
||||
jsonFile = os.path.join(PATH, "annotations", "captions_train2017.json")
|
||||
with open(jsonFile, "r") as f:
|
||||
jsonData = json.load(f)
|
||||
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
|
||||
result = make_index(jsonData, indexDict)
|
||||
indexDict_, captionDict = result
|
||||
indexDict_ = check_file_exist(indexDict_, os.path.join(PATH, "train2017"))
|
||||
print("caption:", len(indexDict_), len(captionDict))
|
||||
# print_result = list(indexDict.keys())
|
||||
# print_result.sort()
|
||||
# print(print_result)
|
||||
# indexList = merge_to_list(indexDict_)
|
||||
# captionList = merge_to_list(captionDict)
|
||||
# print(indexDict[565962], indexList[4864])
|
||||
# print(captionDict[565962], captionList[4864])
|
||||
# print(result)
|
||||
jsonFile = os.path.join(PATH, "annotations", "instances_train2017.json")
|
||||
with open(jsonFile, "r") as f:
|
||||
jsonData = json.load(f)
|
||||
categroy_ids = {}
|
||||
for i, item in enumerate(jsonData['categories']):
|
||||
categroy_ids.update({item['id']: i})
|
||||
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
|
||||
result = make_index(jsonData, indexDict)
|
||||
categoryDict = result[0]
|
||||
cateIndexDict = result[1]
|
||||
# cateIndexList = merge_to_list(cateIndexDict)
|
||||
# print(categoryDict[91654])
|
||||
categoryDict = chage_categories2numpy(categroy_ids, categoryDict)
|
||||
# print(categoryDict[91654])
|
||||
# categoryList = merge_to_list(categoryDict)
|
||||
# print(categoryDict[91654], categoryList[780])
|
||||
# print(indexList[100], cateIndexList[100])
|
||||
# print("category:", len(categoryDict), len(cateIndexList))
|
||||
used_key = get_all_use_key(categoryDict)
|
||||
# 统一index
|
||||
indexDict_ = remove_not_use(indexDict_, used_key)
|
||||
captionDict = remove_not_use(captionDict, used_key)
|
||||
categoryIndexDict = remove_not_use(cateIndexDict, used_key)
|
||||
categoryDict = remove_not_use(categoryDict, used_key)
|
||||
# 转变为list
|
||||
indexList = merge_to_list(indexDict_)
|
||||
captionList = merge_to_list(captionDict)
|
||||
categoryIndexList = merge_to_list(categoryIndexDict)
|
||||
categoryList = merge_to_list(categoryDict)
|
||||
print("result", len(indexDict_), len(categoryDict))
|
||||
print("category:", len(categoryDict), len(categoryIndexList))
|
||||
for i in range(len(indexList)):
|
||||
if indexList[i] != categoryIndexList[i]:
|
||||
print("Not the same:", i, indexList[i], categoryIndexList[i])
|
||||
|
||||
val_jsonFile = os.path.join(PATH, "annotations", "captions_val2017.json")
|
||||
with open(val_jsonFile, "r") as f:
|
||||
jsonData = json.load(f)
|
||||
indexDict = {"images": ["id", "file_name"], "annotations": ["image_id", "caption"]}
|
||||
result = make_index(jsonData, indexDict)
|
||||
val_indexDict = result[0]
|
||||
val_captionDict = result[1]
|
||||
val_indexDict = check_file_exist(val_indexDict, os.path.join(PATH, "val2017"))
|
||||
jsonFile = os.path.join(PATH, "annotations", "instances_val2017.json")
|
||||
with open(jsonFile, "r") as f:
|
||||
jsonData = json.load(f)
|
||||
categroy_ids = {}
|
||||
for i, item in enumerate(jsonData['categories']):
|
||||
categroy_ids.update({item['id']: i})
|
||||
indexDict = {"annotations": ["image_id", "category_id"], "images": ["id", "file_name"]}
|
||||
result = make_index(jsonData, indexDict)
|
||||
val_categoryDict = result[0]
|
||||
val_categoryIndexDict = result[1]
|
||||
val_categoryDict = chage_categories2numpy(categroy_ids, val_categoryDict)
|
||||
used_key = get_all_use_key(val_categoryDict)
|
||||
val_indexDict = remove_not_use(val_indexDict, used_key)
|
||||
val_captionDict = remove_not_use(val_captionDict, used_key)
|
||||
val_categoryIndexDict = remove_not_use(val_categoryIndexDict, used_key)
|
||||
val_categoryDict = remove_not_use(val_categoryDict, used_key)
|
||||
|
||||
val_indexList = merge_to_list(val_indexDict)
|
||||
val_captionList = merge_to_list(val_captionDict)
|
||||
val_categoryIndexList = merge_to_list(val_categoryIndexDict)
|
||||
val_categoryList = merge_to_list(val_categoryDict)
|
||||
|
||||
indexList.extend(val_indexList)
|
||||
captionList.extend(val_captionList)
|
||||
categoryIndexList.extend(val_categoryIndexList)
|
||||
categoryList.extend(val_categoryList)
|
||||
|
||||
print(len(indexList), len(captionList), len(categoryIndexList))
|
||||
indexs = {"index": indexList}
|
||||
captions = {"caption": captionList}
|
||||
categorys = {"category": categoryList}
|
||||
|
||||
scio.savemat(os.path.join(args.save_dir, "index.mat"), indexs)
|
||||
scio.savemat(os.path.join(args.save_dir, "caption.mat"), captions)
|
||||
scio.savemat(os.path.join(args.save_dir, "label.mat"), categorys)
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
import scipy.io as scio
|
||||
import numpy as np
|
||||
|
||||
# mirflickr25k_annotations_v080 and mirflickr
|
||||
# mkdir mat
|
||||
# mv make_mirflickr25k.py mat
|
||||
# python make_mirflickr25k.py
|
||||
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
|
||||
|
||||
file_path = os.path.join(root_dir, "/mirflickr25k_annotations_v080")
|
||||
|
||||
file_list = os.listdir(file_path)
|
||||
|
||||
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
|
||||
|
||||
print("class num:", len(file_list))
|
||||
|
||||
class_index = {}
|
||||
for i, item in enumerate(file_list):
|
||||
class_index.update({item: i})
|
||||
|
||||
label_dict = {}
|
||||
for path_id in file_list:
|
||||
path = os.path.join(file_path, path_id)
|
||||
with open(path, "r") as f:
|
||||
for item in f:
|
||||
item = item.strip()
|
||||
if item not in label_dict:
|
||||
label = np.zeros(len(file_list))
|
||||
label[class_index[path_id]] = 1
|
||||
label_dict.update({item: label})
|
||||
else:
|
||||
# print()
|
||||
label_dict[item][class_index[path_id]] = 1
|
||||
|
||||
# print(label_dict)
|
||||
print("create label:", len(label_dict))
|
||||
keys = list(label_dict.keys())
|
||||
keys.sort()
|
||||
|
||||
labels = []
|
||||
for key in keys:
|
||||
labels.append(label_dict[key])
|
||||
print("labels created:", len(labels))
|
||||
labels = {"category": labels}
|
||||
|
||||
|
||||
PATH = os.path.join(root_dir, "mirflickr")
|
||||
index = [os.path.join(PATH, "im" + item + ".jpg") for item in keys]
|
||||
print("index created:", len(index))
|
||||
index= {"index": index}
|
||||
|
||||
|
||||
captions_path = os.path.join(root_dir, "/mirflickr/meta/tags")
|
||||
captions_list = os.listdir(captions_path)
|
||||
captions_dict = {}
|
||||
for item in captions_list:
|
||||
id_ = item.split(".")[0].replace("tags", "")
|
||||
caption = ""
|
||||
with open(os.path.join(captions_path, item), "r") as f:
|
||||
for word in f.readlines():
|
||||
caption += word.strip() + " "
|
||||
caption = caption.strip()
|
||||
captions_dict.update({id_: caption})
|
||||
|
||||
captions = []
|
||||
|
||||
for item in keys:
|
||||
captions.append([captions_dict[item]])
|
||||
|
||||
print("captions created:", len(captions))
|
||||
captions = {"caption": captions}
|
||||
|
||||
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), index)
|
||||
scio.savemat(os.path.join(root_dir, "/mat/caption.mat"), captions)
|
||||
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
import scipy.io as scio
|
||||
import numpy as np
|
||||
|
||||
# mkdir mat
|
||||
# mv make_nuswide.py mat
|
||||
# python make_nuswide.py
|
||||
root_dir = "PATH/TO/YOUR/DOWNLOAD/DIR/"
|
||||
|
||||
|
||||
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
|
||||
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
|
||||
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
|
||||
classIndexFile = os.path.join(root_dir, "/Low-Level-Features/Concepts81.txt")
|
||||
|
||||
# you can use the image urls to download images
|
||||
imagePath = os.path.join(root_dir, "nuswide/Flickr")
|
||||
|
||||
with open(imageListFile, "r") as f:
|
||||
indexs = f.readlines()
|
||||
|
||||
indexs = [os.path.join(imagePath, item.strip().replace("\\", "/")) for item in indexs]
|
||||
print("indexs length:", len(indexs))
|
||||
|
||||
#class_index = {}
|
||||
#with open(classIndexFile, "r") as f:
|
||||
# data = f.readlines()
|
||||
#
|
||||
#for i, item in enumerate(data):
|
||||
# class_index.update({item.strip(): i})
|
||||
|
||||
captions = []
|
||||
with open(textFile, "r") as f:
|
||||
for line in f:
|
||||
if len(line.strip()) == 0:
|
||||
print("some line empty!")
|
||||
continue
|
||||
caption = line.split()[1:]
|
||||
caption = " ".join(caption).strip()
|
||||
if len(caption) == 0:
|
||||
caption = "123456"
|
||||
captions.append(caption)
|
||||
|
||||
print("captions length:", len(captions))
|
||||
|
||||
#labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
|
||||
# label_lists = os.listdir(labelPath)
|
||||
with open(os.path.join(root_dir, "/nuswide/Groundtruth/used_label.txt")) as f:
|
||||
label_lists = f.readlines()
|
||||
label_lists = [item.strip() for item in label_lists]
|
||||
|
||||
class_index = {}
|
||||
for i, item in enumerate(label_lists):
|
||||
class_index.update({item: i})
|
||||
|
||||
labels = np.zeros([len(indexs), len(class_index)], dtype=np.int8)
|
||||
|
||||
for item in label_lists:
|
||||
path = os.path.join(labelPath, item)
|
||||
class_label = item# .split(".")[0].split("_")[-1]
|
||||
|
||||
with open(path, "r") as f:
|
||||
data = f.readlines()
|
||||
for i, val in enumerate(data):
|
||||
labels[i][class_index[class_label]] = 1 if val.strip() == "1" else 0
|
||||
print("labels sum:", labels.sum())
|
||||
|
||||
not_used_id = []
|
||||
with open(os.path.join(root_dir, "/nuswide/Groundtruth/not_used_id.txt")) as f:
|
||||
not_used_id = f.readlines()
|
||||
not_used_id = [int(item.strip()) for item in not_used_id]
|
||||
|
||||
# for item in not_used_id:
|
||||
# indexs.pop(item)
|
||||
# captions.pop(item)
|
||||
# labels = np.delete(labels, item, 0)
|
||||
ind = list(range(len(indexs)))
|
||||
for item in not_used_id:
|
||||
ind.remove(item)
|
||||
indexs[item] = ""
|
||||
captions[item] = ""
|
||||
indexs = [item for item in indexs if item != ""]
|
||||
captions = [item for item in captions if item != ""]
|
||||
ind = np.asarray(ind)
|
||||
labels = labels[ind]
|
||||
# ind = range(len(indexs))
|
||||
|
||||
print("indexs length:", len(indexs))
|
||||
print("captions length:", len(captions))
|
||||
print("labels shape:", labels.shape)
|
||||
|
||||
indexs = {"index": indexs}
|
||||
captions = {"caption": captions}
|
||||
labels = {"category": labels}
|
||||
|
||||
scio.savemat(os.path.join(root_dir, "/mat/index.mat"), indexs)
|
||||
# scio.savemat("caption.mat", captions)
|
||||
scio.savemat(os.path.join(root_dir, "/mat/label.mat"), labels)
|
||||
|
||||
|
||||
captions = [item + "\n" for item in captions["caption"]]
|
||||
|
||||
with open(os.path.join(root_dir, "/mat/caption.txt"), "w") as f:
|
||||
f.writelines(captions)
|
||||
|
||||
print("finished!")
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from train.hash_train import Trainer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Trainer()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .clip import *
|
||||
Binary file not shown.
|
|
@ -0,0 +1,224 @@
|
|||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Any, Union, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
if torch.__version__.split(".") < ["1", "7", "1"]:
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert("RGB"),
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/clip"
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if inputs[i].node()["value"] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
truncate: bool
|
||||
Whether to truncate the text in case its encoding is longer than the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
if truncate:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
else:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
import os
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
from model.model import build_model
|
||||
from utils import get_logger, get_summary_writer
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
class LinearHash(nn.Module):
|
||||
|
||||
def __init__(self, inputDim=2048, outputDim=64):
|
||||
super(LinearHash, self).__init__()
|
||||
self.fc = nn.Linear(inputDim, outputDim)
|
||||
self.fc.apply(weights_init_kaiming)
|
||||
self.drop_out = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, data):
|
||||
result = self.fc(data)
|
||||
return torch.tanh(self.drop_out(result))
|
||||
|
||||
|
||||
class HashLayer(nn.Module):
|
||||
|
||||
LINEAR_EMBED = 128
|
||||
SIGMOID_ALPH = 10
|
||||
def __init__(self, inputDim=2048, outputDim=64):
|
||||
|
||||
super(HashLayer, self).__init__()
|
||||
self.fc = nn.Linear(inputDim, self.LINEAR_EMBED)
|
||||
self.fc.apply(weights_init_kaiming)
|
||||
self.hash_list = nn.ModuleList([nn.Linear(self.LINEAR_EMBED, 2) for _ in range(outputDim)])
|
||||
for item in self.hash_list:
|
||||
item.apply(weights_init_kaiming)
|
||||
|
||||
def forward(self, data):
|
||||
|
||||
embed = self.fc(data)
|
||||
embed = torch.relu(embed)
|
||||
|
||||
softmax_list = [torch.softmax(item(embed), dim=-1) for item in self.hash_list]
|
||||
|
||||
return softmax_list
|
||||
|
||||
class HashLayer_easy_logic(nn.Module):
|
||||
|
||||
LINEAR_EMBED = 128
|
||||
SIGMOID_ALPH = 10
|
||||
def __init__(self, inputDim=2048, outputDim=64):
|
||||
|
||||
super(HashLayer, self).__init__()
|
||||
self.bit = outputDim
|
||||
self.fc = nn.Linear(inputDim, outputDim * 2)
|
||||
self.fc.apply(weights_init_kaiming)
|
||||
for item in self.hash_list:
|
||||
item.apply(weights_init_kaiming)
|
||||
|
||||
def forward(self, data):
|
||||
|
||||
embed = self.fc(data)
|
||||
softmax_list = embed.view(embed.shape[0], self.bit, 2)
|
||||
|
||||
softmax_list = torch.softmax(softmax_list, dim=-1)
|
||||
|
||||
return softmax_list
|
||||
|
||||
|
||||
class DCMHT(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
outputDim=64,
|
||||
clipPath="./ViT-B-32.pt",
|
||||
writer=None,
|
||||
saveDir="./result/log",
|
||||
logger: logging.Logger=None,
|
||||
is_train=True,
|
||||
linear=False):
|
||||
super(DCMHT, self).__init__()
|
||||
os.makedirs(saveDir, exist_ok=True)
|
||||
self.logger = logger if logger is not None else get_logger(os.path.join(saveDir, "train.log" if is_train else "test.log"))
|
||||
self.writer = writer if writer is not None and is_train else get_summary_writer(os.path.join(saveDir, "tensorboard"))
|
||||
embedDim, self.clip = self.load_clip(clipPath)
|
||||
# if is_train:
|
||||
# self.clip.eval()
|
||||
# print("start freezen")
|
||||
# self.freezen()
|
||||
self.image_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim)
|
||||
self.text_hash = LinearHash(inputDim=embedDim, outputDim=outputDim) if linear else HashLayer(inputDim=embedDim, outputDim=outputDim)
|
||||
# print(self.image_hash)
|
||||
# print(self.text_hash)
|
||||
|
||||
def freezen(self):
|
||||
for name, param in self.clip.named_parameters():
|
||||
# print(name)
|
||||
if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
|
||||
or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
|
||||
# print("1")
|
||||
continue
|
||||
elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
|
||||
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
|
||||
if layer_num >= 12:
|
||||
# print("2")
|
||||
continue
|
||||
if name.find("conv2.") == 0:
|
||||
# print("3")
|
||||
continue
|
||||
else:
|
||||
# paramenters which < freeze_layer_num will be freezed
|
||||
param.requires_grad = False
|
||||
|
||||
def load_clip(self, clipPath: str) -> tuple:
|
||||
try:
|
||||
model = torch.jit.load(clipPath, map_location="cpu").eval()
|
||||
state_dict = model.state_dict()
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(clipPath, map_location="cpu")
|
||||
|
||||
return state_dict["text_projection"].shape[1], build_model(state_dict)
|
||||
|
||||
def encode_image(self, image):
|
||||
|
||||
image_embed = self.clip.encode_image(image)
|
||||
image_embed = self.image_hash(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
def eval(self):
|
||||
self.image_hash.eval()
|
||||
self.text_hash.eval()
|
||||
# self.clip.eval()
|
||||
|
||||
def train(self):
|
||||
self.image_hash.train()
|
||||
self.text_hash.train()
|
||||
|
||||
def encode_text(self, text):
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_embed = self.text_hash(text_embed)
|
||||
|
||||
return text_embed
|
||||
|
||||
def forward(self, image, text):
|
||||
return self.encode_image(image), self.encode_text(text)
|
||||
|
||||
|
|
@ -0,0 +1,450 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
# self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
# return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
attn_mask_ = self.attn_mask
|
||||
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
|
||||
attn_mask_ = self.attn_mask(x.size(0)) # LND
|
||||
|
||||
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# x, video_frame = x_tuple
|
||||
# print(x.shape)
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# print(x.shape)
|
||||
# print(x.shape)
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
# print("image feature map:", x.shape)
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
# print(x.shape)
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
# print(x.shape)
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
# print(x.shape)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
# print(x.shape)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self, context_length):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(context_length, context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding[:x.size(1), :].type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def build_model(state_dict: dict):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
# print("vision width:", vision_width)
|
||||
# print("vision patch size", vision_patch_size)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
# return model.eval()
|
||||
return model
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for BERT model."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + math.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
||||
Learning rate is 1. afterwards. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine': warmup_cosine,
|
||||
'warmup_constant': warmup_constant,
|
||||
'warmup_linear': warmup_linear,
|
||||
}
|
||||
|
||||
|
||||
class BertAdam(Optimizer):
|
||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
||||
Params:
|
||||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
||||
max_grad_norm=1.0):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['next_m'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['next_v'] = torch.zeros_like(p.data)
|
||||
|
||||
next_m, next_v = state['next_m'], state['next_v']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
# next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
|
||||
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
# next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
|
||||
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
update = next_m / (next_v.sqrt() + group['e'])
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
progress = state['step']/group['t_total']
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
return loss
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
def tokenize(self, text):
|
||||
tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.encoder[bpe_token] for bpe_token in tokens]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
tqdm
|
||||
scipy
|
||||
numpy
|
||||
pillow
|
||||
matplotlib
|
||||
sklearn
|
||||
pytorch==1.9.1
|
||||
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from torch import distributed as dist
|
||||
from utils import get_logger, get_summary_writer
|
||||
|
||||
|
||||
class TrainBase(object):
|
||||
|
||||
def __init__(self,
|
||||
args,
|
||||
rank=0):
|
||||
|
||||
self.args = args
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
self._init_writer()
|
||||
self.logger.info(self.args)
|
||||
self.rank = rank
|
||||
|
||||
self._init_dataset()
|
||||
self._init_model()
|
||||
|
||||
self.global_step = 0
|
||||
# self.global_step_t = 0
|
||||
self.max_mapi2t = 0
|
||||
self.max_mapt2i = 0
|
||||
self.best_epoch_i = 0
|
||||
self.best_epoch_t = 0
|
||||
|
||||
def _init_dataset(self):
|
||||
self.train_loader = None
|
||||
self.query_loader = None
|
||||
self.retrieval_loader = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = None
|
||||
self.model_ddp = None
|
||||
|
||||
def _init_writer(self):
|
||||
self.logger = get_logger(os.path.join(self.args.save_dir, "train.log" if self.args.is_train else "test.log"))
|
||||
self.writer = get_summary_writer(os.path.join(self.args.save_dir, "tensorboard"))
|
||||
|
||||
def run(self):
|
||||
if self.args.is_train:
|
||||
self.train()
|
||||
else:
|
||||
self.test()
|
||||
|
||||
def change_state(self, mode):
|
||||
|
||||
if mode == "train":
|
||||
self.model.train()
|
||||
elif mode == "valid":
|
||||
self.model.eval()
|
||||
|
||||
def get_code(self, data_loader, length: int):
|
||||
|
||||
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
|
||||
for image, text, label, index in tqdm(data_loader):
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
image_hash = self.model.encode_image(image)
|
||||
text_hash = self.model.encode_text(text)
|
||||
|
||||
img_buffer[index, :] = image_hash.data
|
||||
text_buffer[index, :] = text_hash.data
|
||||
|
||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||
|
||||
def hash_loss(self, a: torch.Tensor):
|
||||
return torch.mean(torch.sqrt(torch.sum(torch.pow(torch.sign(a) - a, 2), dim=1))) * 0.5
|
||||
|
||||
def similarity_loss(self):
|
||||
raise NotImplementedError("Function of 'similarity_loss' doesn't implement.")
|
||||
|
||||
def save_model(self, epoch):
|
||||
torch.save(self.model.state_dict(), os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth"))
|
||||
self.logger.info("save mode to {}".format(os.path.join(self.args.save_dir, "model-" + str(epoch) + ".pth")))
|
||||
|
||||
def train(self):
|
||||
raise NotImplementedError("Function of 'train' doesn't implement.")
|
||||
|
||||
def valid(self):
|
||||
raise NotImplementedError("Function of 'valid' doesn't implement.")
|
||||
|
||||
def test(self):
|
||||
raise NotImplementedError("Function of 'test' doesn't implement.")
|
||||
|
||||
def compute_loss(self):
|
||||
raise NotImplementedError("Function of 'compute_loss' doesn't implement.")
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
from torch.nn.modules import loss
|
||||
from model.hash_model import DCMHT as DCMHT
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import scipy.io as scio
|
||||
|
||||
|
||||
from .base import TrainBase
|
||||
from model.optimization import BertAdam
|
||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity
|
||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||
from dataset.dataloader import dataloader
|
||||
|
||||
class Trainer(TrainBase):
|
||||
|
||||
def __init__(self,
|
||||
rank=0):
|
||||
args = get_args()
|
||||
super(Trainer, self).__init__(args, rank)
|
||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||
self.run()
|
||||
|
||||
def _init_model(self):
|
||||
self.logger.info("init model.")
|
||||
linear = False
|
||||
if self.args.hash_layer == "linear":
|
||||
linear = True
|
||||
|
||||
self.logger.info("ViT+GPT!")
|
||||
HashModel = DCMHT
|
||||
self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
||||
writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
||||
if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
|
||||
self.logger.info("load pretrained model.")
|
||||
self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
||||
|
||||
self.model.float()
|
||||
self.optimizer = BertAdam([
|
||||
{'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
||||
{'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
||||
{'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
|
||||
], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
|
||||
b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
|
||||
weight_decay=self.args.weight_decay, max_grad_norm=1.0)
|
||||
|
||||
print(self.model)
|
||||
|
||||
def _init_dataset(self):
|
||||
self.logger.info("init dataset.")
|
||||
self.logger.info(f"Using {self.args.dataset} dataset.")
|
||||
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
||||
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
||||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
||||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
||||
indexFile=self.args.index_file,
|
||||
labelFile=self.args.label_file,
|
||||
maxWords=self.args.max_words,
|
||||
imageResolution=self.args.resolution,
|
||||
query_num=self.args.query_num,
|
||||
train_num=self.args.train_num,
|
||||
seed=self.args.seed)
|
||||
self.train_labels = train_data.get_all_label()
|
||||
self.query_labels = query_data.get_all_label()
|
||||
self.retrieval_labels = retrieval_data.get_all_label()
|
||||
self.args.retrieval_num = len(self.retrieval_labels)
|
||||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
||||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
||||
self.train_loader = DataLoader(
|
||||
dataset=train_data,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=True,
|
||||
shuffle=True
|
||||
)
|
||||
self.query_loader = DataLoader(
|
||||
dataset=query_data,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=True,
|
||||
shuffle=True
|
||||
)
|
||||
self.retrieval_loader = DataLoader(
|
||||
dataset=retrieval_data,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=True,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
self.change_state(mode="train")
|
||||
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
all_loss = 0
|
||||
times = 0
|
||||
for image, text, label, index in self.train_loader:
|
||||
self.global_step += 1
|
||||
times += 1
|
||||
image.float()
|
||||
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
||||
label = torch.ones([image.shape[0]], dtype=torch.int)
|
||||
label = label.diag()
|
||||
# print(text.dtype)
|
||||
# text.float()
|
||||
# label.float()
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
# print("text shape:", text.shape)
|
||||
index = index.numpy()
|
||||
# print(text.shape)
|
||||
hash_img, hash_text = self.model(image, text)
|
||||
if self.args.hash_layer == "select":
|
||||
hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1)
|
||||
hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1)
|
||||
loss = self.compute_loss(hash_img, hash_text, label, epoch, times)
|
||||
all_loss += loss
|
||||
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
||||
|
||||
def train(self):
|
||||
self.logger.info("Start train.")
|
||||
|
||||
for epoch in range(self.args.epochs):
|
||||
self.train_epoch(epoch)
|
||||
self.valid(epoch)
|
||||
self.save_model(epoch)
|
||||
|
||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||
|
||||
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
||||
|
||||
s = torch.matmul(a, b.t())
|
||||
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
|
||||
|
||||
return b_loss
|
||||
|
||||
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
||||
"""
|
||||
"""
|
||||
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
|
||||
print("mean", torch.mean(a - b))
|
||||
print("kl", kl_divergence)
|
||||
return kl_divergence
|
||||
|
||||
|
||||
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
|
||||
|
||||
# $\vartheta$
|
||||
vartheta = self.args.vartheta
|
||||
if self.args.sim_threshold != 0:
|
||||
threshold = self.args.sim_threshold
|
||||
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
|
||||
|
||||
positive_similarity = similarity * label_sim
|
||||
# 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。
|
||||
negative_similarity = similarity * (1 - label_sim)
|
||||
|
||||
if self.args.similarity_function == "cosine":
|
||||
positive_similarity = positive_similarity.clip(threshold) - threshold
|
||||
negative_similarity = negative_similarity.clip(max=1.)
|
||||
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
||||
elif self.args.similarity_function == "euclidean":
|
||||
# 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值
|
||||
# 人为认定的最大值是一半长度的hash码不同。
|
||||
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
|
||||
negative_similarity = negative_similarity.clip(max=max_value)
|
||||
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
||||
|
||||
if self.args.loss_type == "l1":
|
||||
positive_loss = positive_similarity.mean()
|
||||
negative_loss = negative_similarity.mean()
|
||||
elif self.args.loss_type == "l2":
|
||||
positive_loss = torch.pow(positive_similarity, 2).mean()
|
||||
negative_loss = torch.pow(negative_similarity, 2).mean()
|
||||
else:
|
||||
raise ValueError("argument of loss_type is not support.")
|
||||
|
||||
return similarity, positive_loss, negative_loss
|
||||
|
||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||
|
||||
code = torch.stack(code)
|
||||
# print(code.shape)
|
||||
code = code.permute(1, 0, 2)
|
||||
hash_code = torch.argmax(code, dim=-1)
|
||||
hash_code[torch.where(hash_code == 0)] = -1
|
||||
hash_code = hash_code.float()
|
||||
|
||||
return hash_code
|
||||
|
||||
def get_code(self, data_loader, length: int):
|
||||
|
||||
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
|
||||
for image, text, label, index in tqdm(data_loader):
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
image_hash = self.model.encode_image(image)
|
||||
image_hash = self.make_hash_code(image_hash)
|
||||
text_hash = self.model.encode_text(text)
|
||||
text_hash = self.make_hash_code(text_hash)
|
||||
# image_hash.to(self.rank)
|
||||
# text_hash.to(self.rank)
|
||||
img_buffer[index, :] = image_hash.data
|
||||
text_buffer[index, :] = text_hash.data
|
||||
|
||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||
|
||||
def our_loss(self, image, text, label, epoch, times):
|
||||
loss = 0
|
||||
|
||||
label_sim = calc_neighbor(label, label)
|
||||
if image.is_cuda:
|
||||
label_sim = label_sim.to(image.device)
|
||||
intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim)
|
||||
inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim)
|
||||
inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim)
|
||||
|
||||
intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss)
|
||||
inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t
|
||||
similarity_loss = inter_similarity_loss + intra_similarity_loss
|
||||
|
||||
# if self.writer is not None:
|
||||
# self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step)
|
||||
# self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step)
|
||||
# self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step)
|
||||
# self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step)
|
||||
|
||||
# self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step)
|
||||
# self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step)
|
||||
# self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step)
|
||||
# self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step)
|
||||
|
||||
# self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step)
|
||||
# self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step)
|
||||
# self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step)
|
||||
# self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step)
|
||||
|
||||
# self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step)
|
||||
# self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step)
|
||||
# self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step)
|
||||
|
||||
if self.args.hash_layer != "select":
|
||||
quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2
|
||||
loss = similarity_loss + quantization_loss
|
||||
if self.global_step % self.args.display_step == 0:
|
||||
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
|
||||
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
|
||||
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
|
||||
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
|
||||
f"QUATIZATION LOSS, {quantization_loss.data}, "\
|
||||
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
||||
else:
|
||||
loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss)
|
||||
if self.global_step % self.args.display_step == 0:
|
||||
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
|
||||
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
|
||||
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
|
||||
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
|
||||
# f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\
|
||||
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
||||
|
||||
return loss
|
||||
|
||||
def compute_loss(self, image, text, label, epoch, times):
|
||||
|
||||
loss = self.our_loss(image, text, label, epoch, times)
|
||||
|
||||
return loss
|
||||
|
||||
def test(self, mode_name="i2t"):
|
||||
if self.args.pretrained == "":
|
||||
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
||||
self.change_state(mode="valid")
|
||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# print("map map")
|
||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
|
||||
|
||||
query_img = query_img.cpu().detach().numpy()
|
||||
query_txt = query_txt.cpu().detach().numpy()
|
||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
query_labels = self.query_labels.numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
result_dict = {
|
||||
'q_img': query_img,
|
||||
'q_txt': query_txt,
|
||||
'r_img': retrieval_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'q_l': query_labels,
|
||||
'r_l': retrieval_labels
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
def valid(self, epoch):
|
||||
self.logger.info("Valid.")
|
||||
self.change_state(mode="valid")
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# print("get all code")
|
||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# print("map map")
|
||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
if self.max_mapi2t < mAPi2t:
|
||||
self.best_epoch_i = epoch
|
||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
if self.max_mapt2i < mAPt2i:
|
||||
self.best_epoch_t = epoch
|
||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
|
||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||
|
||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
query_img = query_img.cpu().detach().numpy()
|
||||
query_txt = query_txt.cpu().detach().numpy()
|
||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
query_labels = self.query_labels.numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
result_dict = {
|
||||
'q_img': query_img,
|
||||
'q_txt': query_txt,
|
||||
'r_img': retrieval_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'q_l': query_labels,
|
||||
'r_l': retrieval_labels
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||
self.logger.info(f">>>>>> save best {mode_name} data!")
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .utils import *
|
||||
from .logger import get_logger, get_summary_writer
|
||||
from .get_args import get_args
|
||||
|
|
@ -0,0 +1,357 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def calc_hammingDist(B1, B2):
|
||||
q = B2.shape[1]
|
||||
if len(B1.shape) < 2:
|
||||
B1 = B1.unsqueeze(0)
|
||||
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
|
||||
return distH
|
||||
|
||||
|
||||
def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
||||
|
||||
num_query = query_L.shape[0]
|
||||
if qB.is_cuda:
|
||||
qB = qB.cpu()
|
||||
rB = rB.cpu()
|
||||
map = 0
|
||||
if k is None:
|
||||
k = retrieval_L.shape[0]
|
||||
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
tsums = torch.sum(gnds, dim=-1, keepdim=True, dtype=torch.int32)
|
||||
hamms = calc_hammingDist(qB, rB)
|
||||
_, ind = torch.sort(hamms, dim=-1)
|
||||
|
||||
totals = torch.min(tsums, torch.tensor([k], dtype=torch.int32).expand_as(tsums))
|
||||
for iter in range(num_query):
|
||||
gnd = gnds[iter][ind[iter]]
|
||||
total = totals[iter].squeeze()
|
||||
count = torch.arange(1, total + 1).type(torch.float32)
|
||||
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
|
||||
map = map + torch.mean(count / tindex)
|
||||
map = map / num_query
|
||||
|
||||
return map
|
||||
|
||||
|
||||
def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
||||
|
||||
num_query = query_L.shape[0]
|
||||
qB = torch.sign(qB)
|
||||
rB = torch.sign(rB)
|
||||
map = 0
|
||||
if k is None:
|
||||
k = retrieval_L.shape[0]
|
||||
for iter in range(num_query):
|
||||
q_L = query_L[iter]
|
||||
if len(q_L.shape) < 2:
|
||||
q_L = q_L.unsqueeze(0) # [1, hash length]
|
||||
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
tsum = torch.sum(gnd)
|
||||
if tsum == 0:
|
||||
continue
|
||||
hamm = calc_hammingDist(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
total = min(k, int(tsum))
|
||||
count = torch.arange(1, total + 1).type(torch.float32)
|
||||
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
|
||||
if tindex.is_cuda:
|
||||
count = count.to(rank)
|
||||
map = map + torch.mean(count / tindex)
|
||||
map = map / num_query
|
||||
return map
|
||||
|
||||
|
||||
def calc_precisions_topn_matrix(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000):
|
||||
if not isinstance(qB, torch.Tensor):
|
||||
qB = torch.from_numpy(qB)
|
||||
rB = torch.from_numpy(rB)
|
||||
query_L = torch.from_numpy(query_L)
|
||||
retrieval_L = torch.from_numpy(retrieval_L)
|
||||
qB = qB.float()
|
||||
rB = rB.float()
|
||||
qB = torch.sign(qB - 0.5)
|
||||
rB = torch.sign(rB - 0.5)
|
||||
if qB.is_cuda:
|
||||
qB = qB.cpu()
|
||||
rB = rB.cpu()
|
||||
num_query = query_L.shape[0]
|
||||
# num_retrieval = retrieval_L.shape[0]
|
||||
precisions = [0] * int(1 / recall_gas)
|
||||
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
hamms = calc_hammingDist(qB, rB)
|
||||
_, inds = torch.sort(hamms, dim=-1)
|
||||
for iter in range(num_query):
|
||||
gnd = gnds[iter]
|
||||
ind = inds[iter]
|
||||
|
||||
gnd = gnd[ind]
|
||||
for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)):
|
||||
total = int(num_retrieval * recall)
|
||||
right = torch.nonzero(gnd[: total]).squeeze().numpy()
|
||||
|
||||
right_num = right.size
|
||||
precisions[i] += (right_num/total)
|
||||
for i in range(len(precisions)):
|
||||
precisions[i] /= num_query
|
||||
return precisions
|
||||
|
||||
|
||||
def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000):
|
||||
qB = qB.float()
|
||||
rB = rB.float()
|
||||
qB = torch.sign(qB - 0.5)
|
||||
rB = torch.sign(rB - 0.5)
|
||||
num_query = query_L.shape[0]
|
||||
# num_retrieval = retrieval_L.shape[0]
|
||||
precisions = [0] * int(1 / recall_gas)
|
||||
for iter in range(num_query):
|
||||
q_L = query_L[iter]
|
||||
if len(q_L.shape) < 2:
|
||||
q_L = q_L.unsqueeze(0) # [1, hash length]
|
||||
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
hamm = calc_hammingDist(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)):
|
||||
total = int(num_retrieval * recall)
|
||||
right = torch.nonzero(gnd[: total]).squeeze().numpy()
|
||||
# right_num = torch.nonzero(gnd[: total]).squeeze().shape[0]
|
||||
right_num = right.size
|
||||
precisions[i] += (right_num/total)
|
||||
for i in range(len(precisions)):
|
||||
precisions[i] /= num_query
|
||||
return precisions
|
||||
|
||||
|
||||
def calc_precisions_hash(qB, rB, query_L, retrieval_L):
|
||||
qB = qB.float()
|
||||
rB = rB.float()
|
||||
qB = torch.sign(qB - 0.5)
|
||||
rB = torch.sign(rB - 0.5)
|
||||
num_query = query_L.shape[0]
|
||||
num_retrieval = retrieval_L.shape[0]
|
||||
bit = qB.shape[1]
|
||||
hamm = calc_hammingDist(qB, rB)
|
||||
hamm = hamm.type(torch.ByteTensor)
|
||||
total_num = [0] * (bit + 1)
|
||||
max_hamm = int(torch.max(hamm))
|
||||
gnd = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze()
|
||||
total_right = torch.sum(torch.matmul(query_L, retrieval_L.t())>0)
|
||||
precisions = np.zeros([max_hamm + 1])
|
||||
recalls = np.zeros([max_hamm + 1])
|
||||
# _, index = torch.sort(hamm)
|
||||
# del _
|
||||
# for i in range(index.shape[0]):
|
||||
# gnd[i, :] = gnd[i, index[i]]
|
||||
# del index
|
||||
right_num = 0
|
||||
recall_num = 0
|
||||
for i, radius in enumerate(range(0, max_hamm+1)):
|
||||
recall = torch.nonzero(hamm == radius)
|
||||
right = gnd[recall.split(1, dim=1)]
|
||||
recall_num += recall.shape[0]
|
||||
del recall
|
||||
right_num += torch.nonzero(right).shape[0]
|
||||
del right
|
||||
precisions[i] += (right_num / (recall_num + 1e-8))
|
||||
# recalls[i] += (recall_num / num_retrieval / num_query)
|
||||
recalls[i] += (recall_num / total_right)
|
||||
return precisions, recalls
|
||||
|
||||
def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval):
|
||||
if not isinstance(qB, torch.Tensor):
|
||||
qB = torch.from_numpy(qB)
|
||||
if not isinstance(rB, torch.Tensor):
|
||||
rB = torch.from_numpy(rB)
|
||||
if not isinstance(Gnd, torch.Tensor):
|
||||
Gnd = torch.from_numpy(Gnd)
|
||||
|
||||
def CalcHammingDist_np(B1, B2):
|
||||
q = B2.shape[1]
|
||||
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
|
||||
return distH
|
||||
bit = qB.shape[1]
|
||||
# if isinstance(qB, np.ndarray):
|
||||
# hamm = CalcHammingDist_np(qB, rB)
|
||||
# else:
|
||||
hamm = calc_hammingDist(qB, rB)
|
||||
hamm = hamm.type(torch.ByteTensor)
|
||||
total_num = [0] * (bit + 1)
|
||||
max_hamm = int(torch.max(hamm))
|
||||
|
||||
gnd = Gnd
|
||||
|
||||
total_right = torch.sum(gnd>0)
|
||||
precisions = np.zeros([max_hamm + 1])
|
||||
recalls = np.zeros([max_hamm + 1])
|
||||
|
||||
right_num = 0
|
||||
recall_num = 0
|
||||
for i, radius in enumerate(range(0, max_hamm+1)):
|
||||
recall = torch.nonzero(hamm == radius)
|
||||
right = gnd[recall.split(1, dim=1)]
|
||||
recall_num += recall.shape[0]
|
||||
del recall
|
||||
right_num += torch.nonzero(right).shape[0]
|
||||
del right
|
||||
precisions[i] += (right_num / (recall_num + 1e-8))
|
||||
recalls[i] += (recall_num / num_retrieval / num_query)
|
||||
# recalls[i] += (recall_num / total_right)
|
||||
p = precisions.round(2)
|
||||
r = recalls.round(2)
|
||||
# return p, r
|
||||
|
||||
precisions = []
|
||||
recalls = []
|
||||
|
||||
precision_ = 0
|
||||
num = 1
|
||||
for i in range(len(r) - 1):
|
||||
if r[i] == r[i + 1]:
|
||||
precision_ += p[i]
|
||||
num += 1
|
||||
else:
|
||||
precision_ += p[i]
|
||||
precisions.append(precision_ / num)
|
||||
recalls.append(r[i])
|
||||
precision_ = 0
|
||||
num = 1
|
||||
|
||||
return np.asarray(precisions).round(2), np.asarray(recalls)
|
||||
|
||||
|
||||
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
|
||||
num_query = query_L.shape[0]
|
||||
bit = qB.shape[1]
|
||||
precisions = [0] * int(bit / hamming_gas)
|
||||
for iter in range(num_query):
|
||||
q_L = query_L[iter]
|
||||
if len(q_L.shape) < 2:
|
||||
q_L = q_L.unsqueeze(0) # [1, hash length]
|
||||
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
hamm = calc_hammingDist(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
for i, recall in enumerate(np.arange(1, bit+1, hamming_gas)):
|
||||
total = torch.nonzero(hamm <= recall).squeeze().shape[0]
|
||||
if total == 0:
|
||||
precisions[i] += 0
|
||||
continue
|
||||
right = torch.nonzero(gnd[: total]).squeeze().numpy()
|
||||
right_num = right.size
|
||||
|
||||
precisions[i] += (right_num / total)
|
||||
for i in range(len(precisions)):
|
||||
precisions[i] /= num_query
|
||||
return precisions
|
||||
|
||||
|
||||
def calc_neighbor(label1, label2):
|
||||
# calculate the similar matrix
|
||||
Sim = label1.matmul(label2.transpose(0, 1)) > 0
|
||||
return Sim.float()
|
||||
|
||||
|
||||
def norm_max_min(x: torch.Tensor, dim=None):
|
||||
if dim is None:
|
||||
max = torch.max(x)
|
||||
min = torch.min(x)
|
||||
if dim is not None:
|
||||
max = torch.max(x, dim=dim)[0]
|
||||
min = torch.min(x, dim=dim)[0]
|
||||
if dim > 0:
|
||||
max = max.unsqueeze(len(x.shape) - 1)
|
||||
min = min.unsqueeze(len(x.shape) - 1)
|
||||
norm = (x - min) / (max - min)
|
||||
return norm
|
||||
|
||||
|
||||
def norm_mean(x: torch.Tensor, dim=None):
|
||||
if dim is None:
|
||||
mean = torch.mean(x)
|
||||
std = torch.std(x)
|
||||
if dim is not None:
|
||||
mean = torch.mean(x, dim=dim)
|
||||
std = torch.std(x, dim=dim)
|
||||
if dim > 0:
|
||||
mean = mean.unsqueeze(len(x.shape) - 1)
|
||||
std = std.unsqueeze(len(x.shape) - 1)
|
||||
norm = (x - mean) / std
|
||||
return norm
|
||||
|
||||
|
||||
def norm_abs_mean(x: torch.Tensor, dim=None):
|
||||
if dim is None:
|
||||
mean = torch.mean(x)
|
||||
std = torch.std(x)
|
||||
if dim is not None:
|
||||
mean = torch.mean(x, dim=dim)
|
||||
std = torch.std(x, dim=dim)
|
||||
if dim > 0:
|
||||
mean = mean.unsqueeze(len(x.shape) - 1)
|
||||
std = std.unsqueeze(len(x.shape) - 1)
|
||||
norm = torch.abs(x - mean) / std
|
||||
return norm
|
||||
|
||||
|
||||
def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
else:
|
||||
return n * factorial(n - 1)
|
||||
|
||||
|
||||
def calc_IF(all_bow):
|
||||
word_num = torch.sum(all_bow, dim=0)
|
||||
total_num = torch.sum(word_num)
|
||||
IF = word_num / total_num
|
||||
return IF
|
||||
|
||||
|
||||
# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta):
|
||||
# theta = torch.matmul(F, G.transpose(0, 1)) / 2
|
||||
# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta)
|
||||
# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2
|
||||
# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f)
|
||||
# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2
|
||||
# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g)
|
||||
# intra_loss = gamma1 * intra_img + gamma2 * intra_txt
|
||||
# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta
|
||||
# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2))
|
||||
# # loss = term1 + gamma * term2 + eta * term3
|
||||
# loss = inter_loss + intra_loss + quan_loss
|
||||
# return loss
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# qB = torch.Tensor([[1, -1, 1, 1],
|
||||
# [-1, -1, -1, 1],
|
||||
# [1, 1, -1, 1],
|
||||
# [1, 1, 1, -1]])
|
||||
# rB = torch.Tensor([[1, -1, 1, -1],
|
||||
# [-1, -1, 1, -1],
|
||||
# [-1, -1, 1, -1],
|
||||
# [1, 1, -1, -1],
|
||||
# [-1, 1, -1, -1],
|
||||
# [1, 1, -1, 1]])
|
||||
# query_L = torch.Tensor([[0, 1, 0, 0],
|
||||
# [1, 1, 0, 0],
|
||||
# [1, 0, 0, 1],
|
||||
# [0, 1, 0, 1]])
|
||||
# retrieval_L = torch.Tensor([[1, 0, 0, 1],
|
||||
# [1, 1, 0, 0],
|
||||
# [0, 1, 1, 0],
|
||||
# [0, 0, 1, 0],
|
||||
# [1, 0, 0, 0],
|
||||
# [0, 0, 1, 0]])
|
||||
#
|
||||
# map = calc_map_k(qB, rB, query_L, retrieval_L)
|
||||
# print(map)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--hash-layer", type=str, default="select", help="choice a hash layer [select, linear] to run. select: select mechaism, linear: sign function.")
|
||||
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
|
||||
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
|
||||
parser.add_argument("--pretrained", type=str, default="")
|
||||
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]")
|
||||
parser.add_argument("--index-file", type=str, default="index.mat")
|
||||
parser.add_argument("--caption-file", type=str, default="caption.mat")
|
||||
parser.add_argument("--label-file", type=str, default="label.mat")
|
||||
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
||||
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
||||
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
|
||||
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
|
||||
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
|
||||
|
||||
parser.add_argument("--output-dim", type=int, default=64)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--max-words", type=int, default=32)
|
||||
parser.add_argument("--resolution", type=int, default=224)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--num-workers", type=int, default=4)
|
||||
parser.add_argument("--query-num", type=int, default=5120)
|
||||
parser.add_argument("--train-num", type=int, default=10240)
|
||||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||
parser.add_argument("--display-step", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=1814)
|
||||
|
||||
parser.add_argument("--lr", type=float, default=0.001)
|
||||
parser.add_argument("--lr-decay", type=float, default=0.9)
|
||||
parser.add_argument("--clip-lr", type=float, default=0.00001)
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2)
|
||||
parser.add_argument("--warmup-proportion", type=float, default=0.1,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--vartheta", type=float, default=0.5, help="the rate of error code.")
|
||||
parser.add_argument("--sim-threshold", type=float, default=0.1)
|
||||
|
||||
parser.add_argument("--is-train", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
||||
|
||||
def get_summary_writer(dirname: str):
|
||||
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
return SummaryWriter(log_dir=dirname)
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sklearn.metrics.pairwise import euclidean_distances
|
||||
|
||||
|
||||
def compute_metrics(x):
|
||||
# 取复值的原因在于cosine的值越大说明越相似,但是需要取的是前N个值,所以取符号变为增函数s
|
||||
sx = np.sort(-x, axis=1)
|
||||
d = np.diag(-x)
|
||||
d = d[:, np.newaxis]
|
||||
ind = sx - d
|
||||
ind = np.where(ind == 0)
|
||||
ind = ind[1]
|
||||
metrics = {}
|
||||
metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
|
||||
metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
|
||||
metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
|
||||
metrics['MR'] = np.median(ind) + 1
|
||||
metrics["MedianR"] = metrics['MR']
|
||||
metrics["MeanR"] = np.mean(ind) + 1
|
||||
metrics["cols"] = [int(i) for i in list(ind)]
|
||||
return metrics
|
||||
|
||||
def encode_hash(a: Union[torch.Tensor, np.ndarray]):
|
||||
if isinstance(a, torch.Tensor):
|
||||
hash_a = torch.sign(a)
|
||||
# where 是吧所有false值转为0
|
||||
# hash_a = torch.where(hash_a>0, hash_a, torch.tensor(0))
|
||||
return hash_a
|
||||
else:
|
||||
hash_a = np.sign(a)
|
||||
# hash_a = np.where(hash_a > 0, hash_a, 0)
|
||||
return hash_a
|
||||
|
||||
def calc_neighbor(a: torch.Tensor, b: torch.Tensor):
|
||||
# print(a.dtype, b.dtype)
|
||||
return (a.matmul(b.transpose(0, 1)) > 0).float()
|
||||
|
||||
|
||||
def euclidean_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
|
||||
|
||||
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
||||
similarity = torch.cdist(a, b, p=2.0)
|
||||
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
|
||||
similarity = euclidean_distances(a, b)
|
||||
else:
|
||||
raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b)))
|
||||
return similarity
|
||||
|
||||
|
||||
def euclidean_dist_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor):
|
||||
"""
|
||||
calculate euclidean distance as inner product
|
||||
:param tensor1: a tensor with shape (a, c)
|
||||
:param tensor2: a tensor with shape (b, c)
|
||||
:return: the euclidean distance matrix which each point is the distance between a row in tensor1 and a row in tensor2.
|
||||
"""
|
||||
dim1 = tensor1.shape[0]
|
||||
dim2 = tensor2.shape[0]
|
||||
multi = torch.matmul(tensor1, tensor2.t())
|
||||
a2 = torch.sum(torch.pow(tensor1, 2), dim=1, keepdim=True).expand(dim1, dim2)
|
||||
b2 = torch.sum(torch.pow(tensor2, 2), dim=1, keepdim=True).t().expand(dim1, dim2)
|
||||
dist = torch.sqrt(a2 + b2 - 2 * multi)
|
||||
return dist
|
||||
|
||||
|
||||
def cosine_similarity(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
|
||||
|
||||
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
||||
a = a / a.norm(dim=-1, keepdim=True) if len(torch.where(a != 0)[0]) > 0 else a
|
||||
b = b / b.norm(dim=-1, keepdim=True) if len(torch.where(b != 0)[0]) > 0 else b
|
||||
return torch.matmul(a, b.t())
|
||||
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
|
||||
a = a / np.linalg.norm(a, axis=-1, keepdims=True) if len(np.where(a != 0)[0]) > 0 else a
|
||||
b = b / np.linalg.norm(b, axis=-1, keepdims=True) if len(np.where(b != 0)[0]) > 0 else b
|
||||
return np.matmul(a, b.T)
|
||||
else:
|
||||
raise ValueError("input value must in [torch.Tensor, numpy.ndarray], but it is %s, %s"%(type(a), type(b)))
|
||||
|
||||
def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
||||
# qB: {-1,+1}^{mxq}
|
||||
# rB: {-1,+1}^{nxq}
|
||||
# query_L: {0,1}^{mxl}
|
||||
# retrieval_L: {0,1}^{nxl}
|
||||
num_query = query_L.shape[0]
|
||||
qB = torch.sign(qB)
|
||||
rB = torch.sign(rB)
|
||||
map = 0
|
||||
if k is None:
|
||||
k = retrieval_L.shape[0]
|
||||
# print("query num:", num_query)
|
||||
for iter in range(num_query):
|
||||
q_L = query_L[iter]
|
||||
if len(q_L.shape) < 2:
|
||||
q_L = q_L.unsqueeze(0) # [1, hash length]
|
||||
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
tsum = torch.sum(gnd)
|
||||
if tsum == 0:
|
||||
continue
|
||||
hamm = calcHammingDist(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
total = min(k, int(tsum))
|
||||
count = torch.arange(1, total + 1).type(torch.float32)
|
||||
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
|
||||
if tindex.is_cuda:
|
||||
count = count.to(rank)
|
||||
map = map + torch.mean(count / tindex)
|
||||
map = map / num_query
|
||||
return map
|
||||
|
||||
def softmax_hash(code: Union[torch.Tensor, np.ndarray], dim_alph=0.25):
|
||||
|
||||
device = None
|
||||
if isinstance(code, torch.Tensor):
|
||||
device = code.device
|
||||
if code.is_cuda:
|
||||
code = code.detach.cpu().numpy()
|
||||
else:
|
||||
code = code.cpu().numpy()
|
||||
|
||||
softmax_code = np.exp(code) / np.exp(code).sum(axis=-1, keepdims=True)
|
||||
# print("max softmax_code:", softmax_code.max())
|
||||
# print("min softmax_code:", softmax_code.min())
|
||||
# print(1 / int(dim_alph * softmax_code.shape[-1]))
|
||||
# print(np.sum(softmax_code[0]))
|
||||
hash_code = np.where(softmax_code >= 1 / int(dim_alph * softmax_code.shape[-1]), softmax_code, -1)
|
||||
hash_code = np.where(hash_code <= 1 / int(dim_alph * softmax_code.shape[-1]), hash_code, 1)
|
||||
# print(hash_code[0])
|
||||
# print("hash code sum:", hash_code.sum())
|
||||
# print(len(np.where(hash_code == 1.0)[0]))
|
||||
# print(len(np.where(hash_code == -1.0)[0]))
|
||||
if device is not None:
|
||||
hash_code = torch.from_numpy(hash_code).to(device)
|
||||
return hash_code
|
||||
|
||||
# def calcHammingDist(B1, B2):
|
||||
# result = np.zeros((B1.shape[0], B2.shape[0]))
|
||||
# for i, data in enumerate(B1):
|
||||
# result[i] = np.sum(np.where((data + B2) != 2, data + B2, 0), axis=-1)
|
||||
# return result
|
||||
|
||||
def calcHammingDist(B1, B2):
|
||||
|
||||
if len(B1.shape) < 2:
|
||||
B1.view(1, -1)
|
||||
if len(B2.shape) < 2:
|
||||
B2.view(1, -1)
|
||||
q = B2.shape[1]
|
||||
if isinstance(B1, torch.Tensor):
|
||||
distH = 0.5 * (q - torch.matmul(B1, B2.t()))
|
||||
elif isinstance(B1, np.ndarray):
|
||||
distH = 0.5 * (q - np.matmul(B1, B2.transpose()))
|
||||
else:
|
||||
raise ValueError("B1, B2 must in [torch.Tensor, np.ndarray]")
|
||||
return distH
|
||||
|
||||
def compute_hash_similarity(visual_embed, text_embed, use_softmax_hash=False, alph=0.25):
|
||||
# hamming distance的值越大说明越不相似
|
||||
hash_visual = encode_hash(visual_embed) if not use_softmax_hash else softmax_hash(visual_embed, alph)
|
||||
hash_text = encode_hash(text_embed) if not use_softmax_hash else softmax_hash(text_embed, alph)
|
||||
vt_similarity = calcHammingDist(hash_visual, hash_text)
|
||||
# print(vt_similarity[0])
|
||||
tv_similarity = calcHammingDist(hash_text, hash_visual)
|
||||
return vt_similarity, tv_similarity
|
||||
|
||||
class CrossEn(nn.Module):
|
||||
def __init__(self, mode="cosine"):
|
||||
super(CrossEn, self).__init__()
|
||||
# if mode == "euclidean":
|
||||
# self.compute_func = F.softmax
|
||||
# else:
|
||||
# self.compute_func = F.log_softmax
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
# if self.mode == "cosine":
|
||||
# logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
# logpt = torch.diag(logpt)
|
||||
# nce_loss = -logpt
|
||||
# sim_loss = nce_loss.mean()
|
||||
# elif self.mode == "euclidean":
|
||||
# logpt = F.softmax(sim_matrix, dim=-1)
|
||||
# logpt = torch.diag(sim_matrix)
|
||||
# sim_loss = logpt.mean()
|
||||
# else:
|
||||
# raise ValueError("mode paramater is not support.[cosine, euclidean]")
|
||||
if self.mode == "euclidean":
|
||||
sim_matrix = -sim_matrix
|
||||
logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
sim_loss = nce_loss.mean()
|
||||
return sim_loss
|
||||
|
||||
class CrossEn_mean(nn.Module):
|
||||
def __init__(self, mode="cosine"):
|
||||
super(CrossEn_mean, self).__init__()
|
||||
# if mode == "euclidean":
|
||||
# self.compute_func = F.softmax
|
||||
# else:
|
||||
# self.compute_func = F.log_softmax
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
# if self.mode == "cosine":
|
||||
# logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
# logpt = torch.diag(logpt)
|
||||
# nce_loss = -logpt
|
||||
# sim_loss = nce_loss.mean()
|
||||
# elif self.mode == "euclidean":
|
||||
# logpt = F.softmax(sim_matrix, dim=-1)
|
||||
# logpt = torch.diag(sim_matrix)
|
||||
# sim_loss = logpt.mean()
|
||||
# else:
|
||||
# raise ValueError("mode paramater is not support.[cosine, euclidean]")
|
||||
# if self.mode == "euclidean":
|
||||
# sim_matrix = -sim_matrix
|
||||
# print(sim_matrix.max(), sim_matrix.min())
|
||||
# logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
# logpt = torch.diag(logpt)
|
||||
# print(logpt.max())
|
||||
sim_loss = sim_matrix.mean()
|
||||
return sim_loss
|
||||
Loading…
Reference in New Issue