This commit is contained in:
parent
02c8c5db1a
commit
c10ee107cc
|
|
@ -13,7 +13,8 @@ from utils import get_model,set_requires_grad,unnormalize
|
|||
# from models import GanAttack
|
||||
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
||||
# from pixel2style2pixel.models.stylegan2.model import Generator
|
||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||
from model import GanAttack,CLIPLoss,VggLoss
|
||||
from prompt import get_prompt
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
# from torchsummary import summary
|
||||
|
|
|
|||
18
model.py
18
model.py
|
|
@ -9,6 +9,12 @@ from pixel2style2pixel.models.psp import pSp
|
|||
from argparse import Namespace
|
||||
from utils import normalize
|
||||
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import sys
|
||||
import os
|
||||
sys.path.append('./pixel2style2pixel')
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
|
|
@ -111,3 +117,15 @@ class VggLoss(torch.nn.Module):
|
|||
feature2=torch.flatten(feature2)
|
||||
similarity = F.cosine_similarity(feature1,feature2)
|
||||
return similarity
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
def test(cfg):
|
||||
prompt=torch.randn([1,1024]).to(device)
|
||||
model=GanAttack(cfg,prompt).to(device)
|
||||
data=torch.randn([2,3,256,256]).to(device)
|
||||
result_images,x=model(data)
|
||||
print(result_images.shape)
|
||||
print(x.shape)
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
Loading…
Reference in New Issue