Skip to content

Train

hyperbench.train

LaTexTableLogger

Bases: Logger

A Lightning Logger that accumulates metrics and writes a LaTex comparison table.

Multiple instances (one per model) share a class-level store keyed by experiment_name. Every time finalize() is called (after fit() or test() for each model), the current state of all accumulated metrics is written to a LaTex file. The last model to finalize produces the most complete table.

This means the file is progressively updated as models finish training/testing, so you can open it mid-run to see partial results.

Parameters:

Name Type Description Default
save_dir str | Path

Base directory where the comparison/ subfolder will be created.

required
model_name str

The model's full name (e.g., "mlp:mean").

required
experiment_name str

Shared key that groups all models in the same experiment.

required
precision int

Decimal places for metric values in the table.

4
Source code in hyperbench/train/latex_logger.py
class LaTexTableLogger(Logger):
    # TODO: settings has to be configurable in Trainer

    """A Lightning Logger that accumulates metrics and writes a LaTex comparison table.

    Multiple instances (one per model) share a class-level store keyed by experiment_name.
    Every time finalize() is called (after fit() or test() for each model), the current
    state of all accumulated metrics is written to a LaTex file. The last model to
    finalize produces the most complete table.

    This means the file is progressively updated as models finish training/testing,
    so you can open it mid-run to see partial results.

    Args:
        save_dir: Base directory where the comparison/ subfolder will be created.
        model_name: The model's full name (e.g., "mlp:mean").
        experiment_name: Shared key that groups all models in the same experiment.
        precision: Decimal places for metric values in the table.
    """

    # Class-level shared store: {experiment_name: {model_name: {metric_name: value}}}
    __shared_stores: ClassVar[dict[str, dict[str, dict[str, Any]]]] = {}

    def __init__(
        self,
        save_dir: str | Path,
        model_name: str,
        experiment_name: str,
        precision: int = 4,
        options: LaTexTableConfig | None = None,
    ) -> None:
        super().__init__()
        self.__save_dir = save_dir
        self.__model_name = model_name
        self.__experiment_name = experiment_name
        self.__precision = precision

        default_empty_options: LaTexTableConfig = {}
        self.__options = options if options is not None else default_empty_options

        if experiment_name not in self.__shared_stores:
            self.__shared_stores[experiment_name] = {}

    @property
    def name(self) -> str:
        return "LaTexTableLogger"

    @property
    def version(self) -> str | int:
        return self.__model_name

    @property
    def store(self) -> dict[str, dict[str, Any]]:
        """Access the shared store for the current experiment."""
        return dict(self.__shared_stores.get(self.__experiment_name, {}))

    @property
    def save_dir(self) -> str | Path:
        return self.__save_dir

    @property
    def experiment_name(self) -> str | Path:
        return self.__experiment_name

    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:
        """Accumulate metrics for this model. Called by Lightning on every log step.

        Keeps only the latest value for each metric name. For example, if
        "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.
        """
        store = self.__shared_stores[self.__experiment_name]
        if self.__model_name not in store:
            store[self.__model_name] = {}
        store[self.__model_name].update(metrics)

    def log_hyperparams(self, params: Any) -> None:
        pass

    def finalize(self, status: str) -> None:
        """Write the LaTex comparison table with all accumulated metrics so far.

        Called by Lightning after fit() and after test() for each model. Since models
        train/test sequentially, each finalize() overwrites the file with all data
        accumulated up to that point. The file grows more complete over time.
        """
        test_results, train_results, val_results = self.__split_results()
        if not test_results and not train_results and not val_results:
            return

        comparison_dir = Path(self.__save_dir) / "comparison"
        table_caption_opt = self.__options.get("table_caption")
        sort_by_opt = self.__options.get("sort_by")
        border_opt = self.__options.get("border")

        table_caption = table_caption_opt if isinstance(table_caption_opt, str) else None
        sort_by = sort_by_opt if isinstance(sort_by_opt, list) and sort_by_opt else ["asc"]
        border = border_opt if isinstance(border_opt, bool) else True

        self.__save_comparison_tables(
            test_results=test_results,
            save_dir=comparison_dir,
            train_results=train_results or None,
            val_results=val_results or None,
            precision=self.__precision,
            table_caption=table_caption,
            sort_by=sort_by,
            border=border,
        )
        self.__save_comparison_tables(
            test_results=test_results,
            save_dir=comparison_dir,
            train_results=None,
            val_results=None,
            precision=self.__precision,
            filename="test.tex",
            table_caption=table_caption,
            sort_by=sort_by,
            border=border,
        )

    def __split_results(
        self,
    ) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
        """
        Split all accumulated metrics into test vs train/val groups.

        Metrics are classified by their name prefix:
        - "test_*"  --> test_results
        - "train_*" --> train_results
        - "val_*" --> val_results
        - anything else (e.g., "epoch") --> ignored
        """
        store = self.__shared_stores.get(self.__experiment_name, {})
        test_results: dict[str, dict[str, Any]] = {}
        train_results: dict[str, dict[str, Any]] = {}
        val_results: dict[str, dict[str, Any]] = {}

        for model_name, metrics in store.items():
            test_metrics: dict[str, Any] = {}
            train_metrics: dict[str, Any] = {}
            val_metrics: dict[str, Any] = {}

            for metric_name, value in metrics.items():
                if metric_name.startswith("test"):
                    test_metrics[metric_name] = value
                elif metric_name.startswith("train"):
                    train_metrics[metric_name] = value
                elif metric_name.startswith("val"):
                    val_metrics[metric_name] = value

            if test_metrics:
                test_results[model_name] = test_metrics
            if train_metrics:
                train_results[model_name] = train_metrics
            if val_metrics:
                val_results[model_name] = val_metrics

        return test_results, train_results, val_results

    def clear(self, experiment_name: str) -> None:
        """Remove accumulated data for an experiment."""
        self.__shared_stores.pop(experiment_name, None)

    def __build_comparison_table(
        self,
        sections_data: list[tuple[str, Mapping[str, Mapping[str, Any]]]],
        precision: int = 4,
        table_caption: str | None = None,
        sort_by: list[str] | None = None,
        border: bool = True,
    ) -> str:
        if not sections_data:
            return ""

        # One tabular must have fixed column count; use max needed across sections.
        max_metrics = max(len({m for mm in rs.values() for m in mm}) for _, rs in sections_data)
        total_cols = 1 + max_metrics
        if border:
            col_spec = "|".join(["l", *(["c"] * (total_cols - 1))])
        else:
            col_spec = "l" + "c" * (total_cols - 1)

        lines: list[str] = [
            rf"\begin{{tabular}}{{{col_spec}}}",
            r"\hline" if border else r"\toprule",
        ]

        for title, results in sections_data:
            lines.extend(
                self.__get_section_lines(title, results, total_cols, precision, sort_by, border)
            )

        # Replace the last section-ending rule with a final closing \hline.
        last_rule = r"\hline" if border else r"\midrule"
        final_rule = r"\hline"
        lines and lines[-1] == last_rule and lines.pop()
        (lines and lines[-1] == final_rule) or lines.append(final_rule)

        lines.append(r"\end{tabular}")
        table_lines: list[str] = [r"\begin{table}[htbp]", r"\centering"]

        if table_caption:
            table_lines.append(rf"\caption{{{self.__escape(table_caption)}}}")

        table_lines.extend(lines)
        table_lines.append(r"\end{table}")
        return "\n".join(table_lines) + "\n"

    def __get_section_lines(
        self,
        title: str,
        results: Mapping[str, Mapping[str, Any]],
        total_cols: int,
        precision: int,
        sort_by: list[str] | None,
        border: bool,
    ) -> list[str]:
        metrics = sorted({m for mm in results.values() for m in mm})
        sort_orders = sort_by or ["asc"]

        normalized_orders: list[str] = []
        for order in sort_orders:
            normalized = order.lower()
            if normalized not in ("asc", "des"):
                raise ValueError(f"Invalid sort_by value: {order}. Use 'asc' or 'des'.")
            normalized_orders.append(normalized)

        metric_sort: dict[str, str] = {}
        for idx, metric in enumerate(metrics):
            metric_sort[metric] = (
                normalized_orders[idx] if idx < len(normalized_orders) else normalized_orders[-1]
            )

        metric_bounds = collect_metric_bounds(results, metrics)

        best_by_metric: dict[str, float] = {}

        for metric in metrics:
            vals = [
                metric_value
                for model_metrics in results.values()
                for metric_name, metric_value in model_metrics.items()
                if metric_name == metric and isinstance(metric_value, (int, float))
            ]
            if vals:
                best_by_metric[metric] = (
                    min(vals) if metric_sort.get(metric, "asc") == "asc" else max(vals)
                )

        header_cells = ["Model", *[self.__escape(metric) for metric in metrics]]
        while len(header_cells) < total_cols:
            header_cells.append("")

        lines = [
            r"\addlinespace[3pt]",
            rf"\multicolumn{{{total_cols}}}{{c}}{{\textbf{{{self.__escape(title)}}}}} \\",
            r"\midrule",
            " & ".join(header_cells) + r" \\",
        ]

        for model_name in sorted(results):
            model_metrics = results[model_name]
            row = [self.__escape(model_name)]

            for metric in metrics:
                value = model_metrics.get(metric)
                if isinstance(value, (int, float)):
                    formatted = f"{value:.{precision}f}"
                    best = best_by_metric.get(metric)
                    if best is not None and value == best:
                        formatted = rf"\underline{{{formatted}}}"

                    row.append(
                        colorize_metric_value(
                            metric=metric,
                            value=float(value),
                            text=formatted,
                            metric_bounds=metric_bounds,
                            sort_order=metric_sort.get(metric, "asc"),
                        )
                    )
                else:
                    row.append("-")

            while len(row) < total_cols:
                row.append("")
            lines.append(" & ".join(row) + r" \\")

        lines.append(r"\hline" if border else r"\midrule")
        return lines

    def __escape(self, value: str) -> str:
        return (
            value.replace("\\", "\\textbackslash{}")
            .replace("&", "\\&")
            .replace("%", "\\%")
            .replace("$", "\\$")
            .replace("#", "\\#")
            .replace("_", "\\_")
            .replace("{", "\\{")
            .replace("}", "\\}")
            .replace("~", "\\textasciitilde{}")
            .replace("^", "\\textasciicircum{}")
        )

    def __save_comparison_tables(
        self,
        test_results: Mapping[str, Mapping[str, Any]],
        save_dir: str | Path,
        train_results: Mapping[str, Mapping[str, Any]] | None = None,
        val_results: Mapping[str, Mapping[str, Any]] | None = None,
        filename: str = "overall.tex",
        precision: int = 4,
        table_caption: str | None = None,
        sort_by: list[str] | None = None,
        border: bool = True,
    ) -> Path:
        sections_data: list[tuple[str, Mapping[str, Mapping[str, Any]]]] = []
        if test_results:
            sections_data.append(("Test Results", test_results))
        if train_results:
            sections_data.append(("Train Results", train_results))
        if val_results:
            sections_data.append(("Val Results", val_results))

        content = self.__build_comparison_table(
            sections_data=sections_data,
            precision=precision,
            border=border,
            table_caption=table_caption,
            sort_by=sort_by,
        )

        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        file_path = save_path / filename
        if content != "":
            content = (
                "% Requires: \\usepackage{booktabs}\n"
                "% Requires: \\usepackage[table]{xcolor}\n" + content
            )
        file_path.write_text(content)
        return file_path

