update
This commit is contained in:
parent
a22408c1ec
commit
f679a8cedb
|
|
@ -0,0 +1,193 @@
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
|
||||||
|
|
||||||
|
### Linux ###
|
||||||
|
*~
|
||||||
|
|
||||||
|
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||||
|
.fuse_hidden*
|
||||||
|
|
||||||
|
# KDE directory preferences
|
||||||
|
.directory
|
||||||
|
|
||||||
|
# Linux trash folder which might appear on any partition or disk
|
||||||
|
.Trash-*
|
||||||
|
|
||||||
|
# .nfs files are created when an open file is removed but is still being accessed
|
||||||
|
.nfs*
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
### Python Patch ###
|
||||||
|
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||||
|
poetry.toml
|
||||||
|
|
||||||
|
# ruff
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# LSP config files
|
||||||
|
pyrightconfig.json
|
||||||
|
|
||||||
|
dataset/
|
||||||
|
|
||||||
|
# End of https://www.toptal.com/developers/gitignore/api/python,linux
|
||||||
|
|
@ -32,7 +32,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
|
||||||
def dataloader(captionFile: str,
|
def dataloader(captionFile: str,
|
||||||
indexFile: str,
|
indexFile: str,
|
||||||
labelFile: str,
|
labelFile: str,
|
||||||
maxWords=32,
|
maxWords=77,
|
||||||
imageResolution=224,
|
imageResolution=224,
|
||||||
query_num=5000,
|
query_num=5000,
|
||||||
train_num=10000,
|
train_num=10000,
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -3,6 +3,8 @@ from train.hash_train import Trainer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
Trainer()
|
engine=Trainer()
|
||||||
|
engine.test()
|
||||||
|
engine.train()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,63 +33,18 @@ class Trainer(TrainBase):
|
||||||
args = get_args()
|
args = get_args()
|
||||||
super(Trainer, self).__init__(args, rank)
|
super(Trainer, self).__init__(args, rank)
|
||||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||||
image_representation, text_representation=self.generate_mapping()
|
text_representation, text_representation=self.generate_mapping()
|
||||||
self.image_representation=image_representation
|
self.image_representation=text_representation
|
||||||
self.text_representation=text_representation
|
self.text_representation=text_representation
|
||||||
self.device=rank
|
self.device=rank
|
||||||
# self.run()
|
# self.run()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
self.logger.info("init model.")
|
self.logger.info("init model.")
|
||||||
# self.generator=Generator()
|
|
||||||
# linear = False
|
|
||||||
# if self.args.hash_layer == "linear":
|
|
||||||
# linear = True
|
|
||||||
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
|
|
||||||
# self.bert.eval()
|
|
||||||
# self.logger.info("ViT+GPT!")
|
|
||||||
# HashModel = DCMHT
|
|
||||||
# if self.args.victim_model == 'JDSH':
|
|
||||||
# from model.JDSH import TxtNet, ImgNet
|
|
||||||
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
|
||||||
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
|
||||||
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
|
|
||||||
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
|
|
||||||
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
|
|
||||||
# checkpoint=torch.load(path)
|
|
||||||
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
|
|
||||||
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
|
|
||||||
# self.img_model.eval()
|
|
||||||
# self.txt_model.eval()
|
|
||||||
# elif self.args.victim_model == 'DJSRH':
|
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
|
||||||
# elif self.args.victim_model == 'SSAH':
|
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
|
||||||
# elif self.args.victim_model == 'DCHUC':
|
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
|
||||||
|
|
||||||
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
|
|
||||||
# self.logger.info("load pretrained model.")
|
|
||||||
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
|
||||||
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
|
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
|
||||||
self.model= model_clip
|
self.model= model_clip
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.float()
|
self.model.float()
|
||||||
# self.optimizer = BertAdam([
|
|
||||||
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
|
||||||
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
|
||||||
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
|
|
||||||
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
|
|
||||||
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
|
|
||||||
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
|
|
||||||
|
|
||||||
# print(self.model)
|
|
||||||
|
|
||||||
def _init_dataset(self):
|
def _init_dataset(self):
|
||||||
self.logger.info("init dataset.")
|
self.logger.info("init dataset.")
|
||||||
|
|
@ -132,22 +87,21 @@ class Trainer(TrainBase):
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_mapping(self):
|
def generate_mapping(self):
|
||||||
text_train=[]
|
text_train=[]
|
||||||
label_train=[]
|
label_train=[]
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
# image=image.to(self.device, non_blocking=True)
|
|
||||||
text=text.to(device, non_blocking=True)
|
text=text.to(device, non_blocking=True)
|
||||||
|
# print(self.model.vocab_size)
|
||||||
temp_text=self.model.encode_text(text)
|
temp_text=self.model.encode_text(text)
|
||||||
# temp_image=self.model.encode_image(image)
|
|
||||||
# image_train.append(temp_image.cpu().detach().numpy())
|
|
||||||
text_train.append(temp_text.cpu().detach().numpy())
|
text_train.append(temp_text.cpu().detach().numpy())
|
||||||
label_train.append(label.detach().numpy())
|
label_train.append(label.detach().numpy())
|
||||||
text_train=np.concatenate(text_train, axis=0)
|
text_train=np.concatenate(text_train, axis=0)
|
||||||
# image_train=np.concatenate(image_train, axis=0)
|
|
||||||
label_train=np.concatenate(label_train, axis=0)
|
label_train=np.concatenate(label_train, axis=0)
|
||||||
label_unipue=np.unique(label_train,axis=0)
|
label_unipue=np.unique(label_train,axis=0)
|
||||||
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
|
||||||
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
|
|
||||||
|
|
@ -162,7 +116,6 @@ class Trainer(TrainBase):
|
||||||
epsilon=0.03125, alpha=3/255, num_iter=100):
|
epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||||
|
|
||||||
delta = torch.zeros_like(image,requires_grad=True)
|
delta = torch.zeros_like(image,requires_grad=True)
|
||||||
# clean_output = self.model.encode_image(image)
|
|
||||||
one=torch.zeros_like(positive)
|
one=torch.zeros_like(positive)
|
||||||
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||||
for i in range(num_iter):
|
for i in range(num_iter):
|
||||||
|
|
@ -178,42 +131,38 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
return delta.detach()
|
return delta.detach()
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
# def train_epoch(self, epoch):
|
||||||
self.change_state(mode="valid")
|
# self.change_state(mode="valid")
|
||||||
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||||
all_loss = 0
|
# all_loss = 0
|
||||||
times = 0
|
# times = 0
|
||||||
adv_images=[]
|
# adv_images=[]
|
||||||
adv_labels=[]
|
# adv_labels=[]
|
||||||
texts=[]
|
# texts=[]
|
||||||
for image, text, label, index in self.train_loader:
|
# for image, text, label, index in self.train_loader:
|
||||||
self.global_step += 1
|
# self.global_step += 1
|
||||||
times += 1
|
# times += 1
|
||||||
image.float()
|
# image.float()
|
||||||
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
# if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
||||||
label = torch.ones([image.shape[0]], dtype=torch.int)
|
# label = torch.ones([image.shape[0]], dtype=torch.int)
|
||||||
label = label.diag()
|
# label = label.diag()
|
||||||
image = image.to(self.rank, non_blocking=True)
|
# image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
# text = text.to(self.rank, non_blocking=True)
|
||||||
index = index.numpy()
|
|
||||||
image_anchor=self.image_representation(label.detach().cpu().numpy())
|
|
||||||
text_anchor=self.text_representation(label.detach().cpu().numpy())
|
|
||||||
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
|
||||||
target_label=label.flip(dims=[0])
|
|
||||||
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
|
||||||
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
|
||||||
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
|
||||||
# print("text shape:", text.shape)
|
|
||||||
# index = index.numpy()
|
# index = index.numpy()
|
||||||
# print(text.shape)
|
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||||
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
# text_anchor=self.text_representation(label.detach().cpu().numpy())
|
||||||
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||||
adv_image=delta+image
|
# target_label=label.flip(dims=[0])
|
||||||
adv_images.append(adv_image)
|
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||||
adv_labels.append(target_label)
|
# target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
||||||
texts.append(text)
|
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
# delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
||||||
return adv_images, texts, adv_labels
|
# torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
||||||
|
# adv_image=delta+image
|
||||||
|
# adv_images.append(adv_image)
|
||||||
|
# adv_labels.append(target_label)
|
||||||
|
# texts.append(text)
|
||||||
|
# return adv_images, texts, adv_labels
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -227,55 +176,7 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||||
|
|
||||||
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
|
||||||
|
|
||||||
s = torch.matmul(a, b.t())
|
|
||||||
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
|
|
||||||
|
|
||||||
return b_loss
|
|
||||||
|
|
||||||
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
|
|
||||||
print("mean", torch.mean(a - b))
|
|
||||||
print("kl", kl_divergence)
|
|
||||||
return kl_divergence
|
|
||||||
|
|
||||||
|
|
||||||
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
|
|
||||||
|
|
||||||
# $\vartheta$
|
|
||||||
vartheta = self.args.vartheta
|
|
||||||
if self.args.sim_threshold != 0:
|
|
||||||
threshold = self.args.sim_threshold
|
|
||||||
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
|
|
||||||
|
|
||||||
positive_similarity = similarity * label_sim
|
|
||||||
# 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。
|
|
||||||
negative_similarity = similarity * (1 - label_sim)
|
|
||||||
|
|
||||||
if self.args.similarity_function == "cosine":
|
|
||||||
positive_similarity = positive_similarity.clip(threshold) - threshold
|
|
||||||
negative_similarity = negative_similarity.clip(max=1.)
|
|
||||||
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
|
||||||
elif self.args.similarity_function == "euclidean":
|
|
||||||
# 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值
|
|
||||||
# 人为认定的最大值是一半长度的hash码不同。
|
|
||||||
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
|
|
||||||
negative_similarity = negative_similarity.clip(max=max_value)
|
|
||||||
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
|
||||||
|
|
||||||
if self.args.loss_type == "l1":
|
|
||||||
positive_loss = positive_similarity.mean()
|
|
||||||
negative_loss = negative_similarity.mean()
|
|
||||||
elif self.args.loss_type == "l2":
|
|
||||||
positive_loss = torch.pow(positive_similarity, 2).mean()
|
|
||||||
negative_loss = torch.pow(negative_similarity, 2).mean()
|
|
||||||
else:
|
|
||||||
raise ValueError("argument of loss_type is not support.")
|
|
||||||
|
|
||||||
return similarity, positive_loss, negative_loss
|
|
||||||
|
|
||||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
@ -290,16 +191,16 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
def get_code(self, data_loader, length: int):
|
def get_code(self, data_loader, length: int):
|
||||||
|
|
||||||
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
img_buffer = []
|
||||||
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
text_buffer = []
|
||||||
|
|
||||||
for image, text, label, index in tqdm(data_loader):
|
for image, text, label, index in tqdm(data_loader):
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
index = index.numpy()
|
index = index.numpy()
|
||||||
image_hash=self.img_model(image)
|
image_hash=self.model.encode_image(image)
|
||||||
text_feat=self.bert(text)[0]
|
# text_feat=self.bert(text)[0]
|
||||||
text_hash=self.txt_model(text_feat)
|
text_hash=self.model.encode_text(text)
|
||||||
img_buffer[index, :] = image_hash.data
|
img_buffer[index, :] = image_hash.data
|
||||||
text_buffer[index, :] = text_hash.data
|
text_buffer[index, :] = text_hash.data
|
||||||
|
|
||||||
|
|
@ -333,9 +234,10 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
|
|
||||||
def test(self, mode_name="i2t"):
|
def test(self, mode_name="i2t"):
|
||||||
if self.args.pretrained == "":
|
self.logger.info("Valid Clean.")
|
||||||
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
# if self.args.pretrained == "":
|
||||||
self.change_state(mode="valid")
|
# raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
||||||
|
# self.change_state(mode="valid")
|
||||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ def get_args():
|
||||||
parser.add_argument("--txt-dim", type=int, default=1024)
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=64)
|
parser.add_argument("--output-dim", type=int, default=64)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--max-words", type=int, default=32)
|
parser.add_argument("--max-words", type=int, default=77)
|
||||||
parser.add_argument("--resolution", type=int, default=224)
|
parser.add_argument("--resolution", type=int, default=224)
|
||||||
parser.add_argument("--batch-size", type=int, default=64)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
parser.add_argument("--num-workers", type=int, default=4)
|
parser.add_argument("--num-workers", type=int, default=4)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue