Skip to content

API Reference

Tracking Module

ExperimentTracker

researchlab.tracking.tracker.ExperimentTracker

A context manager for tracking research experiments with Git state.

This class automates the process of starting an MLflow run, capturing the current Git state (base commit and diff), and logging artifacts and parameters.

Example
from researchlab import ExperimentTracker

with ExperimentTracker(experiment_name="my_project") as tracker:
    tracker.log_config("config.yaml")
    # Your training logic here
    print(f"Run ID: {tracker.run_name}")

Attributes:

Name Type Description
run_name str

The unique, readable ID for this run.

active_run Optional[Run]

The currently active MLflow run.

Source code in src/researchlab/tracking/tracker.py
class ExperimentTracker:
    """A context manager for tracking research experiments with Git state.

    This class automates the process of starting an MLflow run, capturing the
    current Git state (base commit and diff), and logging artifacts and
    parameters.

    Example:
        ```python
        from researchlab import ExperimentTracker

        with ExperimentTracker(experiment_name="my_project") as tracker:
            tracker.log_config("config.yaml")
            # Your training logic here
            print(f"Run ID: {tracker.run_name}")
        ```

    Attributes:
        run_name (str): The unique, readable ID for this run.
        active_run (Optional[Run]): The currently active MLflow run.
    """

    def __init__(
        self,
        experiment_name: str,
        run_name: str | None = None,
        tracking_uri: str | None = None,
    ):
        """Initializes the ExperimentTracker.

        Args:
            experiment_name: The name of the MLflow experiment.
            run_name: Optional custom name for the run. If not provided,
                a readable ID is generated automatically.
            tracking_uri: Optional MLflow tracking URI to use.
        """
        if tracking_uri:
            mlflow.set_tracking_uri(tracking_uri)

        mlflow.set_experiment(experiment_name)

        self.run_name = run_name or generate_run_id()
        self.active_run = None

    def __enter__(self) -> "ExperimentTracker":
        """Starts the MLflow run and captures the Git state.

        Returns:
            ExperimentTracker: The instance of the tracker.
        """
        self.active_run = mlflow.start_run(run_name=self.run_name)

        # Capture and log Git state
        git_state = get_git_state()
        mlflow.set_tag("rlab.base_commit", git_state["base_commit"])
        mlflow.set_tag("rlab.run_id", self.run_name)

        # Save patch file as an artifact
        if git_state["patch"]:
            patch_path = Path("run.patch")
            with patch_path.open("w") as f:
                f.write(git_state["patch"])
            mlflow.log_artifact(str(patch_path))
            patch_path.unlink()

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Ends the MLflow run."""
        mlflow.end_run()

    def log_config(self, config_path: str):
        """Logs a YAML configuration file as parameters and an artifact.

        Recursively parses the YAML file and logs each entry as an MLflow
        parameter. Also uploads the file itself as an artifact.

        Args:
            config_path: Path to the YAML configuration file.
        """
        path = Path(config_path)
        if not path.exists():
            print(f"Warning: Config file {config_path} not found.")
            return

        with path.open() as f:
            try:
                config = yaml.safe_load(f)
                if isinstance(config, dict):
                    log_flattened_params(config)
                mlflow.log_artifact(str(path))
            except yaml.YAMLError as e:
                print(f"Error parsing YAML config: {e}")

__enter__()

Starts the MLflow run and captures the Git state.

Returns:

Name Type Description
ExperimentTracker ExperimentTracker

The instance of the tracker.

Source code in src/researchlab/tracking/tracker.py
def __enter__(self) -> "ExperimentTracker":
    """Starts the MLflow run and captures the Git state.

    Returns:
        ExperimentTracker: The instance of the tracker.
    """
    self.active_run = mlflow.start_run(run_name=self.run_name)

    # Capture and log Git state
    git_state = get_git_state()
    mlflow.set_tag("rlab.base_commit", git_state["base_commit"])
    mlflow.set_tag("rlab.run_id", self.run_name)

    # Save patch file as an artifact
    if git_state["patch"]:
        patch_path = Path("run.patch")
        with patch_path.open("w") as f:
            f.write(git_state["patch"])
        mlflow.log_artifact(str(patch_path))
        patch_path.unlink()

    return self

__exit__(exc_type, exc_val, exc_tb)

Ends the MLflow run.

Source code in src/researchlab/tracking/tracker.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Ends the MLflow run."""
    mlflow.end_run()

__init__(experiment_name, run_name=None, tracking_uri=None)

Initializes the ExperimentTracker.

Parameters:

Name Type Description Default
experiment_name str

The name of the MLflow experiment.

required
run_name str | None

Optional custom name for the run. If not provided, a readable ID is generated automatically.

None
tracking_uri str | None

Optional MLflow tracking URI to use.

None
Source code in src/researchlab/tracking/tracker.py
def __init__(
    self,
    experiment_name: str,
    run_name: str | None = None,
    tracking_uri: str | None = None,
):
    """Initializes the ExperimentTracker.

    Args:
        experiment_name: The name of the MLflow experiment.
        run_name: Optional custom name for the run. If not provided,
            a readable ID is generated automatically.
        tracking_uri: Optional MLflow tracking URI to use.
    """
    if tracking_uri:
        mlflow.set_tracking_uri(tracking_uri)

    mlflow.set_experiment(experiment_name)

    self.run_name = run_name or generate_run_id()
    self.active_run = None

log_config(config_path)

Logs a YAML configuration file as parameters and an artifact.

Recursively parses the YAML file and logs each entry as an MLflow parameter. Also uploads the file itself as an artifact.

Parameters:

Name Type Description Default
config_path str

Path to the YAML configuration file.

required
Source code in src/researchlab/tracking/tracker.py
def log_config(self, config_path: str):
    """Logs a YAML configuration file as parameters and an artifact.

    Recursively parses the YAML file and logs each entry as an MLflow
    parameter. Also uploads the file itself as an artifact.

    Args:
        config_path: Path to the YAML configuration file.
    """
    path = Path(config_path)
    if not path.exists():
        print(f"Warning: Config file {config_path} not found.")
        return

    with path.open() as f:
        try:
            config = yaml.safe_load(f)
            if isinstance(config, dict):
                log_flattened_params(config)
            mlflow.log_artifact(str(path))
        except yaml.YAMLError as e:
            print(f"Error parsing YAML config: {e}")

Utilities

researchlab.tracking.utils

find_run_by_rlab_id(run_id)

Finds an MLflow run by the custom rlab.run_id tag.

Parameters:

Name Type Description Default
run_id str

The readable rlab run ID to search for.

required

Returns:

Type Description
Run | None

Optional[Run]: The MLflow Run object if found, else None.

Source code in src/researchlab/tracking/utils.py
def find_run_by_rlab_id(run_id: str) -> Run | None:
    """Finds an MLflow run by the custom rlab.run_id tag.

    Args:
        run_id: The readable rlab run ID to search for.

    Returns:
        Optional[Run]: The MLflow Run object if found, else None.
    """
    runs = mlflow.search_runs(
        filter_string=f"tags.'rlab.run_id' = '{run_id}'",
        run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
        output_format="list",
    )
    return runs[0] if runs else None

generate_run_id()

