Skip to content

HLP

hyperbench.hlp

CommonNeighborsHlpModule

Bases: HlpModule

A LightningModule for the CommonNeighbors model with optional negative sampling.

Parameters:

Name Type Description Default
aggregation Literal['mean', 'min', 'sum']

The aggregation method for common neighbors ("mean", "min", or "sum").

'mean'
decoder Module | None

An optional decoder module. Defaults to :class:CommonNeighbors.

None
loss_fn Module | None

An optional loss function. Defaults to BCEWithLogitsLoss.

None
metrics MetricCollection | None

An optional dictionary of metric functions.

None
Source code in hyperbench/hlp/common_neighbors_hlp.py
class CommonNeighborsHlpModule(HlpModule):
    """
    A LightningModule for the CommonNeighbors model with optional negative sampling.

    Args:
        aggregation: The aggregation method for common neighbors ("mean", "min", or "sum").
        decoder: An optional decoder module. Defaults to :class:`CommonNeighbors`.
        loss_fn: An optional loss function. Defaults to ``BCEWithLogitsLoss``.
        metrics: An optional dictionary of metric functions.
    """

    def __init__(
        self,
        train_hyperedge_index: Tensor,
        aggregation: Literal["mean", "min", "sum"] = "mean",
        decoder: nn.Module | None = None,
        loss_fn: nn.Module | None = None,
        metrics: MetricCollection | None = None,
    ):
        super().__init__(
            decoder=decoder if decoder is not None else CommonNeighbors(aggregation),
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        # Pre-compute neighbors of training nodes based on training edges only
        # to create a "known world" for the model to make predictions from
        self.node_to_neighbors = Hypergraph.from_hyperedge_index(
            train_hyperedge_index
        ).neighbors_of_all()

        # Disable automatic optimization since there is no training
        self.automatic_optimization = False

    def forward(self, hyperedge_index: Tensor) -> Tensor:
        """
        Compute common neighbor scores for the given hyperedges.

        Args:
            hyperedge_index: Tensor containing incidence information for the hyperedges to score.
        """
        return self.decoder(hyperedge_index, self.node_to_neighbors)

    def on_fit_start(self) -> None:
        """Warn users if they are running unnecessary training epochs."""
        if self.trainer.max_epochs is None or self.trainer.max_epochs > 0:
            warnings.warn(
                f"{self.__class__.__name__} is a non-trainable heuristic model. "
                "No optimization occurs. Set max_epochs=0 in your trainer for instant evaluation.",
                UserWarning,
                stacklevel=2,
            )

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return torch.tensor(0.0, device=self.device)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return torch.tensor(0.0, device=self.device)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__step(batch, stage=Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.hyperedge_index)

    def configure_optimizers(self):
        # No training, so no optimizers needed
        return None

    def __step(self, batch: HData, stage: Stage) -> Tensor:
        """
        Shared evaluation logic for all stages.

        Args:
            batch: :class:`HData` object containing the hypergraph.
            stage: The current stage of evaluation (e.g., ``Stage.TRAIN``, ``Stage.VAL``, ``Stage.TEST``).

        Returns:
            The computed loss.
        """
        scores = self.forward(batch.hyperedge_index)
        labels = batch.y

        # We need to use the number of hyperedges as batch size for logging purposes,
        # since each hyperedge is a separate prediction
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)

        return loss

forward(hyperedge_index)

Compute common neighbor scores for the given hyperedges.

Parameters:

Name Type Description Default
hyperedge_index Tensor

Tensor containing incidence information for the hyperedges to score.

required
Source code in hyperbench/hlp/common_neighbors_hlp.py
def forward(self, hyperedge_index: Tensor) -> Tensor:
    """
    Compute common neighbor scores for the given hyperedges.

    Args:
        hyperedge_index: Tensor containing incidence information for the hyperedges to score.
    """
    return self.decoder(hyperedge_index, self.node_to_neighbors)

on_fit_start()

Warn users if they are running unnecessary training epochs.

Source code in hyperbench/hlp/common_neighbors_hlp.py
def on_fit_start(self) -> None:
    """Warn users if they are running unnecessary training epochs."""
    if self.trainer.max_epochs is None or self.trainer.max_epochs > 0:
        warnings.warn(
            f"{self.__class__.__name__} is a non-trainable heuristic model. "
            "No optimization occurs. Set max_epochs=0 in your trainer for instant evaluation.",
            UserWarning,
            stacklevel=2,
        )

__step(batch, stage)

Shared evaluation logic for all stages.

Parameters:

Name Type Description Default
batch HData

:class:HData object containing the hypergraph.

required
stage Stage

The current stage of evaluation (e.g., Stage.TRAIN, Stage.VAL, Stage.TEST).

required

Returns:

Type Description
Tensor

The computed loss.

Source code in hyperbench/hlp/common_neighbors_hlp.py
def __step(self, batch: HData, stage: Stage) -> Tensor:
    """
    Shared evaluation logic for all stages.

    Args:
        batch: :class:`HData` object containing the hypergraph.
        stage: The current stage of evaluation (e.g., ``Stage.TRAIN``, ``Stage.VAL``, ``Stage.TEST``).

    Returns:
        The computed loss.
    """
    scores = self.forward(batch.hyperedge_index)
    labels = batch.y

    # We need to use the number of hyperedges as batch size for logging purposes,
    # since each hyperedge is a separate prediction
    batch_size = batch.num_hyperedges

    loss = self._compute_loss(scores, labels, batch_size, stage)
    self._compute_metrics(scores, labels, batch_size, stage)

    return loss

GCNEncoderConfig

Bases: TypedDict

Configuration for the GCN encoder in GCNHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
out_channels

Number of output features (embedding size) per node.

required
hidden_channels

Number of hidden units in the intermediate GCN layers.

required
num_layers

Number of GCN layers. Defaults to 2.

required
drop_rate

Dropout rate applied after each hidden GCN layer. Defaults to 0.0.

required
bias

Whether to include bias terms. Defaults to True.

required
improved

Whether to use the improved GCN normalization. Defaults to False.

required
add_self_loops

Whether to add self-loops before convolution. Defaults to True.

required
normalize

Whether to normalize the adjacency matrix in GCNConv. Defaults to True.

required
cached

Whether to cache the normalized graph in GCNConv. Defaults to False.

required
graph_reduction_strategy

Strategy for reducing the hypergraph to a graph. Defaults to "clique_expansion".

required
activation_fn

Activation function to use after each hidden layer. Defaults to nn.ReLU.

required
activation_fn_kwargs

Keyword arguments for the activation function. Defaults to empty dict.

required
Source code in hyperbench/hlp/gcn_hlp.py
class GCNEncoderConfig(TypedDict):
    """
    Configuration for the GCN encoder in GCNHlpModule.

    Args:
        in_channels: Number of input features per node.
        out_channels: Number of output features (embedding size) per node.
        hidden_channels: Number of hidden units in the intermediate GCN layers.
        num_layers: Number of GCN layers. Defaults to ``2``.
        drop_rate: Dropout rate applied after each hidden GCN layer. Defaults to ``0.0``.
        bias: Whether to include bias terms. Defaults to ``True``.
        improved: Whether to use the improved GCN normalization. Defaults to ``False``.
        add_self_loops: Whether to add self-loops before convolution. Defaults to ``True``.
        normalize: Whether to normalize the adjacency matrix in ``GCNConv``. Defaults to ``True``.
        cached: Whether to cache the normalized graph in ``GCNConv``. Defaults to ``False``.
        graph_reduction_strategy: Strategy for reducing the hypergraph to a graph. Defaults to ``"clique_expansion"``.
        activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``.
        activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict.
    """

    in_channels: int
    out_channels: int
    hidden_channels: NotRequired[int]
    num_layers: NotRequired[int]
    drop_rate: NotRequired[float]
    bias: NotRequired[bool]
    improved: NotRequired[bool]
    add_self_loops: NotRequired[bool]
    normalize: NotRequired[bool]
    cached: NotRequired[bool]
    graph_reduction_strategy: NotRequired[Literal["clique_expansion"]]
    activation_fn: NotRequired[ActivationFn]
    activation_fn_kwargs: NotRequired[dict]

GCNHlpModule

Bases: HlpModule

A LightningModule for GCN-based HLP.

Uses a graph reduction of the input hypergraph to run GCN over nodes, aggregates node embeddings per hyperedge, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config GCNEncoderConfig

Configuration for the GCN encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge. Defaults to "mean".

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.001.

0.001
weight_decay float

L2 regularization. Defaults to 0.0.

0.0
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/gcn_hlp.py
class GCNHlpModule(HlpModule):
    """
    A LightningModule for GCN-based HLP.

    Uses a graph reduction of the input hypergraph to run GCN over nodes,
    aggregates node embeddings per hyperedge, and scores each hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the GCN encoder.
        aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.001``.
        weight_decay: L2 regularization. Defaults to ``0.0``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: GCNEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        weight_decay: float = 0.0,
        metrics: MetricCollection | None = None,
    ):
        encoder = GCN(
            in_channels=encoder_config["in_channels"],
            out_channels=encoder_config["out_channels"],
            hidden_channels=encoder_config.get("hidden_channels"),
            num_layers=encoder_config.get("num_layers", 2),
            drop_rate=encoder_config.get("drop_rate", 0.0),
            bias=encoder_config.get("bias", True),
            activation_fn=encoder_config.get("activation_fn"),
            activation_fn_kwargs=encoder_config.get("activation_fn_kwargs"),
            improved=encoder_config.get("improved", False),
            add_self_loops=encoder_config.get("add_self_loops", True),
            normalize=encoder_config.get("normalize", True),
            cached=encoder_config.get("cached", False),
        )
        decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.encoder_config = encoder_config
        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Reduce the hypergraph to a graph, encode nodes with GCN, aggregate per hyperedge, and score.

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

        Returns:
            Logit scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        # Reduce hypergraph to graph and remove self-loops
        reduced_edge_index = HyperedgeIndex(hyperedge_index).reduce(
            strategy=self.encoder_config.get("graph_reduction_strategy", "clique_expansion")
        )
        edge_index = EdgeIndex(reduced_edge_index).remove_selfloops().item

        # Encode nodes with GCN
        node_embeddings: Tensor = self.encoder(x, edge_index)

        # Aggregate node embeddings per hyperedge
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )

        return self.decoder(hyperedge_embeddings).squeeze(-1)

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Reduce the hypergraph to a graph, encode nodes with GCN, aggregate per hyperedge, and score.

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels).

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences).

required

Returns:

Type Description
Tensor

Logit scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/gcn_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Reduce the hypergraph to a graph, encode nodes with GCN, aggregate per hyperedge, and score.

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

    Returns:
        Logit scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    # Reduce hypergraph to graph and remove self-loops
    reduced_edge_index = HyperedgeIndex(hyperedge_index).reduce(
        strategy=self.encoder_config.get("graph_reduction_strategy", "clique_expansion")
    )
    edge_index = EdgeIndex(reduced_edge_index).remove_selfloops().item

    # Encode nodes with GCN
    node_embeddings: Tensor = self.encoder(x, edge_index)

    # Aggregate node embeddings per hyperedge
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation
    )

    return self.decoder(hyperedge_embeddings).squeeze(-1)

HGNNHlpModule

Bases: HlpModule

A LightningModule for HGNN-based Hyperedge Link Prediction.

Uses HGNN as an encoder to produce structure-aware node embeddings via spectral hypergraph convolution, aggregates them per hyperedge, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config HGNNEncoderConfig

Configuration for the HGNN encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge. Defaults to "mean".

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.01.

0.001
weight_decay float

L2 regularization. Defaults to 5e-4.

0.0005
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/hgnn_hlp.py
class HGNNHlpModule(HlpModule):
    """
    A LightningModule for HGNN-based Hyperedge Link Prediction.

    Uses HGNN as an encoder to produce structure-aware node embeddings via
    spectral hypergraph convolution, aggregates them per hyperedge,
    and scores each hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the HGNN encoder.
        aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.01``.
        weight_decay: L2 regularization. Defaults to ``5e-4``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: HGNNEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        weight_decay: float = 5e-4,
        metrics: MetricCollection | None = None,
    ):
        encoder = HGNN(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config["hidden_channels"],
            num_classes=encoder_config["out_channels"],
            bias=encoder_config.get("bias", True),
            use_batch_normalization=encoder_config.get("use_batch_normalization", False),
            drop_rate=encoder_config.get("drop_rate", 0.5),
        )
        decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Run the full HGNN-based hyperedge link prediction pipeline.

        The pipeline has three stages:
        1. Encode: HGNN applies two rounds of ``D_n^{-1/2} H D_e^{-1} H^T D_n^{-1/2}``
           smoothing to propagate information through the hypergraph topology (nodes ->
           hyperedges -> nodes). The output is a structure-aware node embedding matrix of
           shape ``(num_nodes, out_channels)``.
        2. Aggregate: For each hyperedge being scored, pool the embeddings of its member
           nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge
           embedding that summarizes the collective representation of the hyperedge's nodes.
           Shape: ``(num_hyperedges, out_channels)``.
        3. Decode: A single linear layer (SLP) projects each hyperedge embedding to a
           scalar score representing the likelihood that the hyperedge is a real (positive)
           hyperedge. Shape: ``(num_hyperedges,)``.

        Examples:
            Given 5 nodes with 8 features and 2 hyperedges::

                >>> x.shape  # (5, 8) - all nodes in the hypergraph
                >>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs
                ...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

            The forward pass:
                1. HGNN encodes all 5 nodes using the hypergraph Laplacian.
                   ``node_embeddings.shape = (5, out_channels)``
                2. Aggregate per hyperedge:
                   - hyperedge 0: pool(emb[0], emb[1], emb[2])
                   - hyperedge 1: pool(emb[3], emb[4])
                   ``hyperedge_embeddings.shape = (2, out_channels)``
                3. Decode: one scalar per hyperedge -> ``scores.shape = (2,)``

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
                Must contain **all** nodes referenced in ``hyperedge_index``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``,
                with row 0 containing global node IDs and row 1 hyperedge IDs.

        Returns:
            Logit scores of shape ``(num_hyperedges,)``. Pass through sigmoid to get
            probabilities, or use directly with ``BCEWithLogitsLoss``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        # Encode: two-hop HGNN smoothing (nodes -> hyperedges -> nodes), no graph reduction
        # Example: x: (num_nodes, in_channels)
        #          -> node_embeddings: (num_nodes, out_channels)
        node_embeddings: Tensor = self.encoder(x, hyperedge_index)

        # Aggregate: pool node embeddings per hyperedge
        # shape: (num_hyperedges, out_channels)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )

        # Decode: linear projection to scalar score per hyperedge
        # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Run the full HGNN-based hyperedge link prediction pipeline.

