29 lines
660 B
Python
29 lines
660 B
Python
import torch
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import torch.nn as nn
|
|
import os
|
|
import clip
|
|
|
|
|
|
# class GanAttack(nn.Module):
|
|
|
|
class GanAttack(nn.Module):
|
|
def __init__(self, kernel, factor=2):
|
|
super().__init__()
|
|
|
|
self.factor = factor
|
|
kernel = make_kernel(kernel) * (factor ** 2)
|
|
self.register_buffer('kernel', kernel)
|
|
|
|
p = kernel.shape[0] - factor
|
|
|
|
pad0 = (p + 1) // 2 + factor - 1
|
|
pad1 = p // 2
|
|
|
|
self.pad = (pad0, pad1)
|
|
|
|
def forward(self, input):
|
|
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
|
|
|
return out |