Generates a unique and readable run ID.

Returns:

Name Type Description
str str

A string in the format YYYY-MM-DD_slug (e.g., 2026-02-15_radiant-octopus).

Source code in src/researchlab/tracking/utils.py
def generate_run_id() -> str:
    """Generates a unique and readable run ID.

    Returns:
        str: A string in the format YYYY-MM-DD_slug (e.g., 2026-02-15_radiant-octopus).
    """
    date_str = datetime.datetime.now().strftime("%Y-%m-%d")
    slug = coolname.generate_slug(2)
    return f"{date_str}_{slug}"

get_git_state(repo_path='.')

Captures the current Git state including base commit and patch.

This captures staged, unstaged, and untracked changes by temporarily using git's "intent-to-add" feature. It restores the untracked state after generating the patch.

Parameters:

Name Type Description Default
repo_path str

Path to the git repository. Defaults to ".".

'.'

Returns:

Type Description
dict[str, Any]

dict[str, Any]: A dictionary containing: - base_commit (str): The HEX SHA of the current HEAD. - patch (str): The git diff output. - is_dirty (bool): True if there are any changes (including untracked).

Source code in src/researchlab/tracking/utils.py
def get_git_state(repo_path: str = ".") -> dict[str, Any]:
    """Captures the current Git state including base commit and patch.

    This captures staged, unstaged, and untracked changes by temporarily
    using git's "intent-to-add" feature. It restores the untracked state
    after generating the patch.

    Args:
        repo_path: Path to the git repository. Defaults to ".".

    Returns:
        dict[str, Any]: A dictionary containing:
            - base_commit (str): The HEX SHA of the current HEAD.
            - patch (str): The git diff output.
            - is_dirty (bool): True if there are any changes (including untracked).
    """
    repo = git.Repo(repo_path, search_parent_directories=True)

    # Base commit (HEAD)
    base_commit = repo.head.commit.hexsha

    # Identify untracked files to restore them later
    untracked_files = repo.untracked_files

    try:
        # To include untracked files in the patch, we use 'git add -N' (intent-to-add)
        # which makes them appear in the diff without actually staging their content.
        # This DOES NOT modify the actual file content on disk.
        for f in untracked_files:
            repo.git.add(f, intent_to_add=True)

        # Patch of all changes (staged, unstaged, and untracked-via-intent-to-add)
        patch = repo.git.diff(repo.head.commit)

        # Check if dirty (including untracked files)
        is_dirty = repo.is_dirty(untracked_files=True)

    finally:
        # Restore the untracked state in the index for files we touched.
        # repo.git.reset(f) is equivalent to 'git reset <file>', which only
        # affects the index and does NOT revert working tree changes.
        for f in untracked_files:
            with contextlib.suppress(git.exc.GitCommandError):
                repo.git.reset(f)

    return {"base_commit": base_commit, "patch": patch, "is_dirty": is_dirty}

log_flattened_params(d, prefix='')

Recursively logs a dictionary as flattened MLflow parameters.

Parameters:

Name Type Description Default
d dict[str, Any]

The dictionary to log.

required
prefix str

Optional prefix for parameter keys (used for recursion).

''
Source code in src/researchlab/tracking/utils.py
def log_flattened_params(d: dict[str, Any], prefix: str = "") -> None:
    """Recursively logs a dictionary as flattened MLflow parameters.

    Args:
        d: The dictionary to log.
        prefix: Optional prefix for parameter keys (used for recursion).
    """
    for k, v in d.items():
        key = f"{prefix}{k}" if prefix else k
        if isinstance(v, dict):
            log_flattened_params(v, prefix=f"{key}.")
        else:
            mlflow.log_param(key, v)

Design Module

Core

researchlab.design.core

Config

Bases: BaseModel

Base class for immutable hyperparameters.

This class serves as a foundation for configuration objects. It uses Pydantic for schema validation and serialization. Configurations are immutable (frozen) to ensure reproducibility.

Example
class TrainingConfig(Config):
    learning_rate: float = 1e-3
    batch_size: int = 32

config = TrainingConfig(learning_rate=0.01)
print(config.model_dump())
# {'learning_rate': 0.01, 'batch_size': 32}
Source code in src/researchlab/design/core.py
class Config(BaseModel):
    """Base class for immutable hyperparameters.

    This class serves as a foundation for configuration objects. It uses Pydantic
    for schema validation and serialization. Configurations are immutable (frozen)
    to ensure reproducibility.

    Example:
        ```python
        class TrainingConfig(Config):
            learning_rate: float = 1e-3
            batch_size: int = 32

        config = TrainingConfig(learning_rate=0.01)
        print(config.model_dump())
        # {'learning_rate': 0.01, 'batch_size': 32}
        ```
    """

    model_config = ConfigDict(frozen=True)

FieldSelector

Bases: Selector[S, C]

A simplified Selector using dot-notation strings for mapping.

This selector allows you to define a mapping from kernel argument names to dot-notation paths within the State or Config objects. It supports nested attributes (e.g., state.sub.val) and dictionary/sequence access (e.g., state.dict.key, state.list.0).

Parameters:

Name Type Description Default
**mappings str

Keyword arguments where keys are the kernel argument names and values are dot-notation strings starting with 'state.' or 'config.'.

{}
Example
class MyState(State):
    val: int
class MyConfig(Config):
    factor: int

@FieldSelector(a="state.val", b="config.factor")
def multiply(a, b):
    return a * b

state = MyState(val=5)
config = MyConfig(factor=3)
print(multiply(state, config))
# 15
Source code in src/researchlab/design/core.py
class FieldSelector[S: State, C: Config](Selector[S, C]):
    """A simplified Selector using dot-notation strings for mapping.

    This selector allows you to define a mapping from kernel argument names to
    dot-notation paths within the `State` or `Config` objects. It supports
    nested attributes (e.g., `state.sub.val`) and dictionary/sequence access
    (e.g., `state.dict.key`, `state.list.0`).

    Args:
        **mappings: Keyword arguments where keys are the kernel argument names
            and values are dot-notation strings starting with 'state.' or 'config.'.

    Example:
        ```python
        class MyState(State):
            val: int
        class MyConfig(Config):
            factor: int

        @FieldSelector(a="state.val", b="config.factor")
        def multiply(a, b):
            return a * b

        state = MyState(val=5)
        config = MyConfig(factor=3)
        print(multiply(state, config))
        # 15
        ```
    """

    def __init__(self, **mappings: str):
        """Initializes the FieldSelector with path mappings."""

        def extractor(state: S, config: C) -> tuple[tuple[Any, ...], dict[str, Any]]:
            kwargs = {}
            for arg_name, path in mappings.items():
                parts = path.split(".")
                root_name = parts[0]
                if root_name == "state":
                    obj = state
                elif root_name == "config":
                    obj = config
                else:
                    raise ValueError(f"Path must start with 'state' or 'config', got '{path}'")

                for part in parts[1:]:
                    if hasattr(obj, part):
                        obj = getattr(obj, part)
                    else:
                        # Handle dictionary/list/tuple access
                        try:
                            obj = obj[part]
                        except (TypeError, KeyError, IndexError):
                            # Try integer index for sequences
                            try:
                                idx = int(part)
                                obj = obj[idx]
                            except (ValueError, TypeError, IndexError, KeyError):
                                raise AttributeError(
                                    f"Could not resolve path '{path}': '{part}' not found."
                                ) from None
                kwargs[arg_name] = obj
            return (), kwargs

        super().__init__(extractor)

