添加 mlp_mix.py

This commit is contained in:
liwenyun 2024-11-13 17:58:14 +08:00
parent fca1280172
commit 8bb9ef837f
1 changed files with 92 additions and 0 deletions

92
mlp_mix.py Normal file
View File

@ -0,0 +1,92 @@
import typing as tp
import jax
import jax.numpy as jnp
from einops import rearrange
from functools import partial
from flax import nnx
class FeedForward(nnx.Module):
def __init__(self, dim, hidden_dim, dropout , rngs: nnx.Rngs):
self.net=nnx.Sequential(
nnx.Linear(dim, hidden_dim , rngs=rngs),
partial(nnx.gelu),
nnx.Dropout(dropout , rngs=rngs),
nnx.Linear(hidden_dim, dim , rngs=rngs),
nnx.Dropout(dropout , rngs=rngs)
)
def __call__(self, x):
return self.net(x)
class MixerBlock(nnx.Module):
def __init__(self, dim, num_patch, token_dim, channel_dim, dropout , rngs: nnx.Rngs):
super().__init__()
self.ln1=nnx.LayerNorm(dim, rngs=rngs)
self.ffn1=FeedForward(num_patch,token_dim,dropout,rngs=rngs)
self.ln2=nnx.LayerNorm(dim, rngs=rngs)
self.ffn2=FeedForward(dim, channel_dim, dropout, rngs=rngs)
def __call__(self, x):
# print(x.shape)
x = x + self.ffn1(self.ln1(x))
x = x + self.ffn2(self.ln2(x))
return x
class MLPMixer(nnx.Module):
def __init__(self, in_channels, dim, num_classes, patch_size,dropout, image_size, depth, token_dim, channel_dim, rngs: nnx.Rngs):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
self.num_patch = (image_size// patch_size) ** 2
self.to_patch_embedding = nnx.Sequential(
nnx.Conv(in_channels, dim, kernel_size=(patch_size, patch_size), rngs=rngs),
)
self.mixer_blocks=[]
for _ in range(depth):
self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim,dropout, rngs=rngs))
self.layer_norm = nnx.LayerNorm(dim, rngs=rngs)
self.mlp_head = nnx.Sequential(
nnx.Linear(dim, num_classes, rngs=rngs)
)
def __call__(self, x):
x = self.to_patch_embedding(x)
for mixer_block in self.mixer_blocks:
x = mixer_block(x)
x = self.layer_norm(x)
x = jnp.mean(x, axis=1)
return self.mlp_head(x)
if __name__ == "__main__":
img = jnp.ones([1, 3, 224, 224])
model = MLPMixer(in_channels=3, image_size=224, patch_size=16,dropout=0.2, num_classes=1000,
dim=512, depth=8, token_dim=256, channel_dim=2048,rngs=nnx.Rngs(0))
# nnx.display(model)
out_img = model(jnp.ones((1, 224, 224,3)))
print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size]