update GanAttack

This commit is contained in:
leewlving 2023-12-06 20:59:04 +08:00
parent cf67f37cd3
commit a6080d42a7
1 changed files with 6 additions and 4 deletions

View File

@ -97,9 +97,12 @@ def main(cfg: DictConfig) -> None:
loss_vgg=vgg_loss(inputs,generated_img) loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(codes,adv_latent_codes) loss_l1=F.l1_loss(codes,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt) loss_clip=clip_loss(generated_img,prompt)
outputs = classifier(generated_img)
preds=criterion(outputs,labels)
loss_classifier=max_loss(torch.ones_like(preds),preds,-torch.ones_like(preds))
# _, preds = torch.max(classifier(generated_img), 1)
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
_, preds = torch.max(classifier(generated_img), 1)
loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -124,8 +127,7 @@ def main(cfg: DictConfig) -> None:
generated_img,adv_latent_codes=model(inputs) generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img) outputs = classifier(generated_img)
_, preds = torch.max(outputs, 1) preds=criterion(outputs ,labels)
loss = criterion(outputs, labels)
# running_loss += loss.item() * inputs.size(0) # running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) running_corrects += torch.sum(preds == labels.data)