diff --git a/mlp_mix.py b/mlp_mix.py new file mode 100644 index 0000000..5d13711 --- /dev/null +++ b/mlp_mix.py @@ -0,0 +1,92 @@ +import typing as tp + +import jax +import jax.numpy as jnp +from einops import rearrange +from functools import partial + + + + +from flax import nnx + +class FeedForward(nnx.Module): + def __init__(self, dim, hidden_dim, dropout , rngs: nnx.Rngs): + self.net=nnx.Sequential( + nnx.Linear(dim, hidden_dim , rngs=rngs), + partial(nnx.gelu), + nnx.Dropout(dropout , rngs=rngs), + nnx.Linear(hidden_dim, dim , rngs=rngs), + nnx.Dropout(dropout , rngs=rngs) + ) + + + def __call__(self, x): + return self.net(x) + + +class MixerBlock(nnx.Module): + + def __init__(self, dim, num_patch, token_dim, channel_dim, dropout , rngs: nnx.Rngs): + super().__init__() + self.ln1=nnx.LayerNorm(dim, rngs=rngs) + self.ffn1=FeedForward(num_patch,token_dim,dropout,rngs=rngs) + + self.ln2=nnx.LayerNorm(dim, rngs=rngs) + self.ffn2=FeedForward(dim, channel_dim, dropout, rngs=rngs) + + + + def __call__(self, x): + # print(x.shape) + x = x + self.ffn1(self.ln1(x)) + + x = x + self.ffn2(self.ln2(x)) + + return x + +class MLPMixer(nnx.Module): + + def __init__(self, in_channels, dim, num_classes, patch_size,dropout, image_size, depth, token_dim, channel_dim, rngs: nnx.Rngs): + super().__init__() + + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + self.num_patch = (image_size// patch_size) ** 2 + + self.to_patch_embedding = nnx.Sequential( + nnx.Conv(in_channels, dim, kernel_size=(patch_size, patch_size), rngs=rngs), + ) + self.mixer_blocks=[] + + for _ in range(depth): + self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim,dropout, rngs=rngs)) + + self.layer_norm = nnx.LayerNorm(dim, rngs=rngs) + + self.mlp_head = nnx.Sequential( + nnx.Linear(dim, num_classes, rngs=rngs) + ) + + def __call__(self, x): + + + x = self.to_patch_embedding(x) + + for mixer_block in self.mixer_blocks: + x = mixer_block(x) + + x = self.layer_norm(x) + + x = jnp.mean(x, axis=1) + + return self.mlp_head(x) + +if __name__ == "__main__": + img = jnp.ones([1, 3, 224, 224]) + + model = MLPMixer(in_channels=3, image_size=224, patch_size=16,dropout=0.2, num_classes=1000, + dim=512, depth=8, token_dim=256, channel_dim=2048,rngs=nnx.Rngs(0)) + # nnx.display(model) + out_img = model(jnp.ones((1, 224, 224,3))) + + print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size] \ No newline at end of file