remove non-exist key in edncoder

This commit is contained in:
Li Wenyun 2024-06-26 16:48:08 +08:00
parent 1b7f952c39
commit ebc17c80dc
11 changed files with 88 additions and 7 deletions

3
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
# Default ignored files
/shelf/
/workspace.xml

12
.idea/advclip.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="torch2" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -0,0 +1,30 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="17">
<item index="0" class="java.lang.String" itemvalue="jax" />
<item index="1" class="java.lang.String" itemvalue="pyyaml" />
<item index="2" class="java.lang.String" itemvalue="flax" />
<item index="3" class="java.lang.String" itemvalue="tensorflow" />
<item index="4" class="java.lang.String" itemvalue="tensorboard" />
<item index="5" class="java.lang.String" itemvalue="jaxlib" />
<item index="6" class="java.lang.String" itemvalue="opencv-python" />
<item index="7" class="java.lang.String" itemvalue="Pillow" />
<item index="8" class="java.lang.String" itemvalue="transformers" />
<item index="9" class="java.lang.String" itemvalue="timm" />
<item index="10" class="java.lang.String" itemvalue="ruamel_yaml" />
<item index="11" class="java.lang.String" itemvalue="torch" />
<item index="12" class="java.lang.String" itemvalue="torchvision" />
<item index="13" class="java.lang.String" itemvalue="pandas" />
<item index="14" class="java.lang.String" itemvalue="scipy" />
<item index="15" class="java.lang.String" itemvalue="tqdm" />
<item index="16" class="java.lang.String" itemvalue="numpy" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="torch2" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="torch2" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/advclip.iml" filepath="$PROJECT_DIR$/.idea/advclip.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -1,4 +1,4 @@
from train.hash_train import Trainer
from train.text_train import Trainer
if __name__ == "__main__":

View File

@ -131,6 +131,12 @@ class SimpleTokenizer(object):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
def my_decode(self, tokens):
tokens=[item for item in tokens if item in self.decoder]
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
def tokenize(self, text):
tokens = []
text = whitespace_clean(basic_clean(text)).lower()

View File

@ -249,17 +249,18 @@ class Trainer(TrainBase):
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
#clean state
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
final_adverse = []
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
# print(texts)
for i, text in enumerate(texts):
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
words, sub_words, keys = self._tokenize(text)
final_words = copy.deepcopy(words)
@ -315,7 +316,8 @@ class Trainer(TrainBase):
times += 1
print(times)
image.float()
raw_text=[self.clip_tokenizer.decode(token) for token in text]
raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])

View File

@ -16,8 +16,9 @@ def get_args():
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
parser.add_argument("--topk", type=int, default=10)
parser.add_argument("--num-perturbation", type=int, default=3)
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100)