__init__(**mappings)

Initializes the FieldSelector with path mappings.

Source code in src/researchlab/design/core.py
def __init__(self, **mappings: str):
    """Initializes the FieldSelector with path mappings."""

    def extractor(state: S, config: C) -> tuple[tuple[Any, ...], dict[str, Any]]:
        kwargs = {}
        for arg_name, path in mappings.items():
            parts = path.split(".")
            root_name = parts[0]
            if root_name == "state":
                obj = state
            elif root_name == "config":
                obj = config
            else:
                raise ValueError(f"Path must start with 'state' or 'config', got '{path}'")

            for part in parts[1:]:
                if hasattr(obj, part):
                    obj = getattr(obj, part)
                else:
                    # Handle dictionary/list/tuple access
                    try:
                        obj = obj[part]
                    except (TypeError, KeyError, IndexError):
                        # Try integer index for sequences
                        try:
                            idx = int(part)
                            obj = obj[idx]
                        except (ValueError, TypeError, IndexError, KeyError):
                            raise AttributeError(
                                f"Could not resolve path '{path}': '{part}' not found."
                            ) from None
            kwargs[arg_name] = obj
        return (), kwargs

    super().__init__(extractor)

Kernel

Bases: Protocol

Protocol for pure functions representing the core logic.

Kernels are pure functions that take specific inputs (defined by P) and return a result (defined by R). They should be free of side effects and depend only on their inputs, making them easy to test and transform.

Source code in src/researchlab/design/core.py
@runtime_checkable
class Kernel[P, R](Protocol):
    """Protocol for pure functions representing the core logic.

    Kernels are pure functions that take specific inputs (defined by `P`) and
    return a result (defined by `R`). They should be free of side effects and
    depend only on their inputs, making them easy to test and transform.
    """

    def __call__(self, *args: P, **kwargs: Any) -> R: ...

SelectedKernel

Bases: Module

The wrapped kernel returned by a Selector.

This class wraps a pure kernel function and an extractor. When called with State and Config objects, it uses the extractor to derive the arguments for the kernel and then executes the kernel.

Attributes:

Name Type Description
raw Callable[..., R]

The original pure kernel function.

Source code in src/researchlab/design/core.py
class SelectedKernel[S: State, C: Config, R](eqx.Module):
    """The wrapped kernel returned by a Selector.

    This class wraps a pure kernel function and an extractor. When called with
    `State` and `Config` objects, it uses the extractor to derive the arguments
    for the kernel and then executes the kernel.

    Attributes:
        raw: The original pure kernel function.
    """

    _func: Callable[..., R]
    _extractor: Callable[[S, C], tuple[tuple[Any, ...], dict[str, Any]]]

    def __call__(self, state: S, config: C) -> R:
        """Executes the kernel using data extracted from state and config.

        Args:
            state: The current simulation state.
            config: The experiment configuration.

        Returns:
            The result of the kernel function.
        """
        args, kwargs = self._extractor(state, config)
        return self._func(*args, **kwargs)

    @property
    def raw(self) -> Callable[..., R]:
        """Access the original pure function.

        Returns:
            The underlying kernel function before it was wrapped.
        """
        return self._func

raw property

Access the original pure function.

Returns:

Type Description
Callable[..., R]

The underlying kernel function before it was wrapped.

__call__(state, config)

Executes the kernel using data extracted from state and config.

Parameters:

Name Type Description Default
state S

The current simulation state.

required
config C

The experiment configuration.

required

Returns:

Type Description
R

The result of the kernel function.

Source code in src/researchlab/design/core.py
def __call__(self, state: S, config: C) -> R:
    """Executes the kernel using data extracted from state and config.

    Args:
        state: The current simulation state.
        config: The experiment configuration.

    Returns:
        The result of the kernel function.
    """
    args, kwargs = self._extractor(state, config)
    return self._func(*args, **kwargs)

Selector

A decorator/higher-order function to bind State and Config to Kernel arguments.

A Selector defines how to extract arguments for a kernel from the broader State and Config context. It decouples the kernel's signature from the application's data structure.

Parameters:

Name Type Description Default
extractor Callable[[S, C], tuple[tuple[Any, ...], dict[str, Any]]]

A function that takes (state, config) and returns a tuple ((args...), {kwargs...}) containing the arguments to be passed to the kernel.

required
Example
class MyState(State):
    val: int
class MyConfig(Config):
    factor: int

# Define how to extract arguments
def my_extractor(state, config):
    return (state.val, config.factor), {}

# Wrap the kernel
@Selector(my_extractor)
def compute(val, factor):
    return val * factor

state = MyState(val=10)
config = MyConfig(factor=2)
print(compute(state, config))
# 20
Source code in src/researchlab/design/core.py
class Selector[S: State, C: Config]:
    """A decorator/higher-order function to bind State and Config to Kernel arguments.

    A `Selector` defines how to extract arguments for a kernel from the broader
    `State` and `Config` context. It decouples the kernel's signature from the
    application's data structure.

    Args:
        extractor: A function that takes `(state, config)` and returns a tuple
            `((args...), {kwargs...})` containing the arguments to be passed
            to the kernel.

    Example:
        ```python
        class MyState(State):
            val: int
        class MyConfig(Config):
            factor: int

        # Define how to extract arguments
        def my_extractor(state, config):
            return (state.val, config.factor), {}

        # Wrap the kernel
        @Selector(my_extractor)
        def compute(val, factor):
            return val * factor

        state = MyState(val=10)
        config = MyConfig(factor=2)
        print(compute(state, config))
        # 20
        ```
    """

    def __init__(
        self,
        extractor: Callable[[S, C], tuple[tuple[Any, ...], dict[str, Any]]],
    ):
        """Initializes the Selector with an extractor function."""
        self.extractor = extractor

    def __call__[R](self, func: Callable[..., R]) -> SelectedKernel[S, C, R]:
        """Wraps a kernel function.

        Args:
            func: The pure kernel function to wrap.

        Returns:
            A `SelectedKernel` that can be called with `(state, config)`.
        """
        return SelectedKernel(func, self.extractor)

__call__(func)

Wraps a kernel function.

Parameters:

Name Type Description Default
func Callable[..., R]

The pure kernel function to wrap.

required

Returns:

Type Description
SelectedKernel[S, C, R]

A SelectedKernel that can be called with (state, config).

Source code in src/researchlab/design/core.py
def __call__[R](self, func: Callable[..., R]) -> SelectedKernel[S, C, R]:
    """Wraps a kernel function.

    Args:
        func: The pure kernel function to wrap.

    Returns:
        A `SelectedKernel` that can be called with `(state, config)`.
    """
    return SelectedKernel(func, self.extractor)

__init__(extractor)

Initializes the Selector with an extractor function.

Source code in src/researchlab/design/core.py
def __init__(
    self,
    extractor: Callable[[S, C], tuple[tuple[Any, ...], dict[str, Any]]],
):
    """Initializes the Selector with an extractor function."""
    self.extractor = extractor