store property

Access the shared store for the current experiment.

log_metrics(metrics, step=None)

Accumulate metrics for this model. Called by Lightning on every log step.

Keeps only the latest value for each metric name. For example, if "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.

Source code in hyperbench/train/latex_logger.py
def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:
    """Accumulate metrics for this model. Called by Lightning on every log step.

    Keeps only the latest value for each metric name. For example, if
    "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.
    """
    store = self.__shared_stores[self.__experiment_name]
    if self.__model_name not in store:
        store[self.__model_name] = {}
    store[self.__model_name].update(metrics)

finalize(status)

Write the LaTex comparison table with all accumulated metrics so far.

Called by Lightning after fit() and after test() for each model. Since models train/test sequentially, each finalize() overwrites the file with all data accumulated up to that point. The file grows more complete over time.

Source code in hyperbench/train/latex_logger.py
def finalize(self, status: str) -> None:
    """Write the LaTex comparison table with all accumulated metrics so far.

    Called by Lightning after fit() and after test() for each model. Since models
    train/test sequentially, each finalize() overwrites the file with all data
    accumulated up to that point. The file grows more complete over time.
    """
    test_results, train_results, val_results = self.__split_results()
    if not test_results and not train_results and not val_results:
        return

    comparison_dir = Path(self.__save_dir) / "comparison"
    table_caption_opt = self.__options.get("table_caption")
    sort_by_opt = self.__options.get("sort_by")
    border_opt = self.__options.get("border")

    table_caption = table_caption_opt if isinstance(table_caption_opt, str) else None
    sort_by = sort_by_opt if isinstance(sort_by_opt, list) and sort_by_opt else ["asc"]
    border = border_opt if isinstance(border_opt, bool) else True

    self.__save_comparison_tables(
        test_results=test_results,
        save_dir=comparison_dir,
        train_results=train_results or None,
        val_results=val_results or None,
        precision=self.__precision,
        table_caption=table_caption,
        sort_by=sort_by,
        border=border,
    )
    self.__save_comparison_tables(
        test_results=test_results,
        save_dir=comparison_dir,
        train_results=None,
        val_results=None,
        precision=self.__precision,
        filename="test.tex",
        table_caption=table_caption,
        sort_by=sort_by,
        border=border,
    )

__split_results()

Split all accumulated metrics into test vs train/val groups.

Metrics are classified by their name prefix: - "test_" → test_results - "train_" → train_results - "val_*" → val_results - anything else (e.g., "epoch") → ignored

