diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/advclip.iml b/.idea/advclip.iml new file mode 100644 index 0000000..7373713 --- /dev/null +++ b/.idea/advclip.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..bd38a81 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,30 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..2c0f094 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..43f32c1 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/main.py b/main.py index e3c978e..c68339f 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from train.hash_train import Trainer +from train.text_train import Trainer if __name__ == "__main__": diff --git a/model/simple_tokenizer.py b/model/simple_tokenizer.py index 3eb73c2..e22e0d0 100755 --- a/model/simple_tokenizer.py +++ b/model/simple_tokenizer.py @@ -130,6 +130,12 @@ class SimpleTokenizer(object): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text + + def my_decode(self, tokens): + tokens=[item for item in tokens if item in self.decoder] + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text def tokenize(self, text): tokens = [] diff --git a/train/text_train.py b/train/text_train.py index 0b77cf6..5d6211b 100644 --- a/train/text_train.py +++ b/train/text_train.py @@ -249,17 +249,18 @@ class Trainer(TrainBase): def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): - bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True) + bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt') mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits - word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1) + word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1) #clean state clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) final_adverse = [] # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + # print(texts) for i, text in enumerate(texts): - important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length) + important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words) list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) words, sub_words, keys = self._tokenize(text) final_words = copy.deepcopy(words) @@ -315,7 +316,8 @@ class Trainer(TrainBase): times += 1 print(times) image.float() - raw_text=[self.clip_tokenizer.decode(token) for token in text] + + raw_text=[self.clip_tokenizer.my_decode(token) for token in text] image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) diff --git a/utils/get_args.py b/utils/get_args.py index bb23590..36c1518 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -16,8 +16,9 @@ def get_args(): parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]") parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]") parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) - # parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat") - # parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat") + parser.add_argument("--text_encoder", type=str, default="bert-base-uncased") + parser.add_argument("--topk", type=int, default=10) + parser.add_argument("--num-perturbation", type=int, default=3) parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--output-dim", type=int, default=512) parser.add_argument("--epochs", type=int, default=100)