From c10ee107ccc640442aaef73ade2cc38806507737 Mon Sep 17 00:00:00 2001 From: leewlving Date: Sun, 10 Dec 2023 19:46:43 +0800 Subject: [PATCH] ss --- GanAttack.py | 3 ++- model.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/GanAttack.py b/GanAttack.py index bbe189e..47f5538 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -13,7 +13,8 @@ from utils import get_model,set_requires_grad,unnormalize # from models import GanAttack # from pixel2style2pixel.scripts.align_all_parallel import align_face # from pixel2style2pixel.models.stylegan2.model import Generator -from model import GanAttack,CLIPLoss,VggLoss,get_prompt +from model import GanAttack,CLIPLoss,VggLoss +from prompt import get_prompt import torch.nn.functional as F import time # from torchsummary import summary diff --git a/model.py b/model.py index 372d952..0956c1d 100644 --- a/model.py +++ b/model.py @@ -9,6 +9,12 @@ from pixel2style2pixel.models.psp import pSp from argparse import Namespace from utils import normalize +import hydra +from omegaconf import DictConfig, OmegaConf +import sys +import os +sys.path.append('./pixel2style2pixel') + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -110,4 +116,16 @@ class VggLoss(torch.nn.Module): feature1=torch.flatten(feature1) feature2=torch.flatten(feature2) similarity = F.cosine_similarity(feature1,feature2) - return similarity \ No newline at end of file + return similarity + +@hydra.main(version_base=None, config_path="./config", config_name="config") +def test(cfg): + prompt=torch.randn([1,1024]).to(device) + model=GanAttack(cfg,prompt).to(device) + data=torch.randn([2,3,256,256]).to(device) + result_images,x=model(data) + print(result_images.shape) + print(x.shape) +if __name__ == "__main__": + test() + \ No newline at end of file