remove function def for torchscript conversion
This commit is contained in:
parent
3aff74026b
commit
08fd725374
|
|
@ -136,15 +136,11 @@ class ModifiedResNet(nn.Module):
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
def stem(x):
|
|
||||||
x = self.relu1(self.bn1(self.conv1(x)))
|
|
||||||
x = self.relu2(self.bn2(self.conv2(x)))
|
|
||||||
x = self.relu3(self.bn3(self.conv3(x)))
|
|
||||||
x = self.avgpool(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
x = x.type(self.conv1.weight.dtype)
|
x = x.type(self.conv1.weight.dtype)
|
||||||
x = stem(x)
|
x = self.relu1(self.bn1(self.conv1(x)))
|
||||||
|
x = self.relu2(self.bn2(self.conv2(x)))
|
||||||
|
x = self.relu3(self.bn3(self.conv3(x)))
|
||||||
|
x = self.avgpool(x)
|
||||||
x = self.layer1(x)
|
x = self.layer1(x)
|
||||||
x = self.layer2(x)
|
x = self.layer2(x)
|
||||||
x = self.layer3(x)
|
x = self.layer3(x)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue