Skip to content

API Reference

Tracking Module

ExperimentTracker

researchlab.tracking.tracker.ExperimentTracker

A context manager for tracking research experiments with Git state.

This class automates the process of starting an MLflow run, capturing the current Git state (base commit and diff), and logging artifacts and parameters.

Example
from researchlab import ExperimentTracker

with ExperimentTracker(experiment_name="my_project") as tracker:
    tracker.log_config("config.yaml")
    # Your training logic here
    print(f"Run ID: {tracker.run_name}")

Attributes:

Name Type Description
run_name str

The unique, readable ID for this run.

active_run Optional[Run]

The currently active MLflow run.

Source code in src/researchlab/tracking/tracker.py
class ExperimentTracker:
    """A context manager for tracking research experiments with Git state.

    This class automates the process of starting an MLflow run, capturing the
    current Git state (base commit and diff), and logging artifacts and
    parameters.

    Example:
        ```python
        from researchlab import ExperimentTracker

        with ExperimentTracker(experiment_name="my_project") as tracker:
            tracker.log_config("config.yaml")
            # Your training logic here
            print(f"Run ID: {tracker.run_name}")
        ```

    Attributes:
        run_name (str): The unique, readable ID for this run.
        active_run (Optional[Run]): The currently active MLflow run.
    """

    def __init__(
        self,
        experiment_name: str,
        run_name: str | None = None,
        tracking_uri: str | None = None,
    ):
        """Initializes the ExperimentTracker.

        Args:
            experiment_name: The name of the MLflow experiment.
            run_name: Optional custom name for the run. If not provided,
                a readable ID is generated automatically.
            tracking_uri: Optional MLflow tracking URI to use.
        """
        if tracking_uri:
            mlflow.set_tracking_uri(tracking_uri)

        mlflow.set_experiment(experiment_name)

        self.run_name = run_name or generate_run_id()
        self.active_run = None

    def __enter__(self) -> "ExperimentTracker":
        """Starts the MLflow run and captures the Git state.

        Returns:
            ExperimentTracker: The instance of the tracker.
        """
        self.active_run = mlflow.start_run(run_name=self.run_name)

        # Capture and log Git state
        git_state = get_git_state()
        mlflow.set_tag("rlab.base_commit", git_state["base_commit"])
        mlflow.set_tag("rlab.run_id", self.run_name)

        # Save patch file as an artifact
        if git_state["patch"]:
            patch_path = Path("run.patch")
            with patch_path.open("w") as f:
                f.write(git_state["patch"])
            mlflow.log_artifact(str(patch_path))
            patch_path.unlink()

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Ends the MLflow run."""
        mlflow.end_run()

    def log_config(self, config_path: str):
        """Logs a YAML configuration file as parameters and an artifact.

        Recursively parses the YAML file and logs each entry as an MLflow
        parameter. Also uploads the file itself as an artifact.

        Args:
            config_path: Path to the YAML configuration file.
        """
        path = Path(config_path)
        if not path.exists():
            print(f"Warning: Config file {config_path} not found.")
            return

        with path.open() as f:
            try:
                config = yaml.safe_load(f)
                if isinstance(config, dict):
                    log_flattened_params(config)
                mlflow.log_artifact(str(path))
            except yaml.YAMLError as e:
                print(f"Error parsing YAML config: {e}")

__enter__()

Starts the MLflow run and captures the Git state.

Returns:

Name Type Description
ExperimentTracker ExperimentTracker

The instance of the tracker.

Source code in src/researchlab/tracking/tracker.py
def __enter__(self) -> "ExperimentTracker":
    """Starts the MLflow run and captures the Git state.

    Returns:
        ExperimentTracker: The instance of the tracker.
    """
    self.active_run = mlflow.start_run(run_name=self.run_name)

    # Capture and log Git state
    git_state = get_git_state()
    mlflow.set_tag("rlab.base_commit", git_state["base_commit"])
    mlflow.set_tag("rlab.run_id", self.run_name)

    # Save patch file as an artifact
    if git_state["patch"]:
        patch_path = Path("run.patch")
        with patch_path.open("w") as f:
            f.write(git_state["patch"])
        mlflow.log_artifact(str(patch_path))
        patch_path.unlink()

    return self

__exit__(exc_type, exc_val, exc_tb)

Ends the MLflow run.

Source code in src/researchlab/tracking/tracker.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Ends the MLflow run."""
    mlflow.end_run()

__init__(experiment_name, run_name=None, tracking_uri=None)

Initializes the ExperimentTracker.

Parameters:

Name Type Description Default
experiment_name str

The name of the MLflow experiment.

required
run_name str | None

Optional custom name for the run. If not provided, a readable ID is generated automatically.

None
tracking_uri str | None

Optional MLflow tracking URI to use.

None
Source code in src/researchlab/tracking/tracker.py
def __init__(
    self,
    experiment_name: str,
    run_name: str | None = None,
    tracking_uri: str | None = None,
):
    """Initializes the ExperimentTracker.

    Args:
        experiment_name: The name of the MLflow experiment.
        run_name: Optional custom name for the run. If not provided,
            a readable ID is generated automatically.
        tracking_uri: Optional MLflow tracking URI to use.
    """
    if tracking_uri:
        mlflow.set_tracking_uri(tracking_uri)

    mlflow.set_experiment(experiment_name)

    self.run_name = run_name or generate_run_id()
    self.active_run = None

log_config(config_path)

Logs a YAML configuration file as parameters and an artifact.

