Skip to content

Data

hyperbench.data

Dataset

Bases: Dataset

A dataset class for loading and processing hypergraph data. Args: hdata: The processed hypergraph data in HData format. sampling_strategy: The strategy used for sampling sub-hypergraphs (e.g., by node IDs or hyperedge IDs). If not provided, defaults to SamplingStrategy.HYPEREDGE.

Source code in hyperbench/data/dataset.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class Dataset(TorchDataset):
    """
    A dataset class for loading and processing hypergraph data.
    Args:
        hdata: The processed hypergraph data in HData format.
        sampling_strategy: The strategy used for sampling sub-hypergraphs (e.g., by node IDs or hyperedge IDs).
            If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
    """

    def __init__(
        self,
        hdata: HData | None = None,
        sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
    ) -> None:
        """
        Initialize the Dataset.

        Args:
            hdata: Optional HData object to initialize the dataset with.
                If provided, the dataset will be initialized with this data instead of loading and processing from HIF. Must be provided if prepare is set to ``False``.
            sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
        """

        self.__sampler = create_sampler_from_strategy(sampling_strategy)
        self.sampling_strategy = sampling_strategy
        self.hdata = hdata if hdata is not None else HData.empty()

    def __len__(self) -> int:
        return self.__sampler.len(self.hdata)

    def __getitem__(self, index: int | list[int]) -> HData:
        """
        Sample a sub-hypergraph based on the sampling strategy and return it as HData.
        If:
        - Sampling by node IDs, the sub-hypergraph will contain all hyperedges incident to the sampled nodes and all nodes incident to those hyperedges.
        - Sampling by hyperedge IDs, the sub-hypergraph will contain all nodes incident to the sampled hyperedges.

        Args:
            index: An integer or a list of integers representing node or hyperedge IDs to sample, depending on the sampling strategy.

        Returns:
            An HData instance containing the sampled sub-hypergraph.

        Raises:
            ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes/hyperedges).
            IndexError: If any node/hyperedge ID is out of bounds.
        """
        return self.__sampler.sample(index, self.hdata)

    @classmethod
    def from_hdata(
        cls,
        hdata: HData,
        sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
    ) -> "Dataset":
        """
        Create a :class:`Dataset` instance from an :class:`HData` object.

        Args:
            hdata: :class:`HData` object containing the hypergraph data.
            sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.

        Returns:
            The :class:`Dataset` instance with the provided :class:`HData`.
        """
        return cls(hdata=hdata, sampling_strategy=sampling_strategy)

    @classmethod
    def from_url(
        cls,
        url: str,
        sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
        save_on_disk: bool = False,
    ) -> "Dataset":
        """
        Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format.

        Args:
            url: The URL to the .json or .json.zst file containing the HIF hypergraph data.
            sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
            save_on_disk: Whether to save the downloaded file on disk.

        Returns:
            The :class:`Dataset` instance with the loaded hypergraph data.
        """
        hdata = HIFLoader.load_from_url(url=url, save_on_disk=save_on_disk)
        dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy)
        return dataset

    @classmethod
    def from_path(
        cls,
        filepath: str,
        sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
    ) -> "Dataset":
        """
        Create a :class:`Dataset` instance by loading a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format.

        Args:
            filepath: The local file path to the .json or .json.zst file containing the HIF hypergraph data.
            sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.

        Returns:
            The :class:`Dataset` instance with the loaded hypergraph data.
        """
        hypergraph = HIFLoader.load_from_path(filepath=filepath)
        dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy)
        return dataset

    def enrich_node_features(
        self,
        enricher: NodeEnricher,
        enrichment_mode: EnrichmentMode | None = None,
    ) -> None:
        """
        Enrich node features using the provided node feature enricher.

        Args:
            enricher: An instance of NodeEnricher to generate structural node features from hypergraph topology.
            enrichment_mode: How to combine generated features with existing ``hdata.x``.
                ``concatenate`` appends new features as additional columns.
                ``replace`` substitutes ``hdata.x`` entirely.
        """
        self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode)

    def enrich_node_features_from(
        self,
        dataset_with_features: "Dataset",
        node_space_setting: NodeSpaceSetting = "transductive",
        fill_value: NodeSpaceFiller | None = None,
    ) -> None:
        """
        Enrich node features from another dataset by copying features by ``global_node_ids``.

        Examples:
            In a transductive setting, the full node space is preserved across datasets:
            >>> val_dataset.enrich_node_features_from(train_dataset)

            In inductive setting, missing node features can be filled with 0.0:
            >>> test_dataset.enrich_node_features_from(
            ...     train_dataset,
            ...     node_space_setting="inductive",
            ...     fill_value=0.0,  # torch.tensor(0.0) also works and will be broadcast to the appropriate shape
            ... )

        Args:
            dataset_with_features: Source dataset providing node features.
            node_space_setting: The setting for the node space, determining how nodes are handled.
                ``transductive`` (default) preserves the full node space of the target dataset.
                ``inductive`` allows the target dataset to have a different node space, filling missing features with ``fill_value``.
            fill_value: Scalar or vector used to fill missing node features when ``node_space_setting`` is not transductive.

        Raises:
            ValueError: If the source dataset's node features cannot be aligned with the target dataset's nodes.
        """
        self.hdata = self.hdata.enrich_node_features_from(
            hdata_with_features=dataset_with_features.hdata,
            node_space_setting=node_space_setting,
            fill_value=fill_value,
        )

    def enrich_hyperedge_attr(
        self,
        enricher: HyperedgeEnricher,
        enrichment_mode: EnrichmentMode | None = None,
    ) -> None:
        """Enrich hyperedge features using the provided hyperedge feature enricher.

        Args:
            enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
            enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``.
                ``concatenate`` appends new features as additional columns.
                ``replace`` substitutes ``hdata.hyperedge_attr`` entirely.
        """
        self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode)

    def enrich_hyperedge_weights(
        self,
        enricher: HyperedgeEnricher,
        enrichment_mode: EnrichmentMode | None = None,
    ) -> None:
        """Enrich hyperedge weights using the provided hyperedge weight enricher.

        Args:
            enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
            enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_weights``.
                ``concatenate`` appends new features as additional columns.
                ``replace`` substitutes ``hdata.hyperedge_weights`` entirely.
        """
        self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode)

    def update_from_hdata(self, hdata: HData) -> "Dataset":
        """
        Create a :class:`Dataset` instance from an :class:`HData` object.

        Args:
            hdata: :class:`HData` object containing the hypergraph data.

        Returns:
            The :class:`Dataset` instance with the provided :class:`HData`.
        """
        return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy)

    def add_negative_samples(
        self,
        negative_sampler: "NegativeSampler",
        seed: int | None = None,
    ) -> "Dataset":
        """
        Create a new :class:`Dataset` with sampled negative hyperedges added.

        Args:
            negative_sampler: Sampler used to generate negative hyperedges from this dataset's ``hdata``.
            seed: Optional random seed used for both negative sampling and the final shuffle.

        Returns:
            A new :class:`Dataset` instance with positives and sampled negatives.
        """
        hdata_with_negatives = self.hdata.clone()
        hdata_with_negatives = hdata_with_negatives.add_negative_samples(
            negative_sampler=negative_sampler,
            seed=seed,
        )
        return self.update_from_hdata(hdata_with_negatives)

    def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> None:
        """
        Remove hyperedges that have fewer than k incident nodes.

        Args:
            k: The minimum number of nodes a hyperedge must have to be retained.
        """
        self.hdata = self.hdata.remove_hyperedges_with_fewer_than_k_nodes(k)

    def split(
        self,
        ratios: list[float],
        shuffle: bool | None = False,
        seed: int | None = None,
        node_space_setting: NodeSpaceSetting = "transductive",
        assign_node_space_to: NodeSpaceAssignment | None = "first",
    ) -> list["Dataset"]:
        """
        Split the dataset by hyperedges into partitions with contiguous 0-based hyperedge IDs.

        Boundaries are computed using cumulative floor to prevent early splits from
        over-consuming edges. The last split absorbs any rounding remainder.

        Examples:
            Transductive split keeping the full node space only on the first split (default):
            >>> train, test = dataset.split([0.8, 0.2])
            >>> train.hdata.num_nodes == dataset.hdata.num_nodes
            >>> test.hdata.num_nodes <= dataset.hdata.num_nodes

            Transductive split keeping the full node space on all splits:
            >>> train, test = dataset.split(
            ...     [0.8, 0.2],
            ...     node_space_setting="transductive",
            ...     assign_node_space_to="all",
            ... )
            >>> train.hdata.num_nodes == dataset.hdata.num_nodes
            >>> test.hdata.num_nodes == dataset.hdata.num_nodes

            Inductive split:
            >>> train, test = dataset.split(
            ...     [0.8, 0.2],
            ...     node_space_setting="inductive",
            ...     assign_node_space_to=None,
            ... )
            >>> train.hdata.num_nodes <= dataset.hdata.num_nodes
            >>> test.hdata.num_nodes <= dataset.hdata.num_nodes

        Args:
            ratios: List of floats summing to ``1.0``, e.g., ``[0.8, 0.1, 0.1]``.
            shuffle: Whether to shuffle hyperedges before splitting. Defaults to ``False`` for deterministic splits.
            seed: Optional random seed for reproducibility. Ignored if shuffle is set to ``False``.
            node_space_setting: Whether to preserve the full node space in the splits.
                ``transductive`` (default) ensures all nodes are present in every split,
                while ``inductive`` allows splits to have disjoint node spaces.
            assign_node_space_to: Which split(s) preserve the full node space when
                ``node_space_setting="transductive"``.
                ``first`` preserves only the first returned split. ``all`` preserves all splits.

        Returns:
            List of Dataset objects, one per split, each with contiguous IDs.
        """
        # Allow small imprecision in sum of ratios, but raise error if it's significant
        # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid)
        #          ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError)
        #          ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision)
        if abs(sum(ratios) - 1.0) > 1e-6:
            raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.")
        if is_inductive_setting(node_space_setting) and assign_node_space_to is not None:
            raise ValueError(
                "assign_node_space_to can only be provided when node_space_setting='transductive'."
            )

        device = self.hdata.device
        num_hyperedges = self.hdata.num_hyperedges
        hyperedge_ids_permutation = self.__get_hyperedge_ids_permutation(
            num_hyperedges, shuffle, seed
        )

        # Compute cumulative ratio boundaries to avoid independent rounding errors.
        # Independent rounding (e.g., round(0.5*3)=2, round(0.25*3)=1, round(0.25*3)=1 -> total=4)
        # can over-allocate edges to early splits and starve later ones.
        # Cumulative floor boundaries guarantee monotonically increasing cut points.
        # Example: ratios = [0.5, 0.25, 0.25], num_hyperedges = 3
        #          cumulative_ratios = [0.5, 0.75, 1.0]
        cumulative_ratios = []
        cumsum = 0.0
        for ratio in ratios:
            cumsum += ratio
            cumulative_ratios.append(cumsum)

        split_datasets = []
        start = 0
        for i in range(len(ratios)):
            if i == len(ratios) - 1:
                # Last split gets everything remaining, absorbing any rounding remainder
                # Example: start = 2, end = 3 -> permutation[2:3] = [2] (1 edge)
                end = num_hyperedges
            else:
                # Floor of cumulative boundary ensures early splits don't over-consume
                # Example: i=0 -> int(0.5 * 3) = int(1.5) = 1, end = 1
                #          i=1 -> int(0.75 * 3) = int(2.25) = 2, end = 2
                end = int(cumulative_ratios[i] * num_hyperedges)

            # Example: i=0 -> permutation[0:1] = [0] (1 edge)
            #          i=1 -> permutation[1:2] = [1] (1 edge)
            #          i=2 -> permutation[2:3] = [2] (1 edge)
            split_hyperedge_ids = hyperedge_ids_permutation[start:end]

            use_transductive_node_space = is_transductive_split(
                node_space_setting, assign_node_space_to, split_num=i
            )
            split_hdata = HData.split(
                self.hdata,
                split_hyperedge_ids,
                node_space_setting="transductive" if use_transductive_node_space else "inductive",
            ).to(device=device)

            split_dataset = self.__class__(
                hdata=split_hdata,
                sampling_strategy=self.sampling_strategy,
            )
            split_datasets.append(split_dataset)

            start = end

        return split_datasets

    def to(self, device: torch.device) -> "Dataset":
        """
        Move the dataset's HData to the specified device.

        Args:
            device: The target device (e.g., ``torch.device('cuda')`` or ``torch.device('cpu')``).

        Returns:
            The Dataset instance moved to the specified device.
        """
        self.hdata = self.hdata.to(device)
        return self

    def transform_node_attrs(
        self,
        attrs: dict[str, Any],
        attr_keys: list[str] | None = None,
    ) -> Tensor:
        return HIFProcessor.transform_attrs(attrs, attr_keys)

    def transform_hyperedge_attrs(
        self,
        attrs: dict[str, Any],
        attr_keys: list[str] | None = None,
    ) -> Tensor:
        return HIFProcessor.transform_attrs(attrs, attr_keys)

    def stats(self) -> dict[str, Any]:
        """
        Compute statistics for the dataset.
        This method currently delegates to the underlying HData's stats method.
        The fields returned in the dictionary include:
        - ``shape_x``: The shape of the node feature matrix ``x``.
        - ``shape_hyperedge_attr``: The shape of the hyperedge attribute matrix, or ``None`` if hyperedge attributes are not present.
        - ``num_nodes``: The number of nodes in the hypergraph.
        - ``num_hyperedges``: The number of hyperedges in the hypergraph.
        - ``avg_degree_node_raw``: The average degree of nodes, calculated as the mean number of hyperedges each node belongs to.
        - ``avg_degree_node``: The floored node average degree.
        - ``avg_degree_hyperedge_raw``: The average size of hyperedges, calculated as the mean number of nodes each hyperedge contains.
        - ``avg_degree_hyperedge``: The floored hyperedge average size.
        - ``node_degree_max``: The maximum degree of any node in the hypergraph.
        - ``hyperedge_degree_max``: The maximum size of any hyperedge in the hypergraph.
        - ``node_degree_median``: The median degree of nodes in the hypergraph.
        - ``hyperedge_degree_median``: The median size of hyperedges in the hypergraph.
        - ``distribution_node_degree``: A list where the value at index ``i`` represents the count of nodes with degree ``i``.
        - ``distribution_hyperedge_size``: A list where the value at index ``i`` represents the count of hyperedges with size ``i``.
        - ``distribution_node_degree_hist``: A dictionary where the keys are node degrees and the values are the count of nodes with that degree.
        - ``distribution_hyperedge_size_hist``: A dictionary where the keys are hyperedge sizes and the values are the count of hyperedges with that size.

        Returns:
            A dictionary containing various statistics about the hypergraph.
        """

        return self.hdata.stats()

    def __get_hyperedge_ids_permutation(
        self,
        num_hyperedges: int,
        shuffle: bool | None,
        seed: int | None,
    ) -> Tensor:
        device = self.hdata.device

        # Shuffle hyperedge IDs if shuffle is requested, otherwise keep original order for deterministic splits
        if shuffle:
            generator = torch.Generator(device=device)
            if seed is not None:
                generator.manual_seed(seed)

            random_hyperedge_ids_permutation = torch.randperm(
                n=num_hyperedges,
                generator=generator,
                device=device,
            )
            return random_hyperedge_ids_permutation

        ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device)
        return ranged_hyperedge_ids_permutation

__init__(hdata=None, sampling_strategy=SamplingStrategy.HYPEREDGE)

Initialize the Dataset.

Parameters:

Name Type Description Default
hdata HData | None

Optional HData object to initialize the dataset with. If provided, the dataset will be initialized with this data instead of loading and processing from HIF. Must be provided if prepare is set to False.

None
sampling_strategy SamplingStrategy

The sampling strategy to use for the dataset. If not provided, defaults to SamplingStrategy.HYPEREDGE.

HYPEREDGE
Source code in hyperbench/data/dataset.py
def __init__(
    self,
    hdata: HData | None = None,
    sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
) -> None:
    """
    Initialize the Dataset.

    Args:
        hdata: Optional HData object to initialize the dataset with.
            If provided, the dataset will be initialized with this data instead of loading and processing from HIF. Must be provided if prepare is set to ``False``.
        sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
    """

    self.__sampler = create_sampler_from_strategy(sampling_strategy)
    self.sampling_strategy = sampling_strategy
    self.hdata = hdata if hdata is not None else HData.empty()

__getitem__(index)

Sample a sub-hypergraph based on the sampling strategy and return it as HData. If: - Sampling by node IDs, the sub-hypergraph will contain all hyperedges incident to the sampled nodes and all nodes incident to those hyperedges. - Sampling by hyperedge IDs, the sub-hypergraph will contain all nodes incident to the sampled hyperedges.

Parameters:

Name Type Description Default
index int | list[int]

An integer or a list of integers representing node or hyperedge IDs to sample, depending on the sampling strategy.

required

Returns:

Type Description
HData

An HData instance containing the sampled sub-hypergraph.

Raises:

Type Description
ValueError

If the provided index is invalid (e.g., empty list or list length exceeds number of nodes/hyperedges).

IndexError

If any node/hyperedge ID is out of bounds.

Source code in hyperbench/data/dataset.py
def __getitem__(self, index: int | list[int]) -> HData:
    """
    Sample a sub-hypergraph based on the sampling strategy and return it as HData.
    If:
    - Sampling by node IDs, the sub-hypergraph will contain all hyperedges incident to the sampled nodes and all nodes incident to those hyperedges.
    - Sampling by hyperedge IDs, the sub-hypergraph will contain all nodes incident to the sampled hyperedges.

    Args:
        index: An integer or a list of integers representing node or hyperedge IDs to sample, depending on the sampling strategy.

    Returns:
        An HData instance containing the sampled sub-hypergraph.

    Raises:
        ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes/hyperedges).
        IndexError: If any node/hyperedge ID is out of bounds.
    """
    return self.__sampler.sample(index, self.hdata)

from_hdata(hdata, sampling_strategy=SamplingStrategy.HYPEREDGE) classmethod

Create a :class:Dataset instance from an :class:HData object.

Parameters:

Name Type Description Default
hdata HData

:class:HData object containing the hypergraph data.

required
sampling_strategy SamplingStrategy

The sampling strategy to use for the dataset. If not provided, defaults to SamplingStrategy.HYPEREDGE.

HYPEREDGE

Returns:

Name Type Description
The Dataset

class:Dataset instance with the provided :class:HData.

Source code in hyperbench/data/dataset.py
@classmethod
def from_hdata(
    cls,
    hdata: HData,
    sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
) -> "Dataset":
    """
    Create a :class:`Dataset` instance from an :class:`HData` object.

    Args:
        hdata: :class:`HData` object containing the hypergraph data.
        sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.

    Returns:
        The :class:`Dataset` instance with the provided :class:`HData`.
    """
    return cls(hdata=hdata, sampling_strategy=sampling_strategy)

from_url(url, sampling_strategy=SamplingStrategy.HYPEREDGE, save_on_disk=False) classmethod

Create a :class:Dataset instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format.

Parameters:

Name Type Description Default
url str

The URL to the .json or .json.zst file containing the HIF hypergraph data.

required
sampling_strategy SamplingStrategy

The sampling strategy to use for the dataset. If not provided, defaults to SamplingStrategy.HYPEREDGE.

HYPEREDGE
save_on_disk bool

Whether to save the downloaded file on disk.

False

Returns:

Name Type Description
The Dataset

class:Dataset instance with the loaded hypergraph data.

Source code in hyperbench/data/dataset.py
@classmethod
def from_url(
    cls,
    url: str,
    sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
    save_on_disk: bool = False,
) -> "Dataset":
    """
    Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format.

    Args:
        url: The URL to the .json or .json.zst file containing the HIF hypergraph data.
        sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
        save_on_disk: Whether to save the downloaded file on disk.

    Returns:
        The :class:`Dataset` instance with the loaded hypergraph data.
    """
    hdata = HIFLoader.load_from_url(url=url, save_on_disk=save_on_disk)
    dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy)
    return dataset

from_path(filepath, sampling_strategy=SamplingStrategy.HYPEREDGE) classmethod

Create a :class:Dataset instance by loading a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format.

Parameters:

Name Type Description Default
filepath str

The local file path to the .json or .json.zst file containing the HIF hypergraph data.

required
sampling_strategy SamplingStrategy

The sampling strategy to use for the dataset. If not provided, defaults to SamplingStrategy.HYPEREDGE.

HYPEREDGE

Returns:

Name Type Description
The Dataset

class:Dataset instance with the loaded hypergraph data.

Source code in hyperbench/data/dataset.py
@classmethod
def from_path(
    cls,
    filepath: str,
    sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
) -> "Dataset":
    """
    Create a :class:`Dataset` instance by loading a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format.

    Args:
        filepath: The local file path to the .json or .json.zst file containing the HIF hypergraph data.
        sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.

    Returns:
        The :class:`Dataset` instance with the loaded hypergraph data.
    """
    hypergraph = HIFLoader.load_from_path(filepath=filepath)
    dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy)
    return dataset

