This commit is contained in:
Alireza Davoudi 2024-06-11 15:40:36 +03:00 committed by GitHub
commit 3ef827c453
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -159,7 +159,8 @@ class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
ret = F.layer_norm(
x.type(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return ret.type(orig_type)