The pipeline has three stages: 1. Encode: HGNN applies two rounds of D_n^{-1/2} H D_e^{-1} H^T D_n^{-1/2} smoothing to propagate information through the hypergraph topology (nodes -> hyperedges -> nodes). The output is a structure-aware node embedding matrix of shape (num_nodes, out_channels). 2. Aggregate: For each hyperedge being scored, pool the embeddings of its member nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge embedding that summarizes the collective representation of the hyperedge's nodes. Shape: (num_hyperedges, out_channels). 3. Decode: A single linear layer (SLP) projects each hyperedge embedding to a scalar score representing the likelihood that the hyperedge is a real (positive) hyperedge. Shape: (num_hyperedges,).

Examples:

Given 5 nodes with 8 features and 2 hyperedges::

>>> x.shape  # (5, 8) - all nodes in the hypergraph
>>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs
...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

The forward pass: 1. HGNN encodes all 5 nodes using the hypergraph Laplacian. node_embeddings.shape = (5, out_channels) 2. Aggregate per hyperedge: - hyperedge 0: pool(emb[0], emb[1], emb[2]) - hyperedge 1: pool(emb[3], emb[4]) hyperedge_embeddings.shape = (2, out_channels) 3. Decode: one scalar per hyperedge -> scores.shape = (2,)

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels). Must contain all nodes referenced in hyperedge_index.

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences), with row 0 containing global node IDs and row 1 hyperedge IDs.

required

Returns:

Type Description
Tensor

Logit scores of shape (num_hyperedges,). Pass through sigmoid to get

Tensor

probabilities, or use directly with BCEWithLogitsLoss.

Source code in hyperbench/hlp/hgnn_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Run the full HGNN-based hyperedge link prediction pipeline.

    The pipeline has three stages:
    1. Encode: HGNN applies two rounds of ``D_n^{-1/2} H D_e^{-1} H^T D_n^{-1/2}``
       smoothing to propagate information through the hypergraph topology (nodes ->
       hyperedges -> nodes). The output is a structure-aware node embedding matrix of
       shape ``(num_nodes, out_channels)``.
    2. Aggregate: For each hyperedge being scored, pool the embeddings of its member
       nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge
       embedding that summarizes the collective representation of the hyperedge's nodes.
       Shape: ``(num_hyperedges, out_channels)``.
    3. Decode: A single linear layer (SLP) projects each hyperedge embedding to a
       scalar score representing the likelihood that the hyperedge is a real (positive)
       hyperedge. Shape: ``(num_hyperedges,)``.

    Examples:
        Given 5 nodes with 8 features and 2 hyperedges::

            >>> x.shape  # (5, 8) - all nodes in the hypergraph
            >>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs
            ...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

        The forward pass:
            1. HGNN encodes all 5 nodes using the hypergraph Laplacian.
               ``node_embeddings.shape = (5, out_channels)``
            2. Aggregate per hyperedge:
               - hyperedge 0: pool(emb[0], emb[1], emb[2])
               - hyperedge 1: pool(emb[3], emb[4])
               ``hyperedge_embeddings.shape = (2, out_channels)``
            3. Decode: one scalar per hyperedge -> ``scores.shape = (2,)``

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            Must contain **all** nodes referenced in ``hyperedge_index``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``,
            with row 0 containing global node IDs and row 1 hyperedge IDs.

    Returns:
        Logit scores of shape ``(num_hyperedges,)``. Pass through sigmoid to get
        probabilities, or use directly with ``BCEWithLogitsLoss``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    # Encode: two-hop HGNN smoothing (nodes -> hyperedges -> nodes), no graph reduction
    # Example: x: (num_nodes, in_channels)
    #          -> node_embeddings: (num_nodes, out_channels)
    node_embeddings: Tensor = self.encoder(x, hyperedge_index)

    # Aggregate: pool node embeddings per hyperedge
    # shape: (num_hyperedges, out_channels)
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation
    )

    # Decode: linear projection to scalar score per hyperedge
    # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
    scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
    return scores

HGNNEncoderConfig

Bases: TypedDict

Configuration for the HGNN encoder in HGNNHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
hidden_channels

Number of hidden units in the intermediate HGNN layer.

required
out_channels

Number of output features (embedding size) per node.

required
bias

Whether to include bias terms. Defaults to True.

required
use_batch_normalization

Whether to use batch normalization. Defaults to False.

required
drop_rate

Dropout rate. Defaults to 0.5.

required
Source code in hyperbench/hlp/hgnn_hlp.py
class HGNNEncoderConfig(TypedDict):
    """
    Configuration for the HGNN encoder in HGNNHlpModule.

    Args:
        in_channels: Number of input features per node.
        hidden_channels: Number of hidden units in the intermediate HGNN layer.
        out_channels: Number of output features (embedding size) per node.
        bias: Whether to include bias terms. Defaults to ``True``.
        use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
        drop_rate: Dropout rate. Defaults to ``0.5``.
    """

    in_channels: int
    hidden_channels: int
    out_channels: int
    bias: NotRequired[bool]
    use_batch_normalization: NotRequired[bool]
    drop_rate: NotRequired[float]

HNHNEncoderConfig

Bases: TypedDict

Configuration for the HNHN encoder in HNHNHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
hidden_channels

Number of hidden units in the intermediate HNHN layer.

required
out_channels

Number of output features (embedding size) per node.

required
bias

Whether to include bias terms. Defaults to True.

required
use_batch_normalization

Whether to use batch normalization. Defaults to False.

required
drop_rate

Dropout rate. Defaults to 0.5.

required
Source code in hyperbench/hlp/hnhn_hlp.py
class HNHNEncoderConfig(TypedDict):
    """
    Configuration for the HNHN encoder in HNHNHlpModule.

    Args:
        in_channels: Number of input features per node.
        hidden_channels: Number of hidden units in the intermediate HNHN layer.
        out_channels: Number of output features (embedding size) per node.
        bias: Whether to include bias terms. Defaults to ``True``.
        use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
        drop_rate: Dropout rate. Defaults to ``0.5``.
    """

    in_channels: int
    hidden_channels: int
    out_channels: int
    bias: NotRequired[bool]
    use_batch_normalization: NotRequired[bool]
    drop_rate: NotRequired[float]

HNHNHlpModule

Bases: HlpModule

A LightningModule for HNHN-based Hyperedge Link Prediction.

Uses HNHN as an encoder to produce node embeddings through explicit hyperedge neurons, aggregates them per hyperedge, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config HNHNEncoderConfig

Configuration for the HNHN encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge. Defaults to "mean".

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.01.

0.01
weight_decay float

L2 regularization. Defaults to 5e-4.

0.0005
scheduler_step_size int

Step size for learning rate scheduler. Defaults to 100.

100
scheduler_gamma float

Multiplicative factor for learning rate decay. Defaults to 0.51.

