remove non-exist key in edncoder
This commit is contained in:
parent
1b7f952c39
commit
ebc17c80dc
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="torch2" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="17">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="jax" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="pyyaml" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="flax" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="tensorflow" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="tensorboard" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="jaxlib" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="opencv-python" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="Pillow" />
|
||||||
|
<item index="8" class="java.lang.String" itemvalue="transformers" />
|
||||||
|
<item index="9" class="java.lang.String" itemvalue="timm" />
|
||||||
|
<item index="10" class="java.lang.String" itemvalue="ruamel_yaml" />
|
||||||
|
<item index="11" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="12" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
<item index="13" class="java.lang.String" itemvalue="pandas" />
|
||||||
|
<item index="14" class="java.lang.String" itemvalue="scipy" />
|
||||||
|
<item index="15" class="java.lang.String" itemvalue="tqdm" />
|
||||||
|
<item index="16" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="torch2" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="torch2" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/advclip.iml" filepath="$PROJECT_DIR$/.idea/advclip.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
2
main.py
2
main.py
|
|
@ -1,4 +1,4 @@
|
||||||
from train.hash_train import Trainer
|
from train.text_train import Trainer
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,12 @@ class SimpleTokenizer(object):
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def my_decode(self, tokens):
|
||||||
|
tokens=[item for item in tokens if item in self.decoder]
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
|
return text
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
|
||||||
|
|
@ -249,17 +249,18 @@ class Trainer(TrainBase):
|
||||||
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||||
|
|
||||||
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
|
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
||||||
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
|
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
|
||||||
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
|
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
|
||||||
|
|
||||||
#clean state
|
#clean state
|
||||||
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||||||
final_adverse = []
|
final_adverse = []
|
||||||
|
|
||||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||||
|
# print(texts)
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
|
important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
||||||
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
||||||
words, sub_words, keys = self._tokenize(text)
|
words, sub_words, keys = self._tokenize(text)
|
||||||
final_words = copy.deepcopy(words)
|
final_words = copy.deepcopy(words)
|
||||||
|
|
@ -315,7 +316,8 @@ class Trainer(TrainBase):
|
||||||
times += 1
|
times += 1
|
||||||
print(times)
|
print(times)
|
||||||
image.float()
|
image.float()
|
||||||
raw_text=[self.clip_tokenizer.decode(token) for token in text]
|
|
||||||
|
raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,9 @@ def get_args():
|
||||||
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
||||||
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
||||||
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
||||||
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
|
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
|
||||||
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
|
parser.add_argument("--topk", type=int, default=10)
|
||||||
|
parser.add_argument("--num-perturbation", type=int, default=3)
|
||||||
parser.add_argument("--txt-dim", type=int, default=1024)
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=512)
|
parser.add_argument("--output-dim", type=int, default=512)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue