This commit is contained in:
leewlving 2023-12-10 19:46:43 +08:00
parent 02c8c5db1a
commit c10ee107cc
2 changed files with 21 additions and 2 deletions

View File

@ -13,7 +13,8 @@ from utils import get_model,set_requires_grad,unnormalize
# from models import GanAttack
# from pixel2style2pixel.scripts.align_all_parallel import align_face
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
from model import GanAttack,CLIPLoss,VggLoss
from prompt import get_prompt
import torch.nn.functional as F
import time
# from torchsummary import summary

View File

@ -9,6 +9,12 @@ from pixel2style2pixel.models.psp import pSp
from argparse import Namespace
from utils import normalize
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import os
sys.path.append('./pixel2style2pixel')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -111,3 +117,15 @@ class VggLoss(torch.nn.Module):
feature2=torch.flatten(feature2)
similarity = F.cosine_similarity(feature1,feature2)
return similarity
@hydra.main(version_base=None, config_path="./config", config_name="config")
def test(cfg):
prompt=torch.randn([1,1024]).to(device)
model=GanAttack(cfg,prompt).to(device)
data=torch.randn([2,3,256,256]).to(device)
result_images,x=model(data)
print(result_images.shape)
print(x.shape)
if __name__ == "__main__":
test()