enrich_node_features(enricher, enrichment_mode=None)

Enrich node features using the provided node feature enricher.

Parameters:

Name Type Description Default
enricher NodeEnricher

An instance of NodeEnricher to generate structural node features from hypergraph topology.

required
enrichment_mode EnrichmentMode | None

How to combine generated features with existing hdata.x. concatenate appends new features as additional columns. replace substitutes hdata.x entirely.

None
Source code in hyperbench/data/dataset.py
def enrich_node_features(
    self,
    enricher: NodeEnricher,
    enrichment_mode: EnrichmentMode | None = None,
) -> None:
    """
    Enrich node features using the provided node feature enricher.

    Args:
        enricher: An instance of NodeEnricher to generate structural node features from hypergraph topology.
        enrichment_mode: How to combine generated features with existing ``hdata.x``.
            ``concatenate`` appends new features as additional columns.
            ``replace`` substitutes ``hdata.x`` entirely.
    """
    self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode)

enrich_node_features_from(dataset_with_features, node_space_setting='transductive', fill_value=None)

Enrich node features from another dataset by copying features by global_node_ids.

Examples:

In a transductive setting, the full node space is preserved across datasets:

>>> val_dataset.enrich_node_features_from(train_dataset)

In inductive setting, missing node features can be filled with 0.0:

>>> test_dataset.enrich_node_features_from(
...     train_dataset,
...     node_space_setting="inductive",
...     fill_value=0.0,  # torch.tensor(0.0) also works and will be broadcast to the appropriate shape
... )

Parameters:

Name Type Description Default
dataset_with_features Dataset

Source dataset providing node features.

required
node_space_setting NodeSpaceSetting

The setting for the node space, determining how nodes are handled. transductive (default) preserves the full node space of the target dataset. inductive allows the target dataset to have a different node space, filling missing features with fill_value.

'transductive'
fill_value NodeSpaceFiller | None

Scalar or vector used to fill missing node features when node_space_setting is not transductive.

None

Raises:

Type Description
ValueError

If the source dataset's node features cannot be aligned with the target dataset's nodes.

Source code in hyperbench/data/dataset.py
def enrich_node_features_from(
    self,
    dataset_with_features: "Dataset",
    node_space_setting: NodeSpaceSetting = "transductive",
    fill_value: NodeSpaceFiller | None = None,
) -> None:
    """
    Enrich node features from another dataset by copying features by ``global_node_ids``.

    Examples:
        In a transductive setting, the full node space is preserved across datasets:
        >>> val_dataset.enrich_node_features_from(train_dataset)

        In inductive setting, missing node features can be filled with 0.0:
        >>> test_dataset.enrich_node_features_from(
        ...     train_dataset,
        ...     node_space_setting="inductive",
        ...     fill_value=0.0,  # torch.tensor(0.0) also works and will be broadcast to the appropriate shape
        ... )

    Args:
        dataset_with_features: Source dataset providing node features.
        node_space_setting: The setting for the node space, determining how nodes are handled.
            ``transductive`` (default) preserves the full node space of the target dataset.
            ``inductive`` allows the target dataset to have a different node space, filling missing features with ``fill_value``.
        fill_value: Scalar or vector used to fill missing node features when ``node_space_setting`` is not transductive.

    Raises:
        ValueError: If the source dataset's node features cannot be aligned with the target dataset's nodes.
    """
    self.hdata = self.hdata.enrich_node_features_from(
        hdata_with_features=dataset_with_features.hdata,
        node_space_setting=node_space_setting,
        fill_value=fill_value,
    )

enrich_hyperedge_attr(enricher, enrichment_mode=None)

Enrich hyperedge features using the provided hyperedge feature enricher.

Parameters:

Name Type Description Default
enricher HyperedgeEnricher

An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.

required
enrichment_mode EnrichmentMode | None

How to combine generated features with existing hdata.hyperedge_attr. concatenate appends new features as additional columns. replace substitutes hdata.hyperedge_attr entirely.

None
Source code in hyperbench/data/dataset.py
def enrich_hyperedge_attr(
    self,
    enricher: HyperedgeEnricher,
    enrichment_mode: EnrichmentMode | None = None,
) -> None:
    """Enrich hyperedge features using the provided hyperedge feature enricher.

    Args:
        enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
        enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``.
            ``concatenate`` appends new features as additional columns.
            ``replace`` substitutes ``hdata.hyperedge_attr`` entirely.
    """
    self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode)

enrich_hyperedge_weights(enricher, enrichment_mode=None)

Enrich hyperedge weights using the provided hyperedge weight enricher.

Parameters:

Name Type Description Default
enricher HyperedgeEnricher

An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.

required
enrichment_mode EnrichmentMode | None

How to combine generated features with existing hdata.hyperedge_weights. concatenate appends new features as additional columns. replace substitutes hdata.hyperedge_weights entirely.

None
Source code in hyperbench/data/dataset.py
def enrich_hyperedge_weights(
    self,
    enricher: HyperedgeEnricher,
    enrichment_mode: EnrichmentMode | None = None,
) -> None:
    """Enrich hyperedge weights using the provided hyperedge weight enricher.

    Args:
        enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology.
        enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_weights``.
            ``concatenate`` appends new features as additional columns.
            ``replace`` substitutes ``hdata.hyperedge_weights`` entirely.
    """
    self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode)

update_from_hdata(hdata)

Create a :class:Dataset instance from an :class:HData object.

Parameters:

Name Type Description Default
hdata HData

:class:HData object containing the hypergraph data.

required

Returns:

Name Type Description
The Dataset

class:Dataset instance with the provided :class:HData.

Source code in hyperbench/data/dataset.py
def update_from_hdata(self, hdata: HData) -> "Dataset":
    """
    Create a :class:`Dataset` instance from an :class:`HData` object.

    Args:
        hdata: :class:`HData` object containing the hypergraph data.

    Returns:
        The :class:`Dataset` instance with the provided :class:`HData`.
    """
    return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy)

add_negative_samples(negative_sampler, seed=None)

Create a new :class:Dataset with sampled negative hyperedges added.

Parameters:

Name Type Description Default
negative_sampler NegativeSampler

Sampler used to generate negative hyperedges from this dataset's hdata.

required
seed int | None

Optional random seed used for both negative sampling and the final shuffle.

None

Returns:

Type Description
Dataset

A new :class:Dataset instance with positives and sampled negatives.

Source code in hyperbench/data/dataset.py
def add_negative_samples(
    self,
    negative_sampler: "NegativeSampler",
    seed: int | None = None,
) -> "Dataset":
    """
    Create a new :class:`Dataset` with sampled negative hyperedges added.

    Args:
        negative_sampler: Sampler used to generate negative hyperedges from this dataset's ``hdata``.
        seed: Optional random seed used for both negative sampling and the final shuffle.

    Returns:
        A new :class:`Dataset` instance with positives and sampled negatives.
    """
    hdata_with_negatives = self.hdata.clone()
    hdata_with_negatives = hdata_with_negatives.add_negative_samples(
        negative_sampler=negative_sampler,
        seed=seed,
    )
    return self.update_from_hdata(hdata_with_negatives)

remove_hyperedges_with_fewer_than_k_nodes(k)

Remove hyperedges that have fewer than k incident nodes.

Parameters:

Name Type Description Default
k int

The minimum number of nodes a hyperedge must have to be retained.

required
Source code in hyperbench/data/dataset.py
def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> None:
    """
    Remove hyperedges that have fewer than k incident nodes.

    Args:
        k: The minimum number of nodes a hyperedge must have to be retained.
    """
    self.hdata = self.hdata.remove_hyperedges_with_fewer_than_k_nodes(k)

split(ratios, shuffle=False, seed=None, node_space_setting='transductive', assign_node_space_to='first')

Split the dataset by hyperedges into partitions with contiguous 0-based hyperedge IDs.

Boundaries are computed using cumulative floor to prevent early splits from over-consuming edges. The last split absorbs any rounding remainder.

Examples:

Transductive split keeping the full node space only on the first split (default):

>>> train, test = dataset.split([0.8, 0.2])
>>> train.hdata.num_nodes == dataset.hdata.num_nodes
>>> test.hdata.num_nodes <= dataset.hdata.num_nodes

Transductive split keeping the full node space on all splits:

>>> train, test = dataset.split(
...     [0.8, 0.2],
...     node_space_setting="transductive",
...     assign_node_space_to="all",
... )
>>> train.hdata.num_nodes == dataset.hdata.num_nodes
>>> test.hdata.num_nodes == dataset.hdata.num_nodes

Inductive split:

>>> train, test = dataset.split(
...     [0.8, 0.2],
...     node_space_setting="inductive",
...     assign_node_space_to=None,
... )
>>> train.hdata.num_nodes <= dataset.hdata.num_nodes
>>> test.hdata.num_nodes <= dataset.hdata.num_nodes

Parameters:

Name Type Description Default
ratios list[float]

List of floats summing to 1.0, e.g., [0.8, 0.1, 0.1].

required
shuffle bool | None

Whether to shuffle hyperedges before splitting. Defaults to False for deterministic splits.

False
seed int | None

Optional random seed for reproducibility. Ignored if shuffle is set to False.

None
node_space_setting NodeSpaceSetting

Whether to preserve the full node space in the splits. transductive (default) ensures all nodes are present in every split, while inductive allows splits to have disjoint node spaces.

'transductive'
assign_node_space_to NodeSpaceAssignment | None

Which split(s) preserve the full node space when node_space_setting="transductive". first preserves only the first returned split. all preserves all splits.

'first'

Returns:

Type Description
list[Dataset]

List of Dataset objects, one per split, each with contiguous IDs.