0.51
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/hnhn_hlp.py
class HNHNHlpModule(HlpModule):
    """
    A LightningModule for HNHN-based Hyperedge Link Prediction.

    Uses HNHN as an encoder to produce node embeddings through explicit
    hyperedge neurons, aggregates them per hyperedge, and scores each
    hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the HNHN encoder.
        aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.01``.
        weight_decay: L2 regularization. Defaults to ``5e-4``.
        scheduler_step_size: Step size for learning rate scheduler. Defaults to ``100``.
        scheduler_gamma: Multiplicative factor for learning rate decay. Defaults to ``0.51``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: HNHNEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.01,
        weight_decay: float = 5e-4,
        scheduler_step_size: int = 100,
        scheduler_gamma: float = 0.51,
        metrics: MetricCollection | None = None,
    ):
        encoder = HNHN(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config["hidden_channels"],
            num_classes=encoder_config["out_channels"],
            bias=encoder_config.get("bias", True),
            use_batch_normalization=encoder_config.get("use_batch_normalization", False),
            drop_rate=encoder_config.get("drop_rate", 0.5),
        )
        decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay
        self.scheduler_step_size = scheduler_step_size
        self.scheduler_gamma = scheduler_gamma

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Run the full HNHN-based hyperedge link prediction pipeline.

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

        Returns:
            Logit scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        node_embeddings: Tensor = self.encoder(x, hyperedge_index)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=self.scheduler_step_size, gamma=self.scheduler_gamma
        )
        return [optimizer], [scheduler]

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Run the full HNHN-based hyperedge link prediction pipeline.

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels).

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences).

required

Returns:

Type Description
Tensor

Logit scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/hnhn_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Run the full HNHN-based hyperedge link prediction pipeline.

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

    Returns:
        Logit scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    node_embeddings: Tensor = self.encoder(x, hyperedge_index)
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation
    )
    scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
    return scores

HGNNPEncoderConfig

Bases: TypedDict

Configuration for the HGNN+ encoder in HGNNPHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
hidden_channels

Number of hidden units in the intermediate HGNN+ layer.

required
out_channels

Number of output features (embedding size) per node.

required
bias

Whether to include bias terms. Defaults to True.

required
use_batch_normalization

Whether to use batch normalization. Defaults to False.

required
drop_rate

Dropout rate. Defaults to 0.5.

required
Source code in hyperbench/hlp/hgnnp_hlp.py
class HGNNPEncoderConfig(TypedDict):
    """
    Configuration for the HGNN+ encoder in HGNNPHlpModule.

    Args:
        in_channels: Number of input features per node.
        hidden_channels: Number of hidden units in the intermediate HGNN+ layer.
        out_channels: Number of output features (embedding size) per node.
        bias: Whether to include bias terms. Defaults to ``True``.
        use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
        drop_rate: Dropout rate. Defaults to ``0.5``.
    """

    in_channels: int
    hidden_channels: int
    out_channels: int
    bias: NotRequired[bool]
    use_batch_normalization: NotRequired[bool]
    drop_rate: NotRequired[float]

HGNNPHlpModule

Bases: HlpModule

A LightningModule for HGNN+-based Hyperedge Link Prediction.

Uses HGNN+ as an encoder to produce structure-aware node embeddings via row-stochastic hypergraph convolution, aggregates them per hyperedge, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config HGNNPEncoderConfig

Configuration for the HGNN+ encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge. Defaults to "mean".

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.01.

0.01
weight_decay float

L2 regularization. Defaults to 5e-4.

0.0005
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/hgnnp_hlp.py
class HGNNPHlpModule(HlpModule):
    """
    A LightningModule for HGNN+-based Hyperedge Link Prediction.

    Uses HGNN+ as an encoder to produce structure-aware node embeddings via
    row-stochastic hypergraph convolution, aggregates them per hyperedge,
    and scores each hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the HGNN+ encoder.
        aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.01``.
        weight_decay: L2 regularization. Defaults to ``5e-4``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: HGNNPEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.01,
        weight_decay: float = 5e-4,
        metrics: MetricCollection | None = None,
    ):
        encoder = HGNNP(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config["hidden_channels"],
            num_classes=encoder_config["out_channels"],
            bias=encoder_config.get("bias", True),
            use_batch_normalization=encoder_config.get("use_batch_normalization", False),
            drop_rate=encoder_config.get("drop_rate", 0.5),
        )
        decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Run the full HGNN+-based hyperedge link prediction pipeline.

        The pipeline has three stages:
        1. Encode: HGNN+ applies two rounds of ``D_v^{-1} H D_e^{-1} H^T``
           smoothing to propagate information through the hypergraph topology with
           two-stage mean aggregation. The output is a structure-aware node
           embedding matrix of shape ``(num_nodes, out_channels)``.
        2. Aggregate: For each hyperedge being scored, pool the embeddings of its member
           nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge
           embedding of shape ``(num_hyperedges, out_channels)``.
        3. Decode: A single linear layer projects each hyperedge embedding to a
           scalar score. Shape: ``(num_hyperedges,)``.

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
                Must contain **all** nodes referenced in ``hyperedge_index``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``,
                with row 0 containing global node IDs and row 1 hyperedge IDs.

        Returns:
            Logit scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        # Encode: produce node embeddings using HGNN+, no graph reduction is applied
        # Example: x: (num_nodes, in_channels)
        #          -> node_embeddings: (num_nodes, out_channels), out_channels)
        node_embeddings: Tensor = self.encoder(x, hyperedge_index)

        # Aggregate: pool node embeddings per hyperedge
        # shape: (num_hyperedges, out_channels)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )

        # Decode: linear projection to scalar score per hyperedge
        # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Run the full HGNN+-based hyperedge link prediction pipeline.

The pipeline has three stages: 1. Encode: HGNN+ applies two rounds of D_v^{-1} H D_e^{-1} H^T smoothing to propagate information through the hypergraph topology with two-stage mean aggregation. The output is a structure-aware node embedding matrix of shape (num_nodes, out_channels). 2. Aggregate: For each hyperedge being scored, pool the embeddings of its member nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge embedding of shape (num_hyperedges, out_channels). 3. Decode: A single linear layer projects each hyperedge embedding to a scalar score. Shape: (num_hyperedges,).

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels). Must contain all nodes referenced in hyperedge_index.

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences), with row 0 containing global node IDs and row 1 hyperedge IDs.

required

Returns:

Type Description
Tensor

Logit scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/hgnnp_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Run the full HGNN+-based hyperedge link prediction pipeline.

    The pipeline has three stages:
    1. Encode: HGNN+ applies two rounds of ``D_v^{-1} H D_e^{-1} H^T``
       smoothing to propagate information through the hypergraph topology with
       two-stage mean aggregation. The output is a structure-aware node
       embedding matrix of shape ``(num_nodes, out_channels)``.
    2. Aggregate: For each hyperedge being scored, pool the embeddings of its member
       nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge
       embedding of shape ``(num_hyperedges, out_channels)``.
    3. Decode: A single linear layer projects each hyperedge embedding to a
       scalar score. Shape: ``(num_hyperedges,)``.

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            Must contain **all** nodes referenced in ``hyperedge_index``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``,
            with row 0 containing global node IDs and row 1 hyperedge IDs.

    Returns:
        Logit scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    # Encode: produce node embeddings using HGNN+, no graph reduction is applied
    # Example: x: (num_nodes, in_channels)
    #          -> node_embeddings: (num_nodes, out_channels), out_channels)
    node_embeddings: Tensor = self.encoder(x, hyperedge_index)

    # Aggregate: pool node embeddings per hyperedge
    # shape: (num_hyperedges, out_channels)
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation
    )

    # Decode: linear projection to scalar score per hyperedge
    # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
    scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
    return scores

HlpModule

Bases: LightningModule

A LightningModule for HLP models with optional negative sampling.

Parameters:

Name Type Description Default
encoder Module | None

Optional encoder module. Defaults to None as not all HLP model will use an encoder.

None
decoder Module

Decoder module to use to predict whether hyperedges are positive or negative.

required
loss_fn Module

Loss function.

required
metrics MetricCollection | None

Optional MetricCollection of torchmetrics to compute during evaluation. Cloned per stage (train, val, test) for independent state accumulation.

None
negative_sampler NegativeSampler | None

Optional negative sampler. If None, no negative sampling is performed.

None
negative_sampling_schedule NegativeSamplingSchedule

When to perform negative sampling during training. Defaults to EVERY_EPOCH.

EVERY_EPOCH
negative_sampling_every_n int

If using EVERY_N_EPOCHS schedule, how many epochs between negative sampling runs. Defaults to 1.

