21 lines
496 B
Python
21 lines
496 B
Python
from torch import nn
|
|
|
|
|
|
class LatentCodesDiscriminator(nn.Module):
|
|
def __init__(self, style_dim, n_mlp):
|
|
super().__init__()
|
|
|
|
self.style_dim = style_dim
|
|
|
|
layers = []
|
|
for i in range(n_mlp-1):
|
|
layers.append(
|
|
nn.Linear(style_dim, style_dim)
|
|
)
|
|
layers.append(nn.LeakyReLU(0.2))
|
|
layers.append(nn.Linear(512, 1))
|
|
self.mlp = nn.Sequential(*layers)
|
|
|
|
def forward(self, w):
|
|
return self.mlp(w)
|