Source code in hyperbench/data/dataset.py
def split(
    self,
    ratios: list[float],
    shuffle: bool | None = False,
    seed: int | None = None,
    node_space_setting: NodeSpaceSetting = "transductive",
    assign_node_space_to: NodeSpaceAssignment | None = "first",
) -> list["Dataset"]:
    """
    Split the dataset by hyperedges into partitions with contiguous 0-based hyperedge IDs.

    Boundaries are computed using cumulative floor to prevent early splits from
    over-consuming edges. The last split absorbs any rounding remainder.

    Examples:
        Transductive split keeping the full node space only on the first split (default):
        >>> train, test = dataset.split([0.8, 0.2])
        >>> train.hdata.num_nodes == dataset.hdata.num_nodes
        >>> test.hdata.num_nodes <= dataset.hdata.num_nodes

        Transductive split keeping the full node space on all splits:
        >>> train, test = dataset.split(
        ...     [0.8, 0.2],
        ...     node_space_setting="transductive",
        ...     assign_node_space_to="all",
        ... )
        >>> train.hdata.num_nodes == dataset.hdata.num_nodes
        >>> test.hdata.num_nodes == dataset.hdata.num_nodes

        Inductive split:
        >>> train, test = dataset.split(
        ...     [0.8, 0.2],
        ...     node_space_setting="inductive",
        ...     assign_node_space_to=None,
        ... )
        >>> train.hdata.num_nodes <= dataset.hdata.num_nodes
        >>> test.hdata.num_nodes <= dataset.hdata.num_nodes

    Args:
        ratios: List of floats summing to ``1.0``, e.g., ``[0.8, 0.1, 0.1]``.
        shuffle: Whether to shuffle hyperedges before splitting. Defaults to ``False`` for deterministic splits.
        seed: Optional random seed for reproducibility. Ignored if shuffle is set to ``False``.
        node_space_setting: Whether to preserve the full node space in the splits.
            ``transductive`` (default) ensures all nodes are present in every split,
            while ``inductive`` allows splits to have disjoint node spaces.
        assign_node_space_to: Which split(s) preserve the full node space when
            ``node_space_setting="transductive"``.
            ``first`` preserves only the first returned split. ``all`` preserves all splits.

    Returns:
        List of Dataset objects, one per split, each with contiguous IDs.
    """
    # Allow small imprecision in sum of ratios, but raise error if it's significant
    # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid)
    #          ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError)
    #          ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision)
    if abs(sum(ratios) - 1.0) > 1e-6:
        raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.")
    if is_inductive_setting(node_space_setting) and assign_node_space_to is not None:
        raise ValueError(
            "assign_node_space_to can only be provided when node_space_setting='transductive'."
        )

    device = self.hdata.device
    num_hyperedges = self.hdata.num_hyperedges
    hyperedge_ids_permutation = self.__get_hyperedge_ids_permutation(
        num_hyperedges, shuffle, seed
    )

    # Compute cumulative ratio boundaries to avoid independent rounding errors.
    # Independent rounding (e.g., round(0.5*3)=2, round(0.25*3)=1, round(0.25*3)=1 -> total=4)
    # can over-allocate edges to early splits and starve later ones.
    # Cumulative floor boundaries guarantee monotonically increasing cut points.
    # Example: ratios = [0.5, 0.25, 0.25], num_hyperedges = 3
    #          cumulative_ratios = [0.5, 0.75, 1.0]
    cumulative_ratios = []
    cumsum = 0.0
    for ratio in ratios:
        cumsum += ratio
        cumulative_ratios.append(cumsum)

    split_datasets = []
    start = 0
    for i in range(len(ratios)):
        if i == len(ratios) - 1:
            # Last split gets everything remaining, absorbing any rounding remainder
            # Example: start = 2, end = 3 -> permutation[2:3] = [2] (1 edge)
            end = num_hyperedges
        else:
            # Floor of cumulative boundary ensures early splits don't over-consume
            # Example: i=0 -> int(0.5 * 3) = int(1.5) = 1, end = 1
            #          i=1 -> int(0.75 * 3) = int(2.25) = 2, end = 2
            end = int(cumulative_ratios[i] * num_hyperedges)

        # Example: i=0 -> permutation[0:1] = [0] (1 edge)
        #          i=1 -> permutation[1:2] = [1] (1 edge)
        #          i=2 -> permutation[2:3] = [2] (1 edge)
        split_hyperedge_ids = hyperedge_ids_permutation[start:end]

        use_transductive_node_space = is_transductive_split(
            node_space_setting, assign_node_space_to, split_num=i
        )
        split_hdata = HData.split(
            self.hdata,
            split_hyperedge_ids,
            node_space_setting="transductive" if use_transductive_node_space else "inductive",
        ).to(device=device)

        split_dataset = self.__class__(
            hdata=split_hdata,
            sampling_strategy=self.sampling_strategy,
        )
        split_datasets.append(split_dataset)

        start = end

    return split_datasets

to(device)

Move the dataset's HData to the specified device.

Parameters:

Name Type Description Default
device device

The target device (e.g., torch.device('cuda') or torch.device('cpu')).

required

Returns:

Type Description
Dataset

The Dataset instance moved to the specified device.

Source code in hyperbench/data/dataset.py
def to(self, device: torch.device) -> "Dataset":
    """
    Move the dataset's HData to the specified device.

    Args:
        device: The target device (e.g., ``torch.device('cuda')`` or ``torch.device('cpu')``).

    Returns:
        The Dataset instance moved to the specified device.
    """
    self.hdata = self.hdata.to(device)
    return self

stats()

Compute statistics for the dataset. This method currently delegates to the underlying HData's stats method. The fields returned in the dictionary include: - shape_x: The shape of the node feature matrix x. - shape_hyperedge_attr: The shape of the hyperedge attribute matrix, or None if hyperedge attributes are not present. - num_nodes: The number of nodes in the hypergraph. - num_hyperedges: The number of hyperedges in the hypergraph. - avg_degree_node_raw: The average degree of nodes, calculated as the mean number of hyperedges each node belongs to. - avg_degree_node: The floored node average degree. - avg_degree_hyperedge_raw: The average size of hyperedges, calculated as the mean number of nodes each hyperedge contains. - avg_degree_hyperedge: The floored hyperedge average size. - node_degree_max: The maximum degree of any node in the hypergraph. - hyperedge_degree_max: The maximum size of any hyperedge in the hypergraph. - node_degree_median: The median degree of nodes in the hypergraph. - hyperedge_degree_median: The median size of hyperedges in the hypergraph. - distribution_node_degree: A list where the value at index i represents the count of nodes with degree i. - distribution_hyperedge_size: A list where the value at index i represents the count of hyperedges with size i. - distribution_node_degree_hist: A dictionary where the keys are node degrees and the values are the count of nodes with that degree. - distribution_hyperedge_size_hist: A dictionary where the keys are hyperedge sizes and the values are the count of hyperedges with that size.

Returns:

Type Description
dict[str, Any]

A dictionary containing various statistics about the hypergraph.

Source code in hyperbench/data/dataset.py
def stats(self) -> dict[str, Any]:
    """
    Compute statistics for the dataset.
    This method currently delegates to the underlying HData's stats method.
    The fields returned in the dictionary include:
    - ``shape_x``: The shape of the node feature matrix ``x``.
    - ``shape_hyperedge_attr``: The shape of the hyperedge attribute matrix, or ``None`` if hyperedge attributes are not present.
    - ``num_nodes``: The number of nodes in the hypergraph.
    - ``num_hyperedges``: The number of hyperedges in the hypergraph.
    - ``avg_degree_node_raw``: The average degree of nodes, calculated as the mean number of hyperedges each node belongs to.
    - ``avg_degree_node``: The floored node average degree.
    - ``avg_degree_hyperedge_raw``: The average size of hyperedges, calculated as the mean number of nodes each hyperedge contains.
    - ``avg_degree_hyperedge``: The floored hyperedge average size.
    - ``node_degree_max``: The maximum degree of any node in the hypergraph.
    - ``hyperedge_degree_max``: The maximum size of any hyperedge in the hypergraph.
    - ``node_degree_median``: The median degree of nodes in the hypergraph.
    - ``hyperedge_degree_median``: The median size of hyperedges in the hypergraph.
    - ``distribution_node_degree``: A list where the value at index ``i`` represents the count of nodes with degree ``i``.
    - ``distribution_hyperedge_size``: A list where the value at index ``i`` represents the count of hyperedges with size ``i``.
    - ``distribution_node_degree_hist``: A dictionary where the keys are node degrees and the values are the count of nodes with that degree.
    - ``distribution_hyperedge_size_hist``: A dictionary where the keys are hyperedge sizes and the values are the count of hyperedges with that size.

    Returns:
        A dictionary containing various statistics about the hypergraph.
    """

    return self.hdata.stats()

HIFLoader

A utility class to load hypergraphs from HIF format.

Source code in hyperbench/data/hif.py
class HIFLoader:
    """A utility class to load hypergraphs from HIF format."""

    @classmethod
    def load_from_url(cls, url: str, save_on_disk: bool = False) -> HData:
        """
        Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format.
        Args:
            url (str): The URL to the .json or .json.zst file containing the HIF hypergraph data.
            save_on_disk (bool): Whether to save the downloaded file on disk.
        Returns:
            HData: The loaded hypergraph object.
        """
        url = validate_http_url(url)

        response = requests.get(url, timeout=20)
        if response.status_code != 200:
            raise ValueError(
                f"Failed to download dataset from URL '{url}' with status code {response.status_code}"
            )

        with tempfile.NamedTemporaryFile(
            mode="wb", suffix=".json.zst", delete=False
        ) as tmp_zst_file:
            tmp_zst_file.write(response.content)
            zst_filename = tmp_zst_file.name

        if zst_filename.endswith(".zst"):
            if save_on_disk:
                write_to_disk(os.path.basename(url), response.content)
            output = decompress_zst(zst_filename)
        elif zst_filename.endswith(".json"):
            if save_on_disk:
                compressed = compress_to_zst(zst_filename)
                write_to_disk(os.path.basename(url), compressed)
            output = zst_filename
        else:
            raise ValueError(
                f"Unsupported file format for URL '{url}'. Expected .json or .json.zst"
            )

        hypergraph = cls.__extract_hif(output)
        hdata = HIFProcessor.process_hypergraph(hypergraph)
        return hdata

    @classmethod
    def load_from_path(cls, filepath: str) -> HData:
        """
        Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format.
        Args:
            filepath (str): The local file path to the .json or .json.zst file
                containing the HIF hypergraph data.
        Returns:
            HData: The loaded hypergraph object.
        """
        if not os.path.exists(filepath):
            raise ValueError(f"File '{filepath}' does not exist.")

        if filepath.endswith(".zst"):
            output = decompress_zst(filepath)
        elif filepath.endswith(".json"):
            output = filepath
        else:
            raise ValueError(
                f"Unsupported file format for filepath '{filepath}'. Expected .json or .json.zst"
            )

        hypergraph = cls.__extract_hif(output)
        hdata = HIFProcessor.process_hypergraph(hypergraph)
        return hdata

    @classmethod
    def load_by_name(
        cls,
        dataset_name: str,
        hf_sha: str | None = None,
        save_on_disk: bool = False,
    ) -> HData:
        current_dir = os.path.dirname(os.path.abspath(__file__))
        zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst")

        if not os.path.exists(zst_filename):
            github_url = f"https://raw.githubusercontent.com/hypernetwork-research-group/datasets/{GITHUB_COMMIT_SHA}/{dataset_name}.json.zst"
            response = requests.get(github_url, timeout=20)
            if response.status_code != 200:
                warnings.warn(
                    f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n"
                    "Falling back to Hugging Face Hub download for dataset",
                    category=UserWarning,
                    stacklevel=2,
                )

                with tempfile.NamedTemporaryFile(
                    mode="wb", suffix=".json.zst", delete=False
                ) as tmp_hf_file:
                    if hf_sha is not None:
                        try:
                            downloaded_path = hf_hub_download(
                                repo_id=f"HypernetworkRG/{dataset_name}",
                                filename=f"{dataset_name}.json.zst",
                                repo_type="dataset",
                                revision=hf_sha,
                            )
                        except Exception as e:
                            raise ValueError(
                                f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {e!s}"
                            ) from e
                    else:
                        raise ValueError(
                            f"Failed to download dataset '{dataset_name}' from GitHub with status code {response.status_code} and no SHA provided for Hugging Face Hub fallback."
                        )

                    with open(downloaded_path, "rb") as hf_file:
                        hf_content = hf_file.read()
                    tmp_hf_file.write(hf_content)

                response._content = hf_content

            if save_on_disk:
                os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True)
                with open(zst_filename, "wb") as f:
                    f.write(response.content)
            else:
                # Create temporary file for downloaded zst content
                with tempfile.NamedTemporaryFile(
                    mode="wb", suffix=".json.zst", delete=False
                ) as tmp_zst_file:
                    tmp_zst_file.write(response.content)
                    zst_filename = tmp_zst_file.name

        output = decompress_zst(zst_filename)
        hypergraph = cls.__extract_hif(output)
        hdata = HIFProcessor.process_hypergraph(hypergraph)
        return hdata

    @classmethod
    def __extract_hif(cls, json_file: str) -> HIFHypergraph:
        with open(json_file) as f:
            hiftext = json.load(f)
        if not validate_hif_json(json_file):
            raise ValueError(f"Dataset from file '{json_file}' is not HIF-compliant.")
        hypergraph = HIFHypergraph.from_hif(hiftext)
        return hypergraph

