fixed layer_norm

fixed error that occurs for some models
This commit is contained in:
hem9984 2024-07-22 21:30:24 -04:00 committed by GitHub
parent dcba3cb2e2
commit 09af741334
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -154,15 +154,16 @@ class ModifiedResNet(nn.Module):
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
class LayerNorm(nn.Module):
def __init__(self, *args, **kwargs):
super(LayerNorm, self).__init__()
self.inner_layernorm = nn.LayerNorm(*args, **kwargs)
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
ret = self.inner_layernorm(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)