import typing as tp import jax import jax.numpy as jnp from einops import rearrange from functools import partial from flax import nnx class FeedForward(nnx.Module): def __init__(self, dim, hidden_dim, dropout , rngs: nnx.Rngs): self.net=nnx.Sequential( nnx.Linear(dim, hidden_dim , rngs=rngs), partial(nnx.gelu), nnx.Dropout(dropout , rngs=rngs), nnx.Linear(hidden_dim, dim , rngs=rngs), nnx.Dropout(dropout , rngs=rngs) ) def __call__(self, x): return self.net(x) class MixerBlock(nnx.Module): def __init__(self, dim, num_patch, token_dim, channel_dim, dropout , rngs: nnx.Rngs): super().__init__() self.ln1=nnx.LayerNorm(dim, rngs=rngs) self.ffn1=FeedForward(num_patch,token_dim,dropout,rngs=rngs) self.ln2=nnx.LayerNorm(dim, rngs=rngs) self.ffn2=FeedForward(dim, channel_dim, dropout, rngs=rngs) def __call__(self, x): # print(x.shape) x = x + self.ffn1(self.ln1(x)) x = x + self.ffn2(self.ln2(x)) return x class MLPMixer(nnx.Module): def __init__(self, in_channels, dim, num_classes, patch_size,dropout, image_size, depth, token_dim, channel_dim, rngs: nnx.Rngs): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' self.num_patch = (image_size// patch_size) ** 2 self.to_patch_embedding = nnx.Sequential( nnx.Conv(in_channels, dim, kernel_size=(patch_size, patch_size), rngs=rngs), ) self.mixer_blocks=[] for _ in range(depth): self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim,dropout, rngs=rngs)) self.layer_norm = nnx.LayerNorm(dim, rngs=rngs) self.mlp_head = nnx.Sequential( nnx.Linear(dim, num_classes, rngs=rngs) ) def __call__(self, x): x = self.to_patch_embedding(x) for mixer_block in self.mixer_blocks: x = mixer_block(x) x = self.layer_norm(x) x = jnp.mean(x, axis=1) return self.mlp_head(x) if __name__ == "__main__": img = jnp.ones([1, 3, 224, 224]) model = MLPMixer(in_channels=3, image_size=224, patch_size=16,dropout=0.2, num_classes=1000, dim=512, depth=8, token_dim=256, channel_dim=2048,rngs=nnx.Rngs(0)) # nnx.display(model) out_img = model(jnp.ones((1, 224, 224,3))) print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size]