load_from_url(url, save_on_disk=False) classmethod

Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format. Args: url (str): The URL to the .json or .json.zst file containing the HIF hypergraph data. save_on_disk (bool): Whether to save the downloaded file on disk. Returns: HData: The loaded hypergraph object.

Source code in hyperbench/data/hif.py
@classmethod
def load_from_url(cls, url: str, save_on_disk: bool = False) -> HData:
    """
    Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format.
    Args:
        url (str): The URL to the .json or .json.zst file containing the HIF hypergraph data.
        save_on_disk (bool): Whether to save the downloaded file on disk.
    Returns:
        HData: The loaded hypergraph object.
    """
    url = validate_http_url(url)

    response = requests.get(url, timeout=20)
    if response.status_code != 200:
        raise ValueError(
            f"Failed to download dataset from URL '{url}' with status code {response.status_code}"
        )

    with tempfile.NamedTemporaryFile(
        mode="wb", suffix=".json.zst", delete=False
    ) as tmp_zst_file:
        tmp_zst_file.write(response.content)
        zst_filename = tmp_zst_file.name

    if zst_filename.endswith(".zst"):
        if save_on_disk:
            write_to_disk(os.path.basename(url), response.content)
        output = decompress_zst(zst_filename)
    elif zst_filename.endswith(".json"):
        if save_on_disk:
            compressed = compress_to_zst(zst_filename)
            write_to_disk(os.path.basename(url), compressed)
        output = zst_filename
    else:
        raise ValueError(
            f"Unsupported file format for URL '{url}'. Expected .json or .json.zst"
        )

    hypergraph = cls.__extract_hif(output)
    hdata = HIFProcessor.process_hypergraph(hypergraph)
    return hdata

load_from_path(filepath) classmethod

Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. Args: filepath (str): The local file path to the .json or .json.zst file containing the HIF hypergraph data. Returns: HData: The loaded hypergraph object.

Source code in hyperbench/data/hif.py
@classmethod
def load_from_path(cls, filepath: str) -> HData:
    """
    Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format.
    Args:
        filepath (str): The local file path to the .json or .json.zst file
            containing the HIF hypergraph data.
    Returns:
        HData: The loaded hypergraph object.
    """
    if not os.path.exists(filepath):
        raise ValueError(f"File '{filepath}' does not exist.")

    if filepath.endswith(".zst"):
        output = decompress_zst(filepath)
    elif filepath.endswith(".json"):
        output = filepath
    else:
        raise ValueError(
            f"Unsupported file format for filepath '{filepath}'. Expected .json or .json.zst"
        )

    hypergraph = cls.__extract_hif(output)
    hdata = HIFProcessor.process_hypergraph(hypergraph)
    return hdata

HIFProcessor

A utility class to process HIF hypergraph data into :class:HData format.

Source code in hyperbench/data/hif.py
class HIFProcessor:
    """A utility class to process HIF hypergraph data into :class:`HData` format."""

    @staticmethod
    def transform_attrs(
        attrs: dict[str, Any],
        attr_keys: list[str] | None = None,
    ) -> Tensor:
        """
        Extract and encode numeric attributes to tensor.
        Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``.

        Args:
            attrs: Dictionary of attributes
            attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``.

        Returns:
            Tensor of numeric attribute values
        """
        numeric_attrs = {
            key: value
            for key, value in attrs.items()
            if isinstance(value, (int, float)) and not isinstance(value, bool)
        }

        if attr_keys is not None:
            values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys]
            return torch.tensor(values, dtype=torch.float)

        if not numeric_attrs:
            return torch.tensor([], dtype=torch.float)

        values = [float(value) for value in numeric_attrs.values()]
        return torch.tensor(values, dtype=torch.float)

    @classmethod
    def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData:
        """
        Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors.

        Returns:
            The processed hypergraph data.
        """

        num_nodes = len(hypergraph.nodes)
        x = cls.__process_x(hypergraph, num_nodes)

        # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order
        node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)}
        # Initialize edge_set only with edges that have incidences, so that
        # we avoid inflating edge count due to isolated nodes/missing incidences
        hyperedge_id_to_idx: dict[Any, int] = {}

        node_ids = []
        hyperedge_ids = []
        nodes_with_incidences = set()
        for incidence in hypergraph.incidences:
            node_id = incidence.get("node", 0)
            hyperedge_id = incidence.get("edge", 0)

            if hyperedge_id not in hyperedge_id_to_idx:
                # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences
                hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx)

            node_ids.append(node_id_to_idx[node_id])
            hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id])
            nodes_with_incidences.add(node_id_to_idx[node_id])

        # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop)
        for node_idx in range(num_nodes):
            if node_idx not in nodes_with_incidences:
                new_hyperedge_id = len(hyperedge_id_to_idx)
                # Unique dummy key to reserve the index in hyperedge_set
                hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id
                node_ids.append(node_idx)
                hyperedge_ids.append(new_hyperedge_id)

        num_hyperedges = len(hyperedge_id_to_idx)
        hyperedge_attr = cls.__process_hyperedge_attr(
            hypergraph=hypergraph,
            hyperedge_id_to_idx=hyperedge_id_to_idx,
            num_hyperedges=num_hyperedges,
        )

        hyperedge_weights = cls.__process_hyperedge_weights(
            hypergraph=hypergraph,
            hyperedge_id_to_idx=hyperedge_id_to_idx,
            num_hyperedges=num_hyperedges,
        )

        hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long)

        return HData(
            x=x,
            hyperedge_index=hyperedge_index,
            hyperedge_weights=hyperedge_weights,
            hyperedge_attr=hyperedge_attr,
            num_nodes=num_nodes,
            num_hyperedges=num_hyperedges,
        )

    @classmethod
    def __collect_attr_keys(cls, attr_keys: list[dict[str, Any]]) -> list[str]:
        """
        Collect unique numeric attribute keys from a list of attribute dictionaries.

        Args:
            attr_keys: List of attribute dictionaries.

        Returns:
            List of unique numeric attribute keys.
        """
        unique_keys = []
        for attrs in attr_keys:
            for key, value in attrs.items():
                if key not in unique_keys and isinstance(value, (int, float)):
                    unique_keys.append(key)

        return unique_keys

    @classmethod
    def __process_hyperedge_attr(
        cls,
        hypergraph: HIFHypergraph,
        hyperedge_id_to_idx: dict[Any, int],
        num_hyperedges: int,
    ) -> Tensor | None:
        # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes]
        hyperedge_attr = None
        has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0
        has_any_hyperedge_attrs = has_hyperedges and any(
            "attrs" in edge for edge in hypergraph.hyperedges
        )

        if has_any_hyperedge_attrs:
            hyperedge_id_to_attrs: dict[Any, dict[str, Any]] = {
                e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges
            }

            hyperedge_attr_keys = cls.__collect_attr_keys(list(hyperedge_id_to_attrs.values()))

            # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1)
            hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()}

            attrs = []
            for hyperedge_idx in range(num_hyperedges):
                hyperedge_id = hyperedge_idx_to_id[hyperedge_idx]

                transformed_attrs = cls.transform_attrs(
                    # If it's a real hyperedge, get its attrs; if self-loop, get empty dict
                    attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}),
                    attr_keys=hyperedge_attr_keys,
                )
                attrs.append(transformed_attrs)

            hyperedge_attr = torch.stack(attrs)

        return hyperedge_attr

    @classmethod
    def __process_x(cls, hypergraph: HIFHypergraph, num_nodes: int) -> Tensor:
        # Collect all attribute keys to have tensors of same size
        node_attr_keys = cls.__collect_attr_keys(
            [node.get("attrs", {}) for node in hypergraph.nodes]
        )

        if node_attr_keys:
            x = torch.stack(
                [
                    cls.transform_attrs(node.get("attrs", {}), attr_keys=node_attr_keys)
                    for node in hypergraph.nodes
                ]
            )
        else:
            # Fallback to ones if no node features, 1 is better as it can help during
            # training (e.g., avoid zero multiplication), especially in first epochs
            x = torch.ones((num_nodes, 1), dtype=torch.float)

        return x  # shape [num_nodes, num_node_features]

    @classmethod
    def __process_hyperedge_weights(
        cls,
        hypergraph: HIFHypergraph,
        hyperedge_id_to_idx: dict[Any, int],
        num_hyperedges: int,
    ) -> Tensor | None:
        has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0
        has_any_hyperedge_attrs = has_hyperedges and any(
            "attrs" in edge for edge in hypergraph.hyperedges
        )

        # Keep old behavior for fixtures where edges have no attrs at all.
        if not has_any_hyperedge_attrs:
            return None

        # Map real edge id -> attrs (self-loops are absent and will default to 1.0)
        hyperedge_id_to_attrs: dict[Any, dict[str, Any]] = {
            e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges
        }

        # Build in exact hyperedge index order, defaulting missing weights to 1.0.
        hyperedge_idx_to_id = {idx: edge_id for edge_id, idx in hyperedge_id_to_idx.items()}
        weights = []
        for hyperedge_idx in range(num_hyperedges):
            edge_id = hyperedge_idx_to_id[hyperedge_idx]
            edge_attrs = hyperedge_id_to_attrs.get(edge_id, {})
            weights.append(float(edge_attrs.get("weight", 1.0)))

        return torch.tensor(weights, dtype=torch.float)

