add correct

This commit is contained in:
leewlving 2024-01-18 17:47:08 +08:00
parent 0f7414bc5f
commit a3b086bd19
3 changed files with 82 additions and 32 deletions

View File

@ -84,24 +84,12 @@ class LabelEncoder(nn.Module):
class TargetedGanAttack(nn.Module):
def __init__(self, cfg,prompt):
super().__init__()
# self.net= get_stylegan_inverter(cfg)
# self.generator.eval()
# self.inverter=inverter
# self.images_resize=images_resize
# self.generator=get_stylegan_generator(cfg)
self.prompt=prompt
text_len=self.prompt.shape[1]
# self.mlp=nn.Sequential(
# nn.Linear(text_len+512, 4096),
# nn.ReLU(inplace=True),
# nn.Linear(4096, 512)
# )
self.mlp=nn.LazyLinear(4096)
self.mlp=nn.LazyLinear(1024)
self.label_encoder=LabelEncoder()
encoder_lis = [
# MNIST:1*28*28
nn.Conv2d(29, 32, kernel_size=3, stride=1, padding=0, bias=True),
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(32),
nn.ReLU(),
# 8*26*26
@ -150,9 +138,76 @@ class TargetedGanAttack(nn.Module):
mixed_feature=self.encoder(mixed_feature)
mixed_feature=self.bottle_neck(mixed_feature)
mixed_feature=self.decoder(mixed_feature)
output=torch.chunk(mixed_feature,chunks=19,dim=1)
style=torch.cat(output[:18],dim=1)
noisy=output[-1]
return style, noisy
return mixed_feature
class UnTargetedGanAttack(nn.Module):
def __init__(self, prompt):
super().__init__()
self.prompt=prompt
self.mlp=nn.LazyLinear(1024)
self.label_encoder=LabelEncoder()
encoder_lis = [
# MNIST:1*28*28
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(32),
nn.ReLU(),
# 8*26*26
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(64),
nn.ReLU(),
# 16*12*12
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(128),
nn.ReLU(),
# 32*5*5
]
bottle_neck_lis = [ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),]
decoder_lis = [
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(64),
nn.ReLU(),
# state size. 16 x 11 x 11
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(32),
nn.ReLU(),
# state size. 8 x 23 x 23
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
nn.Tanh()
# state size. image_nc x 28 x 28
]
self.encoder = nn.Sequential(*encoder_lis)
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
self.decoder = nn.Sequential(*decoder_lis)
def forward(self, detailcode,noise,label):
label_feat=self.label_encoder(label)
batch_size=detailcode.shape[0]
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([detailcode,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
# noise_temp.append(noise[7].squeeze())
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
# x=x_prompt+x
mixed_feature=self.encoder(mixed_feature)
mixed_feature=self.bottle_neck(mixed_feature)
mixed_feature=self.decoder(mixed_feature)
output=torch.chunk(mixed_feature,chunks=19,dim=1)
style=torch.cat(output[:18],dim=1)
noisy=output[-1]
return style, noisy
class CLIPLoss(torch.nn.Module):
def __init__(self):
@ -212,6 +267,8 @@ class ResnetBlock(nn.Module):
out = x + self.conv_block(x)
return out
# @hydra.main(version_base=None, config_path="./config", config_name="config")
# def test(cfg):
# prompt=torch.randn([1,1024]).to(device)

View File

@ -27,8 +27,6 @@ def generate_labels(labels,NUM_CLASSES=307):
while labels[i]==rand_v:
rand_v = torch.randint(0, NUM_CLASSES-1)
targets[i] = rand_v
# targets = targets.astype(np.int32)
return targets
def cw_loss(outputs, labels):
@ -58,7 +56,7 @@ def main(cfg: DictConfig) -> None:
classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
classifier.eval()
g_ema, _=get_stylegan_generator(cfg)
g_ema=get_stylegan_generator(cfg)
prompt=get_prompt(cfg)
# net=get_stylegan_inverter(cfg)
@ -87,25 +85,20 @@ def main(cfg: DictConfig) -> None:
inputs = inputs.to(device)
labels = labels.to(device)
target=generate_labels(labels)
noise=g_ema.make_noise()
# base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,detail_code,target)
style, noisy=model(inputs,detail_code,target)
detail=torch.clamp(style,-cfg.epsilon,cfg.spsilon)+detail_code
noisy=[ torch.clamp(single, -cfg.noise_epsilon,cfg.noise_epsilon) for single in noisy ]
noise[:8]=noisy
generated_img=g_ema(detail,input_is_latent=True,noise=noise,randomize_noise=False)
optimizer.zero_grad()
loss_vgg=loss_fn_vgg(inputs,generated_img)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
loss_classifier=cw_loss(adv_outputs,target)
# clean_outputs=classifier(inputs)
# adv_preds=criterion(adv_outputs,labels)
# clean_preds=criterion(clean_outputs,labels)
# loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
# _, preds = torch.max(classifier(generated_img), 1)
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
loss.backward()
optimizer.step()

View File

@ -7,7 +7,7 @@ from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
from model import GanAttack,CLIPLoss
from model import UnTargetedGanAttack,CLIPLoss
import lpips
from prompt import get_prompt
import torch.nn.functional as F
@ -56,7 +56,7 @@ def main(cfg: DictConfig) -> None:
# net=get_stylegan_inverter(cfg)
model=GanAttack(cfg,prompt).to(device)
model=UnTargetedGanAttack(cfg,prompt).to(device)