This commit is contained in:
hem9984 2024-07-22 21:30:47 -04:00 committed by GitHub
commit f902900c2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -154,15 +154,16 @@ class ModifiedResNet(nn.Module):
return x return x
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.Module):
"""Subclass torch's LayerNorm to handle fp16.""" def __init__(self, *args, **kwargs):
super(LayerNorm, self).__init__()
self.inner_layernorm = nn.LayerNorm(*args, **kwargs)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
orig_type = x.dtype orig_type = x.dtype
ret = super().forward(x.type(torch.float32)) ret = self.inner_layernorm(x.type(torch.float32))
return ret.type(orig_type) return ret.type(orig_type)
class QuickGELU(nn.Module): class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)