add setter for dtype

This commit is contained in:
KaiyuYue 2022-06-23 22:33:24 -04:00 committed by GitHub
parent b46f5ac758
commit 61ae1d1e31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -296,6 +296,7 @@ class CLIP(nn.Module):
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
self._dtype = self.visual.conv1.weight.dtype
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
@ -336,7 +337,11 @@ class CLIP(nn.Module):
@property
def dtype(self):
return self.visual.conv1.weight.dtype
return self._dtype
@dtype.setter
def dtype(self, value):
self._dtype = value
def encode_image(self, image):
return self.visual(image.type(self.dtype))