diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bf9d45b --- /dev/null +++ b/.gitignore @@ -0,0 +1,193 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,linux +# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +dataset/ + +# End of https://www.toptal.com/developers/gitignore/api/python,linux diff --git a/dataset/dataloader.py b/dataset/dataloader.py index d04c1bb..9cf1eda 100644 --- a/dataset/dataloader.py +++ b/dataset/dataloader.py @@ -32,7 +32,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N def dataloader(captionFile: str, indexFile: str, labelFile: str, - maxWords=32, + maxWords=77, imageResolution=224, query_num=5000, train_num=10000, diff --git a/main.py b/main.py index 4de8332..33dd85e 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ from train.hash_train import Trainer if __name__ == "__main__": - Trainer() + engine=Trainer() + engine.test() + engine.train() diff --git a/train/hash_train.py b/train/hash_train.py index f0e9b42..972fbfa 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -33,63 +33,18 @@ class Trainer(TrainBase): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) - image_representation, text_representation=self.generate_mapping() - self.image_representation=image_representation + text_representation, text_representation=self.generate_mapping() + self.image_representation=text_representation self.text_representation=text_representation self.device=rank # self.run() def _init_model(self): self.logger.info("init model.") - # self.generator=Generator() - # linear = False - # if self.args.hash_layer == "linear": - # linear = True - # self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank) - # self.bert.eval() - # self.logger.info("ViT+GPT!") - # HashModel = DCMHT - # if self.args.victim_model == 'JDSH': - # from model.JDSH import TxtNet, ImgNet - # # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank) - # self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank) - # path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth') - # checkpoint=torch.load(path) - # self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}")) - # self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}")) - # self.img_model.eval() - # self.txt_model.eval() - # elif self.args.victim_model == 'DJSRH': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - # elif self.args.victim_model == 'SSAH': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - # elif self.args.victim_model == 'DCHUC': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - - # if self.args.pretrained != "" and os.path.exists(self.args.pretrained): - # self.logger.info("load pretrained model.") - # self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device) self.model= model_clip self.model.eval() self.model.float() - # self.optimizer = BertAdam([ - # {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, - # {'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, - # {'params': self.model.text_hash.parameters(), 'lr': self.args.lr} - # ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', - # b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs, - # weight_decay=self.args.weight_decay, max_grad_norm=1.0) - - # print(self.model) def _init_dataset(self): self.logger.info("init dataset.") @@ -132,22 +87,21 @@ class Trainer(TrainBase): pin_memory=True, shuffle=True ) + + + def generate_mapping(self): text_train=[] label_train=[] for image, text, label, index in self.train_loader: - # image=image.to(self.device, non_blocking=True) text=text.to(device, non_blocking=True) + # print(self.model.vocab_size) temp_text=self.model.encode_text(text) - # temp_image=self.model.encode_image(image) - # image_train.append(temp_image.cpu().detach().numpy()) text_train.append(temp_text.cpu().detach().numpy()) label_train.append(label.detach().numpy()) text_train=np.concatenate(text_train, axis=0) - # image_train=np.concatenate(image_train, axis=0) label_train=np.concatenate(label_train, axis=0) label_unipue=np.unique(label_train,axis=0) - # image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) @@ -162,7 +116,6 @@ class Trainer(TrainBase): epsilon=0.03125, alpha=3/255, num_iter=100): delta = torch.zeros_like(image,requires_grad=True) - # clean_output = self.model.encode_image(image) one=torch.zeros_like(positive) alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) for i in range(num_iter): @@ -178,42 +131,38 @@ class Trainer(TrainBase): return delta.detach() - def train_epoch(self, epoch): - self.change_state(mode="valid") - self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) - all_loss = 0 - times = 0 - adv_images=[] - adv_labels=[] - texts=[] - for image, text, label, index in self.train_loader: - self.global_step += 1 - times += 1 - image.float() - if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: - label = torch.ones([image.shape[0]], dtype=torch.int) - label = label.diag() - image = image.to(self.rank, non_blocking=True) - text = text.to(self.rank, non_blocking=True) - index = index.numpy() - image_anchor=self.image_representation(label.detach().cpu().numpy()) - text_anchor=self.text_representation(label.detach().cpu().numpy()) - negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) - target_label=label.flip(dims=[0]) - target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) - target_text_anchor=self.text_representation(target_label.detach().cpu().numpy()) - positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) - # print("text shape:", text.shape) - # index = index.numpy() - # print(text.shape) - delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True), - torch.from_numpy(negetive_code).to(self.rank, non_blocking=True)) - adv_image=delta+image - adv_images.append(adv_image) - adv_labels.append(target_label) - texts.append(text) - # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") - return adv_images, texts, adv_labels + # def train_epoch(self, epoch): + # self.change_state(mode="valid") + # self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) + # all_loss = 0 + # times = 0 + # adv_images=[] + # adv_labels=[] + # texts=[] + # for image, text, label, index in self.train_loader: + # self.global_step += 1 + # times += 1 + # image.float() + # if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: + # label = torch.ones([image.shape[0]], dtype=torch.int) + # label = label.diag() + # image = image.to(self.rank, non_blocking=True) + # text = text.to(self.rank, non_blocking=True) + # index = index.numpy() + # image_anchor=self.image_representation(label.detach().cpu().numpy()) + # text_anchor=self.text_representation(label.detach().cpu().numpy()) + # negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) + # target_label=label.flip(dims=[0]) + # target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) + # target_text_anchor=self.text_representation(target_label.detach().cpu().numpy()) + # positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) + # delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True), + # torch.from_numpy(negetive_code).to(self.rank, non_blocking=True)) + # adv_image=delta+image + # adv_images.append(adv_image) + # adv_labels.append(target_label) + # texts.append(text) + # return adv_images, texts, adv_labels @@ -227,55 +176,7 @@ class Trainer(TrainBase): self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") - def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): - - s = torch.matmul(a, b.t()) - b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s))) - - return b_loss - def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): - """ - """ - kl_divergence = torch.mean(a * torch.log(a / (b + 0.001))) - print("mean", torch.mean(a - b)) - print("kl", kl_divergence) - return kl_divergence - - - def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05): - - # $\vartheta$ - vartheta = self.args.vartheta - if self.args.sim_threshold != 0: - threshold = self.args.sim_threshold - similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b) - - positive_similarity = similarity * label_sim - # 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。 - negative_similarity = similarity * (1 - label_sim) - - if self.args.similarity_function == "cosine": - positive_similarity = positive_similarity.clip(threshold) - threshold - negative_similarity = negative_similarity.clip(max=1.) - negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity - elif self.args.similarity_function == "euclidean": - # 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值 - # 人为认定的最大值是一半长度的hash码不同。 - max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5 - negative_similarity = negative_similarity.clip(max=max_value) - negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity - - if self.args.loss_type == "l1": - positive_loss = positive_similarity.mean() - negative_loss = negative_similarity.mean() - elif self.args.loss_type == "l2": - positive_loss = torch.pow(positive_similarity, 2).mean() - negative_loss = torch.pow(negative_similarity, 2).mean() - else: - raise ValueError("argument of loss_type is not support.") - - return similarity, positive_loss, negative_loss def make_hash_code(self, code: list) -> torch.Tensor: @@ -290,16 +191,16 @@ class Trainer(TrainBase): def get_code(self, data_loader, length: int): - img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) - text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + img_buffer = [] + text_buffer = [] for image, text, label, index in tqdm(data_loader): image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) index = index.numpy() - image_hash=self.img_model(image) - text_feat=self.bert(text)[0] - text_hash=self.txt_model(text_feat) + image_hash=self.model.encode_image(image) + # text_feat=self.bert(text)[0] + text_hash=self.model.encode_text(text) img_buffer[index, :] = image_hash.data text_buffer[index, :] = text_hash.data @@ -333,9 +234,10 @@ class Trainer(TrainBase): def test(self, mode_name="i2t"): - if self.args.pretrained == "": - raise RuntimeError("test step must load a model! please set the --pretrained argument.") - self.change_state(mode="valid") + self.logger.info("Valid Clean.") + # if self.args.pretrained == "": + # raise RuntimeError("test step must load a model! please set the --pretrained argument.") + # self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) diff --git a/utils/get_args.py b/utils/get_args.py index 8b6aaa6..d8513fe 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -21,7 +21,7 @@ def get_args(): parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--output-dim", type=int, default=64) parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--max-words", type=int, default=32) + parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--num-workers", type=int, default=4)