This commit is contained in:
leewlving 2024-06-08 16:23:24 +08:00
parent a22408c1ec
commit f679a8cedb
5 changed files with 245 additions and 148 deletions

193
.gitignore vendored Normal file
View File

@ -0,0 +1,193 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
dataset/
# End of https://www.toptal.com/developers/gitignore/api/python,linux

View File

@ -32,7 +32,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
def dataloader(captionFile: str, def dataloader(captionFile: str,
indexFile: str, indexFile: str,
labelFile: str, labelFile: str,
maxWords=32, maxWords=77,
imageResolution=224, imageResolution=224,
query_num=5000, query_num=5000,
train_num=10000, train_num=10000,

View File

@ -3,6 +3,8 @@ from train.hash_train import Trainer
if __name__ == "__main__": if __name__ == "__main__":
Trainer() engine=Trainer()
engine.test()
engine.train()

View File

@ -33,63 +33,18 @@ class Trainer(TrainBase):
args = get_args() args = get_args()
super(Trainer, self).__init__(args, rank) super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
image_representation, text_representation=self.generate_mapping() text_representation, text_representation=self.generate_mapping()
self.image_representation=image_representation self.image_representation=text_representation
self.text_representation=text_representation self.text_representation=text_representation
self.device=rank self.device=rank
# self.run() # self.run()
def _init_model(self): def _init_model(self):
self.logger.info("init model.") self.logger.info("init model.")
# self.generator=Generator()
# linear = False
# if self.args.hash_layer == "linear":
# linear = True
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
# self.bert.eval()
# self.logger.info("ViT+GPT!")
# HashModel = DCMHT
# if self.args.victim_model == 'JDSH':
# from model.JDSH import TxtNet, ImgNet
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
# checkpoint=torch.load(path)
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
# self.img_model.eval()
# self.txt_model.eval()
# elif self.args.victim_model == 'DJSRH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'SSAH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'DCHUC':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
# self.logger.info("load pretrained model.")
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device) model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
self.model= model_clip self.model= model_clip
self.model.eval() self.model.eval()
self.model.float() self.model.float()
# self.optimizer = BertAdam([
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
# print(self.model)
def _init_dataset(self): def _init_dataset(self):
self.logger.info("init dataset.") self.logger.info("init dataset.")
@ -132,22 +87,21 @@ class Trainer(TrainBase):
pin_memory=True, pin_memory=True,
shuffle=True shuffle=True
) )
def generate_mapping(self): def generate_mapping(self):
text_train=[] text_train=[]
label_train=[] label_train=[]
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
# image=image.to(self.device, non_blocking=True)
text=text.to(device, non_blocking=True) text=text.to(device, non_blocking=True)
# print(self.model.vocab_size)
temp_text=self.model.encode_text(text) temp_text=self.model.encode_text(text)
# temp_image=self.model.encode_image(image)
# image_train.append(temp_image.cpu().detach().numpy())
text_train.append(temp_text.cpu().detach().numpy()) text_train.append(temp_text.cpu().detach().numpy())
label_train.append(label.detach().numpy()) label_train.append(label.detach().numpy())
text_train=np.concatenate(text_train, axis=0) text_train=np.concatenate(text_train, axis=0)
# image_train=np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0) label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0) label_unipue=np.unique(label_train,axis=0)
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
@ -162,7 +116,6 @@ class Trainer(TrainBase):
epsilon=0.03125, alpha=3/255, num_iter=100): epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True) delta = torch.zeros_like(image,requires_grad=True)
# clean_output = self.model.encode_image(image)
one=torch.zeros_like(positive) one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter): for i in range(num_iter):
@ -178,42 +131,38 @@ class Trainer(TrainBase):
return delta.detach() return delta.detach()
def train_epoch(self, epoch): # def train_epoch(self, epoch):
self.change_state(mode="valid") # self.change_state(mode="valid")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) # self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0 # all_loss = 0
times = 0 # times = 0
adv_images=[] # adv_images=[]
adv_labels=[] # adv_labels=[]
texts=[] # texts=[]
for image, text, label, index in self.train_loader: # for image, text, label, index in self.train_loader:
self.global_step += 1 # self.global_step += 1
times += 1 # times += 1
image.float() # image.float()
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: # if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
label = torch.ones([image.shape[0]], dtype=torch.int) # label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag() # label = label.diag()
image = image.to(self.rank, non_blocking=True) # image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) # text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_anchor=self.image_representation(label.detach().cpu().numpy())
text_anchor=self.text_representation(label.detach().cpu().numpy())
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
target_label=label.flip(dims=[0])
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# print("text shape:", text.shape)
# index = index.numpy() # index = index.numpy()
# print(text.shape) # image_anchor=self.image_representation(label.detach().cpu().numpy())
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True), # text_anchor=self.text_representation(label.detach().cpu().numpy())
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True)) # negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
adv_image=delta+image # target_label=label.flip(dims=[0])
adv_images.append(adv_image) # target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
adv_labels.append(target_label) # target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
texts.append(text) # positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") # delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
return adv_images, texts, adv_labels # torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
# adv_image=delta+image
# adv_images.append(adv_image)
# adv_labels.append(target_label)
# texts.append(text)
# return adv_images, texts, adv_labels
@ -227,55 +176,7 @@ class Trainer(TrainBase):
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
s = torch.matmul(a, b.t())
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
return b_loss
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
"""
"""
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
print("mean", torch.mean(a - b))
print("kl", kl_divergence)
return kl_divergence
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
# $\vartheta$
vartheta = self.args.vartheta
if self.args.sim_threshold != 0:
threshold = self.args.sim_threshold
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
positive_similarity = similarity * label_sim
# 只要cosine为负值的全都算为计算正确了因为优化到2确实很难。
negative_similarity = similarity * (1 - label_sim)
if self.args.similarity_function == "cosine":
positive_similarity = positive_similarity.clip(threshold) - threshold
negative_similarity = negative_similarity.clip(max=1.)
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
elif self.args.similarity_function == "euclidean":
# 有euclidean距离可知当有一半长度的hash码不同时其negative_similarity距离应该是长度concat操作将outputdim翻倍所以这里clip掉认为认定的值
# 人为认定的最大值是一半长度的hash码不同。
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
negative_similarity = negative_similarity.clip(max=max_value)
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
if self.args.loss_type == "l1":
positive_loss = positive_similarity.mean()
negative_loss = negative_similarity.mean()
elif self.args.loss_type == "l2":
positive_loss = torch.pow(positive_similarity, 2).mean()
negative_loss = torch.pow(negative_similarity, 2).mean()
else:
raise ValueError("argument of loss_type is not support.")
return similarity, positive_loss, negative_loss
def make_hash_code(self, code: list) -> torch.Tensor: def make_hash_code(self, code: list) -> torch.Tensor:
@ -290,16 +191,16 @@ class Trainer(TrainBase):
def get_code(self, data_loader, length: int): def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) img_buffer = []
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) text_buffer = []
for image, text, label, index in tqdm(data_loader): for image, text, label, index in tqdm(data_loader):
image = image.to(self.rank, non_blocking=True) image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True)
index = index.numpy() index = index.numpy()
image_hash=self.img_model(image) image_hash=self.model.encode_image(image)
text_feat=self.bert(text)[0] # text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat) text_hash=self.model.encode_text(text)
img_buffer[index, :] = image_hash.data img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data text_buffer[index, :] = text_hash.data
@ -333,9 +234,10 @@ class Trainer(TrainBase):
def test(self, mode_name="i2t"): def test(self, mode_name="i2t"):
if self.args.pretrained == "": self.logger.info("Valid Clean.")
raise RuntimeError("test step must load a model! please set the --pretrained argument.") # if self.args.pretrained == "":
self.change_state(mode="valid") # raise RuntimeError("test step must load a model! please set the --pretrained argument.")
# self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "PR_cruve") save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)

View File

@ -21,7 +21,7 @@ def get_args():
parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=64) parser.add_argument("--output-dim", type=int, default=64)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=32) parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4)