Update model.py

This commit is contained in:
Alireza Davoudi 2022-09-05 21:10:05 +04:30 committed by GitHub
parent f77e7c2467
commit 7451020bbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -159,7 +159,8 @@ class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super(LayerNorm, self).forward(x.type(torch.float32))
ret = F.layer_norm(
x.type(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return ret.type(orig_type)