transform_attrs(attrs, attr_keys=None) staticmethod

Extract and encode numeric attributes to tensor. Non-numeric attributes are discarded. Missing attributes are filled with 0.0.

Parameters:

Name Type Description Default
attrs dict[str, Any]

Dictionary of attributes

required
attr_keys list[str] | None

Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with 0.0.

None

Returns:

Type Description
Tensor

Tensor of numeric attribute values

Source code in hyperbench/data/hif.py
@staticmethod
def transform_attrs(
    attrs: dict[str, Any],
    attr_keys: list[str] | None = None,
) -> Tensor:
    """
    Extract and encode numeric attributes to tensor.
    Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``.

    Args:
        attrs: Dictionary of attributes
        attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``.

    Returns:
        Tensor of numeric attribute values
    """
    numeric_attrs = {
        key: value
        for key, value in attrs.items()
        if isinstance(value, (int, float)) and not isinstance(value, bool)
    }

    if attr_keys is not None:
        values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys]
        return torch.tensor(values, dtype=torch.float)

    if not numeric_attrs:
        return torch.tensor([], dtype=torch.float)

    values = [float(value) for value in numeric_attrs.values()]
    return torch.tensor(values, dtype=torch.float)

process_hypergraph(hypergraph) classmethod

Process the loaded hypergraph into :class:HData format, mapping HIF structure to tensors.

Returns:

Type Description
HData

The processed hypergraph data.

Source code in hyperbench/data/hif.py
@classmethod
def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData:
    """
    Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors.

    Returns:
        The processed hypergraph data.
    """

    num_nodes = len(hypergraph.nodes)
    x = cls.__process_x(hypergraph, num_nodes)

    # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order
    node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)}
    # Initialize edge_set only with edges that have incidences, so that
    # we avoid inflating edge count due to isolated nodes/missing incidences
    hyperedge_id_to_idx: dict[Any, int] = {}

    node_ids = []
    hyperedge_ids = []
    nodes_with_incidences = set()
    for incidence in hypergraph.incidences:
        node_id = incidence.get("node", 0)
        hyperedge_id = incidence.get("edge", 0)

        if hyperedge_id not in hyperedge_id_to_idx:
            # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences
            hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx)

        node_ids.append(node_id_to_idx[node_id])
        hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id])
        nodes_with_incidences.add(node_id_to_idx[node_id])

    # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop)
    for node_idx in range(num_nodes):
        if node_idx not in nodes_with_incidences:
            new_hyperedge_id = len(hyperedge_id_to_idx)
            # Unique dummy key to reserve the index in hyperedge_set
            hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id
            node_ids.append(node_idx)
            hyperedge_ids.append(new_hyperedge_id)

    num_hyperedges = len(hyperedge_id_to_idx)
    hyperedge_attr = cls.__process_hyperedge_attr(
        hypergraph=hypergraph,
        hyperedge_id_to_idx=hyperedge_id_to_idx,
        num_hyperedges=num_hyperedges,
    )

    hyperedge_weights = cls.__process_hyperedge_weights(
        hypergraph=hypergraph,
        hyperedge_id_to_idx=hyperedge_id_to_idx,
        num_hyperedges=num_hyperedges,
    )

    hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long)

    return HData(
        x=x,
        hyperedge_index=hyperedge_index,
        hyperedge_weights=hyperedge_weights,
        hyperedge_attr=hyperedge_attr,
        num_nodes=num_nodes,
        num_hyperedges=num_hyperedges,
    )

__collect_attr_keys(attr_keys) classmethod

Collect unique numeric attribute keys from a list of attribute dictionaries.

Parameters:

Name Type Description Default
attr_keys list[dict[str, Any]]

List of attribute dictionaries.

required

Returns:

Type Description
list[str]

List of unique numeric attribute keys.

Source code in hyperbench/data/hif.py
@classmethod
def __collect_attr_keys(cls, attr_keys: list[dict[str, Any]]) -> list[str]:
    """
    Collect unique numeric attribute keys from a list of attribute dictionaries.

    Args:
        attr_keys: List of attribute dictionaries.

    Returns:
        List of unique numeric attribute keys.
    """
    unique_keys = []
    for attrs in attr_keys:
        for key, value in attrs.items():
            if key not in unique_keys and isinstance(value, (int, float)):
                unique_keys.append(key)

    return unique_keys

DataLoader

Bases: DataLoader

Source code in hyperbench/data/loader.py
class DataLoader(TorchDataLoader):
    def __init__(
        self,
        dataset: Dataset,
        batch_size: int = 1,
        shuffle: bool | None = False,
        sample_full_hypergraph: bool = False,
        **kwargs,
    ) -> None:
        self.__sample_full_hypergraph = sample_full_hypergraph

        super().__init__(
            dataset=dataset,
            batch_size=len(dataset) if sample_full_hypergraph else batch_size,
            shuffle=shuffle,
            collate_fn=self.collate,
            **kwargs,
        )

        self.__cached_dataset_hdata = dataset.hdata

    def collate(self, batch: list[HData]) -> HData:
        """
        Collates a list of :class:`HData objects into a single batched :class:`HData object.

        This function combines multiple separate samples into a single batched representation suitable for mini-batch training.
        It handles:
        - Concatenating node features from all samples.
        - Concatenating and offsetting hyperedges from all samples.
        - Concatenating hyperedge attributes from all samples, if present.
        - Concatenating hyperedge weights from all samples, if present.

        Examples:
            Given ``batch = [HData_0, HData_1]``:

            For node features:

            >>> HData_0.x.shape  # (3, 64) — 3 nodes with 64 features
            >>> HData_1.x.shape  # (2, 64) — 2 nodes with 64 features
            >>> x.shape  # (5, 64) — all 5 nodes concatenated

            For hyperedge index:

            - ``HData_0`` (3 nodes, 2 hyperedges):

            >>> hyperedge_index = [[0, 1, 1, 2],  # Nodes 0, 1, 1, 2
            ...                    [0, 0, 1, 1]]  # Hyperedge 0 contains {0,1}, Hyperedge 1 contains {1,2}

            - ``HData_1`` (2 nodes, 1 hyperedge):

            >>> hyperedge_index = [[0, 1],  # Nodes 0, 1
            ...                    [0, 0]]  # Hyperedge 0 contains {0,1}

            Batched result:

            >>> hyperedge_index = [[0, 1, 1, 2, 3, 4],  # Node indices: original then offset by 3
            ...                    [0, 0, 1, 1, 2, 2]]  # Hyperedge IDs: original then offset by 2

        Args:
            batch: List of :class:`HData objects to collate.

        Returns:
            A single :class:`HData` object containing the collated data.
        """
        if self.__sample_full_hypergraph:
            return self.__cached_dataset_hdata.clone().to(batch[0].device)

        collated_hyperedge_index = torch.cat([data.hyperedge_index for data in batch], dim=1)
        hyperedge_index_wrapper = HyperedgeIndex(collated_hyperedge_index).remove_duplicate_edges()

        hyperedge_ids = hyperedge_index_wrapper.hyperedge_ids
        node_ids = hyperedge_index_wrapper.node_ids

        collated_x = self.__cached_dataset_hdata.x[node_ids]
        collated_y = self.__cached_dataset_hdata.y[hyperedge_ids]

        collated_global_node_ids = None
        if self.__cached_dataset_hdata.global_node_ids is not None:
            collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids]

        collated_hyperedge_attr = None
        if self.__cached_dataset_hdata.hyperedge_attr is not None:
            collated_hyperedge_attr = self.__cached_dataset_hdata.hyperedge_attr[hyperedge_ids]

        collated_hyperedge_weights = None
        if self.__cached_dataset_hdata.hyperedge_weights is not None:
            collated_hyperedge_weights = self.__cached_dataset_hdata.hyperedge_weights[
                hyperedge_ids
            ]

        collated_hyperedge_index = hyperedge_index_wrapper.to_0based().item

        collated_hdata = HData(
            x=collated_x,
            hyperedge_index=collated_hyperedge_index,
            hyperedge_weights=collated_hyperedge_weights,
            hyperedge_attr=collated_hyperedge_attr,
            num_nodes=hyperedge_index_wrapper.num_nodes,
            num_hyperedges=hyperedge_index_wrapper.num_hyperedges,
            global_node_ids=collated_global_node_ids,
            y=collated_y,
        )

        return collated_hdata.to(batch[0].device)

collate(batch)

Collates a list of :class:HData objects into a single batched :class:HData object.

This function combines multiple separate samples into a single batched representation suitable for mini-batch training. It handles: - Concatenating node features from all samples. - Concatenating and offsetting hyperedges from all samples. - Concatenating hyperedge attributes from all samples, if present. - Concatenating hyperedge weights from all samples, if present.

Examples:

Given batch = [HData_0, HData_1]:

For node features:

>>> HData_0.x.shape  # (3, 64) — 3 nodes with 64 features
>>> HData_1.x.shape  # (2, 64) — 2 nodes with 64 features
>>> x.shape  # (5, 64) — all 5 nodes concatenated

For hyperedge index:

  • HData_0 (3 nodes, 2 hyperedges):
>>> hyperedge_index = [[0, 1, 1, 2],  # Nodes 0, 1, 1, 2
...                    [0, 0, 1, 1]]  # Hyperedge 0 contains {0,1}, Hyperedge 1 contains {1,2}
  • HData_1 (2 nodes, 1 hyperedge):
>>> hyperedge_index = [[0, 1],  # Nodes 0, 1
...                    [0, 0]]  # Hyperedge 0 contains {0,1}

Batched result:

>>> hyperedge_index = [[0, 1, 1, 2, 3, 4],  # Node indices: original then offset by 3
...                    [0, 0, 1, 1, 2, 2]]  # Hyperedge IDs: original then offset by 2

Parameters:

Name Type Description Default
batch list[HData]

