Merge 7451020bbb into dcba3cb2e2
This commit is contained in:
commit
3ef827c453
|
|
@ -159,7 +159,8 @@ class LayerNorm(nn.LayerNorm):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
ret = F.layer_norm(
|
||||
x.type(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue