diff --git a/clip/model.py b/clip/model.py index 232b779..e71bbc2 100644 --- a/clip/model.py +++ b/clip/model.py @@ -159,7 +159,7 @@ class LayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor): orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) + ret = super(LayerNorm).forward(x.type(torch.float32)) return ret.type(orig_type)