Recursively parses the YAML file and logs each entry as an MLflow parameter. Also uploads the file itself as an artifact.

Parameters:

Name Type Description Default
config_path str

Path to the YAML configuration file.

required
Source code in src/researchlab/tracking/tracker.py
def log_config(self, config_path: str):
    """Logs a YAML configuration file as parameters and an artifact.

    Recursively parses the YAML file and logs each entry as an MLflow
    parameter. Also uploads the file itself as an artifact.

    Args:
        config_path: Path to the YAML configuration file.
    """
    path = Path(config_path)
    if not path.exists():
        print(f"Warning: Config file {config_path} not found.")
        return

    with path.open() as f:
        try:
            config = yaml.safe_load(f)
            if isinstance(config, dict):
                log_flattened_params(config)
            mlflow.log_artifact(str(path))
        except yaml.YAMLError as e:
            print(f"Error parsing YAML config: {e}")

Utilities

researchlab.tracking.utils

find_run_by_rlab_id(run_id)

Finds an MLflow run by the custom rlab.run_id tag.

Parameters:

Name Type Description Default
run_id str

The readable rlab run ID to search for.

required

Returns:

Type Description
Run | None

Optional[Run]: The MLflow Run object if found, else None.

Source code in src/researchlab/tracking/utils.py
def find_run_by_rlab_id(run_id: str) -> Run | None:
    """Finds an MLflow run by the custom rlab.run_id tag.

    Args:
        run_id: The readable rlab run ID to search for.

    Returns:
        Optional[Run]: The MLflow Run object if found, else None.
    """
    runs = mlflow.search_runs(
        filter_string=f"tags.'rlab.run_id' = '{run_id}'",
        run_view_type=mlflow.entities.ViewType.ACTIVE_ONLY,
        output_format="list",
    )
    return runs[0] if runs else None

generate_run_id()

Generates a unique and readable run ID.

Returns:

Name Type Description
str str

A string in the format YYYY-MM-DD_slug (e.g., 2026-02-15_radiant-octopus).

Source code in src/researchlab/tracking/utils.py
def generate_run_id() -> str:
    """Generates a unique and readable run ID.

    Returns:
        str: A string in the format YYYY-MM-DD_slug (e.g., 2026-02-15_radiant-octopus).
    """
    date_str = datetime.datetime.now().strftime("%Y-%m-%d")
    slug = coolname.generate_slug(2)
    return f"{date_str}_{slug}"

get_git_state(repo_path='.')

Captures the current Git state including base commit and patch.

This captures staged, unstaged, and untracked changes by temporarily using git's "intent-to-add" feature. It restores the untracked state after generating the patch.

Parameters:

Name Type Description Default
repo_path str

Path to the git repository. Defaults to ".".

'.'

Returns:

Type Description
dict[str, Any]

dict[str, Any]: A dictionary containing: - base_commit (str): The HEX SHA of the current HEAD. - patch (str): The git diff output. - is_dirty (bool): True if there are any changes (including untracked).

Source code in src/researchlab/tracking/utils.py
def get_git_state(repo_path: str = ".") -> dict[str, Any]:
    """Captures the current Git state including base commit and patch.

    This captures staged, unstaged, and untracked changes by temporarily
    using git's "intent-to-add" feature. It restores the untracked state
    after generating the patch.

    Args:
        repo_path: Path to the git repository. Defaults to ".".

    Returns:
        dict[str, Any]: A dictionary containing:
            - base_commit (str): The HEX SHA of the current HEAD.
            - patch (str): The git diff output.
            - is_dirty (bool): True if there are any changes (including untracked).
    """
    repo = git.Repo(repo_path, search_parent_directories=True)

    # Base commit (HEAD)
    base_commit = repo.head.commit.hexsha

    # Identify untracked files to restore them later
    untracked_files = repo.untracked_files

    try:
        # To include untracked files in the patch, we use 'git add -N' (intent-to-add)
        # which makes them appear in the diff without actually staging their content.
        # This DOES NOT modify the actual file content on disk.
        for f in untracked_files:
            repo.git.add(f, intent_to_add=True)

        # Patch of all changes (staged, unstaged, and untracked-via-intent-to-add)
        patch = repo.git.diff(repo.head.commit)

        # Check if dirty (including untracked files)
        is_dirty = repo.is_dirty(untracked_files=True)

    finally:
        # Restore the untracked state in the index for files we touched.
        # repo.git.reset(f) is equivalent to 'git reset <file>', which only
        # affects the index and does NOT revert working tree changes.
        for f in untracked_files:
            with contextlib.suppress(git.exc.GitCommandError):
                repo.git.reset(f)

    return {"base_commit": base_commit, "patch": patch, "is_dirty": is_dirty}

log_flattened_params(d, prefix='')

Recursively logs a dictionary as flattened MLflow parameters.

Parameters:

Name Type Description Default
d dict[str, Any]

The dictionary to log.

required
prefix str

Optional prefix for parameter keys (used for recursion).

''
Source code in src/researchlab/tracking/utils.py
def log_flattened_params(d: dict[str, Any], prefix: str = "") -> None:
    """Recursively logs a dictionary as flattened MLflow parameters.

    Args:
        d: The dictionary to log.
        prefix: Optional prefix for parameter keys (used for recursion).
    """
    for k, v in d.items():
        key = f"{prefix}{k}" if prefix else k
        if isinstance(v, dict):
            log_flattened_params(v, prefix=f"{key}.")
        else:
            mlflow.log_param(key, v)