Merge 09af741334 into dcba3cb2e2
This commit is contained in:
commit
f902900c2d
|
|
@ -154,15 +154,16 @@ class ModifiedResNet(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.inner_layernorm = nn.LayerNorm(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
ret = self.inner_layernorm(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
|
|
|||
Loading…
Reference in New Issue