89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
import torch
|
|
from torch.nn import Parameter
|
|
|
|
|
|
def l2normalize(v, eps=1e-12):
|
|
return v / (v.norm() + eps)
|
|
|
|
|
|
class SpectralNorm(object):
|
|
def __init__(self):
|
|
self.name = "weight"
|
|
#print(self.name)
|
|
self.power_iterations = 1
|
|
|
|
def compute_weight(self, module):
|
|
u = getattr(module, self.name + "_u")
|
|
v = getattr(module, self.name + "_v")
|
|
w = getattr(module, self.name + "_bar")
|
|
|
|
height = w.data.shape[0]
|
|
for _ in range(self.power_iterations):
|
|
v.data = l2normalize(
|
|
torch.mv(torch.t(w.view(height, -1).data), u.data))
|
|
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
|
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
|
sigma = u.dot(w.view(height, -1).mv(v))
|
|
return w / sigma.expand_as(w)
|
|
|
|
@staticmethod
|
|
def apply(module):
|
|
name = "weight"
|
|
fn = SpectralNorm()
|
|
|
|
try:
|
|
u = getattr(module, name + "_u")
|
|
v = getattr(module, name + "_v")
|
|
w = getattr(module, name + "_bar")
|
|
except AttributeError:
|
|
w = getattr(module, name)
|
|
height = w.data.shape[0]
|
|
width = w.view(height, -1).data.shape[1]
|
|
u = Parameter(w.data.new(height).normal_(0, 1),
|
|
requires_grad=False)
|
|
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
|
w_bar = Parameter(w.data)
|
|
|
|
#del module._parameters[name]
|
|
|
|
module.register_parameter(name + "_u", u)
|
|
module.register_parameter(name + "_v", v)
|
|
module.register_parameter(name + "_bar", w_bar)
|
|
|
|
# remove w from parameter list
|
|
del module._parameters[name]
|
|
|
|
setattr(module, name, fn.compute_weight(module))
|
|
|
|
# recompute weight before every forward()
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def remove(self, module):
|
|
weight = self.compute_weight(module)
|
|
delattr(module, self.name)
|
|
del module._parameters[self.name + '_u']
|
|
del module._parameters[self.name + '_v']
|
|
del module._parameters[self.name + '_bar']
|
|
module.register_parameter(self.name, Parameter(weight.data))
|
|
|
|
def __call__(self, module, inputs):
|
|
setattr(module, self.name, self.compute_weight(module))
|
|
|
|
|
|
def spectral_norm(module):
|
|
SpectralNorm.apply(module)
|
|
return module
|
|
|
|
|
|
def remove_spectral_norm(module):
|
|
name = 'weight'
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
|
name, module)) |