from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Generic, Literal, TypeVar, cast, overload
import awkward as ak
import numba as nb
import numpy as np
import polars as pl
from genoray._utils import ContigNormalizer
from hirola import HashTable
from numpy.typing import ArrayLike, NDArray
from seqpro.rag import Ragged, lengths_to_offsets
from typing_extensions import Self
from .._flat import _Flat
from .._fasta_cache import ensure_cache
from .._ragged import RaggedSeqs, reverse_complement_masked, to_padded
from .._torch import TORCH_AVAILABLE, get_dataloader, no_torch_error
from .._types import Idx, StrIdx
from .._utils import is_dtype
from ._indexing import is_str_arr, s2i
from ._splice import SpliceMap, SplicePlan, build_splice_plan
from ._utils import bed_to_regions, padded_slice
INT64_MAX = np.iinfo(np.int64).max
[docs]
@dataclass(slots=True)
class Reference:
"""A reference genome kept in-memory. Typically this is only instantiated to be
passed to :meth:`Dataset.open <genvarloader.Dataset.open>` and avoid data duplication.
.. note::
Do not instantiate this class directly. Use :meth:`Reference.from_path` instead.
"""
path: Path
"""The path to the reference genome."""
reference: NDArray[np.uint8]
"""The reference genome as a numpy array, with contigs concatenated."""
offsets: NDArray[np.int64]
"""The offsets of the contigs in the reference genome. Shape: (n_contigs + 1)"""
pad_char: int
"""The padding character used in the reference genome."""
c_map: ContigNormalizer
[docs]
@classmethod
def from_path(
cls,
fasta: str | Path,
contigs: list[str] | None = None,
in_memory: bool = True,
):
"""Load a reference genome from a FASTA file.
Parameters
----------
fasta
Path to a ``.fa``/``.fa.bgz`` FASTA file or an existing ``.gvlfa``
cache directory. When a FASTA path is given, a sibling ``.gvlfa``
cache is built on first use and reused on subsequent calls; a legacy
``.fa.gvl`` flat cache is automatically migrated to the new format.
contigs
List of contig names to load. If None, all contigs in the FASTA file are loaded.
Can be either UCSC or Ensembl style (i.e. with or without the "chr" prefix) and
will be handled appropriately to match the underlying FASTA.
in_memory
Whether to load the reference genome into memory. If True, the reference genome
is loaded into memory. If False, the reference genome is read on-demand from a
memory mapped array. This will still be much faster than reading from FASTA but
slower than keeping it in memory. This is useful if you need to work with many
reference genomes or have very limited RAM.
"""
path = Path(fasta)
meta, data_path = ensure_cache(fasta)
full_contigs = meta.contigs
ref_mmap = np.memmap(data_path, np.uint8, "r")
offsets = lengths_to_offsets(np.array(list(full_contigs.values())))
pad_char = ord("N")
c_map = ContigNormalizer(full_contigs)
if contigs is None:
contigs = c_map.contigs
else:
_contigs = c_map.norm(contigs)
if unmapped := [
source for source, mapped in zip(contigs, _contigs) if mapped is None
]:
raise ValueError(
f"Some of the given contig names are not present in reference file: {unmapped}"
)
contigs = cast(list[str], _contigs)
c_map = ContigNormalizer(contigs)
if in_memory:
reference = np.empty(sum(full_contigs[c] for c in contigs), np.uint8)
offset = 0
for c in contigs:
c_idx = list(full_contigs).index(c)
o_s, o_e = offsets[c_idx], offsets[c_idx + 1]
reference[offset : offset + o_e - o_s] = ref_mmap[o_s:o_e]
offset += o_e - o_s
offsets = lengths_to_offsets(np.array([full_contigs[c] for c in contigs]))
else:
reference = ref_mmap
return cls(path, reference, offsets, pad_char, c_map)
@property
def contigs(self) -> list[str]:
return self.c_map.contigs
def fetch(
self, contigs: ArrayLike, starts: ArrayLike = 0, ends: ArrayLike = INT64_MAX
) -> Ragged[np.bytes_]:
contigs = np.atleast_1d(contigs)
starts = np.atleast_1d(starts)
ends = np.atleast_1d(ends)
if not is_dtype(contigs, np.integer):
c_idxs = self.c_map.c_idxs(contigs)
if (c_idxs == -1).any():
raise ValueError("Some contigs not found in reference.")
else:
c_idxs = contigs
lengths = ends - starts
offsets = lengths_to_offsets(lengths)
seqs = np.empty(offsets[-1], np.uint8)
_fetch_impl(
c_idxs,
starts,
ends,
self.reference,
self.offsets,
self.pad_char,
seqs,
offsets,
)
seqs = Ragged.from_offsets(seqs.view("S1"), (len(contigs), None), offsets)
return seqs
@nb.njit(parallel=True, nogil=True, cache=True)
def _fetch_impl(
c_idxs: NDArray[np.integer],
starts: NDArray[np.integer],
ends: NDArray[np.integer],
reference: NDArray[np.integer],
ref_offsets: NDArray[np.integer],
pad_char: int,
out: NDArray[np.uint8],
out_offsets: NDArray[np.integer],
):
for i in nb.prange(len(c_idxs)):
r_s, r_e = ref_offsets[c_idxs[i]], ref_offsets[c_idxs[i] + 1]
o_s, o_e = out_offsets[i], out_offsets[i + 1]
padded_slice(reference[r_s:r_e], starts[i], ends[i], pad_char, out[o_s:o_e])
return out
T = TypeVar("T", NDArray[np.bytes_], RaggedSeqs)
[docs]
@dataclass(slots=True)
class RefDataset(Generic[T]):
"""A reference dataset for pulling out sequences from a reference genome.
When ``splice_info`` is provided, the dataset returns per-transcript
concatenated reference sequence, with one row per splice group instead of
one row per BED region. Same semantics as
:meth:`Dataset.open(splice_info=...) <genvarloader.Dataset.open>`.
"""
reference: Reference
"""The reference genome."""
full_bed: pl.DataFrame
"""A table of regions to extract from the reference genome. The table must have the following columns:
- `chrom`: The name of the contig (e.g. "chr1", "chr2", etc.)
- `chromStart`: The start position of the region (0-based).
- `chromEnd`: The end position of the region (0-based).
A `strand` column can also be included, in which case the regions will be reverse complemented if the strand is -1
and the `rc_neg` parameter is set to True.
"""
_subset_bed: pl.DataFrame = field(init=False)
_subset_regions: NDArray[np.int32] = field(init=False)
jitter: int = 0
"""The maximum length for randomly shifting start positions."""
output_length: Literal["ragged", "variable"] | int = "ragged"
"""The output length of the dataset. Same meaning as :attr:`Dataset.output_length`."""
deterministic: bool = True
"""If true, fixed length sequences will be right truncated from their full length to the output length.
If false, fixed length sequences will be randomly shifted to be within the output length.
"""
rc_neg: bool = True
"""Whether to reverse complement the regions that are on the negative strand."""
seed: int | np.random.Generator | None = None
_rng: np.random.Generator = field(init=False)
"""A random number generator."""
region_names: str | None = None
"""The name of the column in the full_bed table to use as the region names."""
_region_map: HashTable | None = field(init=False)
splice_info: str | tuple[str, str] | None = None
"""If set, the dataset is spliced. Either the column name with rows already
in splice order or a (group_col, sort_col) pair applied against ``full_bed``."""
_splice_map: SpliceMap | None = field(init=False, default=None)
_spliced_bed: pl.DataFrame | None = field(init=False, default=None)
def __post_init__(self):
if self.full_bed.height == 0:
raise ValueError("Table of regions has a height of zero.")
if self.jitter < 0:
raise ValueError(f"jitter ({self.jitter}) must be a non-negative integer.")
elif self.jitter > (
min_len := self.full_bed.select(
(pl.col("chromEnd") - pl.col("chromStart")).min()
).item()
):
raise ValueError(
f"jitter ({self.jitter}) must be less than the minimum region length ({min_len})."
)
self._subset_bed = self.full_bed
self._subset_regions = bed_to_regions(self.full_bed, self.reference.c_map)
self._rng = np.random.default_rng(self.seed)
if self.region_names is not None:
region_names = self.full_bed[self.region_names].to_numpy().astype(np.str_)
self._region_map = HashTable(
max=len(region_names) * 2, # type: ignore[bad-argument-type] # hirola HashTable.max typed as numpy.Number but accepts int
dtype=region_names.dtype,
)
self._region_map.add(region_names)
else:
self._region_map = None
if self.splice_info is not None:
sm, sp_bed = SpliceMap.from_bed(self.splice_info, self.full_bed)
self._splice_map = sm
self._spliced_bed = sp_bed
self._check_valid_state()
else:
self._splice_map = None
self._spliced_bed = None
def _check_valid_state(self):
if self._splice_map is None:
return
if self.jitter > 0:
raise RuntimeError(
"Jitter is not supported with splicing. Please set jitter to 0."
)
if not self.deterministic:
raise RuntimeError(
"Non-deterministic algorithms are not supported with splicing."
" Please set deterministic to True."
)
if isinstance(self.output_length, int):
raise RuntimeError(
"Splicing requires output_length='ragged' or 'variable',"
" not a fixed integer length."
)
@property
def regions(self) -> pl.DataFrame:
return self._subset_bed
@property
def is_spliced(self) -> bool:
"""Whether the dataset is spliced."""
return self._splice_map is not None
@property
def spliced_regions(self) -> pl.DataFrame:
"""The spliced BED, subset to the current row subset."""
if self._spliced_bed is None or self._splice_map is None:
raise ValueError("Dataset does not have splice information.")
subset = self._splice_map.row_subset_idxs
if subset is None:
return self._spliced_bed
return self._spliced_bed[subset]
@property
def shape(self) -> tuple[int]:
"""Shape of the dataset."""
if self._splice_map is not None:
return (self._splice_map.n_rows,)
return (self.regions.height,)
def __len__(self) -> int:
"""Length of the dataset."""
if self._splice_map is not None:
return self._splice_map.n_rows
return self.regions.height
@overload
def with_len(self, output_length: Literal["ragged"]) -> RefDataset[RaggedSeqs]: ...
@overload
def with_len(
self, output_length: Literal["variable"] | int
) -> RefDataset[NDArray[np.bytes_]]: ...
def with_len(
self, output_length: Literal["ragged", "variable"] | int
) -> RefDataset:
if isinstance(output_length, int):
if output_length < 1:
raise ValueError(
f"Output length ({output_length}) must be a positive integer."
)
min_r_len: int = (
self._subset_regions[:, 2] - self._subset_regions[:, 1]
).min()
max_output_length = min_r_len
eff_length = output_length + 2 * self.jitter
if eff_length > max_output_length:
raise ValueError(
f"Jitter-expanded output length (out_len={self.output_length}) + 2 * ({self.jitter=}) = {eff_length} must be less"
f" than or equal to the maximum output length of the dataset ({max_output_length})."
f" The maximum output length is the minimum region length ({min_r_len})."
)
out = replace(self, output_length=output_length)
out._check_valid_state()
return out
def with_settings(
self,
jitter: int | None = None,
deterministic: bool | None = None,
rc_neg: bool | None = None,
seed: int | np.random.Generator | None = None,
splice_info: str | tuple[str, str] | Literal[False] | None = None,
) -> Self:
to_evolve = {}
if jitter is not None:
if jitter < 0:
raise ValueError(f"jitter ({jitter}) must be a non-negative integer.")
elif (
jitter
> (
min_len := self._subset_regions[:, 2] - self._subset_regions[:, 1]
).min()
):
raise ValueError(
f"jitter ({jitter}) must be less than the minimum region length ({min_len})."
)
to_evolve["jitter"] = jitter
if deterministic is not None:
to_evolve["deterministic"] = deterministic
if rc_neg is not None:
to_evolve["rc_neg"] = rc_neg
if seed is not None:
to_evolve["seed"] = np.random.default_rng(seed)
new_sm = None
new_bed = None
if splice_info is not None:
if splice_info is False:
to_evolve["splice_info"] = None
else:
new_sm, new_bed = SpliceMap.from_bed(splice_info, self.full_bed)
to_evolve["splice_info"] = splice_info
out = replace(self, **to_evolve)
if splice_info is not None:
out._splice_map = new_sm
out._spliced_bed = new_bed
out._check_valid_state()
return out
[docs]
def subset_to(self, regions: StrIdx):
"""Subset the dataset to a subset of regions (or transcripts, when spliced)."""
if self._splice_map is not None:
new_map = self._splice_map.subset_to(regions)
flat = ak.flatten(new_map.splice_map, None).to_numpy()
self._splice_map = new_map
self._subset_bed = self.full_bed[flat]
self._subset_regions = bed_to_regions(
self._subset_bed, self.reference.c_map
)
return self
if self._region_map is not None:
regions = s2i(regions, self._region_map)
elif is_str_arr(regions):
raise ValueError(
"Cannot subset to regions by name because no region name was set."
)
if (
isinstance(regions, (int, np.integer, slice))
or is_dtype(regions, np.integer)
or (isinstance(regions, Sequence) and isinstance(regions[0], int))
):
self._subset_bed = self.full_bed[regions] # type: ignore[bad-index] # polars DataFrame.__getitem__ doesn't accept all our union members but runtime branch ensures valid kinds
else:
self._subset_bed = self.full_bed.filter(regions) # type: ignore[bad-argument-type] # polars filter accepts predicates / bool arrays; our union has equivalent shapes
self._subset_regions = bed_to_regions(self._subset_bed, self.reference.c_map)
return self
[docs]
def to_full_dataset(self) -> Self:
"""Reset the dataset to the full dataset."""
if self._splice_map is not None:
self._splice_map = self._splice_map.to_full()
self._subset_bed = self.full_bed
self._subset_regions = bed_to_regions(self._subset_bed, self.reference.c_map)
return self
def __getitem__(self, idx: Idx) -> T:
if self._splice_map is not None:
return self._getitem_spliced(idx)
return self._getitem_unspliced(idx)
def _getitem_spliced(self, idx: Idx) -> T:
assert self._splice_map is not None
assert not isinstance(self.output_length, int)
flat_r_idx, offsets, out_reshape, squeeze = self._splice_map.parse_rows(idx)
# flat_r_idx values are absolute indices into full_bed (not _subset_regions).
# polars accepts a 1-D numpy integer array directly — no .tolist() needed.
regions = bed_to_regions(self.full_bed[flat_r_idx], self.reference.c_map)
lengths = (regions[:, 2] - regions[:, 1]).astype(np.int32, copy=False)
n_rows = offsets.shape[0] - 1
plan = build_splice_plan(
lengths=lengths,
splice_row_offsets=offsets,
n_samples=1,
n_rows=n_rows,
)
# Delegate kernel dispatch to the shared helper (eliminates duplication
# with Ref.__call__'s splice branch). Returns a per-element _Flat (n_elements, None)
# already in permuted write order.
per_elem = _fetch_spliced_ref(
regions=regions,
plan=plan,
reference=self.reference.reference,
ref_offsets=self.reference.offsets,
pad_char=self.reference.pad_char,
)
if self.rc_neg:
to_rc_unperm = regions[:, 3] == -1
if to_rc_unperm.any():
from .._ragged import _COMP
to_rc_perm = to_rc_unperm[plan.permutation]
per_elem = per_elem.reverse_masked(to_rc_perm, comp=_COMP)
# Rewrap with group_offsets at (n_rows, None) — skip the (n_rows, 1, None)
# + squeeze(1) trick since RefDataset has no sample axis.
ref = cast(
Ragged[np.bytes_],
_Flat.from_offsets(
per_elem.data, (n_rows, None), plan.group_offsets
).to_ragged(),
)
if out_reshape is not None:
ref = ref.reshape(out_reshape)
if self.output_length == "ragged":
out = ref
elif self.output_length == "variable":
out = to_padded(ref, pad_value=bytes([self.reference.pad_char]))
else:
raise AssertionError(
"splice + fixed-length output should be blocked earlier"
)
if squeeze:
out = out.squeeze(0)
return cast(T, out)
def _getitem_unspliced(self, idx: Idx) -> T:
# (... 4)
regions = self._subset_regions[idx].copy()
out_reshape = None
squeeze = False
if regions.ndim > 2:
out_reshape = regions.shape[:-1]
elif regions.ndim == 1:
squeeze = True
regions = regions.reshape(-1, 4)
batch_size = len(regions)
lengths = regions[:, 2] - regions[:, 1]
if isinstance(self.output_length, int):
# (b)
out_lengths = np.full(batch_size, self.output_length, dtype=np.int32)
else:
out_lengths = lengths
# (b)
if self.deterministic:
extra_len = np.full(batch_size, 0)
else:
extra_len = (lengths - out_lengths).clip(min=0)
max_shift = extra_len + 2 * self.jitter
shifts = self._rng.integers(0, max_shift + 1, dtype=np.int32)
regions[:, 1] += shifts - self.jitter
regions[:, 2] = regions[:, 1] + out_lengths
# (b+1)
out_offsets = lengths_to_offsets(out_lengths)
# ragged (b ~l)
ref = get_reference(
regions=regions,
out_offsets=out_offsets,
reference=self.reference.reference,
ref_offsets=self.reference.offsets,
pad_char=self.reference.pad_char,
).view("S1")
ref = cast(
Ragged[np.bytes_], Ragged.from_offsets(ref, (batch_size, None), out_offsets)
)
to_rc = regions[:, 3] == -1
if to_rc.any():
ref = reverse_complement_masked(ref, to_rc)
if out_reshape is not None:
ref = ref.reshape(out_reshape)
if self.output_length == "ragged":
out = ref
elif self.output_length == "variable":
out = to_padded(ref, pad_value=bytes([self.reference.pad_char]))
else:
out = ref.to_numpy()
if squeeze:
out = out.squeeze(0)
return cast(T, out)
[docs]
def to_torch_dataset(
self, return_indices: bool = False, transform: Callable | None = None
) -> TorchDataset:
"""Convert the dataset to a PyTorch dataset.
Parameters
----------
return_indices
If True, the dataset will return the indices of the regions in the reference genome.
transform
A function to transform the data. Should accept a numpy array of S1 with shape (batch_size, length).
If return_indices is true, the function should accept a tuple of (sequences, indices).
"""
if self.output_length == "ragged":
raise ValueError(
"Cannot convert to PyTorch dataset with ragged output length."
)
self = cast(RefDataset[NDArray[np.bytes_]], self)
return TorchDataset(self, include_indices=return_indices, transform=transform)
[docs]
def to_dataloader(
self,
batch_size: int = 1,
shuffle: bool = False,
sampler: td.Sampler | Iterable | None = None,
num_workers: int = 0,
collate_fn: Callable | None = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Callable | None = None,
multiprocessing_context: Callable | None = None,
generator: torch.Generator | None = None,
*,
prefetch_factor: int | None = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
return_indices: bool = False,
transform: Callable | None = None,
) -> td.DataLoader:
"""Convert the dataset to a PyTorch :class:`DataLoader <torch.utils.data.DataLoader>`. The parameters are the same as a
:class:`DataLoader <torch.utils.data.DataLoader>` with a few omissions e.g. :code:`batch_sampler`.
Requires PyTorch to be installed.
Parameters
----------
batch_size
How many samples per batch to load.
shuffle
Set to True to have the data reshuffled at every epoch.
sampler
Defines the strategy to draw samples from the dataset. Can be any :py:class:`Iterable <typing.Iterable>` with :code:`__len__` implemented. If specified, shuffle must not be specified.
.. important::
Do not provide a :class:`BatchSampler <torch.utils.data.BatchSampler>` here. GVL Datasets use multithreading when indexed with batches of indices to avoid the overhead of multi-processing.
To leverage this, GVL will automatically wrap the :code:`sampler` with a :class:`BatchSampler <torch.utils.data.BatchSampler>`
so that lists of indices are given to the GVL Dataset instead of one index at a time. See `this post <https://discuss.pytorch.org/t/dataloader-sample-by-slices-from-dataset/113005>`_
for more information.
num_workers
How many subprocesses to use for dataloading. :code:`0` means that the data will be loaded in the main process.
.. tip::
For GenVarLoader, it is generally best to set this to 0 or 1 since almost everything in
GVL is multithreaded. However, if you are using a transform that is compute intensive and single threaded, there may
be a benefit to setting this > 1.
collate_fn
Merges a list of samples to form a mini-batch of Tensor(s).
pin_memory
If :code:`True`, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your :code:`collate_fn` returns a batch that is a custom type, see the example below.
drop_last
Set to :code:`True` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If :code:`False` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
timeout
If positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn
If not :code:`None`, this will be called on each worker subprocess with the worker id (an int in :code:`[0, num_workers - 1]`) as input, after seeding and before data loading.
multiprocessing_context
If :code:`None`, the default multiprocessing context of your operating system will be used.
generator
If not :code:`None`, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate :code:`base_seed` for workers.
prefetch_factor
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
If :code:`True`, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive.
pin_memory_device
The device to :code:`pin_memory` to if :code:`pin_memory` is :code:`True`.
return_indices
If True, the dataset will return the indices of the regions in the reference genome.
transform
A function to transform the data. Should accept a numpy array of S1 with shape (batch_size, length).
If return_indices is true, the function should accept a tuple of (sequences, indices).
"""
return get_dataloader(
dataset=self.to_torch_dataset(return_indices, transform),
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
@nb.njit(parallel=True, nogil=True, cache=True)
def get_reference(
regions: NDArray[np.integer],
out_offsets: NDArray[np.integer],
reference: NDArray[np.integer],
ref_offsets: NDArray[np.integer],
pad_char: int,
) -> NDArray[np.uint8]:
out = np.empty(out_offsets[-1], np.uint8)
for i in nb.prange(len(regions)):
o_s, o_e = out_offsets[i], out_offsets[i + 1]
c_idx, start, end = regions[i, :3]
c_s = ref_offsets[c_idx]
c_e = ref_offsets[c_idx + 1]
padded_slice(reference[c_s:c_e], start, end, pad_char, out[o_s:o_e])
return out
def _fetch_spliced_ref(
regions: NDArray[np.integer],
plan: SplicePlan,
reference: NDArray[np.uint8],
ref_offsets: NDArray[np.int64],
pad_char: int,
) -> "_Flat[np.bytes_]":
"""Fetch reference bytes in splice-permuted order, returning a per-element
flat ragged of shape ``(n_elements, None)``.
This is the kernel-dispatch core shared by :class:`Ref.__call__`'s splice
branch and :meth:`RefDataset._getitem_spliced`.
"""
permuted_regions = regions[plan.permutation]
raw = get_reference(
regions=permuted_regions,
out_offsets=plan.permuted_out_offsets,
reference=reference,
ref_offsets=ref_offsets,
pad_char=pad_char,
) # uint8 flat buffer
n_elements = plan.permuted_lengths.shape[0]
return cast(
"_Flat[np.bytes_]",
_Flat.from_offsets(raw, (n_elements, None), plan.permuted_out_offsets).view(
"S1"
),
)
if TORCH_AVAILABLE:
import torch
import torch.utils.data as td
class TorchDataset(td.Dataset):
dataset: RefDataset[NDArray[np.bytes_]]
include_indices: bool
transform: Callable | None
def __init__(
self,
dataset: RefDataset[NDArray[np.bytes_]],
include_indices: bool,
transform: Callable | None,
):
self.dataset = dataset
self.include_indices = include_indices
self.transform = transform
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: list[int]):
batch = (self.dataset[idx],)
if self.include_indices:
_idx = np.atleast_1d(idx)
batch = (*batch, _idx)
if self.transform is not None:
batch = self.transform(*batch)
if len(batch) == 1:
batch = batch[0]
return batch
else:
TorchDataset = no_torch_error