advclip/model/spectral_norm.py

89 lines
2.8 KiB
Python

import torch
from torch.nn import Parameter
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(object):
def __init__(self):
self.name = "weight"
#print(self.name)
self.power_iterations = 1
def compute_weight(self, module):
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
w = getattr(module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(
torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
return w / sigma.expand_as(w)
@staticmethod
def apply(module):
name = "weight"
fn = SpectralNorm()
try:
u = getattr(module, name + "_u")
v = getattr(module, name + "_v")
w = getattr(module, name + "_bar")
except AttributeError:
w = getattr(module, name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1),
requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
w_bar = Parameter(w.data)
#del module._parameters[name]
module.register_parameter(name + "_u", u)
module.register_parameter(name + "_v", v)
module.register_parameter(name + "_bar", w_bar)
# remove w from parameter list
del module._parameters[name]
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_u']
del module._parameters[self.name + '_v']
del module._parameters[self.name + '_bar']
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module):
SpectralNorm.apply(module)
return module
def remove_spectral_norm(module):
name = 'weight'
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))