State

Bases: Module

Abstract base class for simulation state.

States must be JAX PyTrees, which is handled automatically by inheriting from equinox.Module. This ensures compatibility with JAX transformations like jax.jit, jax.grad, and jax.vmap.

Example
import jax.numpy as jnp
import equinox as eqx

class EnvState(State):
    position: jnp.ndarray
    velocity: jnp.ndarray

state = EnvState(position=jnp.zeros(2), velocity=jnp.zeros(2))
Source code in src/researchlab/design/core.py
class State(eqx.Module):
    """Abstract base class for simulation state.

    States must be JAX PyTrees, which is handled automatically by inheriting
    from `equinox.Module`. This ensures compatibility with JAX transformations
    like `jax.jit`, `jax.grad`, and `jax.vmap`.

    Example:
        ```python
        import jax.numpy as jnp
        import equinox as eqx

        class EnvState(State):
            position: jnp.ndarray
            velocity: jnp.ndarray

        state = EnvState(position=jnp.zeros(2), velocity=jnp.zeros(2))
        ```
    """

Infrastructure

researchlab.design.infra

DataProvider

Bases: ABC

Abstract base class for providing data batches to the training loop.

Implementations should handle data loading, preprocessing, and batching. The get_batch method can optionally depend on the current State, allowing for curriculum learning or state-dependent sampling.

Source code in src/researchlab/design/infra.py
class DataProvider[S: State](ABC):
    """Abstract base class for providing data batches to the training loop.

    Implementations should handle data loading, preprocessing, and batching.
    The `get_batch` method can optionally depend on the current `State`,
    allowing for curriculum learning or state-dependent sampling.
    """

    @abstractmethod
    def __next__(self) -> Any:
        """Returns the next batch of data from the iterator.

        This method supports the iterator protocol.

        Returns:
            A batch of data (structure depends on implementation).

        Raises:
            StopIteration: If the data stream is exhausted.
        """
        ...

    @abstractmethod
    def get_batch(self, state: S) -> Any:
        """Returns a batch of data, potentially dependent on the current state.

        Args:
            state: The current simulation state.

        Returns:
            A batch of data.
        """
        ...

__next__() abstractmethod

Returns the next batch of data from the iterator.

This method supports the iterator protocol.

Returns:

Type Description
Any

A batch of data (structure depends on implementation).

Raises:

Type Description
StopIteration

If the data stream is exhausted.

Source code in src/researchlab/design/infra.py
@abstractmethod
def __next__(self) -> Any:
    """Returns the next batch of data from the iterator.

    This method supports the iterator protocol.

    Returns:
        A batch of data (structure depends on implementation).

    Raises:
        StopIteration: If the data stream is exhausted.
    """
    ...

get_batch(state) abstractmethod

Returns a batch of data, potentially dependent on the current state.

Parameters:

Name Type Description Default
state S

The current simulation state.

required

Returns:

Type Description
Any

A batch of data.

Source code in src/researchlab/design/infra.py
@abstractmethod
def get_batch(self, state: S) -> Any:
    """Returns a batch of data, potentially dependent on the current state.

    Args:
        state: The current simulation state.

    Returns:
        A batch of data.
    """
    ...

EquinoxPersister

Bases: Persister[S, C]

Concrete implementation using Equinox's serialization (Safetensors).

This persister saves the State object to a .eqx file using equinox.tree_serialise_leaves. It assumes that the Config object is handled separately or reconstructible, as Equinox serialization primarily handles array data (leaves).

Example
from pathlib import Path

persister = EquinoxPersister()
# Saves only the state (config is ignored)
persister.save(state, config, step=100, path=Path("ckpt.eqx"))
Source code in src/researchlab/design/infra.py
class EquinoxPersister[S: State, C: Config](Persister[S, C]):
    """Concrete implementation using Equinox's serialization (Safetensors).

    This persister saves the `State` object to a `.eqx` file using `equinox.tree_serialise_leaves`.
    It assumes that the `Config` object is handled separately or reconstructible, as Equinox
    serialization primarily handles array data (leaves).

    Example:
        ```python
        from pathlib import Path

        persister = EquinoxPersister()
        # Saves only the state (config is ignored)
        persister.save(state, config, step=100, path=Path("ckpt.eqx"))
        ```
    """

    def save(self, state: S, config: C, step: int, path: Path) -> None:
        """Saves state to a .eqx file.

        Config is not saved by EquinoxPersister as it is usually static,
        but for completeness we could pickle it or verify if config is a PyTree.

        For now, we only serialize the State using equinox, assuming Config is handled separately
        or is reconstructible.

        Args:
            state: The state to save.
            config: The config (ignored).
            step: The step (ignored/not saved in file).
            path: The path to save the .eqx file.
        """
        # Ensure directory exists
        path.parent.mkdir(parents=True, exist_ok=True)
        eqx.tree_serialise_leaves(path, state)

    def load(self, path: Path, state_structure: S, config_structure: C) -> tuple[S, C, int]:
        """Loads state from a .eqx file.

        Note: This implementation currently only restores State.
        It returns the passed `config_structure` as is, and assumes step=0
        (since step is not intrinsically stored in the tree leaves unless added to State).

        Args:
            path: Path to the .eqx file.
            state_structure: Structure of the state to load into.
            config_structure: Structure of the config (returned as is).

        Returns:
            (loaded_state, config_structure, 0)
        """
        loaded_state = eqx.tree_deserialise_leaves(path, state_structure)
        return loaded_state, config_structure, 0

load(path, state_structure, config_structure)

Loads state from a .eqx file.

Note: This implementation currently only restores State. It returns the passed config_structure as is, and assumes step=0 (since step is not intrinsically stored in the tree leaves unless added to State).

Parameters:

Name Type Description Default
path Path

Path to the .eqx file.

required
state_structure S

Structure of the state to load into.

required
config_structure C

Structure of the config (returned as is).

required

Returns:

Type Description
tuple[S, C, int]

(loaded_state, config_structure, 0)

Source code in src/researchlab/design/infra.py
def load(self, path: Path, state_structure: S, config_structure: C) -> tuple[S, C, int]:
    """Loads state from a .eqx file.

    Note: This implementation currently only restores State.
    It returns the passed `config_structure` as is, and assumes step=0
    (since step is not intrinsically stored in the tree leaves unless added to State).

    Args:
        path: Path to the .eqx file.
        state_structure: Structure of the state to load into.
        config_structure: Structure of the config (returned as is).

    Returns:
        (loaded_state, config_structure, 0)
    """
    loaded_state = eqx.tree_deserialise_leaves(path, state_structure)
    return loaded_state, config_structure, 0

save(state, config, step, path)

Saves state to a .eqx file.

Config is not saved by EquinoxPersister as it is usually static, but for completeness we could pickle it or verify if config is a PyTree.

For now, we only serialize the State using equinox, assuming Config is handled separately or is reconstructible.

Parameters:

Name Type Description Default
state S

The state to save.

required
config C

The config (ignored).

required
step int

The step (ignored/not saved in file).

required
path Path

The path to save the .eqx file.

