Source code for common.optim.dl.litmodule.nnmodule.mlp

from dataclasses import dataclass
from typing import Annotated as An

from einops import rearrange
from jaxtyping import Float
from omegaconf import MISSING
from torch import Tensor, nn

from common.utils.beartype import ge, lt


[docs] @dataclass class MLPConfig: dims: list[int] = MISSING p_dropout: An[float, ge(0), lt(1)] = 0.0
[docs] class MLP(nn.Module): def __init__( self: "MLP", config: MLPConfig, activation_fn: nn.Module, ) -> None: super().__init__() self.model = nn.Sequential() for i in range(len(config.dims) - 1): self.model.add_module( name=f"fc_{i}", module=nn.Linear(config.dims[i], config.dims[i + 1]), ) if i < len(config.dims) - 2: self.model.add_module(name=f"act_{i}", module=activation_fn) if config.p_dropout: # > 0.0: self.model.add_module( name=f"drop_{i}", module=nn.Dropout(config.p_dropout), )
[docs] def forward( self: "MLP", x: Float[Tensor, " batch_size *d_input"], ) -> Float[Tensor, " batch_size output_size"]: out: Float[Tensor, " batch_size flattened_d_input"] = rearrange( x, "batch_size ... -> batch_size (...)", ) out: Float[Tensor, " batch_size output_size"] = self.model(out) return out