List of :class:`HData objects to collate.

required

Returns:

Type Description
HData

A single :class:HData object containing the collated data.

Source code in hyperbench/data/loader.py
def collate(self, batch: list[HData]) -> HData:
    """
    Collates a list of :class:`HData objects into a single batched :class:`HData object.

    This function combines multiple separate samples into a single batched representation suitable for mini-batch training.
    It handles:
    - Concatenating node features from all samples.
    - Concatenating and offsetting hyperedges from all samples.
    - Concatenating hyperedge attributes from all samples, if present.
    - Concatenating hyperedge weights from all samples, if present.

    Examples:
        Given ``batch = [HData_0, HData_1]``:

        For node features:

        >>> HData_0.x.shape  # (3, 64) — 3 nodes with 64 features
        >>> HData_1.x.shape  # (2, 64) — 2 nodes with 64 features
        >>> x.shape  # (5, 64) — all 5 nodes concatenated

        For hyperedge index:

        - ``HData_0`` (3 nodes, 2 hyperedges):

        >>> hyperedge_index = [[0, 1, 1, 2],  # Nodes 0, 1, 1, 2
        ...                    [0, 0, 1, 1]]  # Hyperedge 0 contains {0,1}, Hyperedge 1 contains {1,2}

        - ``HData_1`` (2 nodes, 1 hyperedge):

        >>> hyperedge_index = [[0, 1],  # Nodes 0, 1
        ...                    [0, 0]]  # Hyperedge 0 contains {0,1}

        Batched result:

        >>> hyperedge_index = [[0, 1, 1, 2, 3, 4],  # Node indices: original then offset by 3
        ...                    [0, 0, 1, 1, 2, 2]]  # Hyperedge IDs: original then offset by 2

    Args:
        batch: List of :class:`HData objects to collate.

    Returns:
        A single :class:`HData` object containing the collated data.
    """
    if self.__sample_full_hypergraph:
        return self.__cached_dataset_hdata.clone().to(batch[0].device)

    collated_hyperedge_index = torch.cat([data.hyperedge_index for data in batch], dim=1)
    hyperedge_index_wrapper = HyperedgeIndex(collated_hyperedge_index).remove_duplicate_edges()

    hyperedge_ids = hyperedge_index_wrapper.hyperedge_ids
    node_ids = hyperedge_index_wrapper.node_ids

    collated_x = self.__cached_dataset_hdata.x[node_ids]
    collated_y = self.__cached_dataset_hdata.y[hyperedge_ids]

    collated_global_node_ids = None
    if self.__cached_dataset_hdata.global_node_ids is not None:
        collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids]

    collated_hyperedge_attr = None
    if self.__cached_dataset_hdata.hyperedge_attr is not None:
        collated_hyperedge_attr = self.__cached_dataset_hdata.hyperedge_attr[hyperedge_ids]

    collated_hyperedge_weights = None
    if self.__cached_dataset_hdata.hyperedge_weights is not None:
        collated_hyperedge_weights = self.__cached_dataset_hdata.hyperedge_weights[
            hyperedge_ids
        ]

    collated_hyperedge_index = hyperedge_index_wrapper.to_0based().item

    collated_hdata = HData(
        x=collated_x,
        hyperedge_index=collated_hyperedge_index,
        hyperedge_weights=collated_hyperedge_weights,
        hyperedge_attr=collated_hyperedge_attr,
        num_nodes=hyperedge_index_wrapper.num_nodes,
        num_hyperedges=hyperedge_index_wrapper.num_hyperedges,
        global_node_ids=collated_global_node_ids,
        y=collated_y,
    )

    return collated_hdata.to(batch[0].device)

BaseSampler

Bases: ABC

Source code in hyperbench/data/sampling.py
class BaseSampler(ABC):
    @abstractmethod
    def sample(self, index: int | list[int], hdata: HData) -> HData:
        """
        Sample a sub-hypergraph and return HData with global IDs.

        Args:
            index: An integer or list of integers specifying which items to sample.
            hdata: The original HData to sample from.

        Returns:
            A new HData instance containing only the sampled items and their associated data.
        """
        raise NotImplementedError("Subclasses must implement the sample method.")

    @abstractmethod
    def len(self, hdata: HData) -> int:
        """
        Return the number of sampleable items (nodes or hyperedges).

        Args:
            hdata: The HData to query for the number of sampleable items.
        """
        raise NotImplementedError("Subclasses must implement the len method.")

    def _normalize_index(self, index: int | list[int], size: int) -> list[int]:
        """
        Convert index to list, deduplicate, validate length.

        Args:
            index: An integer or a list of integers representing IDs to sample.
            size: The total number of sampleable items (e.g., nodes or hyperedges) for validation.

        Returns:
            List of IDs to sample.

        Raises:
            ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of sampleable items).
        """
        if isinstance(index, list):
            if len(index) < 1:
                raise ValueError("Index list cannot be empty.")
            if len(index) > size:
                raise ValueError(
                    f"Index list length ({len(index)}) cannot exceed the number of sampleable items ({size})."
                )
            return list(set(index))
        return [index]

    def _sample_hyperedge_index(
        self,
        hyperedge_index: Tensor,
        sampled_hyperedge_ids: Tensor,
    ) -> Tensor:
        """
        Sample the hyperedge index to keep only incidences belonging to the specified sampled hyperedge IDs.

        Args:
            hyperedge_index: The original hyperedge index tensor of shape ``[2, num_incidences]``.
            sampled_hyperedge_ids: A tensor containing the IDs of hyperedges to sample.

        Returns:
            A new hyperedge index tensor containing only the incidences of the sampled hyperedges.
        """
        hyperedge_ids = hyperedge_index[1]

        # Find incidences where the hyperedge is in our sampled hyperedges
        # Example: hyperedge_ids = [0, 0, 0, 1, 2, 2], sampled_hyperedge_ids = [0, 2]
        #          -> sampled_hyperedges_mask = [True, True, True, False, True, True]
        sampled_hyperedges_mask = torch.isin(hyperedge_ids, sampled_hyperedge_ids)

        # Keep all incidences belonging to the sampled hyperedges
        # Example: hyperedge_index = [[0, 0, 1, 2, 3, 4],
        #                             [0, 0, 0, 1, 2, 2]],
        #          sampled_hyperedges_mask = [True, True, True, False, True, True]
        #          -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
        #                                        [0, 0, 0, 2, 2]]
        sampled_hyperedge_index = hyperedge_index[:, sampled_hyperedges_mask]
        return sampled_hyperedge_index

    def _validate_bounds(self, ids: list[int], size: int, label: str) -> None:
        """
        Check all IDs are in [0, self.len).

        Args:
            ids: List of IDs to validate.
            size: The total number of sampleable items (e.g., nodes or hyperedges).
            label: A string label for error messages (e.g., "Node ID" or "Hyperedge ID").

        Raises:
            IndexError: If any ID is out of bounds.
        """
        for id in ids:
            if id < 0 or id >= size:
                raise IndexError(f"{label} {id} is out of bounds (0, {size - 1}).")

sample(index, hdata) abstractmethod

Sample a sub-hypergraph and return HData with global IDs.

Parameters:

Name Type Description Default
index int | list[int]

An integer or list of integers specifying which items to sample.

required
hdata HData

The original HData to sample from.

required

Returns:

Type Description
HData

A new HData instance containing only the sampled items and their associated data.

Source code in hyperbench/data/sampling.py
@abstractmethod
def sample(self, index: int | list[int], hdata: HData) -> HData:
    """
    Sample a sub-hypergraph and return HData with global IDs.

    Args:
        index: An integer or list of integers specifying which items to sample.
        hdata: The original HData to sample from.

    Returns:
        A new HData instance containing only the sampled items and their associated data.
    """
    raise NotImplementedError("Subclasses must implement the sample method.")

len(hdata) abstractmethod

Return the number of sampleable items (nodes or hyperedges).

Parameters:

Name Type Description Default
hdata HData

The HData to query for the number of sampleable items.

required
Source code in hyperbench/data/sampling.py
@abstractmethod
def len(self, hdata: HData) -> int:
    """
    Return the number of sampleable items (nodes or hyperedges).

    Args:
        hdata: The HData to query for the number of sampleable items.
    """
    raise NotImplementedError("Subclasses must implement the len method.")

HyperedgeSampler

Bases: BaseSampler

Source code in hyperbench/data/sampling.py
class HyperedgeSampler(BaseSampler):
    def sample(self, index: int | list[int], hdata: HData) -> HData:
        """
        Sample hyperedges by their IDs and return the sub-hypergraph containing only those hyperedges and their incident nodes.

        Examples:
        >>> hyperedge_index = [[0, 0, 1, 2, 3, 4],
        ...                    [0, 0, 0, 1, 2, 2]]
        >>> hdata = HData.from_hyperedge_index(hyperedge_index)
        >>> strategy = HyperedgeSampler()
        >>> sampled_hdata = strategy.sample([0, 2], hdata)
        >>> sampled_hdata.hyperedge_index
        >>> tensor([[0, 0, 1, 3, 4],
        ...         [0, 0, 0, 2, 2]])

        Args:
            index: An integer or a list of integers representing hyperedge IDs to sample.
            hdata: The original HData to sample from.

        Returns:
            An HData instance containing only the sampled hyperedges and their incident nodes.

        Raises:
            ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of hyperedges).
            IndexError: If any hyperedge ID is out of bounds.
        """
        ids = self._normalize_index(index, self.len(hdata))
        self._validate_bounds(ids, self.len(hdata), "Hyperedge ID")

        hyperedge_index = hdata.hyperedge_index

        sampled_hyperedge_ids = torch.tensor(ids, device=hyperedge_index.device)

        # Example: sampled_hyperedge_ids = [0, 2],
        #          hyperedge_index = [[0, 0, 1, 2, 3, 4],
        #                             [0, 0, 0, 1, 2, 2]],
        #          -> sampled_hyperedges_mask = [True, True, True, False, True, True]
        #          -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
        #                                        [0, 0, 0, 2, 2]]
        sampled_hyperedge_index = self._sample_hyperedge_index(
            hyperedge_index, sampled_hyperedge_ids
        )

        return HData.from_hyperedge_index(sampled_hyperedge_index)

    def len(self, hdata: HData) -> int:
        """
        Return the number of hyperedges in the given HData.

        Args:
            hdata: The HData to query for the number of hyperedges.

        Returns:
            The number of hyperedges in the HData.
        """
        return hdata.num_hyperedges

sample(index, hdata)

Sample hyperedges by their IDs and return the sub-hypergraph containing only those hyperedges and their incident nodes.

Examples:

hyperedge_index = [[0, 0, 1, 2, 3, 4], ... [0, 0, 0, 1, 2, 2]] hdata = HData.from_hyperedge_index(hyperedge_index) strategy = HyperedgeSampler() sampled_hdata = strategy.sample([0, 2], hdata) sampled_hdata.hyperedge_index tensor([[0, 0, 1, 3, 4], ... [0, 0, 0, 2, 2]])

Parameters:

Name Type Description Default
index int | list[int]

An integer or a list of integers representing hyperedge IDs to sample.

required
hdata HData

The original HData to sample from.

required

Returns:

Type Description
HData

An HData instance containing only the sampled hyperedges and their incident nodes.

Raises:

Type Description
ValueError

If the provided index is invalid (e.g., empty list or list length exceeds number of hyperedges).

IndexError

If any hyperedge ID is out of bounds.

Source code in hyperbench/data/sampling.py
def sample(self, index: int | list[int], hdata: HData) -> HData:
    """
    Sample hyperedges by their IDs and return the sub-hypergraph containing only those hyperedges and their incident nodes.

    Examples:
    >>> hyperedge_index = [[0, 0, 1, 2, 3, 4],
    ...                    [0, 0, 0, 1, 2, 2]]
    >>> hdata = HData.from_hyperedge_index(hyperedge_index)
    >>> strategy = HyperedgeSampler()
    >>> sampled_hdata = strategy.sample([0, 2], hdata)
    >>> sampled_hdata.hyperedge_index
    >>> tensor([[0, 0, 1, 3, 4],
    ...         [0, 0, 0, 2, 2]])

    Args:
        index: An integer or a list of integers representing hyperedge IDs to sample.
        hdata: The original HData to sample from.

    Returns:
        An HData instance containing only the sampled hyperedges and their incident nodes.

    Raises:
        ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of hyperedges).
        IndexError: If any hyperedge ID is out of bounds.
    """
    ids = self._normalize_index(index, self.len(hdata))
    self._validate_bounds(ids, self.len(hdata), "Hyperedge ID")

    hyperedge_index = hdata.hyperedge_index

    sampled_hyperedge_ids = torch.tensor(ids, device=hyperedge_index.device)

    # Example: sampled_hyperedge_ids = [0, 2],
    #          hyperedge_index = [[0, 0, 1, 2, 3, 4],
    #                             [0, 0, 0, 1, 2, 2]],
    #          -> sampled_hyperedges_mask = [True, True, True, False, True, True]
    #          -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
    #                                        [0, 0, 0, 2, 2]]
    sampled_hyperedge_index = self._sample_hyperedge_index(
        hyperedge_index, sampled_hyperedge_ids
    )

    return HData.from_hyperedge_index(sampled_hyperedge_index)

len(hdata)

Return the number of hyperedges in the given HData.

Parameters:

Name Type Description Default
hdata HData

The HData to query for the number of hyperedges.

required

Returns:

Type Description
int

The number of hyperedges in the HData.

Source code in hyperbench/data/sampling.py
def len(self, hdata: HData) -> int:
    """
    Return the number of hyperedges in the given HData.

    Args:
        hdata: The HData to query for the number of hyperedges.

    Returns:
        The number of hyperedges in the HData.
    """
    return hdata.num_hyperedges

NodeSampler

Bases: BaseSampler

Source code in hyperbench/data/sampling.py
class NodeSampler(BaseSampler):
    def sample(self, index: int | list[int], hdata: HData) -> HData:
        """
        Sample nodes by their IDs and return the sub-hypergraph containing only those nodes and their incident hyperedges.

        Examples:
        >>> hyperedge_index = [[0, 0, 1, 2, 3, 4],
        ...                    [0, 0, 0, 1, 2, 2]]
        >>> hdata = HData.from_hyperedge_index(hyperedge_index)
        >>> strategy = NodeSampler()
        >>> sampled_hdata = strategy.sample([0, 3], hdata)
        >>> sampled_hdata.hyperedge_index
        >>> tensor([[0, 0, 1, 3, 4],
        ...         [0, 0, 0, 2, 2]])

        Args:
            index: An integer or a list of integers representing node IDs to sample.
            hdata: The original HData to sample from.

        Returns:
            An HData instance containing only the sampled nodes and their incident hyperedges.

        Raises:
            ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes).
            IndexError: If any node ID is out of bounds.
        """
        ids = self._normalize_index(index, self.len(hdata))
        self._validate_bounds(ids, self.len(hdata), "Node ID")

        hyperedge_index = hdata.hyperedge_index
        node_ids = hyperedge_index[0]
        hyperedge_ids = hyperedge_index[1]

        sampled_node_ids = torch.tensor(ids, device=node_ids.device)

        # Find incidences where the node is in our sampled nodes
        # Example: node_ids = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3]
        #          -> sampled_nodes_mask = [True, True, False, False, True, False]
        sampled_nodes_mask = torch.isin(node_ids, sampled_node_ids)

        # Get unique hyperedges that have at least one sampled node
        # Example: hyperedge_ids = [0, 0, 0, 1, 2, 2], sampled_nodes_mask = [True, True, False, False, True, False]
        #          -> sampled_hyperedge_ids = [0, 2] as they connect to sampled nodes
        sampled_hyperedge_ids = hyperedge_ids[sampled_nodes_mask].unique()

        # Example: sampled_hyperedge_ids = [0, 2],
        #          hyperedge_index = [[0, 0, 1, 2, 3, 4],
        #                             [0, 0, 0, 1, 2, 2]],
        #          -> sampled_hyperedges_mask = [True, True, True, False, True, True]
        #          -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
        #                                        [0, 0, 0, 2, 2]]
        sampled_hyperedge_index = self._sample_hyperedge_index(
            hyperedge_index, sampled_hyperedge_ids
        )

        return HData.from_hyperedge_index(sampled_hyperedge_index)

    def len(self, hdata: HData) -> int:
        """
        Return the number of nodes in the given HData.

        Args:
            hdata: The HData to query for the number of nodes.

        Returns:
            The number of nodes in the HData.
        """
        return hdata.num_nodes

sample(index, hdata)

Sample nodes by their IDs and return the sub-hypergraph containing only those nodes and their incident hyperedges.

Examples:

hyperedge_index = [[0, 0, 1, 2, 3, 4], ... [0, 0, 0, 1, 2, 2]] hdata = HData.from_hyperedge_index(hyperedge_index) strategy = NodeSampler() sampled_hdata = strategy.sample([0, 3], hdata) sampled_hdata.hyperedge_index tensor([[0, 0, 1, 3, 4], ... [0, 0, 0, 2, 2]])

Parameters:

Name Type Description Default
index int | list[int]

An integer or a list of integers representing node IDs to sample.

required
hdata HData

The original HData to sample from.

required

Returns:

Type Description
HData

An HData instance containing only the sampled nodes and their incident hyperedges.

Raises:

Type Description
ValueError

If the provided index is invalid (e.g., empty list or list length exceeds number of nodes).

IndexError

If any node ID is out of bounds.

Source code in hyperbench/data/sampling.py
def sample(self, index: int | list[int], hdata: HData) -> HData:
    """
    Sample nodes by their IDs and return the sub-hypergraph containing only those nodes and their incident hyperedges.

    Examples:
    >>> hyperedge_index = [[0, 0, 1, 2, 3, 4],
    ...                    [0, 0, 0, 1, 2, 2]]
    >>> hdata = HData.from_hyperedge_index(hyperedge_index)
    >>> strategy = NodeSampler()
    >>> sampled_hdata = strategy.sample([0, 3], hdata)
    >>> sampled_hdata.hyperedge_index
    >>> tensor([[0, 0, 1, 3, 4],
    ...         [0, 0, 0, 2, 2]])

    Args:
        index: An integer or a list of integers representing node IDs to sample.
        hdata: The original HData to sample from.

    Returns:
        An HData instance containing only the sampled nodes and their incident hyperedges.

    Raises:
        ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes).
        IndexError: If any node ID is out of bounds.
    """
    ids = self._normalize_index(index, self.len(hdata))
    self._validate_bounds(ids, self.len(hdata), "Node ID")

    hyperedge_index = hdata.hyperedge_index
    node_ids = hyperedge_index[0]
    hyperedge_ids = hyperedge_index[1]

    sampled_node_ids = torch.tensor(ids, device=node_ids.device)

    # Find incidences where the node is in our sampled nodes
    # Example: node_ids = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3]
    #          -> sampled_nodes_mask = [True, True, False, False, True, False]
    sampled_nodes_mask = torch.isin(node_ids, sampled_node_ids)

    # Get unique hyperedges that have at least one sampled node
    # Example: hyperedge_ids = [0, 0, 0, 1, 2, 2], sampled_nodes_mask = [True, True, False, False, True, False]
    #          -> sampled_hyperedge_ids = [0, 2] as they connect to sampled nodes
    sampled_hyperedge_ids = hyperedge_ids[sampled_nodes_mask].unique()

    # Example: sampled_hyperedge_ids = [0, 2],
    #          hyperedge_index = [[0, 0, 1, 2, 3, 4],
    #                             [0, 0, 0, 1, 2, 2]],
    #          -> sampled_hyperedges_mask = [True, True, True, False, True, True]
    #          -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
    #                                        [0, 0, 0, 2, 2]]
    sampled_hyperedge_index = self._sample_hyperedge_index(
        hyperedge_index, sampled_hyperedge_ids
    )

    return HData.from_hyperedge_index(sampled_hyperedge_index)

len(hdata)

Return the number of nodes in the given HData.

Parameters:

Name Type Description Default
hdata HData

The HData to query for the number of nodes.

required

Returns:

Type Description
int

The number of nodes in the HData.

Source code in hyperbench/data/sampling.py
def len(self, hdata: HData) -> int:
    """
    Return the number of nodes in the given HData.

    Args:
        hdata: The HData to query for the number of nodes.

    Returns:
        The number of nodes in the HData.
    """
    return hdata.num_nodes

create_sampler_from_strategy(strategy)

Factory function to create a sampler instance based on the provided sampling strategy type.

Parameters:

Name Type Description Default
strategy SamplingStrategy

An instance of SamplingStrategy enum indicating which sampling strategy to use.

required

Returns:

Type Description
BaseSampler

An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to HyperedgeSampler.

Source code in hyperbench/data/sampling.py
def create_sampler_from_strategy(strategy: SamplingStrategy) -> BaseSampler:
    """
    Factory function to create a sampler instance based on the provided sampling strategy type.

    Args:
        strategy: An instance of SamplingStrategy enum indicating which sampling strategy to use.

    Returns:
        An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to ``HyperedgeSampler``.
    """
    match strategy:
        case SamplingStrategy.NODE:
            return NodeSampler()
        case _:
            return HyperedgeSampler()