required
Source code in src/researchlab/design/infra.py
def save(self, state: S, config: C, step: int, path: Path) -> None:
    """Saves state to a .eqx file.

    Config is not saved by EquinoxPersister as it is usually static,
    but for completeness we could pickle it or verify if config is a PyTree.

    For now, we only serialize the State using equinox, assuming Config is handled separately
    or is reconstructible.

    Args:
        state: The state to save.
        config: The config (ignored).
        step: The step (ignored/not saved in file).
        path: The path to save the .eqx file.
    """
    # Ensure directory exists
    path.parent.mkdir(parents=True, exist_ok=True)
    eqx.tree_serialise_leaves(path, state)

MLFlowTelemetry

Bases: Telemetry[S, C]

Concrete implementation of Telemetry using MLflow.

This class logs metrics and parameters to an active MLflow run. It assumes that an MLflow run has already been started (e.g., by ExperimentTracker).

It automatically flattens nested parameter dictionaries and converts JAX scalar arrays to Python floats to ensure compatibility with MLflow.

Example
# Inside an active MLflow run
telemetry = MLFlowTelemetry()
telemetry.log_metrics({"loss": 0.5}, step=10)
Source code in src/researchlab/design/infra.py
class MLFlowTelemetry[S: State, C: Config](Telemetry[S, C]):
    """Concrete implementation of Telemetry using MLflow.

    This class logs metrics and parameters to an active MLflow run. It assumes
    that an MLflow run has already been started (e.g., by `ExperimentTracker`).

    It automatically flattens nested parameter dictionaries and converts JAX
    scalar arrays to Python floats to ensure compatibility with MLflow.

    Example:
        ```python
        # Inside an active MLflow run
        telemetry = MLFlowTelemetry()
        telemetry.log_metrics({"loss": 0.5}, step=10)
        ```
    """

    def __init__(self):
        """Initializes the MLFlowTelemetry instance."""

    def log_metrics(self, metrics: dict[str, float], step: int) -> None:
        """Logs scalar metrics to MLflow.

        Converts JAX/NumPy scalars to Python floats before logging.

        Args:
            metrics: Dictionary of metrics.
            step: Current step.
        """
        # Convert JAX/Numpy scalars to python float for MLflow
        clean_metrics = {}
        for k, v in metrics.items():
            if hasattr(v, "item"):
                try:
                    clean_metrics[k] = float(v.item())
                except (ValueError, TypeError):
                    # Fallback if item() doesn't return a float (e.g. array with >1 element)
                    # We might log it as is and let mlflow handle or error?
                    # Better to ignore or log warning? For now, try best effort.
                    clean_metrics[k] = v
            else:
                clean_metrics[k] = v
        mlflow.log_metrics(clean_metrics, step=step)

    def log_params(self, params: dict[str, Any]) -> None:
        """Logs parameters to MLflow, flattening nested dictionaries.

        Args:
            params: Dictionary of parameters.
        """
        log_flattened_params(params)

__init__()

Initializes the MLFlowTelemetry instance.

Source code in src/researchlab/design/infra.py
def __init__(self):
    """Initializes the MLFlowTelemetry instance."""

log_metrics(metrics, step)

Logs scalar metrics to MLflow.

Converts JAX/NumPy scalars to Python floats before logging.

Parameters:

Name Type Description Default
metrics dict[str, float]

Dictionary of metrics.

required
step int

Current step.

required
Source code in src/researchlab/design/infra.py
def log_metrics(self, metrics: dict[str, float], step: int) -> None:
    """Logs scalar metrics to MLflow.

    Converts JAX/NumPy scalars to Python floats before logging.

    Args:
        metrics: Dictionary of metrics.
        step: Current step.
    """
    # Convert JAX/Numpy scalars to python float for MLflow
    clean_metrics = {}
    for k, v in metrics.items():
        if hasattr(v, "item"):
            try:
                clean_metrics[k] = float(v.item())
            except (ValueError, TypeError):
                # Fallback if item() doesn't return a float (e.g. array with >1 element)
                # We might log it as is and let mlflow handle or error?
                # Better to ignore or log warning? For now, try best effort.
                clean_metrics[k] = v
        else:
            clean_metrics[k] = v
    mlflow.log_metrics(clean_metrics, step=step)

log_params(params)

Logs parameters to MLflow, flattening nested dictionaries.

Parameters:

Name Type Description Default
params dict[str, Any]

Dictionary of parameters.

required
Source code in src/researchlab/design/infra.py
def log_params(self, params: dict[str, Any]) -> None:
    """Logs parameters to MLflow, flattening nested dictionaries.

    Args:
        params: Dictionary of parameters.
    """
    log_flattened_params(params)

Persister

Bases: ABC

Abstract base class for saving and loading checkpoints.

Persisters handle the serialization and deserialization of the simulation state and configuration.

Source code in src/researchlab/design/infra.py
class Persister[S: State, C: Config](ABC):
    """Abstract base class for saving and loading checkpoints.

    Persisters handle the serialization and deserialization of the simulation
    state and configuration.
    """

    @abstractmethod
    def save(self, state: S, config: C, step: int, path: Path) -> None:
        """Saves the current state and config to the specified path.

        Args:
            state: The current simulation state.
            config: The experiment configuration.
            step: The current step count.
            path: The file path or directory to save to.
        """
        ...

    @abstractmethod
    def load(self, path: Path, state_structure: S, config_structure: C) -> tuple[S, C, int]:
        """Loads state and config from a checkpoint.

        Args:
            path: Path to the checkpoint file/directory.
            state_structure: A structure (PyTree) matching the state to load.
            config_structure: A structure matching the config to load.

        Returns:
            A tuple containing `(loaded_state, loaded_config, step)`.
        """
        ...

load(path, state_structure, config_structure) abstractmethod

Loads state and config from a checkpoint.

Parameters:

Name Type Description Default
path Path

Path to the checkpoint file/directory.

required
state_structure S

A structure (PyTree) matching the state to load.

required
config_structure C

A structure matching the config to load.

required

Returns:

Type Description
tuple[S, C, int]

A tuple containing (loaded_state, loaded_config, step).

Source code in src/researchlab/design/infra.py
@abstractmethod
def load(self, path: Path, state_structure: S, config_structure: C) -> tuple[S, C, int]:
    """Loads state and config from a checkpoint.

    Args:
        path: Path to the checkpoint file/directory.
        state_structure: A structure (PyTree) matching the state to load.
        config_structure: A structure matching the config to load.

    Returns:
        A tuple containing `(loaded_state, loaded_config, step)`.
    """
    ...

save(state, config, step, path) abstractmethod

Saves the current state and config to the specified path.

Parameters:

Name Type Description Default
state S

The current simulation state.

required
config C

The experiment configuration.

required
step int

The current step count.

required
path Path

The file path or directory to save to.

required
Source code in src/researchlab/design/infra.py
@abstractmethod
def save(self, state: S, config: C, step: int, path: Path) -> None:
    """Saves the current state and config to the specified path.

    Args:
        state: The current simulation state.
        config: The experiment configuration.
        step: The current step count.
        path: The file path or directory to save to.
    """
    ...

Telemetry

Bases: ABC

Abstract base class for logging metrics and parameters.

