This commit is contained in:
leewlving 2023-12-10 19:46:43 +08:00
parent 02c8c5db1a
commit c10ee107cc
2 changed files with 21 additions and 2 deletions

View File

@ -13,7 +13,8 @@ from utils import get_model,set_requires_grad,unnormalize
# from models import GanAttack # from models import GanAttack
# from pixel2style2pixel.scripts.align_all_parallel import align_face # from pixel2style2pixel.scripts.align_all_parallel import align_face
# from pixel2style2pixel.models.stylegan2.model import Generator # from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss,get_prompt from model import GanAttack,CLIPLoss,VggLoss
from prompt import get_prompt
import torch.nn.functional as F import torch.nn.functional as F
import time import time
# from torchsummary import summary # from torchsummary import summary

View File

@ -9,6 +9,12 @@ from pixel2style2pixel.models.psp import pSp
from argparse import Namespace from argparse import Namespace
from utils import normalize from utils import normalize
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import os
sys.path.append('./pixel2style2pixel')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -110,4 +116,16 @@ class VggLoss(torch.nn.Module):
feature1=torch.flatten(feature1) feature1=torch.flatten(feature1)
feature2=torch.flatten(feature2) feature2=torch.flatten(feature2)
similarity = F.cosine_similarity(feature1,feature2) similarity = F.cosine_similarity(feature1,feature2)
return similarity return similarity
@hydra.main(version_base=None, config_path="./config", config_name="config")
def test(cfg):
prompt=torch.randn([1,1024]).to(device)
model=GanAttack(cfg,prompt).to(device)
data=torch.randn([2,3,256,256]).to(device)
result_images,x=model(data)
print(result_images.shape)
print(x.shape)
if __name__ == "__main__":
test()