Source code for common.optim.dl.litmodule.classification.base

""":class:`.BaseClassificationLitModule` & its config."""

from abc import ABC, abstractmethod
from collections.abc import Callable  # noqa: TC003
from dataclasses import dataclass, field
from typing import Annotated as An
from typing import Any, final

import torch
import torch.nn.functional as f
from jaxtyping import Float, Int
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy

from common.optim.dl.litmodule import BaseLitModule, BaseLitModuleConfig
from common.utils.beartype import ge, one_of


[docs] @dataclass class BaseClassificationLitModuleConfig(BaseLitModuleConfig): """Holds :class:`BaseClassificationLitModule` config values. Args: num_classes wandb_columns """ num_classes: An[int, ge(2)] = 2 wandb_column_names: list[str] = field( default_factory=lambda: ["x", "y", "y_hat", "logits"], )
[docs] class BaseClassificationLitModule(BaseLitModule, ABC): """Base Classification ``LightningModule``. Ref: :class:`lightning.pytorch.core.LightningModule` Attributes: config (BaseClassificationLitModuleConfig) accuracy (torchmetrics.classification.MulticlassAccuracy) wandb_table (wandb.Table): A table to upload to W&B containing validation data. """ def __init__( self: "BaseClassificationLitModule", *args: Any, # noqa: ANN401 **kwargs: Any, # noqa: ANN401 ) -> None: super().__init__(*args, **kwargs) self.config: BaseClassificationLitModuleConfig self.accuracy = MulticlassAccuracy( num_classes=self.config.num_classes, ) self.to_wandb_media: Callable[..., Any] = lambda x: x @property @abstractmethod def wandb_media_x(self): # type: ignore[no-untyped-def] # noqa: ANN201 """Converts a tensor to a W&B media object."""
[docs] @final def step( self: "BaseClassificationLitModule", data: tuple[ Float[Tensor, " batch_size *x_dim"], Int[Tensor, " batch_size"], ], stage: An[str, one_of("train", "val", "test")], ) -> Float[Tensor, " "]: """Computes the model accuracy and cross entropy loss. Args: data: A tuple ``(x, y)`` where ``x`` is the input data and ``y`` is the target data. stage: See :paramref:`~.BaseLitModule.stage_step.stage`. Returns: The cross entropy loss. """ x: Float[Tensor, " batch_size *x_dim"] = data[0] y: Int[Tensor, " batch_size"] = data[1] logits: Float[Tensor, " batch_size num_classes"] = self.nnmodule(x) y_hat: Int[Tensor, " batch_size"] = torch.argmax(input=logits, dim=1) accuracy: Float[Tensor, " "] = self.accuracy(preds=y_hat, target=y) self.log(name=f"{stage}/acc", value=accuracy) self.save_wandb_data(stage, x, y, y_hat, logits) return f.cross_entropy(input=logits, target=y)
[docs] @final def save_wandb_data( self: "BaseClassificationLitModule", stage: An[str, one_of("train", "val", "test")], x: Float[Tensor, " batch_size *x_dim"], y: Int[Tensor, " batch_size"], y_hat: Int[Tensor, " batch_size"], logits: Float[Tensor, " batch_size num_classes"], ) -> None: """Saves rich data to be logged to W&B. Args: stage x y y_hat logits: The raw `num_classes` network outputs. """ data = ( self.wandb_train_data if stage == "train" else self.wandb_val_data ) if data or self.global_rank != 0: return x, y, y_hat, logits = x.cpu(), y.cpu(), y_hat.cpu(), logits.cpu() for i in range(self.config.wandb_num_samples): index = ( self.curr_val_epoch * self.config.wandb_num_samples + i ) % len(x) data.append( { "x": self.wandb_media_x(x[index]), "y": y[index], "y_hat": y_hat[index], "logits": logits[index].tolist(), }, )