diff --git a/GanAttack.py b/GanAttack.py index ff69b93..4335c6b 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -18,6 +18,7 @@ from pixel2style2pixel.models.psp import pSp from model import GanAttack,CLIPLoss,VggLoss,get_prompt import torch.nn.functional as F import time +# from torchsummary import summary import hydra @@ -64,7 +65,7 @@ def main(cfg: DictConfig) -> None: classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() # g_ema, _=get_stylegan_generator(cfg) - prompt=get_prompt(cfg.prompt) + prompt=get_prompt(cfg) net=get_stylegan_inverter(cfg) model=GanAttack(net,prompt).to(device) @@ -75,9 +76,11 @@ def main(cfg: DictConfig) -> None: criterion = nn.CrossEntropyLoss() max_loss=nn.MarginRankingLoss(0.1) clip_loss=CLIPLoss().to(device) - vgg_loss=VggLoss().to(device) - - set_requires_grad(model.mlp.parameters()) +# vgg_loss=VggLoss().to(device) +# summary(model, input_size = (3, 256, 256), batch_size = 5) +# set_requires_grad(model.mlp.parameters()) + for p in (model.mlp.parameters()): + p.requires_grad =True optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() @@ -95,7 +98,7 @@ def main(cfg: DictConfig) -> None: optimizer.zero_grad() generated_img,adv_latent_codes=model(inputs) loss_vgg=vgg_loss(inputs,generated_img) - loss_l1=F.l1_loss(codes,adv_latent_codes) +# loss_l1=F.l1_loss(codes,adv_latent_codes) loss_clip=clip_loss(generated_img,prompt) outputs = classifier(generated_img) preds=criterion(outputs,labels) @@ -103,7 +106,9 @@ def main(cfg: DictConfig) -> None: # _, preds = torch.max(classifier(generated_img), 1) # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) - loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier +# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier +# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier + loss=loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() diff --git a/config/config.yaml b/config/config.yaml index bf93078..c03f642 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,17 +1,17 @@ -dataset: gender_dataset +dataset: identity_dataset classifier: model: resnet18 lr: 0.01 momentum: 0.9 - num_epochs: 25 + num_epochs: 20 num_workers : 4 batch_size: 64 paths: gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset - identity_dataset: ./dataset/CelebA_HQ_facial_identity_dataset - inverter_cfg: checkpoint/psp_ffhq_encode.pt + identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset + inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt classifier: checkpoint stylegan: checkpoint/stylegan2-ffhq-config-f.pt adv_embedding: pretrained_models @@ -23,7 +23,7 @@ prompt: red lipstick # 'Arched_Eyebrows', 'Bangs', 'Wearing_Earrings', 'Bags_Under_Eyes', 'Receding_Hairline', 'Pale_Skin'] optim: - batch_size: 8 + batch_size: 1 num_epochs: 200 num_workers : 4 images_resize: 256 diff --git a/model.py b/model.py index e1015dc..1feeee4 100644 --- a/model.py +++ b/model.py @@ -11,7 +11,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_prompt(cfg): - model, preprocess = clip.load("ViT-B/32", device=device) + model, preprocess = clip.load("RN50", device=device) prompt=cfg.prompt text=clip.tokenize(prompt).to(device) with torch.no_grad(): @@ -28,9 +28,9 @@ class GanAttack(nn.Module): # self.inverter=inverter # self.images_resize=images_resize self.prompt=prompt - text_len=self.prompt.shape[0] + text_len=self.prompt.shape[1] self.mlp=nn.Sequential( - nn.Linear(text_len+512, 4096), + nn.Linear(1024+512, 4096), nn.ReLU(inplace=True), nn.Linear(4096, 512) ) @@ -44,11 +44,15 @@ class GanAttack(nn.Module): # _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path) x=codes batch_size=img.shape[0] - prompt=self.prompt.repeat(batch_size).to(device) - x_prompt=torch.cat([codes,prompt],dim=1) +# print(img.shape) +# print(self.prompt.shape) +# print(codes.shape) + prompt=self.prompt.repeat(batch_size,18,1).to(device) + print(prompt.shape) + x_prompt=torch.cat([codes,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x - im,_=self.net.decoder([x], input_is_latent=True, randomize_noise=False, return_latents=False) + im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False) result_images = self.net.face_pool(im) return result_images,x @@ -56,7 +60,7 @@ class GanAttack(nn.Module): class CLIPLoss(torch.nn.Module): def __init__(self): super(CLIPLoss, self).__init__() - self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") + self.model, self.preprocess = clip.load("RN50", device="cuda") self.model.eval() self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) diff --git a/prompt.py b/prompt.py new file mode 100644 index 0000000..f69e1a5 --- /dev/null +++ b/prompt.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models +from omegaconf import DictConfig, OmegaConf +from data.dataset import get_dataset,get_adv_dataset +from utils import get_model,set_requires_grad,unnormalize +import sys +import os +import clip +import torch.nn.functional as F +import time +import hydra +from scipy import io as spio + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +@hydra.main(version_base=None, config_path="./config", config_name="config") +def save_prompt(cfg): + model, preprocess = clip.load("RN50", device=device) + prompt=cfg.prompt + text=clip.tokenize(prompt).to(device) + with torch.no_grad(): + prompt = model.encode_text(text) + prompt=prompt.cpu().numpy() + spio.savemat("prompt.mat",{'prompt':prompt}) + +def get_prompt(cfg): + data=spio.loadmat("prompt.mat") + prompt=data['prompt'] + return torch.from_numpy(prompt).to(device) + + + +if __name__ == "__main__": + save_prompt() \ No newline at end of file