add setter for dtype
This commit is contained in:
parent
b46f5ac758
commit
61ae1d1e31
|
|
@ -296,6 +296,7 @@ class CLIP(nn.Module):
|
|||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
self._dtype = self.visual.conv1.weight.dtype
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
|
|
@ -336,7 +337,11 @@ class CLIP(nn.Module):
|
|||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, value):
|
||||
self._dtype = value
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
|
|
|||
Loading…
Reference in New Issue