From 7671b4496ef8c5565e2687435021cbff6dd010bd Mon Sep 17 00:00:00 2001 From: leewlving Date: Thu, 18 Jan 2024 10:17:23 +0800 Subject: [PATCH] add cw loss --- targetedAttack.py | 44 ++++++++++++++++++++++++++++---------------- untargetedAttack.py | 27 +++++++++++---------------- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/targetedAttack.py b/targetedAttack.py index 47d99aa..77a7d5a 100644 --- a/targetedAttack.py +++ b/targetedAttack.py @@ -19,7 +19,26 @@ import hydra device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +def generate_labels(labels,NUM_CLASSES=307): + targets = torch.zeros_like(labels) + for i in range(len(labels)): + rand_v = torch.randint(0, NUM_CLASSES-1) + while labels[i]==rand_v: + rand_v = torch.randint(0, NUM_CLASSES-1) + targets[i] = rand_v + # targets = targets.astype(np.int32) + + return targets + +def cw_loss(outputs, labels): + one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels] + + # find the max logit other than the target class + other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0] + # get the target class's logit + real = torch.max(one_hot_labels * outputs, dim=1)[0] + return torch.clamp((other - real), min=-0.) def get_stylegan_generator(cfg): generator=Generator(1024, 512, 8) @@ -53,13 +72,7 @@ def main(cfg: DictConfig) -> None: max_loss=nn.MarginRankingLoss(0.1) clip_loss=CLIPLoss().to(device) loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) - # vgg_loss=VggLoss().to(device) -# summary(model, input_size = (3, 256, 256), batch_size = 5) -# set_requires_grad(model.mlp.parameters()) - # for p in (model.mlp.parameters()): - # p.requires_grad =True - # for p in (model.noise_mlp.parameters()): - # p.requires_grad =True + model.train() optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() @@ -73,22 +86,21 @@ def main(cfg: DictConfig) -> None: for i, (inputs, labels,detail_code) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) + target=generate_labels(labels) # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) - generated_img,adv_latent_codes=model(inputs,detail_code) - # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) + generated_img,adv_latent_codes=model(inputs,detail_code,target) + optimizer.zero_grad() - # generated_img,adv_latent_codes=model(inputs) - # loss_vgg=vgg_loss(inputs,generated_img) - # loss_l1=F.l1_loss(base_code,adv_latent_codes) loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) - clean_outputs=classifier(inputs) - adv_preds=criterion(adv_outputs,labels) - clean_preds=criterion(clean_outputs,labels) - loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) + loss_classifier=cw_loss(adv_outputs,target) + # clean_outputs=classifier(inputs) + # adv_preds=criterion(adv_outputs,labels) + # clean_preds=criterion(clean_outputs,labels) + # loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) # _, preds = torch.max(classifier(generated_img), 1) # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) diff --git a/untargetedAttack.py b/untargetedAttack.py index ea62f4d..02249e6 100644 --- a/untargetedAttack.py +++ b/untargetedAttack.py @@ -31,6 +31,15 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # mean_latent = g_ema.mean_latent(4096) # return g_ema, mean_latent +def cw_loss(outputs, labels): + one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels] + + # find the max logit other than the target class + other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0] + # get the target class's logit + real = torch.max(one_hot_labels * outputs, dim=1)[0] + return torch.clamp((real - other), min=-0.) + @@ -59,10 +68,7 @@ def main(cfg: DictConfig) -> None: # vgg_loss=VggLoss().to(device) # summary(model, input_size = (3, 256, 256), batch_size = 5) # set_requires_grad(model.mlp.parameters()) - for p in (model.mlp.parameters()): - p.requires_grad =True - for p in (model.noise_mlp.parameters()): - p.requires_grad =True + optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() @@ -79,23 +85,12 @@ def main(cfg: DictConfig) -> None: detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) generated_img,adv_latent_codes=model(inputs,detail_code) - # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() - # generated_img,adv_latent_codes=model(inputs) - # loss_vgg=vgg_loss(inputs,generated_img) - # loss_l1=F.l1_loss(base_code,adv_latent_codes) loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) - clean_outputs=classifier(inputs) - adv_preds=criterion(adv_outputs,labels) - clean_preds=criterion(clean_outputs,labels) - loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) - # _, preds = torch.max(classifier(generated_img), 1) - # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) + loss_classifier=cw_loss(adv_outputs,labels) -# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier -# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step()