1
Source code in hyperbench/hlp/common.py
class HlpModule(L.LightningModule):
    """
    A LightningModule for HLP models with optional negative sampling.

    Args:
        encoder: Optional encoder module. Defaults to ``None`` as not all HLP model will use an encoder.
        decoder: Decoder module to use to predict whether hyperedges are positive or negative.
        loss_fn: Loss function.
        metrics: Optional ``MetricCollection`` of torchmetrics to compute during evaluation.
            Cloned per stage (train, val, test) for independent state accumulation.
        negative_sampler: Optional negative sampler. If ``None``, no negative sampling is performed.
        negative_sampling_schedule: When to perform negative sampling during training. Defaults to ``EVERY_EPOCH``.
        negative_sampling_every_n: If using ``EVERY_N_EPOCHS`` schedule, how many epochs between negative sampling runs. Defaults to ``1``.
    """

    def __init__(
        self,
        decoder: nn.Module,
        loss_fn: nn.Module,
        encoder: nn.Module | None = None,
        metrics: MetricCollection | None = None,
        negative_sampler: NegativeSampler | None = None,
        negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH,
        negative_sampling_every_n: int = 1,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.loss_fn = loss_fn

        if metrics is not None:
            self.train_metrics = metrics.clone(prefix=f"{Stage.TRAIN.value}_")
            self.val_metrics = metrics.clone(prefix=f"{Stage.VAL.value}_")
            self.test_metrics = metrics.clone(prefix=f"{Stage.TEST.value}_")
        else:
            self.train_metrics = None
            self.val_metrics = None
            self.test_metrics = None

        self.__negative_sampling_scheduler = None
        if negative_sampler is not None:
            self.__negative_sampling_scheduler = NegativeSamplingScheduler(
                negative_sampler,
                negative_sampling_schedule,
                negative_sampling_every_n,
            )

    @property
    def negative_sampling_config(self) -> dict[str, Any]:
        if self.__negative_sampling_scheduler is None:
            return {}
        return self.__negative_sampling_scheduler.config

    def _compute_loss(
        self,
        scores: Tensor,
        labels: Tensor,
        batch_size: int,
        stage: Stage,
    ) -> Tensor:
        """
        Compute and log loss based on scores and labels.

        Args:
            scores: The predicted scores from the model.
            labels: The true labels corresponding to the scores.
            batch_size: The size of the current batch, used for logging.
            stage: The current stage (train/val/test) for logging purposes.

        Returns:
            The computed loss tensor.
        """
        loss = self.loss_fn(scores, labels)
        self.log(name=f"{stage.value}_loss", value=loss, prog_bar=True, batch_size=batch_size)
        return loss

    def _compute_metrics(
        self,
        scores: Tensor,
        labels: Tensor,
        batch_size: int,
        stage: Stage,
    ) -> None:
        """
        Compute and log metrics based on scores and labels.

        Uses class-based torchmetrics with proper multi-batch accumulation:
        1. ``update()`` accumulates predictions/targets across batches.
        2. Passing the MetricCollection to ``self.log_dict()`` tells Lightning to call ``compute()`` at epoch end and ``reset()`` automatically.

        Args:
            scores: The predicted scores (logits) from the model.
            labels: The true labels corresponding to the scores.
            batch_size: The size of the current batch, used for logging.
            stage: The current stage (train/val/test) for logging purposes.
        """
        stage_metrics = self._get_stage_metrics(stage)
        if stage_metrics is None:
            return  # No metrics to compute

        # Apply sigmoid to convert logits to probabilities as BinaryAUROC
        # and BinaryAveragePrecision expect probabilities in [0, 1]
        preds = torch.sigmoid(scores)
        targets = labels.long()

        # Accumulate predictions/targets for this batch
        stage_metrics.update(preds, targets)

        self.log_dict(
            stage_metrics,
            prog_bar=True,
            on_step=False,
            on_epoch=True,  # Compute and log metrics at epoch end, not per step, for proper accumulation
            batch_size=batch_size,
        )

    def _get_stage_metrics(self, stage: Stage) -> MetricCollection | None:
        """
        Return the metric collection for the given stage, or ``None``.

        Args:
            stage: The current stage (train/val/test) for which to get metrics.

        Returns:
            The metric collection corresponding to the given stage, or ``None`` if no metrics are configured.
        """
        match stage:
            case Stage.TRAIN:
                return self.train_metrics
            case Stage.VAL:
                return self.val_metrics
            case Stage.TEST:
                return self.test_metrics
            case _:
                raise ValueError(f"Unrecognized stage: {stage}")

    def _should_sample_negatives(self) -> bool:
        """Whether to resample negatives for the current epoch."""
        if self.__negative_sampling_scheduler is None:
            raise ValueError(
                "Asked to check negative sampling schedule but no negative sampler is configured."
            )
        return self.__negative_sampling_scheduler.should_sample(self.current_epoch)

    def _sample_negatives(self, batch: HData) -> HData:
        """
        Sample fresh negatives if the schedule requires it, otherwise return cache.

        Args:
            batch: The current batch of data for which to sample negatives.

        Returns:
            A batch of negative samples, either freshly sampled or from cache.
        """
        if self.__negative_sampling_scheduler is None:
            raise ValueError("Asked to sample negatives but no negative sampler is not configured.")
        return self.__negative_sampling_scheduler.sample(batch, self.current_epoch)

HyperGCNHlpModule

Bases: HlpModule

A LightningModule for HyperGCN-based Hyperedge Link Prediction.

Uses HyperGCN as an encoder to produce structure-aware node embeddings via graph convolution on the full hypergraph, aggregates them per hyperedge, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config HyperGCNEncoderConfig

Configuration for the HyperGCN encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge. Defaults to "mean".

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.01.

0.01
weight_decay float

L2 regularization. Defaults to 5e-4.

0.0005
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/hypergcn_hlp.py
class HyperGCNHlpModule(HlpModule):
    """
    A LightningModule for HyperGCN-based Hyperedge Link Prediction.

    Uses HyperGCN as an encoder to produce structure-aware node embeddings via
    graph convolution on the full hypergraph, aggregates them per hyperedge,
    and scores each hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the HyperGCN encoder.
        aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.01``.
        weight_decay: L2 regularization. Defaults to ``5e-4``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: HyperGCNEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.01,
        weight_decay: float = 5e-4,
        metrics: MetricCollection | None = None,
    ):
        encoder = HyperGCN(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config["hidden_channels"],
            num_classes=encoder_config["out_channels"],
            bias=encoder_config.get("bias", True),
            use_batch_normalization=encoder_config.get("use_batch_normalization", False),
            drop_rate=encoder_config.get("drop_rate", 0.5),
            use_mediator=encoder_config.get("use_mediator", False),
            fast=encoder_config.get("fast", True),
        )
        decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Encode node features via HyperGCN, aggregate per hyperedge, and score.

        Steps:
            1. Encode: HyperGCN builds a GCN Laplacian from ``hyperedge_index``
               and applies message passing to produce structure-aware node embeddings.
            2. Aggregate: For each hyperedge, aggregate its member nodes' embeddings
               using the configured pooling method (mean/max/min/sum).
            3. Decode: A linear layer scores each hyperedge embedding.

        Examples:
            Given 5 nodes with 3 features and 2 hyperedges::

                >>> x.shape  # (5, 3) — all nodes in the hypergraph
                >>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs (global)
                ...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

            The forward pass:
                1. HyperGCN encodes all 5 nodes using the full graph Laplacian.
                   ``node_embeddings.shape = (5, out_channels)``
                2. Aggregate per hyperedge:
                   - hyperedge 0: pool(emb[0], emb[1], emb[2])
                   - hyperedge 1: pool(emb[3], emb[4])
                3. Decode: one scalar score per hyperedge → ``scores.shape = (2,)``

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
                Must contain **all** nodes in the hypergraph.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``
                with **global** node IDs.

        Returns:
            Scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        # Encode: HyperGCN applies Laplacian-based message passing
        # Example: x: (num_nodes, in_channels)
        #          -> node_embeddings: (num_nodes, out_channels)
        node_embeddings: Tensor = self.encoder(x, hyperedge_index)

        # Aggregate: pool node embeddings per hyperedge
        # shape: (num_hyperedges, out_channels)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )

        # Decode: linear projection to scalar score per hyperedge
        # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Encode node features via HyperGCN, aggregate per hyperedge, and score.

Steps
  1. Encode: HyperGCN builds a GCN Laplacian from hyperedge_index and applies message passing to produce structure-aware node embeddings.
  2. Aggregate: For each hyperedge, aggregate its member nodes' embeddings using the configured pooling method (mean/max/min/sum).
  3. Decode: A linear layer scores each hyperedge embedding.

Examples:

Given 5 nodes with 3 features and 2 hyperedges::

>>> x.shape  # (5, 3) — all nodes in the hypergraph
>>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs (global)
...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

The forward pass: 1. HyperGCN encodes all 5 nodes using the full graph Laplacian. node_embeddings.shape = (5, out_channels) 2. Aggregate per hyperedge: - hyperedge 0: pool(emb[0], emb[1], emb[2]) - hyperedge 1: pool(emb[3], emb[4]) 3. Decode: one scalar score per hyperedge → scores.shape = (2,)

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels). Must contain all nodes in the hypergraph.

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences) with global node IDs.

required

Returns:

Type Description
Tensor

Scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/hypergcn_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Encode node features via HyperGCN, aggregate per hyperedge, and score.

    Steps:
        1. Encode: HyperGCN builds a GCN Laplacian from ``hyperedge_index``
           and applies message passing to produce structure-aware node embeddings.
        2. Aggregate: For each hyperedge, aggregate its member nodes' embeddings
           using the configured pooling method (mean/max/min/sum).
        3. Decode: A linear layer scores each hyperedge embedding.

    Examples:
        Given 5 nodes with 3 features and 2 hyperedges::

            >>> x.shape  # (5, 3) — all nodes in the hypergraph
            >>> hyperedge_index = [[0, 1, 2, 3, 4],  # node IDs (global)
            ...                    [0, 0, 0, 1, 1]]  # hyperedge IDs

        The forward pass:
            1. HyperGCN encodes all 5 nodes using the full graph Laplacian.
               ``node_embeddings.shape = (5, out_channels)``
            2. Aggregate per hyperedge:
               - hyperedge 0: pool(emb[0], emb[1], emb[2])
               - hyperedge 1: pool(emb[3], emb[4])
            3. Decode: one scalar score per hyperedge → ``scores.shape = (2,)``

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            Must contain **all** nodes in the hypergraph.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``
            with **global** node IDs.

    Returns:
        Scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    # Encode: HyperGCN applies Laplacian-based message passing
    # Example: x: (num_nodes, in_channels)
    #          -> node_embeddings: (num_nodes, out_channels)
    node_embeddings: Tensor = self.encoder(x, hyperedge_index)

    # Aggregate: pool node embeddings per hyperedge
    # shape: (num_hyperedges, out_channels)
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation
    )

    # Decode: linear projection to scalar score per hyperedge
    # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
    scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
    return scores

HyperGCNEncoderConfig

Bases: TypedDict

Configuration for the HyperGCN encoder in HyperGCNHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
hidden_channels

Number of hidden units in the intermediate HyperGCN layer.

required
out_channels

Number of output features (embedding size) per node.

required
bias

Whether to include bias terms. Defaults to True.

required
use_batch_normalization

Whether to use batch normalization. Defaults to False.

required
drop_rate

Dropout rate. Defaults to 0.5.

required
use_mediator

Whether to use mediator nodes for hyperedge-to-edge conversion. Defaults to False.

required
fast

Whether to cache the graph structure after first computation. Defaults to True.

required
Source code in hyperbench/hlp/hypergcn_hlp.py
class HyperGCNEncoderConfig(TypedDict):
    """
    Configuration for the HyperGCN encoder in HyperGCNHlpModule.

    Args:
        in_channels: Number of input features per node.
        hidden_channels: Number of hidden units in the intermediate HyperGCN layer.
        out_channels: Number of output features (embedding size) per node.
        bias: Whether to include bias terms. Defaults to ``True``.
        use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
        drop_rate: Dropout rate. Defaults to ``0.5``.
        use_mediator: Whether to use mediator nodes for hyperedge-to-edge conversion. Defaults to ``False``.
        fast: Whether to cache the graph structure after first computation. Defaults to ``True``.
    """

    in_channels: int
    hidden_channels: int
    out_channels: int
    bias: NotRequired[bool]
    use_batch_normalization: NotRequired[bool]
    drop_rate: NotRequired[float]
    use_mediator: NotRequired[bool]
    fast: NotRequired[bool]

MLPHlpModule

Bases: HlpModule

A LightningModule for MLP-based Hyperedge Link Prediction.

Uses an MLP encoder to produce node embeddings, aggregates them per hyperedge via mean pooling, and scores each hyperedge with a linear decoder.

Parameters:

Name Type Description Default
encoder_config MlpEncoderConfig

Configuration for the MLP encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge.

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.001.

0.001
metrics MetricCollection | None

Optional dictionary of metric functions.

None
Source code in hyperbench/hlp/mlp_hlp.py
class MLPHlpModule(HlpModule):
    """
    A LightningModule for MLP-based Hyperedge Link Prediction.

    Uses an MLP encoder to produce node embeddings, aggregates them per hyperedge
    via mean pooling, and scores each hyperedge with a linear decoder.

    Args:
        encoder_config: Configuration for the MLP encoder.
        aggregation: Method to aggregate node embeddings per hyperedge.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.001``.
        metrics: Optional dictionary of metric functions.
    """

    def __init__(
        self,
        encoder_config: MlpEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        metrics: MetricCollection | None = None,
    ):
        # The encoder outputs node embeddings of shape (num_nodes, out_channels).
        encoder = MLP(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config.get("hidden_channels"),
            out_channels=encoder_config.get("out_channels", 1),
            num_layers=encoder_config.get("num_layers", 1),
            activation_fn=encoder_config.get("activation_fn"),
            activation_fn_kwargs=encoder_config.get("activation_fn_kwargs"),
            normalization_fn=encoder_config.get("normalization_fn"),
            normalization_fn_kwargs=encoder_config.get("normalization_fn_kwargs"),
            bias=encoder_config.get("bias", True),
            drop_rate=encoder_config.get("drop_rate", 0.0),
        )

        # The decoder takes in the aggregated hyperedge embeddings of shape (num_hyperedges, encoder_config.out_channels)
        # and produces a score for each hyperedge of shape (num_hyperedges, 1).
        decoder = SLP(in_channels=encoder_config.get("out_channels", 1), out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Encode node features, aggregate per hyperedge via mean pooling, and score.

        Examples:
            Given 4 nodes with 3 features each and 2 hyperedges:
                >>> x = [[0.1, 0.2, 0.3],   # node 0
                ...      [0.4, 0.5, 0.6],   # node 1
                ...      [0.7, 0.8, 0.9],   # node 2
                ...      [1.0, 1.1, 1.2]]   # node 3

                >>> # hyperedge 0 = {node 0, node 1, node 2}
                >>> # hyperedge 1 = {node 2, node 3}
                >>> hyperedge_index = [[0, 1, 2, 2, 3],   # node ids
                ...                    [0, 0, 0, 1, 1]]   # hyperedge ids

            The forward pass:
                1. Encoder maps each node to an embedding vector.
                2. Aggregate embeddings by summing them per hyperedge:
                    - hyperedge 0: emb[0] + emb[1] + emb[2]
                    - hyperedge 1: emb[2] + emb[3]
                3. Sums are divided by the number of nodes per hyperedge (mean pooling):
                    - hyperedge 0: (emb[0] + emb[1] + emb[2]) / 3
                    - hyperedge 1: (emb[2] + emb[3]) / 2
                4. Decoder scores each hyperedge embedding, producing one scalar per hyperedge.

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

        Returns:
            Scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")

        # Encode: map each node raw features to an embedding vector.
        # x: (num_nodes, in_channels) -> node_embeddings: (num_nodes, out_channels)
        # Example: in_channels=3, out_channels=2
        #          -> node 0: [0.1, 0.2, 0.3] -> [e00, e01]
        #          -> node 1: [0.4, 0.5, 0.6] -> [e10, e11]
        #          -> node 2: [0.7, 0.8, 0.9] -> [e20, e21]
        #          -> node 3: [1.0, 1.1, 1.2] -> [e30, e31]
        node_embeddings: Tensor = self.encoder(x)

        # Aggregate: for each hyperedge, aggregate the embeddings of its member nodes.
        # Example::
        # - hyperedge 0 contains node 0, 1, 2 -> aggregate([e00, e01], [e10, e11], [e20, e21]) -> [pooled_0, pooled_1]
        # - hyperedge 1 contains node 2, 3 -> aggregate([e20, e21], [e30, e31]) -> [pooled_0, pooled_1]
        # shape: (num_hyperedges, out_channels)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation,
        )

        # Decode: score each hyperedge embedding, producing one scalar per hyperedge.
        # Example:
        # - hyperedge 0: [pooled_0, pooled_1] -> score_0
        # - hyperedge 1: [pooled_0, pooled_1] -> score_1
        # shape: (2,)
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN)
        self._compute_metrics(scores, labels, batch_size, Stage.TRAIN)
        return loss

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Encode node features, aggregate per hyperedge via mean pooling, and score.

Examples:

Given 4 nodes with 3 features each and 2 hyperedges: >>> x = [[0.1, 0.2, 0.3], # node 0 ... [0.4, 0.5, 0.6], # node 1 ... [0.7, 0.8, 0.9], # node 2 ... [1.0, 1.1, 1.2]] # node 3

>>> # hyperedge 0 = {node 0, node 1, node 2}
>>> # hyperedge 1 = {node 2, node 3}
>>> hyperedge_index = [[0, 1, 2, 2, 3],   # node ids
...                    [0, 0, 0, 1, 1]]   # hyperedge ids

The forward pass: 1. Encoder maps each node to an embedding vector. 2. Aggregate embeddings by summing them per hyperedge: - hyperedge 0: emb[0] + emb[1] + emb[2] - hyperedge 1: emb[2] + emb[3] 3. Sums are divided by the number of nodes per hyperedge (mean pooling): - hyperedge 0: (emb[0] + emb[1] + emb[2]) / 3 - hyperedge 1: (emb[2] + emb[3]) / 2 4. Decoder scores each hyperedge embedding, producing one scalar per hyperedge.

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels).

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences).

required

Returns:

Type Description
Tensor

Scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/mlp_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Encode node features, aggregate per hyperedge via mean pooling, and score.

    Examples:
        Given 4 nodes with 3 features each and 2 hyperedges:
            >>> x = [[0.1, 0.2, 0.3],   # node 0
            ...      [0.4, 0.5, 0.6],   # node 1
            ...      [0.7, 0.8, 0.9],   # node 2
            ...      [1.0, 1.1, 1.2]]   # node 3

            >>> # hyperedge 0 = {node 0, node 1, node 2}
            >>> # hyperedge 1 = {node 2, node 3}
            >>> hyperedge_index = [[0, 1, 2, 2, 3],   # node ids
            ...                    [0, 0, 0, 1, 1]]   # hyperedge ids

        The forward pass:
            1. Encoder maps each node to an embedding vector.
            2. Aggregate embeddings by summing them per hyperedge:
                - hyperedge 0: emb[0] + emb[1] + emb[2]
                - hyperedge 1: emb[2] + emb[3]
            3. Sums are divided by the number of nodes per hyperedge (mean pooling):
                - hyperedge 0: (emb[0] + emb[1] + emb[2]) / 3
                - hyperedge 1: (emb[2] + emb[3]) / 2
            4. Decoder scores each hyperedge embedding, producing one scalar per hyperedge.

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

    Returns:
        Scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")

    # Encode: map each node raw features to an embedding vector.
    # x: (num_nodes, in_channels) -> node_embeddings: (num_nodes, out_channels)
    # Example: in_channels=3, out_channels=2
    #          -> node 0: [0.1, 0.2, 0.3] -> [e00, e01]
    #          -> node 1: [0.4, 0.5, 0.6] -> [e10, e11]
    #          -> node 2: [0.7, 0.8, 0.9] -> [e20, e21]
    #          -> node 3: [1.0, 1.1, 1.2] -> [e30, e31]
    node_embeddings: Tensor = self.encoder(x)

    # Aggregate: for each hyperedge, aggregate the embeddings of its member nodes.
    # Example::
    # - hyperedge 0 contains node 0, 1, 2 -> aggregate([e00, e01], [e10, e11], [e20, e21]) -> [pooled_0, pooled_1]
    # - hyperedge 1 contains node 2, 3 -> aggregate([e20, e21], [e30, e31]) -> [pooled_0, pooled_1]
    # shape: (num_hyperedges, out_channels)
    hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
        self.aggregation,
    )

    # Decode: score each hyperedge embedding, producing one scalar per hyperedge.
    # Example:
    # - hyperedge 0: [pooled_0, pooled_1] -> score_0
    # - hyperedge 1: [pooled_0, pooled_1] -> score_1
    # shape: (2,)
    scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
    return scores

MlpEncoderConfig

Bases: TypedDict

Configuration for the MLP encoder in MLPHlpModule.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
out_channels

Number of output features (embedding size) per node.

required
num_layers

Number of layers in the MLP encoder.

required
hidden_channels

Optional number of hidden units per layer. If None, no hidden layers are used and the encoder is a simple linear layer.

required
activation_fn

Optional activation function class to use in the MLP encoder. If None, no activation function is applied.

required
activation_fn_kwargs

Optional dictionary of keyword arguments to pass to the activation function constructor.

required
normalization_fn

Optional normalization function class to use in the MLP encoder. If None, no normalization is applied.

required
normalization_fn_kwargs

Optional dictionary of keyword arguments to pass to the normalization function constructor.

required
bias

Whether to include bias terms in the MLP layers. Defaults to True.

required
drop_rate

Dropout rate to apply after each MLP layer (except the last one). Defaults to 0.0 (no dropout).

required
Source code in hyperbench/hlp/mlp_hlp.py
class MlpEncoderConfig(TypedDict):
    """
    Configuration for the MLP encoder in MLPHlpModule.

    Args:
        in_channels: Number of input features per node.
        out_channels: Number of output features (embedding size) per node.
        num_layers: Number of layers in the MLP encoder.
        hidden_channels: Optional number of hidden units per layer. If ``None``, no hidden layers are used and the encoder is a simple linear layer.
        activation_fn: Optional activation function class to use in the MLP encoder. If ``None``, no activation function is applied.
        activation_fn_kwargs: Optional dictionary of keyword arguments to pass to the activation function constructor.
        normalization_fn: Optional normalization function class to use in the MLP encoder. If ``None``, no normalization is applied.
        normalization_fn_kwargs: Optional dictionary of keyword arguments to pass to the normalization function constructor.
        bias: Whether to include bias terms in the MLP layers. Defaults to ``True``.
        drop_rate: Dropout rate to apply after each MLP layer (except the last one). Defaults to ``0.0`` (no dropout).
    """

    in_channels: int
    out_channels: NotRequired[int]
    num_layers: NotRequired[int]
    hidden_channels: NotRequired[int | None]
    activation_fn: NotRequired[ActivationFn | None]
    activation_fn_kwargs: NotRequired[dict | None]
    normalization_fn: NotRequired[NormalizationFn | None]
    normalization_fn_kwargs: NotRequired[dict | None]
    bias: NotRequired[bool]
    drop_rate: NotRequired[float]

NHPEncoderConfig

Bases: TypedDict

Configuration for the NHP encoder/scorer to be used for hyperedge link prediction.

Parameters:

Name Type Description Default
in_channels

Number of input features per node.

required
hidden_channels

Number of hidden channels for incidence embeddings. Defaults to 512.

required
aggregation

Hyperedge scoring aggregation. "maxmin" uses the paper's element-wise range representation; "mean" uses mean pooling.

required
bias

Whether to include bias terms. Defaults to True.

required
Source code in hyperbench/hlp/nhp_hlp.py
class NHPEncoderConfig(TypedDict):
    """
    Configuration for the NHP encoder/scorer to be used for hyperedge link prediction.

    Args:
        in_channels: Number of input features per node.
        hidden_channels: Number of hidden channels for incidence embeddings. Defaults to ``512``.
        aggregation: Hyperedge scoring aggregation. ``"maxmin"`` uses the paper's
            element-wise range representation; ``"mean"`` uses mean pooling.
        bias: Whether to include bias terms. Defaults to ``True``.
    """

    in_channels: int
    hidden_channels: NotRequired[int]
    activation_fn: NotRequired[ActivationFn | None]
    activation_fn_kwargs: NotRequired[dict | None]
    aggregation: NotRequired[Literal["mean", "maxmin"]]
    bias: NotRequired[bool]

NHPHlpModule

Bases: HlpModule

A LightningModule for undirected NHP hyperedge link prediction.

NHP encodes and scores candidate hyperedges in a single pass. Unlike encoder wrappers that produce reusable global node embeddings, NHP builds candidate-specific incidence embeddings before pooling and scoring each hyperedge.

Parameters:

Name Type Description Default
encoder_config NHPEncoderConfig

Configuration for the NHP encoder/scorer.

required
loss_fn Module | None

Loss function. Defaults to :class:NHPRankingLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.001.

0.001
weight_decay float

L2 regularization. Defaults to 5e-4.

0.0005
metrics MetricCollection | None

Optional metric collection for evaluation.

None
Source code in hyperbench/hlp/nhp_hlp.py
class NHPHlpModule(HlpModule):
    """
    A LightningModule for undirected NHP hyperedge link prediction.

    NHP encodes and scores candidate hyperedges in a single pass.
    Unlike encoder wrappers that produce reusable global node embeddings,
    NHP builds candidate-specific incidence embeddings before pooling and scoring each hyperedge.

    Args:
        encoder_config: Configuration for the NHP encoder/scorer.
        loss_fn: Loss function. Defaults to :class:`NHPRankingLoss`.
        lr: Learning rate for the optimizer. Defaults to ``0.001``.
        weight_decay: L2 regularization. Defaults to ``5e-4``.
        metrics: Optional metric collection for evaluation.
    """

    def __init__(
        self,
        encoder_config: NHPEncoderConfig,
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        weight_decay: float = 5e-4,
        metrics: MetricCollection | None = None,
    ):
        encoder = NHP(
            in_channels=encoder_config["in_channels"],
            hidden_channels=encoder_config.get("hidden_channels", 512),
            activation_fn=encoder_config.get("activation_fn"),
            activation_fn_kwargs=encoder_config.get("activation_fn_kwargs"),
            aggregation=encoder_config.get("aggregation", "maxmin"),
            bias=encoder_config.get("bias", True),
        )

        super().__init__(
            encoder=encoder,
            decoder=nn.Identity(),
            loss_fn=loss_fn if loss_fn is not None else NHPRankingLoss(),
            metrics=metrics,
        )

        self.lr = lr
        self.weight_decay = weight_decay

    def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
        """
        Encode and score each candidate hyperedge.

        Args:
            x: Node feature matrix of shape ``(num_nodes, in_channels)``.
            hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

        Returns:
            Scores of shape ``(num_hyperedges,)``.
        """
        if self.encoder is None:
            raise ValueError("Encoder is not defined for this HLP module.")
        return self.encoder(x, hyperedge_index)

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TRAIN)

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

forward(x, hyperedge_index)

Encode and score each candidate hyperedge.

Parameters:

Name Type Description Default
x Tensor

Node feature matrix of shape (num_nodes, in_channels).

required
hyperedge_index Tensor

Hyperedge connectivity of shape (2, num_incidences).

required

Returns:

Type Description
Tensor

Scores of shape (num_hyperedges,).

Source code in hyperbench/hlp/nhp_hlp.py
def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
    """
    Encode and score each candidate hyperedge.

    Args:
        x: Node feature matrix of shape ``(num_nodes, in_channels)``.
        hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

    Returns:
        Scores of shape ``(num_hyperedges,)``.
    """
    if self.encoder is None:
        raise ValueError("Encoder is not defined for this HLP module.")
    return self.encoder(x, hyperedge_index)

NHPRankingLoss

Bases: Module

Ranking loss that pushes positive hyperedges above sampled negatives.

Examples:

>>> logits = [2.0, 1.0, -1.0]
>>> labels = [1.0, 1.0, 0.0]
>>> loss = NHPRankingLoss()
>>> loss(logits, labels)
>>> loss.ndim
... 0
Source code in hyperbench/nn/loss.py
class NHPRankingLoss(nn.Module):
    """
    Ranking loss that pushes positive hyperedges above sampled negatives.

    Examples:
        >>> logits = [2.0, 1.0, -1.0]
        >>> labels = [1.0, 1.0, 0.0]
        >>> loss = NHPRankingLoss()
        >>> loss(logits, labels)
        >>> loss.ndim
        ... 0
    """

    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        """
        Compute the ranking loss.

        Args:
            logits: Logit scores for each candidate hyperedge, of shape ``(num_hyperedges,)``.
            labels: Binary labels indicating positive (1) and negative (0) hyperedges, of shape ``(num_hyperedges,)``.

        Returns:
            Scalar loss value.
        """
        # Split logits by label as we need to compare positive scores against negative scores.
        # Example: logits = [2.0, 1.0, -1.0]
        #          labels = [1.0, 1.0, 0.0]
        #          -> positive_logits = [2.0, 1.0]
        #          -> negative_logits = [-1.0]
        positive_logits = logits[labels == 1]
        negative_logits = logits[labels == 0]

        positive_scores = torch.sigmoid(positive_logits)
        negative_scores = torch.sigmoid(negative_logits)
        if positive_scores.numel() == 0 or negative_scores.numel() == 0:
            raise ValueError("NHPRankingLoss requires both positive and negative hyperedges.")

        # Objective: enforce that each positive score is higher than the average negative score.
        # For each positive score pos_i:
        #   margin_i = mean(negative_scores) - pos_i
        # We interpret margin_i as follows:
        # - If pos_i > mean(negatives), then margin_i < 0    -> desirable
        # - If pos_i <= mean(negatives), then margin_i >= 0  -> violation
        margins = negative_scores.mean() - positive_scores

        # Then softplus(margin_i):
        # - Is ~0 when margin_i is strongly negative (good ranking).
        # - Grows smoothly when margin_i > 0 (penalizing violations).
        # Final loss is the average over all positive samples.
        return F.softplus(margins).mean()

forward(logits, labels)

Compute the ranking loss.

Parameters:

Name Type Description Default
logits Tensor

Logit scores for each candidate hyperedge, of shape (num_hyperedges,).

required
labels Tensor

Binary labels indicating positive (1) and negative (0) hyperedges, of shape (num_hyperedges,).

required

Returns:

Type Description
Tensor

Scalar loss value.

Source code in hyperbench/nn/loss.py
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
    """
    Compute the ranking loss.

    Args:
        logits: Logit scores for each candidate hyperedge, of shape ``(num_hyperedges,)``.
        labels: Binary labels indicating positive (1) and negative (0) hyperedges, of shape ``(num_hyperedges,)``.

    Returns:
        Scalar loss value.
    """
    # Split logits by label as we need to compare positive scores against negative scores.
    # Example: logits = [2.0, 1.0, -1.0]
    #          labels = [1.0, 1.0, 0.0]
    #          -> positive_logits = [2.0, 1.0]
    #          -> negative_logits = [-1.0]
    positive_logits = logits[labels == 1]
    negative_logits = logits[labels == 0]

    positive_scores = torch.sigmoid(positive_logits)
    negative_scores = torch.sigmoid(negative_logits)
    if positive_scores.numel() == 0 or negative_scores.numel() == 0:
        raise ValueError("NHPRankingLoss requires both positive and negative hyperedges.")

    # Objective: enforce that each positive score is higher than the average negative score.
    # For each positive score pos_i:
    #   margin_i = mean(negative_scores) - pos_i
    # We interpret margin_i as follows:
    # - If pos_i > mean(negatives), then margin_i < 0    -> desirable
    # - If pos_i <= mean(negatives), then margin_i >= 0  -> violation
    margins = negative_scores.mean() - positive_scores

    # Then softplus(margin_i):
    # - Is ~0 when margin_i is strongly negative (good ranking).
    # - Grows smoothly when margin_i > 0 (penalizing violations).
    # Final loss is the average over all positive samples.
    return F.softplus(margins).mean()

Node2VecGCNHlpConfig

Bases: TypedDict

Configuration for the GCN model.

Parameters:

Name Type Description Default
out_channels

Dimension of the output node embeddings from the GCN layers.

required
hidden_channels

Dimension of the hidden node embeddings in the GCN layers.

required
num_layers

Number of GCN layers. Must be at least 1. Defaults to 2.

required
drop_rate

Dropout rate applied after each GCN layer (except the last one). Defaults to 0.0 (no dropout).

required
bias

Whether to include a bias term in the GCN layers. Defaults to True.

required
improved

Whether to use the improved version of GCNConv. Defaults to False.

required
add_self_loops

Whether to add self-loops to the input graph. Defaults to True.

required
normalize

Whether to symmetrically normalize the adjacency matrix in GCNConv. Defaults to True.

required
cached

Whether to cache the normalized adjacency matrix in GCNConv. Only applicable if the graph structure does not change between epochs. Defaults to False.

required
graph_reduction_strategy

Strategy for reducing the hyperedge graph. Defaults to clique_expansion.

required
activation_fn

Activation function to use after each hidden layer. Defaults to nn.ReLU.

required
activation_fn_kwargs

Keyword arguments for the activation function. Defaults to empty dict.

required
Source code in hyperbench/hlp/node2vec_common.py
class Node2VecGCNHlpConfig(TypedDict):
    """
    Configuration for the GCN model.

    Args:
        out_channels: Dimension of the output node embeddings from the GCN layers.
        hidden_channels: Dimension of the hidden node embeddings in the GCN layers.
        num_layers: Number of GCN layers. Must be at least 1. Defaults to ``2``.
        drop_rate: Dropout rate applied after each GCN layer (except the last one). Defaults to ``0.0`` (no dropout).
        bias: Whether to include a bias term in the GCN layers. Defaults to ``True``.
        improved: Whether to use the improved version of GCNConv. Defaults to ``False``.
        add_self_loops: Whether to add self-loops to the input graph. Defaults to ``True``.
        normalize: Whether to symmetrically normalize the adjacency matrix in GCNConv. Defaults to ``True``.
        cached: Whether to cache the normalized adjacency matrix in GCNConv.
            Only applicable if the graph structure does not change between epochs. Defaults to ``False``.
        graph_reduction_strategy: Strategy for reducing the hyperedge graph. Defaults to ``clique_expansion``.
        activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``.
        activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict.
    """

    out_channels: int
    hidden_channels: NotRequired[int]
    num_layers: NotRequired[int]
    drop_rate: NotRequired[float]
    bias: NotRequired[bool]
    improved: NotRequired[bool]
    add_self_loops: NotRequired[bool]
    normalize: NotRequired[bool]
    cached: NotRequired[bool]
    graph_reduction_strategy: NotRequired[Literal["clique_expansion"]]
    activation_fn: NotRequired[ActivationFn]
    activation_fn_kwargs: NotRequired[dict]

Node2VecHlpConfig

Bases: TypedDict

Configuration for the Node2Vec encoder.

Parameters:

Name Type Description Default
context_size

Skip-gram context size for Node2Vec. For example, if context_size=2 and walk_length=5, then for a random walk [v0, v1, v2, v3, v4], the context for v2 would be [v0, v1, v3, v4] as we take neighbors within distance 2 in the walk. The pairs generated by skip-gram would be [(v2, v0), (v2, v1), (v2, v3), (v2, v4)]. Rule of thumb: Graphs with strong local structure (5-10), Graphs with communities/long-range patterns (10-20). Defaults to 10.

required
walk_length

Length of each random walk.

required
num_walks_per_node

Number of walks sampled per node.

required
p

Node2Vec return parameter. Controls the probability of stepping back to the node visited in the previous step. Lower values of p make immediate backtracking more likely, while higher values discourage returning to the previous node.

required
q

Node2Vec in-out parameter. Controls whether walks stay near the source node or explore further outward. Lower values of q bias the walk toward DFS-like exploration and structural similarity, while higher values bias it toward BFS-like exploration and local community structure and homophily.

required
num_negative_samples

Number of negative samples per positive walk context. If set to X, then for each positive pair (u, v) generated from the random walks, X negative pairs (u, v_neg) will be generated, where v_neg is a node sampled uniformly at random from all nodes in the graph. Defaults to 1, meaning one negative sample per positive pair.

required
num_nodes

Number of nodes in the stable node space. Defaults to the number of nodes in the hyperedge_index if not provided.

required
train_hyperedge_index

Training hypereddge index used to build the Node2Vec walk graph. Required in joint mode.

required
graph_reduction_strategy

Strategy for reducing the hyperedge graph. Defaults to clique_expansion.

required
random_walk_batch_size

Batch size used by the walk sampler in joint mode.

required
node2vec_loss_weight

Weight applied to the Node2Vec walk loss in joint mode. This is to decide how much the loss of Node2Vec contributes to the overall loss in joint training, relative to the HLP loss. Defaults to 1.0 (equal weighting). Set to a higher value to prioritize learning good node embeddings, or a lower value to prioritize the HLP loss. Ignored in precomputed mode.

required
sparse

Whether to use sparse gradients in the Node2Vec encoder. Defaults to False.

required
Source code in hyperbench/hlp/node2vec_common.py
class Node2VecHlpConfig(TypedDict):
    """
    Configuration for the Node2Vec encoder.

    Args:
        context_size: Skip-gram context size for Node2Vec.
            For example, if ``context_size=2`` and ``walk_length=5``, then for a random walk ``[v0, v1, v2, v3, v4]``,
            the context for ``v2`` would be ``[v0, v1, v3, v4]`` as we take neighbors within distance 2 in the walk.
            The pairs generated by skip-gram would be ``[(v2, v0), (v2, v1), (v2, v3), (v2, v4)]``.
            Rule of thumb: Graphs with strong local structure (5-10), Graphs with communities/long-range patterns (10-20).
            Defaults to ``10``.
        walk_length: Length of each random walk.
        num_walks_per_node: Number of walks sampled per node.
        p: Node2Vec return parameter. Controls the probability of stepping back to the node visited
            in the previous step. Lower values of ``p`` make immediate backtracking more likely,
            while higher values discourage returning to the previous node.
        q: Node2Vec in-out parameter. Controls whether walks stay near the source node or explore
            further outward. Lower values of ``q`` bias the walk toward DFS-like exploration and
            structural similarity, while higher values bias it toward BFS-like exploration and
            local community structure and homophily.
        num_negative_samples: Number of negative samples per positive walk context.
            If set to ``X``, then for each positive pair ``(u, v)`` generated from the random walks,
            ``X`` negative pairs ``(u, v_neg)`` will be generated,
            where ``v_neg`` is a node sampled uniformly at random from all nodes in the graph.
            Defaults to ``1``, meaning one negative sample per positive pair.
        num_nodes: Number of nodes in the stable node space. Defaults to the number of nodes in the ``hyperedge_index`` if not provided.
        train_hyperedge_index: Training hypereddge index used to build the Node2Vec walk graph. Required in ``joint`` mode.
        graph_reduction_strategy: Strategy for reducing the hyperedge graph. Defaults to ``clique_expansion``.
        random_walk_batch_size: Batch size used by the walk sampler in joint mode.
        node2vec_loss_weight: Weight applied to the Node2Vec walk loss in joint mode.
            This is to decide how much the loss of Node2Vec contributes to the overall loss in joint training, relative to the HLP loss.
             Defaults to ``1.0`` (equal weighting). Set to a higher value to prioritize learning good node embeddings,
             or a lower value to prioritize the HLP loss. Ignored in precomputed mode.
        sparse: Whether to use sparse gradients in the Node2Vec encoder. Defaults to ``False``.
    """

    context_size: NotRequired[int]
    walk_length: NotRequired[int]
    num_walks_per_node: NotRequired[int]
    p: NotRequired[float]
    q: NotRequired[float]
    num_negative_samples: NotRequired[int]
    num_nodes: NotRequired[int]
    train_hyperedge_index: NotRequired[Tensor]
    graph_reduction_strategy: NotRequired[Literal["clique_expansion"]]
    random_walk_batch_size: NotRequired[int]
    node2vec_loss_weight: NotRequired[float]
    sparse: NotRequired[bool]

Node2VecGCNEncoderConfig

Bases: TypedDict

Configuration for the Node2Vec encoder in Node2VecGCNHlpModule.

Parameters:

Name Type Description Default
mode

Whether to use precomputed node embeddings from x or train a Node2Vec encoder jointly inside the module.

required
num_features

Dimension of the node embeddings consumed by the decoder.

required
node2vec_config

Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings.

required
gcn_config

Configuration for the GCN layers.

required
Source code in hyperbench/hlp/node2vecgcn_hlp.py
class Node2VecGCNEncoderConfig(TypedDict):
    """
    Configuration for the Node2Vec encoder in ``Node2VecGCNHlpModule``.

    Args:
        mode: Whether to use precomputed node embeddings from ``x`` or train a Node2Vec encoder jointly inside the module.
        num_features: Dimension of the node embeddings consumed by the decoder.
        node2vec_config: Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings.
        gcn_config: Configuration for the GCN layers.
    """

    mode: NotRequired[Node2VecMode]
    num_features: int
    node2vec_config: Node2VecHlpConfig
    gcn_config: Node2VecGCNHlpConfig

Node2VecGCNHlpModule

Bases: HlpModule

A LightningModule for Node2Vec-based Hyperedge Link Prediction with GCN encoder.

Supports two modes: - precomputed: use node embeddings already stored in batch.x. - joint: train a Node2Vec encoder jointly with the GCN layers and hyperedge decoder.

Parameters:

Name Type Description Default
encoder_config Node2VecGCNEncoderConfig

Configuration for the Node2Vec encoder and GCN layers.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge.

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.001.

0.001
weight_decay float

Weight decay (L2 regularization) for the optimizer. Defaults to 0.0 (no weight decay).

0.0
metrics MetricCollection | None

Optional dictionary of metric functions.

None
Source code in hyperbench/hlp/node2vecgcn_hlp.py
class Node2VecGCNHlpModule(HlpModule):
    """
    A LightningModule for Node2Vec-based Hyperedge Link Prediction with GCN encoder.

    Supports two modes:
    - ``precomputed``: use node embeddings already stored in ``batch.x``.
    - ``joint``: train a Node2Vec encoder jointly with the GCN layers and hyperedge decoder.

    Args:
        encoder_config: Configuration for the Node2Vec encoder and GCN layers.
        aggregation: Method to aggregate node embeddings per hyperedge.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.001``.
        weight_decay: Weight decay (L2 regularization) for the optimizer. Defaults to ``0.0`` (no weight decay).
        metrics: Optional dictionary of metric functions.
    """

    def __init__(
        self,
        encoder_config: Node2VecGCNEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        weight_decay: float = 0.0,
        metrics: MetricCollection | None = None,
    ):
        self.mode = encoder_config.get("mode", NODE2VEC_JOINT_MODE)
        self.embedding_dim = encoder_config["num_features"]

        self.node2vec_hlp_config = encoder_config["node2vec_config"]
        self.gcn_hlp_config = encoder_config["gcn_config"]

        node2vecgcn_encoder = (
            self.__build_node2vecgcn_encoder(
                embedding_dim=self.embedding_dim,
                node2vec_config=self.node2vec_hlp_config,
                gcn_config=self.gcn_hlp_config,
                mode=self.mode,
            )
            if self.mode == NODE2VEC_JOINT_MODE
            else None
        )

        decoder = SLP(in_channels=self.gcn_hlp_config["out_channels"], out_channels=1)

        super().__init__(
            encoder=node2vecgcn_encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.precomputed_gcn_encoder = (
            self.__build_gcn_encoder(self.embedding_dim, self.gcn_hlp_config)
            if self.mode == NODE2VEC_PRECOMPUTED_MODE
            else None
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay
        self.random_walk_batch_size = self.node2vec_hlp_config.get("random_walk_batch_size", 128)
        self.node2vec_loss_weight = self.node2vec_hlp_config.get("node2vec_loss_weight", 1.0)

        self.__walk_loader_state = Node2VecWalkLoaderState()

    def forward(
        self,
        x: Tensor,
        hyperedge_index: Tensor,
        global_node_ids: Tensor | None = None,
    ) -> Tensor:
        gcn_edge_index = self.__to_gcn_edge_index(hyperedge_index)

        if self.mode == NODE2VEC_JOINT_MODE:
            encoder = _to_node2vec_encoder(self.encoder, self.mode)
            _validate_global_node_ids(encoder.num_embeddings, global_node_ids, self.mode)
            node_embeddings = encoder(batch=global_node_ids, edge_index=gcn_edge_index)
        else:
            if x.size(1) != self.embedding_dim:
                raise ValueError(
                    f"Expected precomputed node embeddings with dimension "
                    f"{self.embedding_dim}, got {x.size(1)}."
                )
            if self.precomputed_gcn_encoder is None:
                raise ValueError("Precomputed GCN encoder is not initialized.")
            node_embeddings = self.precomputed_gcn_encoder(x, gcn_edge_index)

        hyperedge_embeddings = HyperedgeAggregator(
            hyperedge_index,
            node_embeddings,
        ).pool(self.aggregation)

        return self.decoder(hyperedge_embeddings).squeeze(-1)

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)
        labels = batch.y
        batch_size = batch.num_hyperedges

        if self.mode == NODE2VEC_JOINT_MODE:
            positive_random_walk, negative_random_walk = _next_walk_batch(
                mode=self.mode,
                encoder=self.encoder,
                batch_size=self.random_walk_batch_size,
                state=self.__walk_loader_state,
            )
            positive_random_walk = positive_random_walk.to(self.device)
            negative_random_walk = negative_random_walk.to(self.device)

            hlp_loss = self.loss_fn(scores, labels)
            node2vec_loss = _to_node2vec_encoder(self.encoder, self.mode).loss(
                positive_random_walk, negative_random_walk
            )
            loss = hlp_loss + (self.node2vec_loss_weight * node2vec_loss)

            self.log("train_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size)
            self.log("train_node2vec_loss", node2vec_loss, prog_bar=True, batch_size=batch_size)
            self.log("train_loss", loss, prog_bar=True, batch_size=batch_size)
        else:
            loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN)

        self._compute_metrics(scores, labels, batch_size, Stage.TRAIN)
        return loss

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __build_gcn_encoder(self, embedding_dim: int, gcn_config: Node2VecGCNHlpConfig) -> GCN:
        return GCN(**_to_gcn_config(embedding_dim, gcn_config))

    def __build_node2vecgcn_encoder(
        self,
        embedding_dim: int,
        node2vec_config: Node2VecHlpConfig,
        gcn_config: Node2VecGCNHlpConfig,
        mode: Node2VecMode,
    ) -> Node2VecGCN:
        _validate_walk_length_and_context_size(
            walk_length=node2vec_config.get("walk_length", 20),
            context_size=node2vec_config.get("context_size", 10),
        )

        edge_index, num_nodes = _to_node2vec_edge_index(node2vec_config, mode)

        model_node2vec_config: Node2VecConfig = {
            "edge_index": edge_index,
            "embedding_dim": embedding_dim,
            "walk_length": node2vec_config.get("walk_length", 20),
            "context_size": node2vec_config.get("context_size", 10),
            "num_walks_per_node": node2vec_config.get("num_walks_per_node", 10),
            "p": node2vec_config.get("p", 1.0),
            "q": node2vec_config.get("q", 1.0),
            "num_negative_samples": node2vec_config.get("num_negative_samples", 1),
            "num_nodes": num_nodes,
            "sparse": node2vec_config.get("sparse", False),
        }

        return Node2VecGCN(
            node2vec_config=model_node2vec_config,
            gcn_config=_to_gcn_config(embedding_dim, gcn_config),
        )

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

    def __to_gcn_edge_index(self, hyperedge_index: Tensor) -> Tensor:
        graph_reduction_strategy = self.gcn_hlp_config.get(
            "graph_reduction_strategy", "clique_expansion"
        )
        reduced_gcn_edge_index = HyperedgeIndex(hyperedge_index).reduce(graph_reduction_strategy)
        return EdgeIndex(reduced_gcn_edge_index).remove_selfloops().item

Node2VecSLPEncoderConfig

Bases: TypedDict

Configuration for the Node2Vec encoder in Node2VecSLPHlpModule.

Parameters:

Name Type Description Default
mode

Whether to use precomputed node embeddings from x or train a Node2Vec encoder jointly inside the module.

required
num_features

Dimension of the node embeddings consumed by the decoder.

required
node2vec_config

Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings.

required
Source code in hyperbench/hlp/node2vecslp_hlp.py
class Node2VecSLPEncoderConfig(TypedDict):
    """
    Configuration for the Node2Vec encoder in ``Node2VecSLPHlpModule``.

    Args:
        mode: Whether to use precomputed node embeddings from ``x`` or train a Node2Vec encoder jointly inside the module.
        num_features: Dimension of the node embeddings consumed by the decoder.
        node2vec_config: Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings.
    """

    mode: NotRequired[Node2VecMode]
    num_features: int
    node2vec_config: Node2VecHlpConfig

Node2VecSLPHlpModule

Bases: HlpModule

A LightningModule for Node2Vec-based Hyperedge Link Prediction.

Supports two modes: - precomputed: use node embeddings already stored in batch.x. - joint: train a Node2Vec encoder jointly with the hyperedge decoder.

Parameters:

Name Type Description Default
encoder_config Node2VecSLPEncoderConfig

Configuration for the Node2Vec encoder.

required
aggregation Literal['mean', 'max', 'min', 'sum']

Method to aggregate node embeddings per hyperedge.

'mean'
loss_fn Module | None

Loss function. Defaults to BCEWithLogitsLoss.

None
lr float

Learning rate for the optimizer. Defaults to 0.001.

0.001
weight_decay float

Weight decay (L2 regularization) for the optimizer. Defaults to 0.0 (no weight decay).

0.0
metrics MetricCollection | None

Optional dictionary of metric functions.

None
Source code in hyperbench/hlp/node2vecslp_hlp.py
class Node2VecSLPHlpModule(HlpModule):
    """
    A LightningModule for Node2Vec-based Hyperedge Link Prediction.

    Supports two modes:
    - ``precomputed``: use node embeddings already stored in ``batch.x``.
    - ``joint``: train a Node2Vec encoder jointly with the hyperedge decoder.

    Args:
        encoder_config: Configuration for the Node2Vec encoder.
        aggregation: Method to aggregate node embeddings per hyperedge.
        loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
        lr: Learning rate for the optimizer. Defaults to ``0.001``.
        weight_decay: Weight decay (L2 regularization) for the optimizer. Defaults to ``0.0`` (no weight decay).
        metrics: Optional dictionary of metric functions.
    """

    def __init__(
        self,
        encoder_config: Node2VecSLPEncoderConfig,
        aggregation: Literal["mean", "max", "min", "sum"] = "mean",
        loss_fn: nn.Module | None = None,
        lr: float = 0.001,
        weight_decay: float = 0.0,
        metrics: MetricCollection | None = None,
    ):
        self.mode = encoder_config.get("mode", NODE2VEC_JOINT_MODE)
        self.embedding_dim = encoder_config["num_features"]
        node2vec_config = encoder_config["node2vec_config"]

        encoder = (
            self.__build_node2vec_encoder(self.embedding_dim, node2vec_config, self.mode)
            if self.mode == NODE2VEC_JOINT_MODE
            else None
        )

        decoder = SLP(in_channels=self.embedding_dim, out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay
        self.random_walk_batch_size = node2vec_config.get("random_walk_batch_size", 128)
        self.node2vec_loss_weight = node2vec_config.get("node2vec_loss_weight", 1.0)

        self.__walk_loader_state = Node2VecWalkLoaderState()

    def forward(
        self,
        x: Tensor,
        hyperedge_index: Tensor,
        global_node_ids: Tensor | None = None,
    ) -> Tensor:
        # Encode: get node embeddings from precomputation or joint encoder
        if self.mode == NODE2VEC_JOINT_MODE:
            encoder = _to_node2vec_encoder(self.encoder, self.mode)
            _validate_global_node_ids(encoder.num_embeddings, global_node_ids, self.mode)
            node_embeddings = encoder(batch=global_node_ids)
        else:
            if x.size(1) != self.embedding_dim:
                raise ValueError(
                    f"Expected precomputed node embeddings with dimension "
                    f"{self.embedding_dim}, got {x.size(1)}."
                )
            node_embeddings = x

        # Aggregate: pool node embeddings per hyperedge
        # shape: (num_hyperedges, embedding_dim)
        hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
            self.aggregation
        )

        # Decode: linear projection to scalar score per hyperedge
        # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)
        labels = batch.y
        batch_size = batch.num_hyperedges

        if self.mode == NODE2VEC_JOINT_MODE:
            # Node2Vec.loss() is already a stochastic objective over sampled walks,
            # so one walk batch is a standard SGD estimate, not a logically different loss,
            # meaning we can optimize training by using a single walk batch per training step,
            # instead of averaging over multiple walk batches.
            positive_random_walk, negative_random_walk = _next_walk_batch(
                mode=self.mode,
                encoder=self.encoder,
                batch_size=self.random_walk_batch_size,
                state=self.__walk_loader_state,
            )
            positive_random_walk = positive_random_walk.to(self.device)
            negative_random_walk = negative_random_walk.to(self.device)

            hlp_loss = self.loss_fn(scores, labels)
            node2vec_loss = _to_node2vec_encoder(self.encoder, self.mode).loss(
                positive_random_walk,
                negative_random_walk,
            )
            loss = hlp_loss + (self.node2vec_loss_weight * node2vec_loss)

            loss_prefix = Stage.TRAIN.value
            self.log(f"{loss_prefix}_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size)
            self.log(
                f"{loss_prefix}_node2vec_loss", node2vec_loss, prog_bar=True, batch_size=batch_size
            )
            self.log(f"{loss_prefix}_loss", loss, prog_bar=True, batch_size=batch_size)
        else:
            loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN)

        self._compute_metrics(scores, labels, batch_size, Stage.TRAIN)
        return loss

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __build_node2vec_encoder(
        self,
        embedding_dim: int,
        node2vec_config: Node2VecHlpConfig,
        mode: Node2VecMode,
    ) -> Node2Vec:
        _validate_walk_length_and_context_size(
            walk_length=node2vec_config.get("walk_length", 20),
            context_size=node2vec_config.get("context_size", 10),
        )

        edge_index, num_nodes = _to_node2vec_edge_index(node2vec_config, mode)

        return Node2Vec(
            edge_index=edge_index,
            embedding_dim=embedding_dim,
            walk_length=node2vec_config.get("walk_length", 20),
            context_size=node2vec_config.get("context_size", 10),
            num_walks_per_node=node2vec_config.get("num_walks_per_node", 10),
            p=node2vec_config.get("p", 1.0),
            q=node2vec_config.get("q", 1.0),
            num_negative_samples=node2vec_config.get("num_negative_samples", 1),
            num_nodes=num_nodes,
            sparse=node2vec_config.get("sparse", False),
        )

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids)
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

VilLainEncoderConfig

Bases: TypedDict

Configuration for VilLainHlpModule.

Parameters:

Name Type Description Default
num_nodes

Total number of trainable nodes.

required
embedding_dim

Returned node and hyperedge embedding dimension. Defaults to 128.

required
labels_per_subspace

Number of virtual labels per subspace. Defaults to 2.

required
training_steps

Propagation steps used for VilLain loss. Defaults to 4.

required
generation_steps

Propagation steps averaged by forward. Defaults to 100.

required
tau

Gumbel-Softmax temperature. Defaults to 1.0.

required
eps

Numerical stability constant. Defaults to 1e-10.

required
villain_loss_weight

Weight applied to VilLain self-supervision. Defaults to 1.0.

required
Source code in hyperbench/hlp/villain_hlp.py
class VilLainEncoderConfig(TypedDict):
    """
    Configuration for ``VilLainHlpModule``.

    Args:
        num_nodes: Total number of trainable nodes.
        embedding_dim: Returned node and hyperedge embedding dimension. Defaults to ``128``.
        labels_per_subspace: Number of virtual labels per subspace. Defaults to ``2``.
        training_steps: Propagation steps used for VilLain loss. Defaults to ``4``.
        generation_steps: Propagation steps averaged by ``forward``. Defaults to ``100``.
        tau: Gumbel-Softmax temperature. Defaults to ``1.0``.
        eps: Numerical stability constant. Defaults to ``1e-10``.
        villain_loss_weight: Weight applied to VilLain self-supervision. Defaults to ``1.0``.
    """

    num_nodes: int
    embedding_dim: NotRequired[int]
    labels_per_subspace: NotRequired[int]
    training_steps: NotRequired[int]
    generation_steps: NotRequired[int]
    tau: NotRequired[float]
    eps: NotRequired[float]
    villain_loss_weight: NotRequired[float]

VilLainHlpModule

Bases: HlpModule

Feature-free VilLain Hyperedge Link Prediction module.

Parameters:

Name Type Description Default
encoder_config VilLainEncoderConfig

Configuration for the VilLain encoder.

required
embedding_mode Literal['node', 'hyperedge']

Whether to return node or hyperedge embeddings from the VilLain encoder.

'node'
aggregation Literal['mean', 'max', 'min', 'maxmin', 'sum']

Aggregation method to pool node embeddings into hyperedge embeddings when embedding_mode="node". Ignored when embedding_mode="hyperedge". Defaults to maxmin.

'maxmin'
loss_fn Module | None

Loss function for the HLP task. Defaults to nn.BCEWithLogitsLoss().

None
lr float

Learning rate for the optimizer. Defaults to 0.01.

0.01
weight_decay float

Weight decay for the optimizer. Defaults to 0.0.

0.0
metrics MetricCollection | None

Metrics to compute during training and evaluation. Defaults to None.

None
Source code in hyperbench/hlp/villain_hlp.py
class VilLainHlpModule(HlpModule):
    """
    Feature-free VilLain Hyperedge Link Prediction module.

    Args:
        encoder_config: Configuration for the VilLain encoder.
        embedding_mode: Whether to return node or hyperedge embeddings from the VilLain encoder.
        aggregation: Aggregation method to pool node embeddings into hyperedge embeddings when ``embedding_mode="node"``.
            Ignored when ``embedding_mode="hyperedge"``. Defaults to ``maxmin``.
        loss_fn: Loss function for the HLP task. Defaults to ``nn.BCEWithLogitsLoss()``.
        lr: Learning rate for the optimizer. Defaults to ``0.01``.
        weight_decay: Weight decay for the optimizer. Defaults to ``0.0``.
        metrics: Metrics to compute during training and evaluation. Defaults to ``None``.
    """

    def __init__(
        self,
        encoder_config: VilLainEncoderConfig,
        embedding_mode: Literal["node", "hyperedge"] = "node",
        aggregation: Literal["mean", "max", "min", "maxmin", "sum"] = "maxmin",
        loss_fn: nn.Module | None = None,
        lr: float = 0.01,
        weight_decay: float = 0.0,
        metrics: MetricCollection | None = None,
    ):
        self.embedding_dim = encoder_config.get("embedding_dim", 128)
        self.aggregation = aggregation
        self.lr = lr
        self.weight_decay = weight_decay
        self.villain_loss_weight = encoder_config.get("villain_loss_weight", 1.0)
        self.embedding_mode = embedding_mode

        encoder = VilLain(
            num_nodes=encoder_config["num_nodes"],
            embedding_dim=self.embedding_dim,
            labels_per_subspace=encoder_config.get("labels_per_subspace", 2),
            training_steps=encoder_config.get("training_steps", 4),
            generation_steps=encoder_config.get("generation_steps", 100),
            tau=encoder_config.get("tau", 1.0),
            eps=encoder_config.get("eps", 1e-10),
        )
        decoder = SLP(in_channels=self.embedding_dim, out_channels=1)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
            metrics=metrics,
        )

    def forward(
        self,
        hyperedge_index: Tensor,
        global_node_ids: Tensor | None = None,
        num_hyperedges: int | None = None,
    ) -> Tensor:
        encoder = self.__to_villain_encoder()

        match self.embedding_mode:
            case "hyperedge":
                hyperedge_embeddings = encoder.hyperedge_embeddings(
                    hyperedge_index=hyperedge_index,
                    node_ids=global_node_ids,
                    num_hyperedges=num_hyperedges,
                )
            case _:
                node_embeddings = encoder.node_embeddings(
                    hyperedge_index=hyperedge_index,
                    node_ids=global_node_ids,
                    num_hyperedges=num_hyperedges,
                )
                hyperedge_embeddings = HyperedgeAggregator(
                    hyperedge_index=hyperedge_index,
                    node_embeddings=node_embeddings,
                    num_hyperedges=num_hyperedges,
                ).pool(self.aggregation)

        scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
        return scores

    def training_step(self, batch: HData, batch_idx: int) -> Tensor:
        scores = self.forward(
            hyperedge_index=batch.hyperedge_index,
            global_node_ids=batch.global_node_ids,
            num_hyperedges=batch.num_hyperedges,
        )

        labels = batch.y
        batch_size = batch.num_hyperedges

        hlp_loss = self.loss_fn(scores, labels)
        villain_loss, villain_loss_parts = self.__to_villain_encoder().loss(
            hyperedge_index=batch.hyperedge_index,
            node_ids=batch.global_node_ids,
            num_hyperedges=batch.num_hyperedges,
        )
        loss = hlp_loss + (self.villain_loss_weight * villain_loss)

        loss_prefix = Stage.TRAIN.value
        self.log(f"{loss_prefix}_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size)
        self.log(
            f"{loss_prefix}_villain_loss",
            villain_loss,
            prog_bar=True,
            batch_size=batch_size,
        )
        self.log(
            f"{loss_prefix}_local_loss",
            villain_loss_parts["local_loss"],
            prog_bar=False,
            batch_size=batch_size,
        )
        self.log(
            f"{loss_prefix}_global_loss",
            villain_loss_parts["global_loss"],
            prog_bar=False,
            batch_size=batch_size,
        )
        self.log(f"{loss_prefix}_loss", loss, prog_bar=True, batch_size=batch_size)

        self._compute_metrics(scores, labels, batch_size, Stage.TRAIN)
        return loss

    def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.VAL)

    def test_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.__eval_step(batch, Stage.TEST)

    def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
        return self.forward(
            hyperedge_index=batch.hyperedge_index,
            global_node_ids=batch.global_node_ids,
            num_hyperedges=batch.num_hyperedges,
        )

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
        scores = self.forward(
            hyperedge_index=batch.hyperedge_index,
            global_node_ids=batch.global_node_ids,
            num_hyperedges=batch.num_hyperedges,
        )
        labels = batch.y
        batch_size = batch.num_hyperedges

        loss = self._compute_loss(scores, labels, batch_size, stage)
        self._compute_metrics(scores, labels, batch_size, stage)
        return loss

    def __to_villain_encoder(self) -> VilLain:
        if self.encoder is None or not isinstance(self.encoder, VilLain):
            raise ValueError("VilLain requires a VilLain encoder, but none was provided.")
        return self.encoder