API Reference
Tracking Module
ExperimentTracker
researchlab.tracking.tracker.ExperimentTracker
A context manager for tracking research experiments with Git state.
This class automates the process of starting an MLflow run, capturing the current Git state (base commit and diff), and logging artifacts and parameters.
Example
Attributes:
| Name | Type | Description |
|---|---|---|
run_name |
str
|
The unique, readable ID for this run. |
active_run |
Optional[Run]
|
The currently active MLflow run. |
Source code in src/researchlab/tracking/tracker.py
__enter__()
Starts the MLflow run and captures the Git state.
Returns:
| Name | Type | Description |
|---|---|---|
ExperimentTracker |
ExperimentTracker
|
The instance of the tracker. |
Source code in src/researchlab/tracking/tracker.py
__exit__(exc_type, exc_val, exc_tb)
__init__(experiment_name, run_name=None, tracking_uri=None)
Initializes the ExperimentTracker.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
experiment_name
|
str
|
The name of the MLflow experiment. |
required |
run_name
|
str | None
|
Optional custom name for the run. If not provided, a readable ID is generated automatically. |
None
|
tracking_uri
|
str | None
|
Optional MLflow tracking URI to use. |
None
|
Source code in src/researchlab/tracking/tracker.py
log_config(config_path)
Logs a YAML configuration file as parameters and an artifact.
Recursively parses the YAML file and logs each entry as an MLflow parameter. Also uploads the file itself as an artifact.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config_path
|
str
|
Path to the YAML configuration file. |
required |
Source code in src/researchlab/tracking/tracker.py
Utilities
researchlab.tracking.utils
find_run_by_rlab_id(run_id)
Finds an MLflow run by the custom rlab.run_id tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
run_id
|
str
|
The readable rlab run ID to search for. |
required |
Returns:
| Type | Description |
|---|---|
Run | None
|
Optional[Run]: The MLflow Run object if found, else None. |
Source code in src/researchlab/tracking/utils.py
generate_run_id()
Generates a unique and readable run ID.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
A string in the format YYYY-MM-DD_slug (e.g., 2026-02-15_radiant-octopus). |
Source code in src/researchlab/tracking/utils.py
get_git_state(repo_path='.')
Captures the current Git state including base commit and patch.
This captures staged, unstaged, and untracked changes by temporarily using git's "intent-to-add" feature. It restores the untracked state after generating the patch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_path
|
str
|
Path to the git repository. Defaults to ".". |
'.'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
dict[str, Any]: A dictionary containing: - base_commit (str): The HEX SHA of the current HEAD. - patch (str): The git diff output. - is_dirty (bool): True if there are any changes (including untracked). |
Source code in src/researchlab/tracking/utils.py
log_flattened_params(d, prefix='')
Recursively logs a dictionary as flattened MLflow parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict[str, Any]
|
The dictionary to log. |
required |
prefix
|
str
|
Optional prefix for parameter keys (used for recursion). |
''
|
Source code in src/researchlab/tracking/utils.py
Design Module
Core
researchlab.design.core
Config
Bases: BaseModel
Base class for immutable hyperparameters.
This class serves as a foundation for configuration objects. It uses Pydantic for schema validation and serialization. Configurations are immutable (frozen) to ensure reproducibility.
Example
Source code in src/researchlab/design/core.py
FieldSelector
Bases: Selector[S, C]
A simplified Selector using dot-notation strings for mapping.
This selector allows you to define a mapping from kernel argument names to
dot-notation paths within the State or Config objects. It supports
nested attributes (e.g., state.sub.val) and dictionary/sequence access
(e.g., state.dict.key, state.list.0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**mappings
|
str
|
Keyword arguments where keys are the kernel argument names and values are dot-notation strings starting with 'state.' or 'config.'. |
{}
|
Example
Source code in src/researchlab/design/core.py
__init__(**mappings)
Initializes the FieldSelector with path mappings.
Source code in src/researchlab/design/core.py
Kernel
Bases: Protocol
Protocol for pure functions representing the core logic.
Kernels are pure functions that take specific inputs (defined by P) and
return a result (defined by R). They should be free of side effects and
depend only on their inputs, making them easy to test and transform.
Source code in src/researchlab/design/core.py
SelectedKernel
Bases: Module
The wrapped kernel returned by a Selector.
This class wraps a pure kernel function and an extractor. When called with
State and Config objects, it uses the extractor to derive the arguments
for the kernel and then executes the kernel.
Attributes:
| Name | Type | Description |
|---|---|---|
raw |
Callable[..., R]
|
The original pure kernel function. |
Source code in src/researchlab/design/core.py
raw
property
Access the original pure function.
Returns:
| Type | Description |
|---|---|
Callable[..., R]
|
The underlying kernel function before it was wrapped. |
__call__(state, config)
Executes the kernel using data extracted from state and config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The current simulation state. |
required |
config
|
C
|
The experiment configuration. |
required |
Returns:
| Type | Description |
|---|---|
R
|
The result of the kernel function. |
Source code in src/researchlab/design/core.py
Selector
A decorator/higher-order function to bind State and Config to Kernel arguments.
A Selector defines how to extract arguments for a kernel from the broader
State and Config context. It decouples the kernel's signature from the
application's data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
extractor
|
Callable[[S, C], tuple[tuple[Any, ...], dict[str, Any]]]
|
A function that takes |
required |
Example
class MyState(State):
val: int
class MyConfig(Config):
factor: int
# Define how to extract arguments
def my_extractor(state, config):
return (state.val, config.factor), {}
# Wrap the kernel
@Selector(my_extractor)
def compute(val, factor):
return val * factor
state = MyState(val=10)
config = MyConfig(factor=2)
print(compute(state, config))
# 20
Source code in src/researchlab/design/core.py
__call__(func)
Wraps a kernel function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., R]
|
The pure kernel function to wrap. |
required |
Returns:
| Type | Description |
|---|---|
SelectedKernel[S, C, R]
|
A |
Source code in src/researchlab/design/core.py
__init__(extractor)
State
Bases: Module
Abstract base class for simulation state.
States must be JAX PyTrees, which is handled automatically by inheriting
from equinox.Module. This ensures compatibility with JAX transformations
like jax.jit, jax.grad, and jax.vmap.
Example
Source code in src/researchlab/design/core.py
Infrastructure
researchlab.design.infra
DataProvider
Bases: ABC
Abstract base class for providing data batches to the training loop.
Implementations should handle data loading, preprocessing, and batching.
The get_batch method can optionally depend on the current State,
allowing for curriculum learning or state-dependent sampling.
Source code in src/researchlab/design/infra.py
__next__()
abstractmethod
Returns the next batch of data from the iterator.
This method supports the iterator protocol.
Returns:
| Type | Description |
|---|---|
Any
|
A batch of data (structure depends on implementation). |
Raises:
| Type | Description |
|---|---|
StopIteration
|
If the data stream is exhausted. |
Source code in src/researchlab/design/infra.py
get_batch(state)
abstractmethod
Returns a batch of data, potentially dependent on the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The current simulation state. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
A batch of data. |
EquinoxPersister
Bases: Persister[S, C]
Concrete implementation using Equinox's serialization (Safetensors).
This persister saves the State object to a .eqx file using equinox.tree_serialise_leaves.
It assumes that the Config object is handled separately or reconstructible, as Equinox
serialization primarily handles array data (leaves).
Example
Source code in src/researchlab/design/infra.py
load(path, state_structure, config_structure)
Loads state from a .eqx file.
Note: This implementation currently only restores State.
It returns the passed config_structure as is, and assumes step=0
(since step is not intrinsically stored in the tree leaves unless added to State).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Path
|
Path to the .eqx file. |
required |
state_structure
|
S
|
Structure of the state to load into. |
required |
config_structure
|
C
|
Structure of the config (returned as is). |
required |
Returns:
| Type | Description |
|---|---|
tuple[S, C, int]
|
(loaded_state, config_structure, 0) |
Source code in src/researchlab/design/infra.py
save(state, config, step, path)
Saves state to a .eqx file.
Config is not saved by EquinoxPersister as it is usually static, but for completeness we could pickle it or verify if config is a PyTree.
For now, we only serialize the State using equinox, assuming Config is handled separately or is reconstructible.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The state to save. |
required |
config
|
C
|
The config (ignored). |
required |
step
|
int
|
The step (ignored/not saved in file). |
required |
path
|
Path
|
The path to save the .eqx file. |
required |
Source code in src/researchlab/design/infra.py
MLFlowTelemetry
Bases: Telemetry[S, C]
Concrete implementation of Telemetry using MLflow.
This class logs metrics and parameters to an active MLflow run. It assumes
that an MLflow run has already been started (e.g., by ExperimentTracker).
It automatically flattens nested parameter dictionaries and converts JAX scalar arrays to Python floats to ensure compatibility with MLflow.
Example
Source code in src/researchlab/design/infra.py
__init__()
log_metrics(metrics, step)
Logs scalar metrics to MLflow.
Converts JAX/NumPy scalars to Python floats before logging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, float]
|
Dictionary of metrics. |
required |
step
|
int
|
Current step. |
required |
Source code in src/researchlab/design/infra.py
log_params(params)
Logs parameters to MLflow, flattening nested dictionaries.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict[str, Any]
|
Dictionary of parameters. |
required |
Persister
Bases: ABC
Abstract base class for saving and loading checkpoints.
Persisters handle the serialization and deserialization of the simulation state and configuration.
Source code in src/researchlab/design/infra.py
load(path, state_structure, config_structure)
abstractmethod
Loads state and config from a checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Path
|
Path to the checkpoint file/directory. |
required |
state_structure
|
S
|
A structure (PyTree) matching the state to load. |
required |
config_structure
|
C
|
A structure matching the config to load. |
required |
Returns:
| Type | Description |
|---|---|
tuple[S, C, int]
|
A tuple containing |
Source code in src/researchlab/design/infra.py
save(state, config, step, path)
abstractmethod
Saves the current state and config to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The current simulation state. |
required |
config
|
C
|
The experiment configuration. |
required |
step
|
int
|
The current step count. |
required |
path
|
Path
|
The file path or directory to save to. |
required |
Source code in src/researchlab/design/infra.py
Telemetry
Bases: ABC
Abstract base class for logging metrics and parameters.
Telemetry components handle the reporting of experiment results. They should be non-intrusive and only observe the state/metrics.
Source code in src/researchlab/design/infra.py
log_metrics(metrics, step)
abstractmethod
Logs scalar metrics for a given step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, float]
|
A dictionary mapping metric names to values. |
required |
step
|
int
|
The current global step or iteration. |
required |
Source code in src/researchlab/design/infra.py
log_params(params)
abstractmethod
Logs hyperparameters or configuration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict[str, Any]
|
A dictionary of parameters to log. Nested dictionaries are supported. |
required |
Visualizer
Bases: ABC
Abstract base class for rendering and recording simulation states.
This abstraction supports three main functionalities:
1. Generating frames from states via render(state).
2. Displaying the last generated frame via show().
3. Recording/Streaming frames via the _record() hook (if is_recording=True).
Source code in src/researchlab/design/infra.py
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | |
__del__()
__init__(record=False, fps=None)
Initializes the visualizer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
record
|
bool
|
If True, enable recording mode to trigger the |
False
|
fps
|
None | int
|
Optional frames per second for recording and showing. |
None
|
Source code in src/researchlab/design/infra.py
close()
abstractmethod
render(state)
Generate a frame and trigger the _record() hook if recording is enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S
|
The state to render. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The rendered frame (e.g., image array). |
Source code in src/researchlab/design/infra.py
Orchestrator
researchlab.design.orchestrator
Loop
A generic training/simulation loop orchestrator.
The Loop class coordinates the interaction between the core components
(State, Config, Kernel) and the infrastructure components (Data, Telemetry,
Persistence, Visualization). It manages the execution flow, step counting,
and periodic tasks.
Example
Source code in src/researchlab/design/orchestrator.py
__init__(config, initial_state, step_fn, data_provider, telemetry=None, persister=None, visualizer=None)
Initializes the loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
C
|
Immutable hyperparameters. |
required |
initial_state
|
S
|
Initial simulation state. |
required |
step_fn
|
Kernel[tuple[S, C, Any], tuple[S, dict[str, Any]]]
|
A pure kernel function that takes (state, config, batch) and returns (new_state, metrics). |
required |
data_provider
|
DataProvider[S]
|
Source of data. |
required |
telemetry
|
Telemetry[S, C] | None
|
Optional logger. |
None
|
persister
|
Persister[S, C] | None
|
Optional checkpointer. |
None
|
visualizer
|
Visualizer[S] | None
|
Optional renderer. |
None
|
Source code in src/researchlab/design/orchestrator.py
run(num_steps)
Run the loop for a specified number of steps.
Source code in src/researchlab/design/orchestrator.py
Utilities
researchlab.design.utils
flatten_config(config, prefix='', separator='.')
Flattens a Config object (Pydantic model) into a dictionary.
It uses model_dump() to get the configuration as a dictionary and then
flattens any nested dictionaries.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Config
|
The |
required |
prefix
|
str
|
An optional string to prepend to keys. |
''
|
separator
|
str
|
The separator for nested keys. Defaults to ".". |
'.'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A flattened dictionary representation of the configuration. |
Example
Source code in src/researchlab/design/utils.py
flatten_pytree(tree, prefix='', separator='.')
Flattens a JAX PyTree into a dictionary with dot-notation keys.
This utility traverses the PyTree structure and produces a flat dictionary
where keys represent the path to each leaf node using dot notation. This
is particularly useful for logging complex nested State objects to
flat metric logging systems like MLflow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Any
|
The PyTree object to flatten (e.g., |
required |
prefix
|
str
|
An optional string to prepend to all generated keys. |
''
|
separator
|
str
|
The string used to separate path components. Defaults to ".". |
'.'
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary mapping flattened string keys to the leaf values of the tree. |
Source code in src/researchlab/design/utils.py
unflatten_config(flat_dict, config_cls, separator='.')
Reconstructs a Config object from a flattened dictionary.
This function unflattens a dictionary into a nested structure and then
validates it against the Pydantic model config_cls.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
flat_dict
|
dict[str, Any]
|
The flattened dictionary containing configuration values. |
required |
config_cls
|
type[C]
|
The class of the |
required |
separator
|
str
|
The separator used in the flattened keys. Defaults to ".". |
'.'
|
Returns:
| Type | Description |
|---|---|
C
|
An instance of |
Raises:
| Type | Description |
|---|---|
ValueError
|
If there is a key conflict (e.g., 'a' is both a value and a container). |
ValidationError
|
If the reconstructed data fails Pydantic validation. |
Source code in src/researchlab/design/utils.py
unflatten_pytree(flat_dict, structure, separator='.')
Populates a PyTree structure with values from a flattened dictionary.
This function attempts to reconstruct the values of a PyTree by looking up
keys in a flattened dictionary that correspond to the paths in the structure.
Leaves not found in the dictionary are kept as is from the structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
flat_dict
|
dict[str, Any]
|
A dictionary containing flattened keys and values. |
required |
structure
|
T
|
A PyTree instance defining the target structure. |
required |
separator
|
str
|
The separator used in the flattened keys. Defaults to ".". |
'.'
|
Returns:
| Type | Description |
|---|---|
T
|
A new PyTree with the same structure as |
T
|
updated from |