Telemetry components handle the reporting of experiment results. They should be non-intrusive and only observe the state/metrics.

Source code in src/researchlab/design/infra.py
class Telemetry[S: State, C: Config](ABC):
    """Abstract base class for logging metrics and parameters.

    Telemetry components handle the reporting of experiment results. They should
    be non-intrusive and only observe the state/metrics.
    """

    @abstractmethod
    def log_metrics(self, metrics: dict[str, float], step: int) -> None:
        """Logs scalar metrics for a given step.

        Args:
            metrics: A dictionary mapping metric names to values.
            step: The current global step or iteration.
        """
        ...

    @abstractmethod
    def log_params(self, params: dict[str, Any]) -> None:
        """Logs hyperparameters or configuration settings.

        Args:
            params: A dictionary of parameters to log. Nested dictionaries are supported.
        """
        ...

log_metrics(metrics, step) abstractmethod

Logs scalar metrics for a given step.

Parameters:

Name Type Description Default
metrics dict[str, float]

A dictionary mapping metric names to values.

required
step int

The current global step or iteration.

required
Source code in src/researchlab/design/infra.py
@abstractmethod
def log_metrics(self, metrics: dict[str, float], step: int) -> None:
    """Logs scalar metrics for a given step.

    Args:
        metrics: A dictionary mapping metric names to values.
        step: The current global step or iteration.
    """
    ...

log_params(params) abstractmethod

Logs hyperparameters or configuration settings.

Parameters:

Name Type Description Default
params dict[str, Any]

A dictionary of parameters to log. Nested dictionaries are supported.

required
Source code in src/researchlab/design/infra.py
@abstractmethod
def log_params(self, params: dict[str, Any]) -> None:
    """Logs hyperparameters or configuration settings.

    Args:
        params: A dictionary of parameters to log. Nested dictionaries are supported.
    """
    ...

Visualizer

Bases: ABC

Abstract base class for rendering and recording simulation states.

This abstraction supports three main functionalities: 1. Generating frames from states via render(state). 2. Displaying the last generated frame via show(). 3. Recording/Streaming frames via the _record() hook (if is_recording=True).

Source code in src/researchlab/design/infra.py
class Visualizer[S: State](ABC):
    """Abstract base class for rendering and recording simulation states.

    This abstraction supports three main functionalities:
    1. Generating frames from states via `render(state)`.
    2. Displaying the last generated frame via `show()`.
    3. Recording/Streaming frames via the `_record()` hook (if `is_recording=True`).
    """

    def __init__(self, record: bool = False, fps: None | int = None):
        """Initializes the visualizer.

        Args:
            record: If True, enable recording mode to trigger the `_record()` hook.
            fps: Optional frames per second for recording and showing.
        """
        self.is_recording = record
        self._fps = fps
        self._frame_interval = 1.0 / fps if fps is not None else None
        self._last_frame: Any | None = None
        self._last_show_time: float = time.time()  # For FPS control in show()

    def render(self, state: S) -> Any:
        """Generate a frame and trigger the `_record()` hook if recording is enabled.

        Args:
            state: The state to render.

        Returns:
            The rendered frame (e.g., image array).
        """
        frame = self._render_frame(state)
        self._last_frame = frame
        if self.is_recording:
            self._record()
        return frame

    @abstractmethod
    def _render_frame(self, state: S) -> Any:
        """Implementation-specific frame generation logic."""
        ...

    @abstractmethod
    def _record(self) -> None:
        """Hook called at the end of `render` if `is_recording` is True.

        Subclasses should implement this to stream the frame to a video file,
        buffer it selectively, or perform other persistence tasks. The last
        rendered frame is available in `self._last_frame`.
        """

    def show(self) -> None:
        """Display the last rendered frame to the user."""
        if self._last_frame is not None:
            self._show_frame(self._last_frame)

    @abstractmethod
    def _show_frame(self, frame: Any) -> None:
        """Implementation-specific display logic.

        `self._frame_interval` and `self._last_show_time` can
        be used to control the display rate if `fps` was set.

        Args:
            frame: The frame to display (e.g., image array).
        """
        ...

    @abstractmethod
    def close(self) -> None:
        """Clean up any resources used by the visualizer (e.g., close windows, release video writers)."""
        ...

    def __del__(self):
        """Ensure resources are cleaned up when the visualizer is garbage collected."""
        self.close()

__del__()

Ensure resources are cleaned up when the visualizer is garbage collected.

Source code in src/researchlab/design/infra.py
def __del__(self):
    """Ensure resources are cleaned up when the visualizer is garbage collected."""
    self.close()

__init__(record=False, fps=None)

Initializes the visualizer.

Parameters:

Name Type Description Default
record bool

If True, enable recording mode to trigger the _record() hook.

False
fps None | int

Optional frames per second for recording and showing.

None
Source code in src/researchlab/design/infra.py
def __init__(self, record: bool = False, fps: None | int = None):
    """Initializes the visualizer.

    Args:
        record: If True, enable recording mode to trigger the `_record()` hook.
        fps: Optional frames per second for recording and showing.
    """
    self.is_recording = record
    self._fps = fps
    self._frame_interval = 1.0 / fps if fps is not None else None
    self._last_frame: Any | None = None
    self._last_show_time: float = time.time()  # For FPS control in show()

close() abstractmethod

Clean up any resources used by the visualizer (e.g., close windows, release video writers).

Source code in src/researchlab/design/infra.py
@abstractmethod
def close(self) -> None:
    """Clean up any resources used by the visualizer (e.g., close windows, release video writers)."""
    ...

render(state)

Generate a frame and trigger the _record() hook if recording is enabled.

Parameters:

Name Type Description Default
state S

The state to render.

required

Returns:

Type Description
Any

The rendered frame (e.g., image array).

Source code in src/researchlab/design/infra.py
def render(self, state: S) -> Any:
    """Generate a frame and trigger the `_record()` hook if recording is enabled.

    Args:
        state: The state to render.

    Returns:
        The rendered frame (e.g., image array).
    """
    frame = self._render_frame(state)
    self._last_frame = frame
    if self.is_recording:
        self._record()
    return frame

show()

Display the last rendered frame to the user.

Source code in src/researchlab/design/infra.py
def show(self) -> None:
    """Display the last rendered frame to the user."""
    if self._last_frame is not None:
        self._show_frame(self._last_frame)

Orchestrator

researchlab.design.orchestrator

Loop

A generic training/simulation loop orchestrator.

The Loop class coordinates the interaction between the core components (State, Config, Kernel) and the infrastructure components (Data, Telemetry, Persistence, Visualization). It manages the execution flow, step counting, and periodic tasks.

Example
# Assuming components are defined
loop = Loop(
    config=config,
    initial_state=state,
    step_fn=step_fn,
    data_provider=provider,
    telemetry=telemetry
)
loop.run(num_steps=100)
Source code in src/researchlab/design/orchestrator.py
class Loop[S: State, C: Config]:
    """A generic training/simulation loop orchestrator.

    The `Loop` class coordinates the interaction between the core components
    (State, Config, Kernel) and the infrastructure components (Data, Telemetry,
    Persistence, Visualization). It manages the execution flow, step counting,
    and periodic tasks.

    Example:
        ```python
        # Assuming components are defined
        loop = Loop(
            config=config,
            initial_state=state,
            step_fn=step_fn,
            data_provider=provider,
            telemetry=telemetry
        )
        loop.run(num_steps=100)
        ```
    """

    def __init__(
        self,
        config: C,
        initial_state: S,
        step_fn: Kernel[tuple[S, C, Any], tuple[S, dict[str, Any]]],
        data_provider: DataProvider[S],
        telemetry: Telemetry[S, C] | None = None,
        persister: Persister[S, C] | None = None,
        visualizer: Visualizer[S] | None = None,
    ):
        """Initializes the loop.

        Args:
            config: Immutable hyperparameters.
            initial_state: Initial simulation state.
            step_fn: A pure kernel function that takes (state, config, batch) and returns (new_state, metrics).
            data_provider: Source of data.
            telemetry: Optional logger.
            persister: Optional checkpointer.
            visualizer: Optional renderer.
        """
        self.config = config
        self.state = initial_state
        self.step_fn = step_fn
        self.data_provider = data_provider
        self.telemetry = telemetry
        self.persister = persister
        self.visualizer = visualizer
        self.step = 0

    def run(self, num_steps: int):
        """Run the loop for a specified number of steps."""
        for _ in range(num_steps):
            self.step += 1

            # 1. Get Data
            try:
                batch = self.data_provider.get_batch(self.state)
            except NotImplementedError:
                batch = next(self.data_provider) # Fallback to iterator protocol if implemented
            except StopIteration:
                break

            # 2. Step (Pure Kernel)
            self.state, metrics = self.step_fn(self.state, self.config, batch)

            # 3. Telemetry
            if self.telemetry:
                self.telemetry.log_metrics(metrics, self.step)

            # 4. Persistence (Example: every 1000 steps or similar logic could be added)
            # For simplicity, we expose a manual save method or rely on the user to call it.
            # Here we just show how it *could* be used.

            # 5. Visualization (Example: every N steps)
            if self.visualizer:
                # self.visualizer.render(self.state)
                # self.visualizer.show()
                pass

    def save_checkpoint(self, path: Any): # using Any for path to avoid circular imports if Path is needed
        if self.persister:
            self.persister.save(self.state, self.config, self.step, path)

    def load_checkpoint(self, path: Any):
        if self.persister:
            self.state, self.config, self.step = self.persister.load(path, self.state, self.config)

__init__(config, initial_state, step_fn, data_provider, telemetry=None, persister=None, visualizer=None)

Initializes the loop.

Parameters:

Name Type Description Default
config C

Immutable hyperparameters.

required
initial_state S

Initial simulation state.

required
step_fn Kernel[tuple[S, C, Any], tuple[S, dict[str, Any]]]

A pure kernel function that takes (state, config, batch) and returns (new_state, metrics).

required
data_provider DataProvider[S]

Source of data.

required
telemetry Telemetry[S, C] | None

Optional logger.

None
persister Persister[S, C] | None

Optional checkpointer.

None
visualizer Visualizer[S] | None

Optional renderer.

None
Source code in src/researchlab/design/orchestrator.py
def __init__(
    self,
    config: C,
    initial_state: S,
    step_fn: Kernel[tuple[S, C, Any], tuple[S, dict[str, Any]]],
    data_provider: DataProvider[S],
    telemetry: Telemetry[S, C] | None = None,
    persister: Persister[S, C] | None = None,
    visualizer: Visualizer[S] | None = None,
):
    """Initializes the loop.

    Args:
        config: Immutable hyperparameters.
        initial_state: Initial simulation state.
        step_fn: A pure kernel function that takes (state, config, batch) and returns (new_state, metrics).
        data_provider: Source of data.
        telemetry: Optional logger.
        persister: Optional checkpointer.
        visualizer: Optional renderer.
    """
    self.config = config
    self.state = initial_state
    self.step_fn = step_fn
    self.data_provider = data_provider
    self.telemetry = telemetry
    self.persister = persister
    self.visualizer = visualizer
    self.step = 0

run(num_steps)

Run the loop for a specified number of steps.

Source code in src/researchlab/design/orchestrator.py
def run(self, num_steps: int):
    """Run the loop for a specified number of steps."""
    for _ in range(num_steps):
        self.step += 1

        # 1. Get Data
        try:
            batch = self.data_provider.get_batch(self.state)
        except NotImplementedError:
            batch = next(self.data_provider) # Fallback to iterator protocol if implemented
        except StopIteration:
            break

        # 2. Step (Pure Kernel)
        self.state, metrics = self.step_fn(self.state, self.config, batch)

        # 3. Telemetry
        if self.telemetry:
            self.telemetry.log_metrics(metrics, self.step)

        # 4. Persistence (Example: every 1000 steps or similar logic could be added)
        # For simplicity, we expose a manual save method or rely on the user to call it.
        # Here we just show how it *could* be used.

        # 5. Visualization (Example: every N steps)
        if self.visualizer:
            # self.visualizer.render(self.state)
            # self.visualizer.show()
            pass

Utilities

researchlab.design.utils

flatten_config(config, prefix='', separator='.')

Flattens a Config object (Pydantic model) into a dictionary.

It uses model_dump() to get the configuration as a dictionary and then flattens any nested dictionaries.

Parameters:

Name Type Description Default
config Config

The Config instance to flatten.

required
prefix str

An optional string to prepend to keys.

''
separator str

The separator for nested keys. Defaults to ".".

'.'

Returns:

Type Description
dict[str, Any]

A flattened dictionary representation of the configuration.

Example
class MyConfig(Config):
    nested: dict = {"a": 1}


config = MyConfig()
print(flatten_config(config))
# {'nested.a': 1}
Source code in src/researchlab/design/utils.py
def flatten_config(config: Config, prefix: str = "", separator: str = ".") -> dict[str, Any]:
    """Flattens a Config object (Pydantic model) into a dictionary.

    It uses `model_dump()` to get the configuration as a dictionary and then
    flattens any nested dictionaries.

    Args:
        config: The `Config` instance to flatten.
        prefix: An optional string to prepend to keys.
        separator: The separator for nested keys. Defaults to ".".

    Returns:
        A flattened dictionary representation of the configuration.

    Example:
        ```python
        class MyConfig(Config):
            nested: dict = {"a": 1}


        config = MyConfig()
        print(flatten_config(config))
        # {'nested.a': 1}
        ```
    """
    # Use Pydantic's model_dump to get nested dict
    d = config.model_dump()

    # Flatten dict
    flat = {}

    def _recurse(curr, current_key):
        if isinstance(curr, dict):
            for k, v in curr.items():
                new_key = f"{current_key}{separator}{k}" if current_key else k
                _recurse(v, new_key)
        else:
            flat[current_key] = curr

    _recurse(d, prefix)
    return flat

flatten_pytree(tree, prefix='', separator='.')

Flattens a JAX PyTree into a dictionary with dot-notation keys.

This utility traverses the PyTree structure and produces a flat dictionary where keys represent the path to each leaf node using dot notation. This is particularly useful for logging complex nested State objects to flat metric logging systems like MLflow.

Parameters:

Name Type Description Default
tree Any

