GanAttack/model.py

29 lines
660 B
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import os
import clip
# class GanAttack(nn.Module):
class GanAttack(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out