Update model.py
This commit is contained in:
parent
f9bc83733b
commit
f77e7c2467
|
|
@ -159,7 +159,7 @@ class LayerNorm(nn.LayerNorm):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super(LayerNorm).forward(x.type(torch.float32))
|
||||
ret = super(LayerNorm, self).forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue