92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import typing as tp
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from einops import rearrange
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
from flax import nnx
|
|
|
|
class FeedForward(nnx.Module):
|
|
def __init__(self, dim, hidden_dim, dropout , rngs: nnx.Rngs):
|
|
self.net=nnx.Sequential(
|
|
nnx.Linear(dim, hidden_dim , rngs=rngs),
|
|
partial(nnx.gelu),
|
|
nnx.Dropout(dropout , rngs=rngs),
|
|
nnx.Linear(hidden_dim, dim , rngs=rngs),
|
|
nnx.Dropout(dropout , rngs=rngs)
|
|
)
|
|
|
|
|
|
def __call__(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class MixerBlock(nnx.Module):
|
|
|
|
def __init__(self, dim, num_patch, token_dim, channel_dim, dropout , rngs: nnx.Rngs):
|
|
super().__init__()
|
|
self.ln1=nnx.LayerNorm(dim, rngs=rngs)
|
|
self.ffn1=FeedForward(num_patch,token_dim,dropout,rngs=rngs)
|
|
|
|
self.ln2=nnx.LayerNorm(dim, rngs=rngs)
|
|
self.ffn2=FeedForward(dim, channel_dim, dropout, rngs=rngs)
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
# print(x.shape)
|
|
x = x + self.ffn1(self.ln1(x))
|
|
|
|
x = x + self.ffn2(self.ln2(x))
|
|
|
|
return x
|
|
|
|
class MLPMixer(nnx.Module):
|
|
|
|
def __init__(self, in_channels, dim, num_classes, patch_size,dropout, image_size, depth, token_dim, channel_dim, rngs: nnx.Rngs):
|
|
super().__init__()
|
|
|
|
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
|
self.num_patch = (image_size// patch_size) ** 2
|
|
|
|
self.to_patch_embedding = nnx.Sequential(
|
|
nnx.Conv(in_channels, dim, kernel_size=(patch_size, patch_size), rngs=rngs),
|
|
)
|
|
self.mixer_blocks=[]
|
|
|
|
for _ in range(depth):
|
|
self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim,dropout, rngs=rngs))
|
|
|
|
self.layer_norm = nnx.LayerNorm(dim, rngs=rngs)
|
|
|
|
self.mlp_head = nnx.Sequential(
|
|
nnx.Linear(dim, num_classes, rngs=rngs)
|
|
)
|
|
|
|
def __call__(self, x):
|
|
|
|
|
|
x = self.to_patch_embedding(x)
|
|
|
|
for mixer_block in self.mixer_blocks:
|
|
x = mixer_block(x)
|
|
|
|
x = self.layer_norm(x)
|
|
|
|
x = jnp.mean(x, axis=1)
|
|
|
|
return self.mlp_head(x)
|
|
|
|
if __name__ == "__main__":
|
|
img = jnp.ones([1, 3, 224, 224])
|
|
|
|
model = MLPMixer(in_channels=3, image_size=224, patch_size=16,dropout=0.2, num_classes=1000,
|
|
dim=512, depth=8, token_dim=256, channel_dim=2048,rngs=nnx.Rngs(0))
|
|
# nnx.display(model)
|
|
out_img = model(jnp.ones((1, 224, 224,3)))
|
|
|
|
print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size] |