remove function def for torchscript conversion

This commit is contained in:
ofernandesumojo 2023-08-28 15:18:20 -05:00 committed by GitHub
parent 3aff74026b
commit 08fd725374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 8 deletions

View File

@ -136,15 +136,11 @@ class ModifiedResNet(nn.Module):
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)