Source code in hyperbench/train/latex_logger.py
def __split_results(
    self,
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
    """
    Split all accumulated metrics into test vs train/val groups.

    Metrics are classified by their name prefix:
    - "test_*"  --> test_results
    - "train_*" --> train_results
    - "val_*" --> val_results
    - anything else (e.g., "epoch") --> ignored
    """
    store = self.__shared_stores.get(self.__experiment_name, {})
    test_results: dict[str, dict[str, Any]] = {}
    train_results: dict[str, dict[str, Any]] = {}
    val_results: dict[str, dict[str, Any]] = {}

    for model_name, metrics in store.items():
        test_metrics: dict[str, Any] = {}
        train_metrics: dict[str, Any] = {}
        val_metrics: dict[str, Any] = {}

        for metric_name, value in metrics.items():
            if metric_name.startswith("test"):
                test_metrics[metric_name] = value
            elif metric_name.startswith("train"):
                train_metrics[metric_name] = value
            elif metric_name.startswith("val"):
                val_metrics[metric_name] = value

        if test_metrics:
            test_results[model_name] = test_metrics
        if train_metrics:
            train_results[model_name] = train_metrics
        if val_metrics:
            val_results[model_name] = val_metrics

    return test_results, train_results, val_results

clear(experiment_name)

Remove accumulated data for an experiment.

Source code in hyperbench/train/latex_logger.py
def clear(self, experiment_name: str) -> None:
    """Remove accumulated data for an experiment."""
    self.__shared_stores.pop(experiment_name, None)

MarkdownTableLogger

Bases: Logger

A Lightning Logger that accumulates metrics and writes a markdown comparison table.

Multiple instances (one per model) share a class-level store keyed by experiment_name. Every time finalize() is called (after fit() or test() for each model), the current state of all accumulated metrics is written to a markdown file. The last model to finalize produces the most complete table.

This means the file is progressively updated as models finish training/testing, so partial results are available while running.

Parameters:

Name Type Description Default
save_dir str | Path

Base directory where the comparison/ subfolder will be created.

required
model_name str

The model's full name (e.g., "mlp:mean").

required
experiment_name str

Shared key that groups all models in the same experiment.

required
precision int

Decimal places for metric values in the table.

4
Source code in hyperbench/train/markdown_logger.py
class MarkdownTableLogger(Logger):
    """A Lightning Logger that accumulates metrics and writes a markdown comparison table.

    Multiple instances (one per model) share a class-level store keyed by experiment_name.
    Every time finalize() is called (after fit() or test() for each model), the current
    state of all accumulated metrics is written to a markdown file. The last model to
    finalize produces the most complete table.

    This means the file is progressively updated as models finish training/testing,
    so partial results are available while running.

    Args:
        save_dir: Base directory where the comparison/ subfolder will be created.
        model_name: The model's full name (e.g., "mlp:mean").
        experiment_name: Shared key that groups all models in the same experiment.
        precision: Decimal places for metric values in the table.
    """

    # Class-level shared store: {experiment_name: {model_name: {metric_name: value}}}
    __shared_stores: ClassVar[dict[str, dict[str, dict[str, float]]]] = {}

    def __init__(
        self,
        save_dir: str | Path,
        model_name: str,
        experiment_name: str,
        precision: int = 4,
    ) -> None:
        super().__init__()
        self.__save_dir = save_dir
        self.__model_name = model_name
        self.__experiment_name = experiment_name
        self.__precision = precision

        if experiment_name not in self.__shared_stores:
            self.__shared_stores[experiment_name] = {}

    @property
    def name(self) -> str:
        return "MarkdownTableLogger"

    @property
    def version(self) -> str | int:
        return self.__model_name

    @property
    def store(self) -> dict[str, dict[str, float]]:
        """Access the shared store for the current experiment."""
        return copy.deepcopy(self.__shared_stores.get(self.__experiment_name, {}))

    @property
    def save_dir(self) -> str | Path:
        return self.__save_dir

    @property
    def experiment_name(self) -> str | Path:
        return self.__experiment_name

    def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
        """Accumulate metrics for this model. Called by Lightning on every log step.

        Keeps only the latest value for each metric name. For example, if
        "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.
        """
        store = self.__shared_stores[self.__experiment_name]
        if self.__model_name not in store:
            store[self.__model_name] = {}
        store[self.__model_name].update(metrics)

    def log_hyperparams(self, params: Any) -> None:
        pass

    def finalize(self, status: str) -> None:
        """Write the markdown comparison table with all accumulated metrics so far.

        Called by Lightning after fit() and after test() for each model. Since models
        train/test sequentially, each finalize() overwrites the file with all data
        accumulated up to that point. The file grows more complete over time.

        Args:
            status: The stage that just completed, e.g., "fit" or "test".
        """
        test_results, train_results, val_results = self.__split_results()

        if not test_results and not train_results and not val_results:
            return

        comparison_dir = Path(self.__save_dir) / "comparison"
        self.__save_comparison_tables(
            test_results=test_results,
            save_dir=comparison_dir,
            train_results=train_results or None,
            val_results=val_results or None,
            precision=self.__precision,
        )

    def __split_results(
        self,
    ) -> tuple[
        dict[str, dict[str, float]], dict[str, dict[str, float]], dict[str, dict[str, float]]
    ]:
        """Split all accumulated metrics into test vs train/val groups.

        Metrics are classified by their name prefix:
        - "test*"  --> test_results
        - "train*" --> train_results
        - "val*" --> val_results
        - anything else (e.g., "epoch") --> ignored

        Returns:
            Tuple of (test_results, train_results, val_results), where each is a dict
            mapping model names to their respective metric dicts. Models with no metrics
            in a category are excluded from that category's dict.
        """
        store = self.__shared_stores.get(self.__experiment_name, {})
        test_results: dict[str, dict[str, float]] = {}
        train_results: dict[str, dict[str, float]] = {}
        val_results: dict[str, dict[str, float]] = {}

        for model_name, metrics in store.items():
            test_metrics: dict[str, float] = {}
            train_metrics: dict[str, float] = {}
            val_metrics: dict[str, float] = {}

            for metric_name, value in metrics.items():
                if metric_name.startswith("test"):
                    test_metrics[metric_name] = value
                elif metric_name.startswith("train"):
                    train_metrics[metric_name] = value
                elif metric_name.startswith("val"):
                    val_metrics[metric_name] = value

            if test_metrics:
                test_results[model_name] = test_metrics
            if train_metrics:
                train_results[model_name] = train_metrics
            if val_metrics:
                val_results[model_name] = val_metrics

        return test_results, train_results, val_results

    def clear(self, experiment_name: str) -> None:
        """Remove accumulated data for an experiment.
        Args:
            experiment_name: The experiment name whose data should be cleared.

        """
        self.__shared_stores.pop(experiment_name, None)

    def __build_comparison_table(
        self,
        results: Mapping[str, Mapping[str, float]],
        precision: int = 4,
    ) -> str:
        """Build a markdown comparison table from model results.

        Examples:
            Input:

            ```python
            {
            "mlp:mean": {"test_auc": 0.85, "test_loss": 0.32},
            "gat:default": {"test_auc": 0.82},
            }
            ```

            Output:

            ```md
            | Model | test_auc | test_loss |
            | --- | --- | --- |
            | gat:default | 0.8200 | - |
            | mlp:mean | 0.8500 | 0.3200 |
            ```

        Args:
            results: Mapping of model names to metric dictionaries.
            precision: Number of decimal places for numeric metric values.

        Returns:
            Markdown table string. Returns an empty string if ``results`` is empty.


        """
        if not results:
            return ""

        # Collect all unique metric names across all models, sorted for determinism
        all_metrics = sorted(
            {metric for model_metrics in results.values() for metric in model_metrics}
        )

        # Build header row
        header = "| Model | " + " | ".join(all_metrics) + " |"
        separator = "| --- | " + " | ".join("---" for _ in all_metrics) + " |"

        # Build one row per model, sorted by model name for determinism
        rows = []
        for model_name in sorted(results):
            model_metrics = results[model_name]
            cells = []
            for metric in all_metrics:
                value = model_metrics.get(metric)
                if isinstance(value, (int, float)):
                    cells.append(f"{value:.{precision}f}")
                else:
                    cells.append("-")
            rows.append(f"| {model_name} | " + " | ".join(cells) + " |")

        return "\n".join([header, separator, *rows])

    def __save_comparison_tables(
        self,
        test_results: Mapping[str, Mapping[str, float]],
        save_dir: str | Path,
        train_results: Mapping[str, Mapping[str, float]] | None = None,
        val_results: Mapping[str, Mapping[str, float]] | None = None,
        filename: str = "results.md",
        precision: int = 4,
    ) -> Path:
        """Build and save markdown comparison tables to a file.

        Writes two sections:
        - "## Test Results" with the test metrics table
        - "## Train/Val Results" with the train/val metrics table (if provided)

        Args:
            test_results: Dict from test_all(), mapping model names to test metric dicts.
            save_dir: Directory where the markdown file will be written.
            train_results: Optional dict mapping model names to train metric dicts.
            val_results: Optional dict mapping model names to val metric dicts.
            filename: Name of the output file.
            precision: Decimal places for metric values.

        Returns:
            Path to the written file.
        """
        sections = []

        # Test results table
        test_table = self.__build_comparison_table(test_results, precision)
        if test_table:
            sections.append(f"## Test Results\n\n{test_table}")

        # Train/val results table
        if train_results or val_results:
            if train_results:
                train_table = self.__build_comparison_table(train_results, precision)
                sections.append(f"## Train Results\n\n{train_table}")
            if val_results:
                val_table = self.__build_comparison_table(val_results, precision)
                sections.append(f"## Val Results\n\n{val_table}")

        content = "\n\n".join(sections) + "\n" if sections else ""

        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        file_path = save_path / filename
        file_path.write_text(content)

        return file_path

store property

Access the shared store for the current experiment.

log_metrics(metrics, step=None)

Accumulate metrics for this model. Called by Lightning on every log step.

Keeps only the latest value for each metric name. For example, if "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.

Source code in hyperbench/train/markdown_logger.py
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
    """Accumulate metrics for this model. Called by Lightning on every log step.

    Keeps only the latest value for each metric name. For example, if
    "val_auc" is logged at step 10 and step 20, only the step 20 value is kept.
    """
    store = self.__shared_stores[self.__experiment_name]
    if self.__model_name not in store:
        store[self.__model_name] = {}
    store[self.__model_name].update(metrics)

finalize(status)

Write the markdown comparison table with all accumulated metrics so far.

Called by Lightning after fit() and after test() for each model. Since models train/test sequentially, each finalize() overwrites the file with all data accumulated up to that point. The file grows more complete over time.

Parameters:

Name Type Description Default
status str

The stage that just completed, e.g., "fit" or "test".

required
Source code in hyperbench/train/markdown_logger.py
def finalize(self, status: str) -> None:
    """Write the markdown comparison table with all accumulated metrics so far.

    Called by Lightning after fit() and after test() for each model. Since models
    train/test sequentially, each finalize() overwrites the file with all data
    accumulated up to that point. The file grows more complete over time.

    Args:
        status: The stage that just completed, e.g., "fit" or "test".
    """
    test_results, train_results, val_results = self.__split_results()

    if not test_results and not train_results and not val_results:
        return

    comparison_dir = Path(self.__save_dir) / "comparison"
    self.__save_comparison_tables(
        test_results=test_results,
        save_dir=comparison_dir,
        train_results=train_results or None,
        val_results=val_results or None,
        precision=self.__precision,
    )

__split_results()

Split all accumulated metrics into test vs train/val groups.

Metrics are classified by their name prefix: - "test" → test_results - "train" → train_results - "val*" → val_results - anything else (e.g., "epoch") → ignored

Returns:

Type Description
dict[str, dict[str, float]]

Tuple of (test_results, train_results, val_results), where each is a dict

dict[str, dict[str, float]]

mapping model names to their respective metric dicts. Models with no metrics

dict[str, dict[str, float]]

in a category are excluded from that category's dict.

Source code in hyperbench/train/markdown_logger.py
def __split_results(
    self,
) -> tuple[
    dict[str, dict[str, float]], dict[str, dict[str, float]], dict[str, dict[str, float]]
]:
    """Split all accumulated metrics into test vs train/val groups.

    Metrics are classified by their name prefix:
    - "test*"  --> test_results
    - "train*" --> train_results
    - "val*" --> val_results
    - anything else (e.g., "epoch") --> ignored

    Returns:
        Tuple of (test_results, train_results, val_results), where each is a dict
        mapping model names to their respective metric dicts. Models with no metrics
        in a category are excluded from that category's dict.
    """
    store = self.__shared_stores.get(self.__experiment_name, {})
    test_results: dict[str, dict[str, float]] = {}
    train_results: dict[str, dict[str, float]] = {}
    val_results: dict[str, dict[str, float]] = {}

    for model_name, metrics in store.items():
        test_metrics: dict[str, float] = {}
        train_metrics: dict[str, float] = {}
        val_metrics: dict[str, float] = {}

        for metric_name, value in metrics.items():
            if metric_name.startswith("test"):
                test_metrics[metric_name] = value
            elif metric_name.startswith("train"):
                train_metrics[metric_name] = value
            elif metric_name.startswith("val"):
                val_metrics[metric_name] = value

        if test_metrics:
            test_results[model_name] = test_metrics
        if train_metrics:
            train_results[model_name] = train_metrics
        if val_metrics:
            val_results[model_name] = val_metrics

    return test_results, train_results, val_results

clear(experiment_name)

Remove accumulated data for an experiment. Args: experiment_name: The experiment name whose data should be cleared.

Source code in hyperbench/train/markdown_logger.py
def clear(self, experiment_name: str) -> None:
    """Remove accumulated data for an experiment.
    Args:
        experiment_name: The experiment name whose data should be cleared.

    """
    self.__shared_stores.pop(experiment_name, None)

__build_comparison_table(results, precision=4)

Build a markdown comparison table from model results.

Examples:

Input:

{
"mlp:mean": {"test_auc": 0.85, "test_loss": 0.32},
"gat:default": {"test_auc": 0.82},
}

Output:

| Model | test_auc | test_loss |
| --- | --- | --- |
| gat:default | 0.8200 | - |
| mlp:mean | 0.8500 | 0.3200 |

Parameters:

Name Type Description Default
results Mapping[str, Mapping[str, float]]

Mapping of model names to metric dictionaries.

required
precision int

Number of decimal places for numeric metric values.

4

Returns:

Type Description
str

Markdown table string. Returns an empty string if results is empty.

Source code in hyperbench/train/markdown_logger.py
def __build_comparison_table(
    self,
    results: Mapping[str, Mapping[str, float]],
    precision: int = 4,
) -> str:
    """Build a markdown comparison table from model results.

    Examples:
        Input:

        ```python
        {
        "mlp:mean": {"test_auc": 0.85, "test_loss": 0.32},
        "gat:default": {"test_auc": 0.82},
        }
        ```

        Output:

        ```md
        | Model | test_auc | test_loss |
        | --- | --- | --- |
        | gat:default | 0.8200 | - |
        | mlp:mean | 0.8500 | 0.3200 |
        ```

    Args:
        results: Mapping of model names to metric dictionaries.
        precision: Number of decimal places for numeric metric values.

    Returns:
        Markdown table string. Returns an empty string if ``results`` is empty.


    """
    if not results:
        return ""

    # Collect all unique metric names across all models, sorted for determinism
    all_metrics = sorted(
        {metric for model_metrics in results.values() for metric in model_metrics}
    )

    # Build header row
    header = "| Model | " + " | ".join(all_metrics) + " |"
    separator = "| --- | " + " | ".join("---" for _ in all_metrics) + " |"

    # Build one row per model, sorted by model name for determinism
    rows = []
    for model_name in sorted(results):
        model_metrics = results[model_name]
        cells = []
        for metric in all_metrics:
            value = model_metrics.get(metric)
            if isinstance(value, (int, float)):
                cells.append(f"{value:.{precision}f}")
            else:
                cells.append("-")
        rows.append(f"| {model_name} | " + " | ".join(cells) + " |")

    return "\n".join([header, separator, *rows])

__save_comparison_tables(test_results, save_dir, train_results=None, val_results=None, filename='results.md', precision=4)

Build and save markdown comparison tables to a file.

Writes two sections: - "## Test Results" with the test metrics table - "## Train/Val Results" with the train/val metrics table (if provided)

Parameters:

Name Type Description Default
test_results Mapping[str, Mapping[str, float]]

Dict from test_all(), mapping model names to test metric dicts.

required
save_dir str | Path

Directory where the markdown file will be written.

required
train_results Mapping[str, Mapping[str, float]] | None

Optional dict mapping model names to train metric dicts.

None
val_results Mapping[str, Mapping[str, float]] | None

Optional dict mapping model names to val metric dicts.

None
filename str

Name of the output file.

'results.md'
precision int

Decimal places for metric values.

4

Returns:

Type Description
Path

Path to the written file.

Source code in hyperbench/train/markdown_logger.py
def __save_comparison_tables(
    self,
    test_results: Mapping[str, Mapping[str, float]],
    save_dir: str | Path,
    train_results: Mapping[str, Mapping[str, float]] | None = None,
    val_results: Mapping[str, Mapping[str, float]] | None = None,
    filename: str = "results.md",
    precision: int = 4,
) -> Path:
    """Build and save markdown comparison tables to a file.

    Writes two sections:
    - "## Test Results" with the test metrics table
    - "## Train/Val Results" with the train/val metrics table (if provided)

    Args:
        test_results: Dict from test_all(), mapping model names to test metric dicts.
        save_dir: Directory where the markdown file will be written.
        train_results: Optional dict mapping model names to train metric dicts.
        val_results: Optional dict mapping model names to val metric dicts.
        filename: Name of the output file.
        precision: Decimal places for metric values.

    Returns:
        Path to the written file.
    """
    sections = []

    # Test results table
    test_table = self.__build_comparison_table(test_results, precision)
    if test_table:
        sections.append(f"## Test Results\n\n{test_table}")

    # Train/val results table
    if train_results or val_results:
        if train_results:
            train_table = self.__build_comparison_table(train_results, precision)
            sections.append(f"## Train Results\n\n{train_table}")
        if val_results:
            val_table = self.__build_comparison_table(val_results, precision)
            sections.append(f"## Val Results\n\n{val_table}")

    content = "\n\n".join(sections) + "\n" if sections else ""

    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    file_path = save_path / filename
    file_path.write_text(content)

    return file_path

SameNodeSpaceNegativeSampler

Bases: NegativeSampler, ABC

Base class for negative samplers that sample only from existing nodes.

Parameters:

Name Type Description Default
hyperedge_attr_enricher HyperedgeAttrsEnricher | None

An optional :class:HyperedgeAttrsEnricher to generate attributes for the new hyperedges.

None
hyperedge_weights_enricher HyperedgeWeightsEnricher | None

An optional :class:HyperedgeWeightsEnricher to generate weights for the new hyperedges.

None
return_0based_negatives bool
  • If True, the negative samples returned by the sample method will have 0-based node and hyperedge IDs.
  • If False, the negative samples will retain the original global node and hyperedge IDs from the input data.
False
Source code in hyperbench/train/negative_sampler.py
class SameNodeSpaceNegativeSampler(NegativeSampler, ABC):
    """
    Base class for negative samplers that sample only from existing nodes.

    Args:
        hyperedge_attr_enricher: An optional :class:`HyperedgeAttrsEnricher` to generate attributes for the new hyperedges.
        hyperedge_weights_enricher: An optional :class:`HyperedgeWeightsEnricher` to generate weights for the new hyperedges.
        return_0based_negatives:
            - If ``True``, the negative samples returned by the ``sample`` method will have 0-based node and hyperedge IDs.
            - If ``False``, the negative samples will retain the original global node and hyperedge IDs from the input data.
    """

    def __init__(
        self,
        hyperedge_attr_enricher: HyperedgeAttrsEnricher | None = None,
        hyperedge_weights_enricher: HyperedgeWeightsEnricher | None = None,
        return_0based_negatives: bool = False,
    ):
        super().__init__(return_0based_negatives=return_0based_negatives)
        self.hyperedge_attr_enricher = hyperedge_attr_enricher
        self.hyperedge_weights_enricher = hyperedge_weights_enricher

GeneratedNodesNegativeSampler

Bases: NegativeSampler, ABC

Base class for negative samplers that generate new nodes instead of sampling from existing ones.

Parameters:

Name Type Description Default
node_feature_enricher NodeEnricher

A :class:NodeEnricher to generate features for the new nodes.

required
hyperedge_attr_enricher HyperedgeAttrsEnricher | None

An optional :class:HyperedgeAttrsEnricher to generate attributes for the new hyperedges.

None
hyperedge_weights_enricher HyperedgeWeightsEnricher | None

An optional :class:HyperedgeWeightsEnricher to generate weights for the new hyperedges.

None
return_0based_negatives bool
  • If True, the negative samples returned by the sample method will have 0-based node and hyperedge IDs.
  • If False, the negative samples will retain the original global node and hyperedge IDs from the input data.
False
Source code in hyperbench/train/negative_sampler.py
class GeneratedNodesNegativeSampler(NegativeSampler, ABC):
    """
    Base class for negative samplers that generate new nodes instead of sampling from existing ones.

    Args:
        node_feature_enricher: A :class:`NodeEnricher` to generate features for the new nodes.
        hyperedge_attr_enricher: An optional :class:`HyperedgeAttrsEnricher` to generate attributes for the new hyperedges.
        hyperedge_weights_enricher: An optional :class:`HyperedgeWeightsEnricher` to generate weights for the new hyperedges.
        return_0based_negatives:
            - If ``True``, the negative samples returned by the ``sample`` method will have 0-based node and hyperedge IDs.
            - If ``False``, the negative samples will retain the original global node and hyperedge IDs from the input data.
    """

    def __init__(
        self,
        node_feature_enricher: NodeEnricher,
        hyperedge_attr_enricher: HyperedgeAttrsEnricher | None = None,
        hyperedge_weights_enricher: HyperedgeWeightsEnricher | None = None,
        return_0based_negatives: bool = False,
    ):
        super().__init__(return_0based_negatives=return_0based_negatives)
        self.node_feature_enricher = node_feature_enricher
        self.hyperedge_attr_enricher = hyperedge_attr_enricher
        self.hyperedge_weights_enricher = hyperedge_weights_enricher

NegativeSampler

Bases: ABC

Abstract base class for negative samplers.

Parameters:

Name Type Description Default
return_0based_negatives bool
  • If True, the negative samples returned by the sample method will have 0-based node and hyperedge IDs.
  • If False, the negative samples will retain the original global node and hyperedge IDs from the input data.
False
Source code in hyperbench/train/negative_sampler.py
class NegativeSampler(ABC):
    """
    Abstract base class for negative samplers.

    Args:
        return_0based_negatives:
            - If ``True``, the negative samples returned by the ``sample`` method will have 0-based node and hyperedge IDs.
            - If ``False``, the negative samples will retain the original global node and hyperedge IDs from the input data.
    """

    def __init__(self, return_0based_negatives: bool = False):
        super().__init__()
        self.return_0based_negatives: bool = return_0based_negatives

    @abstractmethod
    def sample(self, hdata: HData, seed: int | None = None) -> HData:
        """
        Abstract method for negative sampling.

        Args:
            hdata: The input data object containing graph or hypergraph information.
            seed: Optional random seed for reproducible negative sampling.

        Returns:
            The negative samples as a new :class:`HData` object.

        Raises:
            NotImplementedError: If the method is not implemented in a subclass.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    def _new_negative_hyperedge_index(
        self,
        sampled_hyperedge_indexes: list[Tensor],
        negative_node_ids: Tensor,
        negative_hyperedge_ids: Tensor,
    ) -> Tensor:
        """
        Concatenate, sort, and remap the sampled hyperedge indexes for negative samples.

        Args:
            sampled_hyperedge_indexes: List of hyperedge index tensors for each negative sample.
            negative_node_ids: Tensor of negative node IDs.
            negative_hyperedge_ids: Tensor of negative hyperedge IDs.

        Returns:
            The concatenated, sorted, and remapped hyperedge index tensor.
            If ``self.return_0based_negatives`` is ``True``, the returned tensor will have 0-based node and hyperedge IDs.
            Otherwise, it will retain the original global node and hyperedge IDs from the input data.
        """
        negative_hyperedge_index = torch.cat(sampled_hyperedge_indexes, dim=1)
        if not self.return_0based_negatives:
            return negative_hyperedge_index

        negative_hyperedge_index_wrapper = HyperedgeIndex(negative_hyperedge_index).to_0based(
            node_ids_to_rebase=negative_node_ids,
            hyperedge_ids_to_rebase=negative_hyperedge_ids,
        )

        return negative_hyperedge_index_wrapper.item

    def _new_global_node_ids(
        self,
        global_node_ids: Tensor | None,
        negative_node_ids: Tensor,
    ) -> Tensor | None:
        """
        Get the global node IDs for the negative samples.

        Args:
            global_node_ids: The original global node IDs from the input data.
            negative_node_ids: Tensor of negative node IDs.

        Returns:
            The global node IDs for the negative samples, or ``None`` if the input global node IDs are ``None``.
        """
        if global_node_ids is None:
            return None
        return global_node_ids[negative_node_ids]

    def _new_hyperedge_attr(
        self,
        sampled_hyperedge_attrs: list[Tensor],
        hyperedge_attr: Tensor | None = None,
    ) -> Tensor | None:
        """
        Concatenate the hyperedge attributes for the negative samples.

        Args:
            sampled_hyperedge_attrs: List of hyperedge attribute tensors for each negative sample.
            hyperedge_attr: The original hyperedge attributes from the input data.

        Returns:
            The concatenated hyperedge attribute tensor for the negative samples.
        """
        if hyperedge_attr is None or len(sampled_hyperedge_attrs) < 1:
            return None

        negative_hyperedge_attr = torch.stack(sampled_hyperedge_attrs, dim=0)
        return negative_hyperedge_attr

    def _new_enriched_hyperedge_attr(
        self,
        hyperedge_attr_enricher: HyperedgeAttrsEnricher | None,
        negative_hyperedge_index: Tensor,
    ) -> Tensor | None:
        """
        Generate enriched hyperedge attributes for the negative samples.

        Args:
            hyperedge_attr_enricher: An optional :class:`HyperedgeAttrsEnricher` to generate attributes for the new hyperedges.
            negative_hyperedge_index: The index tensor for the negative hyperedges.

        Returns:
            The enriched hyperedge attribute tensor for the negative samples, or ``None`` if the enricher is not provided.
        """
        if hyperedge_attr_enricher is None:
            return None

        negative_hyperedge_index_0based = (
            HyperedgeIndex(negative_hyperedge_index.clone()).to_0based().item
        )
        return hyperedge_attr_enricher.enrich(negative_hyperedge_index_0based)

    def _new_enriched_hyperedge_weights(
        self,
        hyperedge_weights_enricher: HyperedgeWeightsEnricher | None,
        negative_hyperedge_index: Tensor,
    ) -> Tensor | None:
        """
        Generate enriched hyperedge weights for the negative samples.

        Args:
            hyperedge_weights_enricher: An optional :class:`HyperedgeWeightsEnricher` to generate weights for the new hyperedges.
            negative_hyperedge_index: The index tensor for the negative hyperedges.

        Returns:
            The enriched hyperedge weight tensor for the negative samples, or ``None`` if the enricher is not provided.
        """
        if hyperedge_weights_enricher is None:
            return None

        negative_hyperedge_index_0based = (
            HyperedgeIndex(negative_hyperedge_index.clone()).to_0based().item
        )
        return hyperedge_weights_enricher.enrich(negative_hyperedge_index_0based)

    def _new_x(self, x: Tensor, negative_node_ids: Tensor) -> tuple[Tensor, int]:
        """
        Get the node feature matrix for the negative samples.

        Args:
            x: The original node feature matrix from the input data.
            negative_node_ids: Tensor of negative node IDs.

        Returns:
            The node feature matrix for the negative samples and the number of negative nodes.
        """
        return x[negative_node_ids], len(negative_node_ids)

sample(hdata, seed=None) abstractmethod

Abstract method for negative sampling.

Parameters:

Name Type Description Default
hdata HData

The input data object containing graph or hypergraph information.

required
seed int | None

Optional random seed for reproducible negative sampling.

None

Returns:

Type Description
HData

The negative samples as a new :class:HData object.

Raises:

Type Description
NotImplementedError

If the method is not implemented in a subclass.

Source code in hyperbench/train/negative_sampler.py
@abstractmethod
def sample(self, hdata: HData, seed: int | None = None) -> HData:
    """
    Abstract method for negative sampling.

    Args:
        hdata: The input data object containing graph or hypergraph information.
        seed: Optional random seed for reproducible negative sampling.

    Returns:
        The negative samples as a new :class:`HData` object.

    Raises:
        NotImplementedError: If the method is not implemented in a subclass.
    """
    raise NotImplementedError("Subclasses must implement this method.")

RandomNegativeSampler

Bases: SameNodeSpaceNegativeSampler

A random negative sampler. Negatives generated with return_0based_negatives = False aren't usable standalone as they have global node and hyperedge IDs. They must be concatenated with the original :class:HData object that is provided as input to the sample method, as it contains the global node and hyperedge IDs and features that can be indexed with the negative samples' IDs.

Parameters:

Name Type Description Default
num_negative_samples int

Number of negative hyperedges to generate.

required
num_nodes_per_sample int

Number of nodes per negative hyperedge.

required
hyperedge_attr_enricher HyperedgeAttrsEnricher | None

An optional :class:HyperedgeAttrsEnricher to generate attributes for the new hyperedges. If not provided, random attributes will be generated for the negative hyperedges if the input data has hyperedge attributes.

None
hyperedge_weights_enricher HyperedgeWeightsEnricher | None

An optional :class:HyperedgeEnricher to generate weights for the new hyperedges. If not provided, the negative hyperedges will not have weights.

None
return_0based_negatives bool
  • If True, the negative samples returned by the sample method will have 0-based node and hyperedge IDs.
  • If False, the negative samples will retain the original global node and hyperedge IDs from the input data.
False

Raises:

Type Description
ValueError

If either argument is not positive.

Source code in hyperbench/train/negative_sampler.py
class RandomNegativeSampler(SameNodeSpaceNegativeSampler):
    """
    A random negative sampler. Negatives generated with ``return_0based_negatives = False`` aren't usable standalone
    as they have global node and hyperedge IDs. They must be concatenated with the original :class:`HData` object
    that is provided as input to the ``sample`` method, as it contains the global node and hyperedge IDs and features
    that can be indexed with the negative samples' IDs.

    Args:
        num_negative_samples: Number of negative hyperedges to generate.
        num_nodes_per_sample: Number of nodes per negative hyperedge.
        hyperedge_attr_enricher: An optional :class:`HyperedgeAttrsEnricher` to generate attributes for the new hyperedges.
            If not provided, random attributes will be generated for the negative hyperedges if the input data has hyperedge attributes.
        hyperedge_weights_enricher: An optional :class:`HyperedgeEnricher` to generate weights for the new hyperedges.
            If not provided, the negative hyperedges will not have weights.
        return_0based_negatives:
            - If ``True``, the negative samples returned by the ``sample`` method will have 0-based node and hyperedge IDs.
            - If ``False``, the negative samples will retain the original global node and hyperedge IDs from the input data.

    Raises:
        ValueError: If either argument is not positive.
    """

    def __init__(
        self,
        num_negative_samples: int,
        num_nodes_per_sample: int,
        hyperedge_attr_enricher: HyperedgeAttrsEnricher | None = None,
        hyperedge_weights_enricher: HyperedgeWeightsEnricher | None = None,
        return_0based_negatives: bool = False,
    ):
        if num_negative_samples <= 0:
            raise ValueError(f"num_negative_samples must be positive, got {num_negative_samples}.")
        if num_nodes_per_sample <= 0:
            raise ValueError(f"num_nodes_per_sample must be positive, got {num_nodes_per_sample}.")

        super().__init__(
            hyperedge_attr_enricher=hyperedge_attr_enricher,
            hyperedge_weights_enricher=hyperedge_weights_enricher,
            return_0based_negatives=return_0based_negatives,
        )
        self.num_negative_samples = num_negative_samples
        self.num_nodes_per_sample = num_nodes_per_sample

    def sample(self, hdata: HData, seed: int | None = None) -> HData:
        """
        Generate negative hyperedges by randomly sampling unique node IDs.
        Node IDs are sampled from the same node space as the input data, and the new negative hyperedge IDs
        start from the original number of hyperedges in the input data to avoid ID conflicts.
        The resulting negative samples are returned as a new :class:`HData` object with remapped 0-based node and hyperedge IDs, if ``self.return_0based_negatives == True``.
        Otherwise, the negative samples retain their original global node and hyperedge IDs from the input data.

        Examples:
            With ``self.return_0based_negatives = True``:

            >>> num_negative_samples = 2
            >>> num_nodes_per_sample = 3
            >>> negative_hyperedge_index = [[0, 0, 1, 2, 3, 4],
            ...                             [0, 1, 1, 0, 1, 0]]

            The negative hyperedge 0 connects nodes 0, 2, 3.
            The second negative hyperedge 1 connects nodes 0, 1, 4.

            >>> negative_x = data.x[[0, 1, 2, 3, 4]]
            >>> negative_hyperedge_attr = random_attributes_for_2_negative_hyperedges

            With ``self.return_0based_negatives = False``:

            >>> num_negative_samples = 2
            >>> num_nodes_per_sample = 3
            >>> negative_hyperedge_index = [[100, 120, 300, 450, 500, 501],
            ...                             [3, 3, 3, 4, 4, 4]]

            Since node IDs are not remapped, the original feature matrix can be used directly.

            >>> negative_x = data.x

        Args:
            hdata: The input data object containing node and hyperedge information.
            seed: Optional random seed for reproducible negative sampling.

        Returns:
            A new :class:`HData` instance containing the negative samples.

        Raises:
            ValueError: If ``num_nodes_per_sample`` is greater than the number of available nodes.
        """
        if self.num_nodes_per_sample > hdata.num_nodes:
            raise ValueError(
                f"Asked to create samples with {self.num_nodes_per_sample} nodes, but only {hdata.num_nodes} nodes are available."
            )

        device = hdata.device

        (
            sampled_hyperedge_indexes,
            sampled_hyperedge_attrs,
            sampled_negative_node_ids,
            new_hyperedge_id_offset,
        ) = self.__sample_loop(hdata=hdata, device=device, seed=seed)

        negative_node_ids_tensor = torch.tensor(sorted(sampled_negative_node_ids), device=device)
        new_x, num_negative_nodes = self._new_x(hdata.x, negative_node_ids_tensor)

        # Example: new_hyperedge_id_offset = 3 (if hdata.num_hyperedges was 3)
        #          num_negative_samples = 2
        #          -> num_hyperedges_including_negatives = 5
        num_hyperedges_including_negatives = new_hyperedge_id_offset + self.num_negative_samples
        negative_hyperedge_ids = torch.arange(
            new_hyperedge_id_offset,
            num_hyperedges_including_negatives,
            device=device,
        )

        negative_hyperedge_index = self._new_negative_hyperedge_index(
            sampled_hyperedge_indexes,
            negative_node_ids_tensor,
            negative_hyperedge_ids,
        )

        negative_hyperedge_attr = self._new_enriched_hyperedge_attr(
            hyperedge_attr_enricher=self.hyperedge_attr_enricher,
            negative_hyperedge_index=negative_hyperedge_index,
        )
        # Default to the random attributes if no enricher is provided and the input data has hyperedge attributes
        if negative_hyperedge_attr is None:
            negative_hyperedge_attr = self._new_hyperedge_attr(
                sampled_hyperedge_attrs=sampled_hyperedge_attrs, hyperedge_attr=hdata.hyperedge_attr
            )

        return HData(
            x=new_x,
            hyperedge_index=negative_hyperedge_index,
            hyperedge_weights=self._new_enriched_hyperedge_weights(
                hyperedge_weights_enricher=self.hyperedge_weights_enricher,
                negative_hyperedge_index=negative_hyperedge_index,
            ),
            hyperedge_attr=negative_hyperedge_attr,
            num_nodes=num_negative_nodes,
            num_hyperedges=self.num_negative_samples,
            global_node_ids=self._new_global_node_ids(
                global_node_ids=hdata.global_node_ids, negative_node_ids=negative_node_ids_tensor
            ),
        ).with_y_zeros()

    def __sample_loop(
        self,
        hdata: HData,
        device: torch.device,
        seed: int | None = None,
    ) -> tuple[list[Tensor], list[Tensor], set[int], int]:
        generator = None
        if seed is not None:
            generator = torch.Generator(device=device)
            generator.manual_seed(seed)

        sampled_negative_node_ids: set[int] = set()
        sampled_hyperedge_indexes: list[Tensor] = []
        sampled_hyperedge_attrs: list[Tensor] = []

        new_hyperedge_id_offset = hdata.num_hyperedges
        for new_hyperedge_id in range(self.num_negative_samples):
            # Sample with multinomial without replacement to ensure unique node ids
            # and assign each node id equal probability of being selected by setting all of them to 1
            # Example: num_nodes_per_sample=3, max_node_id=5
            #          -> possible output: [2, 0, 4]
            equal_probabilities = torch.ones(hdata.num_nodes, device=device)
            sampled_node_ids = torch.multinomial(
                input=equal_probabilities,
                num_samples=self.num_nodes_per_sample,
                replacement=False,
                generator=generator,
            )

            # Example: sampled_node_ids = [2, 0, 4], new_hyperedge_id=0, new_hyperedge_id_offset=3
            #          -> hyperedge_index = [[2, 0, 4],
            #                                [3, 3, 3]]  # this is sampled_hyperedge_id_tensor
            sampled_hyperedge_id_tensor = torch.full(
                (self.num_nodes_per_sample,),
                new_hyperedge_id + new_hyperedge_id_offset,
                device=device,
            )
            sampled_hyperedge_index = torch.stack(
                [sampled_node_ids, sampled_hyperedge_id_tensor], dim=0
            )
            sampled_hyperedge_indexes.append(sampled_hyperedge_index)

            # Example: nodes = [0, 1, 2],
            #          sampled_node_ids_0 = [0, 1], sampled_node_ids_1 = [1, 2],
            #          -> sampled_negative_node_ids = {0, 1, 2}
            sampled_negative_node_ids.update(sampled_node_ids.tolist())

            if hdata.hyperedge_attr is not None:
                random_hyperedge_attr = torch.randn(
                    hdata.hyperedge_attr[0].shape,
                    dtype=hdata.hyperedge_attr.dtype,
                    device=device,
                    generator=generator,
                )
                sampled_hyperedge_attrs.append(random_hyperedge_attr)

        return (
            sampled_hyperedge_indexes,
            sampled_hyperedge_attrs,
            sampled_negative_node_ids,
            new_hyperedge_id_offset,
        )

sample(hdata, seed=None)

Generate negative hyperedges by randomly sampling unique node IDs. Node IDs are sampled from the same node space as the input data, and the new negative hyperedge IDs start from the original number of hyperedges in the input data to avoid ID conflicts. The resulting negative samples are returned as a new :class:HData object with remapped 0-based node and hyperedge IDs, if self.return_0based_negatives == True. Otherwise, the negative samples retain their original global node and hyperedge IDs from the input data.

Examples:

With self.return_0based_negatives = True:

>>> num_negative_samples = 2
>>> num_nodes_per_sample = 3
>>> negative_hyperedge_index = [[0, 0, 1, 2, 3, 4],
...                             [0, 1, 1, 0, 1, 0]]

The negative hyperedge 0 connects nodes 0, 2, 3. The second negative hyperedge 1 connects nodes 0, 1, 4.

>>> negative_x = data.x[[0, 1, 2, 3, 4]]
>>> negative_hyperedge_attr = random_attributes_for_2_negative_hyperedges

With self.return_0based_negatives = False:

>>> num_negative_samples = 2
>>> num_nodes_per_sample = 3
>>> negative_hyperedge_index = [[100, 120, 300, 450, 500, 501],
...                             [3, 3, 3, 4, 4, 4]]

Since node IDs are not remapped, the original feature matrix can be used directly.

>>> negative_x = data.x

Parameters:

Name Type Description Default
hdata HData

The input data object containing node and hyperedge information.

required
seed int | None

Optional random seed for reproducible negative sampling.

None

Returns:

Type Description
HData

A new :class:HData instance containing the negative samples.

Raises:

Type Description
ValueError

If num_nodes_per_sample is greater than the number of available nodes.

Source code in hyperbench/train/negative_sampler.py
def sample(self, hdata: HData, seed: int | None = None) -> HData:
    """
    Generate negative hyperedges by randomly sampling unique node IDs.
    Node IDs are sampled from the same node space as the input data, and the new negative hyperedge IDs
    start from the original number of hyperedges in the input data to avoid ID conflicts.
    The resulting negative samples are returned as a new :class:`HData` object with remapped 0-based node and hyperedge IDs, if ``self.return_0based_negatives == True``.
    Otherwise, the negative samples retain their original global node and hyperedge IDs from the input data.

    Examples:
        With ``self.return_0based_negatives = True``:

        >>> num_negative_samples = 2
        >>> num_nodes_per_sample = 3
        >>> negative_hyperedge_index = [[0, 0, 1, 2, 3, 4],
        ...                             [0, 1, 1, 0, 1, 0]]

        The negative hyperedge 0 connects nodes 0, 2, 3.
        The second negative hyperedge 1 connects nodes 0, 1, 4.

        >>> negative_x = data.x[[0, 1, 2, 3, 4]]
        >>> negative_hyperedge_attr = random_attributes_for_2_negative_hyperedges

        With ``self.return_0based_negatives = False``:

        >>> num_negative_samples = 2
        >>> num_nodes_per_sample = 3
        >>> negative_hyperedge_index = [[100, 120, 300, 450, 500, 501],
        ...                             [3, 3, 3, 4, 4, 4]]

        Since node IDs are not remapped, the original feature matrix can be used directly.

        >>> negative_x = data.x

    Args:
        hdata: The input data object containing node and hyperedge information.
        seed: Optional random seed for reproducible negative sampling.

    Returns:
        A new :class:`HData` instance containing the negative samples.

    Raises:
        ValueError: If ``num_nodes_per_sample`` is greater than the number of available nodes.
    """
    if self.num_nodes_per_sample > hdata.num_nodes:
        raise ValueError(
            f"Asked to create samples with {self.num_nodes_per_sample} nodes, but only {hdata.num_nodes} nodes are available."
        )

    device = hdata.device

    (
        sampled_hyperedge_indexes,
        sampled_hyperedge_attrs,
        sampled_negative_node_ids,
        new_hyperedge_id_offset,
    ) = self.__sample_loop(hdata=hdata, device=device, seed=seed)

    negative_node_ids_tensor = torch.tensor(sorted(sampled_negative_node_ids), device=device)
    new_x, num_negative_nodes = self._new_x(hdata.x, negative_node_ids_tensor)

    # Example: new_hyperedge_id_offset = 3 (if hdata.num_hyperedges was 3)
    #          num_negative_samples = 2
    #          -> num_hyperedges_including_negatives = 5
    num_hyperedges_including_negatives = new_hyperedge_id_offset + self.num_negative_samples
    negative_hyperedge_ids = torch.arange(
        new_hyperedge_id_offset,
        num_hyperedges_including_negatives,
        device=device,
    )

    negative_hyperedge_index = self._new_negative_hyperedge_index(
        sampled_hyperedge_indexes,
        negative_node_ids_tensor,
        negative_hyperedge_ids,
    )

    negative_hyperedge_attr = self._new_enriched_hyperedge_attr(
        hyperedge_attr_enricher=self.hyperedge_attr_enricher,
        negative_hyperedge_index=negative_hyperedge_index,
    )
    # Default to the random attributes if no enricher is provided and the input data has hyperedge attributes
    if negative_hyperedge_attr is None:
        negative_hyperedge_attr = self._new_hyperedge_attr(
            sampled_hyperedge_attrs=sampled_hyperedge_attrs, hyperedge_attr=hdata.hyperedge_attr
        )

    return HData(
        x=new_x,
        hyperedge_index=negative_hyperedge_index,
        hyperedge_weights=self._new_enriched_hyperedge_weights(
            hyperedge_weights_enricher=self.hyperedge_weights_enricher,
            negative_hyperedge_index=negative_hyperedge_index,
        ),
        hyperedge_attr=negative_hyperedge_attr,
        num_nodes=num_negative_nodes,
        num_hyperedges=self.num_negative_samples,
        global_node_ids=self._new_global_node_ids(
            global_node_ids=hdata.global_node_ids, negative_node_ids=negative_node_ids_tensor
        ),
    ).with_y_zeros()

NegativeSamplingSchedule

Bases: Enum

When to run negative sampling during training.

Source code in hyperbench/train/negative_sampling_scheduler.py
class NegativeSamplingSchedule(Enum):
    """When to run negative sampling during training."""

    FIRST_EPOCH = "first_epoch"  # Only at epoch 0, cached for all subsequent epochs
    EVERY_N_EPOCHS = "every_n_epochs"  # Every N epochs (N provided separately)
    EVERY_EPOCH = "every_epoch"  # Negatives generated every epoch

NegativeSamplingScheduler

Manages when to perform negative sampling during training based on a specified schedule. This class allows for flexible scheduling of negative sampling, enabling it to be performed at different frequencies (e.g., every epoch, every N epochs, or only at the first epoch). The scheduler maintains a cache of the most recently sampled negatives, which can be reused across epochs if the schedule does not require resampling. This helps to optimize training by avoiding unnecessary sampling when the schedule dictates that negatives should only be generated at certain intervals.

Parameters:

Name Type Description Default
negative_sampler NegativeSampler

An instance of a NegativeSampler that defines how to sample negatives.

required
negative_sampling_schedule NegativeSamplingSchedule

An instance of NegativeSamplingSchedule that specifies the schedule for sampling negatives.

EVERY_EPOCH
negative_sampling_every_n int

An integer specifying the interval for sampling negatives when the schedule is set to EVERY_N_EPOCHS. This parameter is ignored for other schedules.

1
Source code in hyperbench/train/negative_sampling_scheduler.py
class NegativeSamplingScheduler:
    """
    Manages when to perform negative sampling during training based on a specified schedule.
    This class allows for flexible scheduling of negative sampling, enabling it to be performed at different frequencies (e.g., every epoch, every N epochs, or only at the first epoch).
        The scheduler maintains a cache of the most recently sampled negatives, which can be reused across epochs if the schedule does not require resampling. This helps to optimize training
        by avoiding unnecessary sampling when the schedule dictates that negatives should only be generated at certain intervals.

    Args:
        negative_sampler: An instance of a ``NegativeSampler`` that defines how to sample negatives.
        negative_sampling_schedule: An instance of ``NegativeSamplingSchedule`` that specifies the schedule for sampling negatives.
        negative_sampling_every_n: An integer specifying the interval for sampling negatives when the schedule is set to ``EVERY_N_EPOCHS``. This parameter is ignored for other schedules.
    """

    def __init__(
        self,
        negative_sampler: NegativeSampler,
        negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH,
        negative_sampling_every_n: int = 1,
    ) -> None:
        self.negative_sampler = negative_sampler
        self.negative_sampling_schedule = negative_sampling_schedule
        self.negative_sampling_every_n = negative_sampling_every_n

        self.__cached_negative_samples: HData | None = None

    @property
    def config(self) -> dict[str, Any]:
        """Returns the configuration of the negative sampling scheduler as a dictionary."""
        return {
            "negative_sampler": self.negative_sampler,
            "negative_sampling_schedule": self.negative_sampling_schedule,
            "negative_sampling_every_n": self.negative_sampling_every_n,
        }

    def should_sample(self, epoch: int) -> bool:
        """
        Whether to resample negatives for the current epoch.

        Args:
            epoch: The current epoch number, used to determine if sampling should occur based on the schedule.

        Returns:
            True if negatives should be resampled for the current epoch, False otherwise.
        """
        match self.negative_sampling_schedule:
            case NegativeSamplingSchedule.EVERY_N_EPOCHS:
                return epoch % self.negative_sampling_every_n == 0
            case NegativeSamplingSchedule.FIRST_EPOCH:
                return epoch == 0
            case _:  # Defaults to NegativeSamplingSchedule.EVERY_EPOCH
                return True

    def sample(self, batch: HData, epoch: int) -> HData:
        """
        Sample fresh negatives if the schedule requires it, otherwise return cache.

        Args:
            batch: The current batch of data for which to sample negatives.
            epoch: The current epoch number, used to determine if sampling should occur based on the schedule.

        Returns:
            A batch of negative samples, either freshly sampled or from cache.
        """
        if self.should_sample(epoch):
            self.__cached_negative_samples = self.negative_sampler.sample(batch)

        if self.__cached_negative_samples is None:
            raise ValueError(
                "Asked to sample negatives but no scheduling happen, "
                f"check that the configuration is correct: {self.config}"
            )

        return self.__cached_negative_samples

config property

Returns the configuration of the negative sampling scheduler as a dictionary.

should_sample(epoch)

Whether to resample negatives for the current epoch.

Parameters:

Name Type Description Default
epoch int

The current epoch number, used to determine if sampling should occur based on the schedule.

required

Returns:

Type Description
bool

True if negatives should be resampled for the current epoch, False otherwise.

Source code in hyperbench/train/negative_sampling_scheduler.py
def should_sample(self, epoch: int) -> bool:
    """
    Whether to resample negatives for the current epoch.

    Args:
        epoch: The current epoch number, used to determine if sampling should occur based on the schedule.

    Returns:
        True if negatives should be resampled for the current epoch, False otherwise.
    """
    match self.negative_sampling_schedule:
        case NegativeSamplingSchedule.EVERY_N_EPOCHS:
            return epoch % self.negative_sampling_every_n == 0
        case NegativeSamplingSchedule.FIRST_EPOCH:
            return epoch == 0
        case _:  # Defaults to NegativeSamplingSchedule.EVERY_EPOCH
            return True

sample(batch, epoch)

Sample fresh negatives if the schedule requires it, otherwise return cache.

Parameters:

Name Type Description Default
batch HData

The current batch of data for which to sample negatives.

required
epoch int

The current epoch number, used to determine if sampling should occur based on the schedule.

required

Returns:

Type Description
HData

A batch of negative samples, either freshly sampled or from cache.

Source code in hyperbench/train/negative_sampling_scheduler.py
def sample(self, batch: HData, epoch: int) -> HData:
    """
    Sample fresh negatives if the schedule requires it, otherwise return cache.

    Args:
        batch: The current batch of data for which to sample negatives.
        epoch: The current epoch number, used to determine if sampling should occur based on the schedule.

    Returns:
        A batch of negative samples, either freshly sampled or from cache.
    """
    if self.should_sample(epoch):
        self.__cached_negative_samples = self.negative_sampler.sample(batch)

    if self.__cached_negative_samples is None:
        raise ValueError(
            "Asked to sample negatives but no scheduling happen, "
            f"check that the configuration is correct: {self.config}"
        )

    return self.__cached_negative_samples

MultiModelTrainer

A trainer class to handle training multiple models with individual trainers.

Parameters:

Name Type Description Default
model_configs list[ModelConfig]

A list of ModelConfig objects, each containing a model and its associated trainer (if any).

required
experiment_name str | None

Name for this experiment run's log directory. When None (default), auto-increments as experiment_0, experiment_1, etc. under the log root directory. Only used when logger is not provided.

None
accelerator str | Accelerator

Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") as well as custom accelerator instances.

'auto'
devices list[int] | str | int

The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Defaults to "auto".

'auto'
strategy str | Strategy

Supports different training strategies with aliases as well custom strategies. Defaults to "auto".

'auto'
num_nodes int

Number of GPU nodes for distributed training. Defaults to 1.

1
precision Any | None

Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, or HPUs. Defaults to '32-true'.

None
max_epochs int | None

Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000. To enable infinite training, set max_epochs = -1.

None
min_epochs int | None

Force training for at least these many epochs. Disabled by default (None).

None
max_steps int

Stop training after this number of steps. Disabled by default (-1). If max_steps = -1 and max_epochs = None, will default to max_epochs = 1000. To enable infinite training, set max_epochs to -1.

-1
min_steps int | None

Force training for at least these number of steps. Disabled by default (None).

None
check_val_every_n_epoch int | None

Perform a validation loop after every N training epochs. If None, validation will be done solely based on the number of training batches, requiring val_check_interval to be an integer value. When used together with a time-based val_check_interval and check_val_every_n_epoch > 1, validation is aligned to epoch multiples: if the interval elapses before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. For None or 1 cases, the time-based behavior of val_check_interval applies without additional alignment. Defaults to 1.

1
logger Logger | Iterable[Logger] | bool | None

Logger (or iterable collection of loggers) for experiment tracking. A True value uses the default TensorBoardLogger if it is installed, otherwise CSVLogger. False will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the log_dir of the first logger. Defaults to True.

None
default_root_dir str | Path | None

Default path for logs and weights when no logger/ckpt_callback passed. Defaults to os.getcwd(). Can be remote file paths such as s3://mybucket/path or 'hdfs://path/'

None
enable_autolog_hparams bool

Whether to log hyperparameters at the start of a run. Defaults to True.

True
log_every_n_steps int | None

How often to log within steps. Defaults to 50.

None
profiler Profiler | str | None

To profile individual steps during training and assist in identifying bottlenecks. Defaults to None.

None
fast_dev_run int | bool

Runs n if set to n (int) else 1 if set to True batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Defaults to False.

False
enable_checkpointing bool

If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:~hyperbench.train.MultiModelTrainer.callbacks. Defaults to True.

True
enable_progress_bar bool

Whether to enable the progress bar by default. Defaults to True.

True
enable_model_summary bool | None

Whether to enable model summarization by default. Defaults to True.

None
callbacks list[Callback] | Callback | None

Add a callback or list of callbacks. Defaults to None.

None
auto_start_tensorboard bool

When True and tensorboard is installed, automatically starts a TensorBoard server pointing at the experiment log directory. Using this option requires that TensorBoard is installed in the environment and moves control of the TensorBoard server lifecycle to the trainer, which will automatically terminate the server when the trainer is finalized (e.g., at the end of a with block or when the object is garbage collected). Enable auto_wait to keep the server alive after training completes so you can inspect results before the trainer is finalized. Defaults to False.

False
tensorboard_port int

Port for the auto-launched TensorBoard server. Defaults to 6006.

6006
auto_wait bool

When True and a TensorBoard server is running, automatically calls :meth:wait inside finalize before terminating the server, so the user can inspect results before the process is stopped. Defaults to False.

False
Source code in hyperbench/train/trainer.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
class MultiModelTrainer:
    """
    A trainer class to handle training multiple models with individual trainers.

    Args:
        model_configs: A list of ModelConfig objects, each containing a model and its associated trainer (if any).

        experiment_name: Name for this experiment run's log directory. When ``None`` (default),
            auto-increments as ``experiment_0``, ``experiment_1``, etc. under the log root directory.
            Only used when ``logger`` is not provided.

        accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto")
            as well as custom accelerator instances.

        devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices
            (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
            automatic selection based on the chosen accelerator. Defaults to ``"auto"``.

        strategy: Supports different training strategies with aliases as well custom strategies.
            Defaults to ``"auto"``.

        num_nodes: Number of GPU nodes for distributed training.
            Defaults to ``1``.

        precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
            16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
            Can be used on CPU, GPU, TPUs, or HPUs.
            Defaults to ``'32-true'``.

        max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
            If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
            To enable infinite training, set ``max_epochs = -1``.

        min_epochs: Force training for at least these many epochs. Disabled by default (None).

        max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
            and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
            ``max_epochs`` to ``-1``.

        min_steps: Force training for at least these number of steps. Disabled by default (``None``).

        check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
            validation will be done solely based on the number of training batches, requiring ``val_check_interval``
            to be an integer value. When used together with a time-based ``val_check_interval`` and
            ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses
            before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch)
            and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch.
            For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without
            additional alignment.
            Defaults to ``1``.

        logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
            the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
            ``False`` will disable logging. If multiple loggers are provided, local files
            (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
            Defaults to ``True``.

        default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
            Defaults to ``os.getcwd()``.
            Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

        enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
            Defaults to ``True``.

        log_every_n_steps: How often to log within steps.
            Defaults to ``50``.

        profiler: To profile individual steps during training and assist in identifying bottlenecks.
            Defaults to ``None``.

        fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
            of train, val and test to find any bugs (ie: a sort of unit test).
            Defaults to ``False``.

        enable_checkpointing: If ``True``, enable checkpointing.
            It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
            :paramref:`~hyperbench.train.MultiModelTrainer.callbacks`.
            Defaults to ``True``.

        enable_progress_bar: Whether to enable the progress bar by default.
            Defaults to ``True``.

        enable_model_summary: Whether to enable model summarization by default.
            Defaults to ``True``.

        callbacks: Add a callback or list of callbacks.
            Defaults to ``None``.

        auto_start_tensorboard: When ``True`` and tensorboard is installed, automatically starts
            a TensorBoard server pointing at the experiment log directory.
            Using this option requires that TensorBoard is installed in the environment and moves control
            of the TensorBoard server lifecycle to the trainer, which will automatically terminate the server
            when the trainer is finalized (e.g., at the end of a `with` block or when the object is garbage collected).
            Enable `auto_wait` to keep the server alive after training completes so you can inspect results before the trainer is finalized.
            Defaults to ``False``.

        tensorboard_port: Port for the auto-launched TensorBoard server.
            Defaults to ``6006``.

        auto_wait: When ``True`` and a TensorBoard server is running, automatically calls
            :meth:`wait` inside `finalize` before terminating the server, so the user
            can inspect results before the process is stopped.
            Defaults to ``False``.
    """

    DEFAULT_BASE_LOG_DIR = "hyperbench_logs"
    EXPERIMENT_NAME_PREFIX = "experiment"
    VERSION_NAME_PREFIX = "version"

    __UNKNOWN_DEVICE = "unknown"

    def __init__(
        self,
        model_configs: list[ModelConfig],
        experiment_name: str | None = None,
        # args to pass to each Trainer
        accelerator: str | Accelerator = "auto",
        devices: list[int] | str | int = "auto",
        strategy: str | Strategy = "auto",
        num_nodes: int = 1,
        precision: Any
        | None = None,  # Any as Lightning accepts multiple types (int, str, Literal, etc.)
        max_epochs: int | None = None,
        min_epochs: int | None = None,
        max_steps: int = -1,
        min_steps: int | None = None,
        check_val_every_n_epoch: int | None = 1,
        logger: Logger | Iterable[Logger] | bool | None = None,
        default_root_dir: str | Path | None = None,
        enable_autolog_hparams: bool = True,
        log_every_n_steps: int | None = None,
        profiler: Profiler | str | None = None,
        fast_dev_run: int | bool = False,
        enable_checkpointing: bool = True,
        enable_progress_bar: bool = True,
        enable_model_summary: bool | None = None,
        callbacks: list[Callback] | Callback | None = None,
        auto_start_tensorboard: bool = False,
        tensorboard_port: int = 6006,
        auto_wait: bool = False,
        **kwargs,
    ) -> None:
        self.model_configs = model_configs
        self.log_dir = self.__setup_logdir(default_root_dir, experiment_name)

        self.auto_start_tensorboard = auto_start_tensorboard
        self.auto_wait = auto_wait
        self.tensorboard_port = tensorboard_port
        self.__tensorboard_process: subprocess.Popen | None = None

        for model_config in model_configs:
            if model_config.trainer is None:
                model_logger = self.__setup_logger(model_config, logger)

                model_config.trainer = L.Trainer(
                    accelerator=accelerator,
                    devices=devices,
                    strategy=strategy,
                    num_nodes=num_nodes,
                    precision=precision,
                    max_epochs=max_epochs,
                    min_epochs=min_epochs,
                    max_steps=max_steps,
                    min_steps=min_steps,
                    check_val_every_n_epoch=check_val_every_n_epoch,
                    logger=model_logger,
                    default_root_dir=default_root_dir,
                    enable_autolog_hparams=enable_autolog_hparams,
                    log_every_n_steps=log_every_n_steps,
                    profiler=profiler,
                    fast_dev_run=fast_dev_run,
                    enable_checkpointing=enable_checkpointing,
                    enable_progress_bar=enable_progress_bar,
                    enable_model_summary=enable_model_summary,
                    callbacks=copy.deepcopy(callbacks),
                    **kwargs,
                )

        print(f"Initialized trainer(models: {len(model_configs)}, log_dir: {self.log_dir})")
        self.__auto_start_tensorboard_if_enabled()

    def __enter__(self) -> "MultiModelTrainer":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.finalize()

    def __del__(self) -> None:
        try:
            self.finalize()
        except Exception as e:
            warnings.warn(
                f"Exception occurred during {self.__class__.__name__} cleanup. Error: {e}",
                category=UserWarning,
                stacklevel=2,
            )

    @property
    def models(self) -> list[L.LightningModule]:
        return [config.model for config in self.model_configs]

    def model(self, name: str, version: str = "default") -> L.LightningModule | None:
        for config in self.model_configs:
            if config.name == name and config.version == version:
                return config.model
        return None

    def fit_all(
        self,
        train_dataloader: DataLoader | None = None,
        val_dataloader: DataLoader | None = None,
        datamodule: L.LightningDataModule | None = None,
        ckpt_path: CkptStrategy | None = None,
        verbose: bool = True,
    ) -> None:
        if len(self.model_configs) < 1:
            raise ValueError("No models to fit.")

        for i, config in enumerate(self.model_configs):
            if not config.is_trainable:
                if verbose:
                    print(
                        f"Skipping training for model {config.full_model_name()} [{i + 1}/{len(self.model_configs)} models] (is_trainable=False)"
                    )
                continue

            if config.trainer is None:
                raise ValueError(f"Trainer not defined for model {config.full_model_name()}.")

            if verbose:
                print(
                    f"Fit model {config.full_model_name()} "
                    f"[{i + 1}/{len(self.model_configs)} models] "
                    f"(device: {self.__device(config.trainer)})"
                )

            train_dataloaders = (
                config.train_dataloader if config.train_dataloader is not None else train_dataloader
            )
            val_dataloaders = (
                config.val_dataloader if config.val_dataloader is not None else val_dataloader
            )
            config.trainer.fit(
                model=config.model,
                train_dataloaders=train_dataloaders,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule,
                ckpt_path=ckpt_path,
            )

    def test_all(
        self,
        dataloader: DataLoader | None = None,
        datamodule: L.LightningDataModule | None = None,
        ckpt_path: CkptStrategy | None = None,
        verbose: bool = True,
        verbose_loop: bool = True,
    ) -> Mapping[str, TestResult]:
        if len(self.model_configs) < 1:
            raise ValueError("No models to test.")

        test_results: dict[str, TestResult] = {}

        for i, config in enumerate(self.model_configs):
            if config.trainer is None:
                raise ValueError(f"Trainer not defined for model {config.full_model_name()}.")

            if verbose:
                print(
                    f"Test model {config.full_model_name()} "
                    f"[{i + 1}/{len(self.model_configs)} models] "
                    f"(device: {self.__device(config.trainer)})"
                )

            test_dataloaders = (
                config.test_dataloader if config.test_dataloader is not None else dataloader
            )
            trainer_test_results: list[TestResult] = config.trainer.test(
                model=config.model,
                dataloaders=test_dataloaders,
                datamodule=datamodule,
                ckpt_path=ckpt_path,
                verbose=verbose_loop,
            )

            # In Lightning, test() returns a list of dicts, one per dataloader, but we use a single dataloader
            test_results[config.full_model_name()] = (
                trainer_test_results[0] if len(trainer_test_results) > 0 else {}
            )

        return test_results

    def finalize(self) -> None:
        if self.auto_wait:
            self.wait()
        if self.__tensorboard_process is not None:
            self.__tensorboard_process.terminate()
            self.__tensorboard_process = None

    def wait(self) -> None:
        """
        Wait until the user presses Enter, keeping process alive.
        If no process is running, this method does nothing.
        """
        # For now, we only use this for waiting on TensorBoard, but this can be extended
        # to support waiting for other processes or conditions as needed
        if self.__tensorboard_process is None:
            return

        print(f"TensorBoard is running at http://localhost:{self.tensorboard_port}")

        try:
            input("Press Enter to stop...")
        except (KeyboardInterrupt, EOFError):
            print("Stopping TensorBoard...")

    def __auto_start_tensorboard_if_enabled(self) -> None:
        if self.auto_start_tensorboard:
            if self.__is_tensorboard_available():
                self.__tensorboard_process = self.__start_tensorboard_process()
            else:
                warnings.warn(
                    "TensorBoard is not available. "
                    "Install it with `pip install hyperbench[tensorboard]` or `pip install tensorboard`"
                    "to enable auto-start.",
                    category=UserWarning,
                    stacklevel=2,
                )

    def __is_tensorboard_available(self) -> bool:
        return importlib.util.find_spec("tensorboard") is not None

    def __start_tensorboard_process(self) -> subprocess.Popen | None:
        try:
            tensorboard_executable = shutil.which("tensorboard")
            if tensorboard_executable is None:
                return None

            log_dir = str(self.log_dir)
            tensorboard_port = str(self.tensorboard_port)
            process = subprocess.Popen(
                [tensorboard_executable, "--logdir", log_dir, "--port", tensorboard_port],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            )
            print(f"TensorBoard started at http://localhost:{tensorboard_port} (logdir={log_dir})")
            return process
        except Exception as e:
            warnings.warn(
                f"Proceeding without starting TensorBoard as it failed: {e}",
                category=UserWarning,
                stacklevel=2,
            )
            return None

    def __device(self, trainer: L.Trainer) -> str:
        if trainer.strategy is None:
            return self.__UNKNOWN_DEVICE
        strategy = trainer.strategy
        if strategy.root_device is None:
            return self.__UNKNOWN_DEVICE
        return str(strategy.root_device)

    def __next_experiment_name(self, save_dir: Path) -> Path:
        if not save_dir.exists():
            return Path(f"{self.EXPERIMENT_NAME_PREFIX}_0")

        existing_experiment_names: list[str] = [
            dir.name
            for dir in save_dir.iterdir()
            if dir.is_dir() and dir.name.startswith(self.EXPERIMENT_NAME_PREFIX)
        ]
        if len(existing_experiment_names) < 1:
            return Path(f"{self.EXPERIMENT_NAME_PREFIX}_0")

        last_experiment_number = max(
            int(experiment_name.split("_")[1])
            for experiment_name in existing_experiment_names
            if experiment_name.split("_")[1].isdigit()
        )
        return Path(f"{self.EXPERIMENT_NAME_PREFIX}_{last_experiment_number + 1}")

    def __setup_logdir(
        self,
        default_root_dir: str | Path | None,
        experiment_name: str | None,
    ) -> Path:
        base_dir = (
            Path(self.DEFAULT_BASE_LOG_DIR) if default_root_dir is None else Path(default_root_dir)
        )
        next_experiment_name = (
            self.__next_experiment_name(base_dir)
            if experiment_name is None
            else Path(experiment_name)
        )
        return base_dir / next_experiment_name

    def __setup_logger(
        self,
        model_config: ModelConfig,
        logger: Logger | Iterable[Logger] | bool | None,
    ) -> Logger | Iterable[Logger] | bool | None:
        if logger is not None:
            return logger

        experiment_name = str(self.__next_experiment_name(self.log_dir))

        loggers: list[Logger] = [
            CSVLogger(
                save_dir=self.log_dir,
                name=model_config.name,
                version=f"{self.VERSION_NAME_PREFIX}_{model_config.version}",
            ),
            MarkdownTableLogger(
                save_dir=self.log_dir,
                model_name=model_config.full_model_name(),
                experiment_name=experiment_name,
            ),
            LaTexTableLogger(
                save_dir=self.log_dir,
                model_name=model_config.full_model_name(),
                experiment_name=experiment_name,
                options={
                    "table_caption": "Results for Experiments",
                    "sort_by": ["des", "asc"],
                    "border": False,
                },
            ),
        ]

        if self.__is_tensorboard_available():
            from lightning.pytorch.loggers import TensorBoardLogger

            loggers.append(
                TensorBoardLogger(
                    save_dir=self.log_dir,
                    name=model_config.name,
                    version=f"{self.VERSION_NAME_PREFIX}_{model_config.version}",
                ),
            )

        return loggers

wait()

Wait until the user presses Enter, keeping process alive. If no process is running, this method does nothing.

Source code in hyperbench/train/trainer.py
def wait(self) -> None:
    """
    Wait until the user presses Enter, keeping process alive.
    If no process is running, this method does nothing.
    """
    # For now, we only use this for waiting on TensorBoard, but this can be extended
    # to support waiting for other processes or conditions as needed
    if self.__tensorboard_process is None:
        return

    print(f"TensorBoard is running at http://localhost:{self.tensorboard_port}")

    try:
        input("Press Enter to stop...")
    except (KeyboardInterrupt, EOFError):
        print("Stopping TensorBoard...")