The PyTree object to flatten (e.g., State, dict, list).

required
prefix str

An optional string to prepend to all generated keys.

''
separator str

The string used to separate path components. Defaults to ".".

'.'

Returns:

Type Description
dict[str, Any]

A dictionary mapping flattened string keys to the leaf values of the tree.

Example
tree = {"x": 10, "sub": {"y": 20}}
print(flatten_pytree(tree))
# {'x': 10, 'sub.y': 20}
Source code in src/researchlab/design/utils.py
def flatten_pytree(tree: Any, prefix: str = "", separator: str = ".") -> dict[str, Any]:
    """Flattens a JAX PyTree into a dictionary with dot-notation keys.

    This utility traverses the PyTree structure and produces a flat dictionary
    where keys represent the path to each leaf node using dot notation. This
    is particularly useful for logging complex nested `State` objects to
    flat metric logging systems like MLflow.

    Args:
        tree: The PyTree object to flatten (e.g., `State`, dict, list).
        prefix: An optional string to prepend to all generated keys.
        separator: The string used to separate path components. Defaults to ".".

    Returns:
        A dictionary mapping flattened string keys to the leaf values of the tree.

    Example:
        ```python
        tree = {"x": 10, "sub": {"y": 20}}
        print(flatten_pytree(tree))
        # {'x': 10, 'sub.y': 20}
        ```
    """
    flat_dict = {}

    # Get leaves with path
    # is_leaf=None means default JAX behavior (arrays are leaves, lists/tuples/dicts are nodes)
    leaves_with_path = jtu.tree_leaves_with_path(tree)

    for path, leaf in leaves_with_path:
        # Generate key from path
        # jax path entries are GetAttrKey, DictKey, SequenceKey, etc.
        key_parts = []
        for p in path:
            if isinstance(p, jtu.GetAttrKey):
                key_parts.append(p.name)
            elif isinstance(p, jtu.DictKey):
                key_parts.append(str(p.key))
            elif isinstance(p, jtu.SequenceKey):
                key_parts.append(str(p.idx))
            else:
                # Fallback
                key_parts.append(str(p))

        full_key = separator.join(key_parts)
        if prefix:
            full_key = f"{prefix}{separator}{full_key}"

        flat_dict[full_key] = leaf

    return flat_dict

unflatten_config(flat_dict, config_cls, separator='.')

Reconstructs a Config object from a flattened dictionary.

This function unflattens a dictionary into a nested structure and then validates it against the Pydantic model config_cls.

Parameters:

Name Type Description Default
flat_dict dict[str, Any]

The flattened dictionary containing configuration values.

required
config_cls type[C]

The class of the Config object to reconstruct.

required
separator str

The separator used in the flattened keys. Defaults to ".".

'.'

Returns:

Type Description
C

An instance of config_cls populated with values from flat_dict.

Raises:

Type Description
ValueError

If there is a key conflict (e.g., 'a' is both a value and a container).

ValidationError

If the reconstructed data fails Pydantic validation.

Example
flat = {"nested.a": 1}
print(unflatten_config(flat, MyConfig))
# MyConfig(nested={'a': 1})
Source code in src/researchlab/design/utils.py
def unflatten_config[C: Config](
    flat_dict: dict[str, Any], config_cls: type[C], separator: str = "."
) -> C:
    """Reconstructs a Config object from a flattened dictionary.

    This function unflattens a dictionary into a nested structure and then
    validates it against the Pydantic model `config_cls`.

    Args:
        flat_dict: The flattened dictionary containing configuration values.
        config_cls: The class of the `Config` object to reconstruct.
        separator: The separator used in the flattened keys. Defaults to ".".

    Returns:
        An instance of `config_cls` populated with values from `flat_dict`.

    Raises:
        ValueError: If there is a key conflict (e.g., 'a' is both a value and a container).
        ValidationError: If the reconstructed data fails Pydantic validation.

    Example:
        ```python
        flat = {"nested.a": 1}
        print(unflatten_config(flat, MyConfig))
        # MyConfig(nested={'a': 1})
        ```
    """
    # Unflatten dict
    nested = {}
    for key, value in flat_dict.items():
        parts = key.split(separator)
        curr = nested
        for _, part in enumerate(parts[:-1]):
            if part not in curr:
                curr[part] = {}
            curr = curr[part]
            if not isinstance(curr, dict):
                # Conflict: trying to use a value as a dict container
                # This might happen if keys are ambiguous e.g. "a" and "a.b"
                raise ValueError(f"Key conflict at '{part}' in '{key}'")

        curr[parts[-1]] = value

    return config_cls.model_validate(nested)

unflatten_pytree(flat_dict, structure, separator='.')

Populates a PyTree structure with values from a flattened dictionary.

This function attempts to reconstruct the values of a PyTree by looking up keys in a flattened dictionary that correspond to the paths in the structure. Leaves not found in the dictionary are kept as is from the structure.

Parameters:

Name Type Description Default
flat_dict dict[str, Any]

A dictionary containing flattened keys and values.

required
structure T

A PyTree instance defining the target structure.

required
separator str

The separator used in the flattened keys. Defaults to ".".

'.'

Returns:

Type Description
T

A new PyTree with the same structure as structure, but with leaves

T

updated from flat_dict.

Example
flat = {"x": 100, "sub.y": 200}
structure = {"x": 0, "sub": {"y": 0}}
print(unflatten_pytree(flat, structure))
# {'x': 100, 'sub': {'y': 200}}
Source code in src/researchlab/design/utils.py
def unflatten_pytree[T](flat_dict: dict[str, Any], structure: T, separator: str = ".") -> T:
    """Populates a PyTree structure with values from a flattened dictionary.

    This function attempts to reconstruct the values of a PyTree by looking up
    keys in a flattened dictionary that correspond to the paths in the `structure`.
    Leaves not found in the dictionary are kept as is from the structure.

    Args:
        flat_dict: A dictionary containing flattened keys and values.
        structure: A PyTree instance defining the target structure.
        separator: The separator used in the flattened keys. Defaults to ".".

    Returns:
        A new PyTree with the same structure as `structure`, but with leaves
        updated from `flat_dict`.

    Example:
        ```python
        flat = {"x": 100, "sub.y": 200}
        structure = {"x": 0, "sub": {"y": 0}}
        print(unflatten_pytree(flat, structure))
        # {'x': 100, 'sub': {'y': 200}}
        ```
    """
    # We traverse the structure and look up values in flat_dict using the generated path key.

    def leaf_transform(path, leaf):
        key_parts = []
        for p in path:
            if isinstance(p, jtu.GetAttrKey):
                key_parts.append(p.name)
            elif isinstance(p, jtu.DictKey):
                key_parts.append(str(p.key))
            elif isinstance(p, jtu.SequenceKey):
                key_parts.append(str(p.idx))
            else:
                key_parts.append(str(p))

        key = separator.join(key_parts)

        if key in flat_dict:
            return flat_dict[key]
        return leaf  # Keep original leaf if not in dict (or raise error?)
        # For strict reconstruction, maybe we should warn?
        # But keeping original (default/placeholder) value is often desired for partial updates.

    return jtu.tree_map_with_path(leaf_transform, structure)