diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/advclip.iml b/.idea/advclip.iml
new file mode 100644
index 0000000..7373713
--- /dev/null
+++ b/.idea/advclip.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..bd38a81
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,30 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..2c0f094
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..43f32c1
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/main.py b/main.py
index e3c978e..c68339f 100644
--- a/main.py
+++ b/main.py
@@ -1,4 +1,4 @@
-from train.hash_train import Trainer
+from train.text_train import Trainer
if __name__ == "__main__":
diff --git a/model/simple_tokenizer.py b/model/simple_tokenizer.py
index 3eb73c2..e22e0d0 100755
--- a/model/simple_tokenizer.py
+++ b/model/simple_tokenizer.py
@@ -130,6 +130,12 @@ class SimpleTokenizer(object):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
return text
+
+ def my_decode(self, tokens):
+ tokens=[item for item in tokens if item in self.decoder]
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
def tokenize(self, text):
tokens = []
diff --git a/train/text_train.py b/train/text_train.py
index 0b77cf6..5d6211b 100644
--- a/train/text_train.py
+++ b/train/text_train.py
@@ -249,17 +249,18 @@ class Trainer(TrainBase):
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
- bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
+ bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
- word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
+ word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
#clean state
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
final_adverse = []
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
+ # print(texts)
for i, text in enumerate(texts):
- important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
+ important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
words, sub_words, keys = self._tokenize(text)
final_words = copy.deepcopy(words)
@@ -315,7 +316,8 @@ class Trainer(TrainBase):
times += 1
print(times)
image.float()
- raw_text=[self.clip_tokenizer.decode(token) for token in text]
+
+ raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
diff --git a/utils/get_args.py b/utils/get_args.py
index bb23590..36c1518 100644
--- a/utils/get_args.py
+++ b/utils/get_args.py
@@ -16,8 +16,9 @@ def get_args():
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
- # parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
- # parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
+ parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
+ parser.add_argument("--topk", type=int, default=10)
+ parser.add_argument("--num-perturbation", type=int, default=3)
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100)