Update model.py
This commit is contained in:
parent
d50d76daa6
commit
f9bc83733b
|
|
@ -159,7 +159,7 @@ class LayerNorm(nn.LayerNorm):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
ret = super(LayerNorm).forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue