Merge 09af741334 into dcba3cb2e2
This commit is contained in:
commit
f902900c2d
|
|
@ -154,15 +154,16 @@ class ModifiedResNet(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.Module):
|
||||||
"""Subclass torch's LayerNorm to handle fp16."""
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.inner_layernorm = nn.LayerNorm(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
orig_type = x.dtype
|
orig_type = x.dtype
|
||||||
ret = super().forward(x.type(torch.float32))
|
ret = self.inner_layernorm(x.type(torch.float32))
|
||||||
return ret.type(orig_type)
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
class QuickGELU(nn.Module):
|
class QuickGELU(nn.Module):
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue