from __future__ import annotations
from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Generic, Literal, NoReturn, TypeVar, overload
import awkward as ak
import numpy as np
import polars as pl
import seqpro as sp
from loguru import logger
from numpy.typing import NDArray
from seqpro.rag import Ragged
from typing_extensions import Self, assert_never
from .._ragged import (
INTERVAL_DTYPE,
RaggedAnnotatedHaps,
RaggedIntervals,
RaggedSeqs,
RaggedTracks,
)
from .._torch import TORCH_AVAILABLE, TorchDataset, get_dataloader
from .._types import AnnotatedHaps, Idx, StrIdx
from .._utils import (
lengths_to_offsets,
normalize_contig_name,
)
from ._indexing import DatasetIndexer, SpliceIndexer, is_str_arr
from ._insertion_fill import InsertionFill
from ._rag_variants import RaggedVariants
from ._reconstruct import (
Haps,
HapsTracks,
Ref,
RefTracks,
Tracks,
TrackType,
_build_reconstructor,
)
from ._reference import Reference
from ._splice import SpliceMap
from ._utils import regions_to_bed
if TORCH_AVAILABLE:
import torch
import torch.utils.data as td
_py_open = open
[docs]
@dataclass(slots=True, frozen=True)
class Dataset:
"""A dataset of genotypes, reference sequences, and intervals.
.. note::
This class is not meant to be instantiated directly. Use the :py:meth:`Dataset.open() <genvarloader.Dataset.open()>`
method to open a dataset after writing the data with :py:func:`genvarloader.write()` or the GenVarLoader CLI.
GVL Datasets act like a collection of lazy ragged arrays that can be lazily subset or eagerly indexed as a 2D NumPy array. They
have an effective shape of :code:`(n_regions, n_samples, [tracks], [ploidy], output_length)`, but only the region and sample
dimensions can be indexed directly since the return value is generally a tuple of arrays.
**Eager indexing**
.. code-block:: python
dataset[0, 9] # first region, 10th sample
dataset[:10] # first 10 regions and all samples
dataset[:10, :5] # first 10 regions and 5 samples
dataset[[2, 2], [0, 1]] # 3rd region, 1st and 2nd samples
**Lazy indexing**
See :meth:`Dataset.subset_to() <Dataset.subset_to()>`. This is useful, for example, to create
splits for training, validation, and testing, or filter out regions or samples after writing a full dataset.
This is also necessary if you intend to create a Pytorch :class:`DataLoader <torch.utils.data.DataLoader>`
from the Dataset using :meth:`Dataset.to_dataloader() <Dataset.to_dataloader()>`.
**Return values**
The return value depends on the :code:`Dataset` state, namely :attr:`sequence_type <Dataset.sequence_type>`,
:attr:`active_tracks <Dataset.active_tracks>`, and :attr:`output_length <Dataset.output_length>`.
These can all be modified after opening a :code:`Dataset` using the following methods:
- :meth:`Dataset.with_seqs() <Dataset.with_seqs()>`
- :meth:`Dataset.with_tracks() <Dataset.with_tracks()>`
- :meth:`Dataset.with_len() <Dataset.with_len()>`
"""
@staticmethod
@overload
def open(
path: str | Path,
reference: None = ...,
jitter: int = 0,
rng: int | np.random.Generator | None = False,
deterministic: bool = True,
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
var_fields: list[str] | None = None,
region_names: str | None = None,
splice_info: str | tuple[str, str] | None = None,
var_filter: Literal["exonic"] | None = None,
*,
svar: str | Path | None = None,
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ...
@staticmethod
@overload
def open(
path: str | Path,
reference: str | Path | Reference,
jitter: int = 0,
rng: int | np.random.Generator | None = False,
deterministic: bool = True,
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
var_fields: list[str] | None = None,
region_names: str | None = None,
splice_info: str | tuple[str, str] | None = None,
var_filter: Literal["exonic"] | None = None,
*,
svar: str | Path | None = None,
) -> RaggedDataset[RaggedSeqs, MaybeRTRK]: ...
[docs]
@staticmethod
def open(
path: str | Path,
reference: str | Path | Reference | None = None,
jitter: int = 0,
rng: int | np.random.Generator | None = False,
deterministic: bool = True,
rc_neg: bool = True,
min_af: float | None = None,
max_af: float | None = None,
var_fields: list[str] | None = None,
region_names: str | None = None,
splice_info: str | tuple[str, str] | None = None,
var_filter: Literal["exonic"] | None = None,
*,
svar: str | Path | None = None,
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]:
"""Open a dataset from a path. If no reference genome is provided, the dataset cannot yield sequences.
Will initialize the dataset such that it will return tracks and haplotypes (reference sequences if no genotypes) if possible.
If tracks are available, they will be set to be returned in alphabetical order.
Parameters
----------
path
Path to a dataset.
reference
Path to a reference genome.
jitter
Amount of jitter to use, cannot be more than the maximum jitter of the dataset.
rng
Random seed or np.random.Generator for any stochastic operations.
deterministic
Whether to use randomized or deterministic algorithms. If set to True, this will disable random
shifting of longer-than-requested haplotypes.
rc_neg
Whether to reverse-complement sequences and reverse tracks on negative strands.
min_af
The minimum allele frequency to include in the dataset. If dataset is not backed by SVAR genotypes, this will raise an error.
max_af
The maximum allele frequency to include in the dataset. If dataset is not backed by SVAR genotypes, this will raise an error.
var_fields
The variant fields to include in the dataset. Defaults to the
minimum useful set ``["alt", "ilen", "start"]``. Pass additional
field names (e.g. ``"ref"``, ``"dosage"``, or any info column
present in the source variants table) to load them eagerly at open
time. Must be a subset of :attr:`available_var_fields`.
splice_info
A string or tuple of strings representing the splice information to use.
If a string, it will be used as the transcript ID and the exons are expected to be in order.
If a tuple of strings, the first string will be used as the transcript ID and the second string will be used as the exon number.
If a dictionary, the keys will be used as the transcript ID and the values should be the row number for each exon, in order.
If False, splicing will be disabled.
var_filter
Whether to filter variants. If set to :code:`"exonic"`, only exonic variants will be applied.
svar
Override the recorded SVAR location. Use when the original SVAR has
moved and the dataset cannot find it via the stored relative/absolute
path or by sibling discovery.
"""
from ._open import OpenRequest
return OpenRequest(
path=Path(path),
reference=reference,
jitter=jitter,
rng=rng,
deterministic=deterministic,
rc_neg=rc_neg,
min_af=min_af,
max_af=max_af,
var_fields=var_fields,
region_names=region_names,
splice_info=splice_info,
var_filter=var_filter,
svar=svar,
).resolve()
[docs]
def with_settings(
self,
jitter: int | None = None,
rng: int | np.random.Generator | None = None,
deterministic: bool | None = None,
rc_neg: bool | None = None,
min_af: float | Literal[False] | None = None,
max_af: float | Literal[False] | None = None,
var_fields: list[str] | None = None,
splice_info: str | tuple[str, str] | Literal[False] | None = None,
var_filter: Literal[False, "exonic"] | None = None,
) -> Self:
"""Modify settings of the dataset, returning a new dataset without modifying the old one.
Parameters
----------
jitter
How much jitter to use. Must be non-negative and <= the :attr:`max_jitter <genvarloader.Dataset.max_jitter>` of the dataset.
rng
Random seed or np.random.Generator for non-deterministic operations e.g. jittering and shifting longer-than-requested haplotypes.
deterministic
Whether to use randomized or deterministic algorithms. If set to True, this will disable random
shifting of longer-than-requested haplotypes and, for unphased variants, will enable deterministic variant assignment
and always apply the highest CCF group. Note that for unphased variants, this will mean not all possible haplotypes
can be returned.
rc_neg
Whether to reverse-complement sequences and reverse tracks on negative strands.
min_af
The minimum allele frequency to include in the dataset. If set to :code:`False`, disables this filter.
If dataset is not backed by SVAR genotypes, this will raise an error.
max_af
The maximum allele frequency to include in the dataset. If set to :code:`False`, disables this filter.
If dataset is not backed by SVAR genotypes, this will raise an error.
var_fields
The variant fields to include in the dataset.
splice_info
A string or tuple of strings representing the splice information to use.
If a string, it will be used as the transcript ID and the exons are expected to be in order.
If a tuple of strings, the first string will be used as the transcript ID and the second string will be used as the exon number.
If a dictionary, the keys will be used as the transcript ID and the values should be the row number for each exon, in order.
If False, splicing will be disabled.
var_filter
Whether to filter variants. If set to :code:`"exonic"`, only exonic variants will be applied.
"""
to_evolve = {}
if jitter is not None:
if jitter != self.jitter:
if isinstance(self.output_length, int):
min_r_len: int = (
self._full_regions[:, 2] - self._full_regions[:, 1]
).min()
max_output_length = min_r_len + 2 * self.max_jitter
eff_length = self.output_length + 2 * jitter
if eff_length > max_output_length:
raise ValueError(
f"Jitter-expanded output length (out_len={self.output_length}) + 2 * ({jitter=}) = {eff_length} must be less"
f" than or equal to the maximum output length of the dataset ({max_output_length})."
f" The maximum output length is the minimum region length ({min_r_len}) + 2 * (max_jitter={self.max_jitter})."
)
to_evolve["jitter"] = jitter
if rng is not None:
to_evolve["_rng"] = np.random.default_rng(rng)
if deterministic is not None:
to_evolve["deterministic"] = deterministic
if rc_neg is not None:
to_evolve["rc_neg"] = rc_neg
if min_af is not None or max_af is not None:
if not isinstance(self._seqs, Haps):
raise ValueError("Dataset has no genotypes to filter.")
if min_af is None:
min_af = self._seqs.min_af
elif min_af is False:
min_af = None
if max_af is None:
max_af = self._seqs.max_af
elif max_af is False:
max_af = None
haps = to_evolve.get("_seqs", self._seqs)
haps = replace(haps, min_af=min_af, max_af=max_af)
to_evolve["_seqs"] = haps
if var_fields is not None:
missing = list(set(var_fields) - set(self.available_var_fields))
if missing or not isinstance(self._seqs, Haps):
raise ValueError(f"Missing variant fields: {missing}")
haps = to_evolve.get("_seqs", self._seqs)
# Lazily load any newly-requested info columns into the existing
# _Variants struct (mutates haps.variants.info in place).
builtin = {"alt", "ilen", "start", "ref", "dosage"}
new_info_fields = [
f
for f in var_fields
if f not in builtin and f not in haps.variants.info
]
if new_info_fields:
haps.variants.load_info(new_info_fields)
# Lazily memmap dosages if newly requested.
if "dosage" in var_fields and haps.dosages is None:
haps = _lazy_load_dosages(self, haps)
haps = replace(haps, var_fields=var_fields)
to_evolve["_seqs"] = haps
if splice_info is not None:
if splice_info is False:
splice_idxer = None
spliced_bed = None
else:
sm, spliced_bed = SpliceMap.from_bed(splice_info, self._full_bed)
if (
ak.max(sm.splice_map, None) >= self._idxer.n_regions
or ak.min(sm.splice_map, None) < -self._idxer.n_regions
):
raise ValueError(
"Found indices in the splice map that are out of bounds for the dataset."
)
splice_idxer = SpliceIndexer(map=sm, dsi=self._idxer)
to_evolve["_sp_idxer"] = splice_idxer
to_evolve["_spliced_bed"] = spliced_bed
if var_filter is not None:
if not isinstance(self._seqs, Haps):
raise ValueError(
"Filtering variants can only be done when the dataset has variants."
)
if var_filter is False:
var_filter = None
if var_filter != self._seqs.filter:
haps = to_evolve.get("_seqs", self._seqs)
to_evolve["_seqs"] = replace(haps, filter=var_filter)
# If any source state changed, rebuild _recon via the factory.
if "_seqs" in to_evolve or "_tracks" in to_evolve:
new_seqs = to_evolve.get("_seqs", self._seqs)
new_tracks = to_evolve.get("_tracks", self._tracks)
to_evolve["_recon"] = _build_reconstructor(
new_seqs, new_tracks, self._seqs_kind
)
self = replace(self, **to_evolve)
self._check_valid_state()
return self
def _check_valid_state(self):
if self.is_spliced:
if self.jitter > 0:
raise RuntimeError(
"Jitter is not supported with splicing. Please set jitter to 0."
)
if not self.deterministic:
raise RuntimeError(
"Non-deterministic algorithms are not supported with splicing. Please set deterministic to True."
)
if self.sequence_type == "variants":
raise ValueError("Splicing is not supported with variants.")
if self.jitter < 0:
raise ValueError(f"Jitter ({self.jitter}) must be a non-negative integer.")
elif self.jitter > self.max_jitter:
raise ValueError(
f"Jitter ({self.jitter}) must be less than or equal to the maximum jitter of the dataset ({self.max_jitter})."
)
if isinstance(self.output_length, int):
if self.sequence_type == "variants":
raise ValueError(
"Output length must be ragged when the sequence type is variants."
)
if self.output_length < 1:
raise ValueError(
f"Output length ({self.output_length}) must be a positive integer."
)
min_r_len: int = (self._full_regions[:, 2] - self._full_regions[:, 1]).min()
max_output_length = min_r_len + 2 * self.max_jitter
eff_length = self.output_length + 2 * self.jitter
if eff_length > max_output_length:
raise ValueError(
f"Effective length (out_len={self.output_length}) + 2 * ({self.jitter=}) = {eff_length} must be less"
f" than or equal to the maximum output length of the dataset ({max_output_length})."
f" The maximum output length is the minimum region length ({min_r_len}) + 2 * (max_jitter={self.max_jitter})."
)
elif self.output_length == "variable" and self.sequence_type == "variants":
raise ValueError(
"Output length must be ragged when the sequence type is variants."
)
[docs]
def with_len(
self, output_length: Literal["ragged", "variable"] | int
) -> ArrayDataset | RaggedDataset:
"""Modify the output length of the dataset, returning a new dataset without modifying the old one.
Parameters
----------
output_length
The output length. Can be set to :code:`"ragged"` or :code:`"variable"` to allow for variable length sequences.
If set to an integer, all sequences will be padded or truncated to this length. See the
`online documentation <https://genvarloader.readthedocs.io/en/latest/dataset.html>`_ for more information.
"""
if isinstance(output_length, int) or output_length == "variable":
if isinstance(output_length, int):
if output_length < 1:
raise ValueError(
f"Output length ({output_length}) must be a positive integer."
)
min_r_len: int = (
self._full_regions[:, 2] - self._full_regions[:, 1]
).min()
max_output_length = min_r_len + 2 * self.max_jitter
eff_length = output_length + 2 * self.jitter
if eff_length > max_output_length:
raise ValueError(
f"Jitter-expanded output length (out_len={self.output_length}) + 2 * ({self.jitter=}) = {eff_length} must be less"
f" than or equal to the maximum output length of the dataset ({max_output_length})."
f" The maximum output length is the minimum region length ({min_r_len}) + 2 * (max_jitter={self.max_jitter})."
)
return ArrayDataset(
path=self.path,
output_length=output_length,
max_jitter=self.max_jitter,
jitter=self.jitter,
contigs=self.contigs,
return_indices=self.return_indices,
rc_neg=self.rc_neg,
deterministic=self.deterministic,
_idxer=self._idxer,
_sp_idxer=self._sp_idxer,
_full_bed=self._full_bed,
_spliced_bed=self._spliced_bed,
_full_regions=self._full_regions,
_seqs=self._seqs,
_tracks=self._tracks,
_seqs_kind=self._seqs_kind,
_recon=self._recon,
_rng=self._rng,
)
else:
out = RaggedDataset(
path=self.path,
output_length=output_length,
max_jitter=self.max_jitter,
jitter=self.jitter,
contigs=self.contigs,
return_indices=self.return_indices,
rc_neg=self.rc_neg,
deterministic=self.deterministic,
_idxer=self._idxer,
_sp_idxer=self._sp_idxer,
_full_bed=self._full_bed,
_spliced_bed=self._spliced_bed,
_full_regions=self._full_regions,
_seqs=self._seqs,
_tracks=self._tracks,
_seqs_kind=self._seqs_kind,
_recon=self._recon,
_rng=self._rng,
)
out._check_valid_state()
return out
[docs]
def with_seqs(
self, kind: Literal["reference", "haplotypes", "annotated", "variants"] | None
):
"""Return a new dataset with the specified sequence type. The sequence type can be one of the following:
- :code:`"reference"`: reference sequences.
- :code:`"haplotypes"`: personalized haplotype sequences.
- :code:`"annotated"`: annotated haplotype sequences, which includes personalized haplotypes along with annotations.
- :code:`"variants"`: no sequences, just variants as :class:`RaggedVariants`
Annotated haplotypes are returned as the :class:`~genvarloader._types.AnnotatedHaps` class which is roughly:
.. code-block:: python
class AnnotatedHaps:
haps: NDArray[np.bytes_]
var_idxs: NDArray[np.int32]
ref_coords: NDArray[np.int32]
where :code:`haps` are the haplotypes as bytes/S1, and :code:`var_idxs` and :code:`ref_coords` are
arrays with the same shape as :code:`haps` that annotate every nucleotide with the variant index and
reference coordinate it corresponds to. A variant index of -1 corresponds to a reference nucleotide, and a reference
coordinate of -1 corresponds to padded nucleotides that were added for regions beyond the bounds of the reference genome.
i.e. if the region's start position is negative or the end position is beyond the end of the reference genome.
For example, a toy result for :code:`chr1:1-10` could be:
.. code-block:: text
haps: A C G T ... T T A ...
var_idxs: -1 3 3 -1 ... -1 4 -1 ...
ref_coords: 1 2 2 3 ... 6 7 9 ...
where variant 3 is a 1 bp :code:`CG` insertion and variant 4 is a 1 bp deletion :code:`T-`. Note that the first nucleotide
of every indel maps to a reference position since :func:`gvl.write() <genvarloader.write()>` expects that variants
are all left-aligned.
.. important::
The :code:`var_idxs` are numbered with respect to the full set of variants even if the variants were extracted from per-chromosome VCFs/PGENs.
So a variant index of 0 corresponds to the first variant across all chromosomes. Thus, if you want to map the variant index to per-chromosome VCFs/PGENs, you will
need to subtract the number of variants on all other chromosomes before the variant index to get the correct variant index in the VCF/PGEN. Relevant values
can be obtained by instantiating a `gvl.Variants` class from the VCFs/PGENs and accessing the `Variants.records.contig_offsets` attribute.
If the Dataset's output length is :code:`"ragged"`, then annotated haplotypes will be :class:`~genvarloader._ragged.RaggedAnnotatedHaps` where each
field is a Ragged array instead of NumPy arrays.
Parameters
----------
kind
The type of sequences to return. Can be one of :code:`"reference"`, :code:`"haplotypes"`, :code:`"annotated"`, :code:`"variants"`, or :code:`None`
to return no sequences.
"""
# Validate the requested kind against storage state.
if kind is None:
tracks_active = self._tracks is not None and bool(
self._tracks.active_tracks
)
if not tracks_active:
raise RuntimeError(
"Dataset is set to only return sequences, so setting sequence_type to None would"
" result in a Dataset that cannot return anything."
)
elif kind == "reference":
if not isinstance(self._seqs, (Haps, Ref)):
raise ValueError("Dataset has no reference to yield sequences from.")
if self._seqs.reference is None:
raise ValueError(
"Dataset has no reference genome to reconstruct sequences from."
)
elif kind in ("haplotypes", "annotated", "variants"):
if not isinstance(self._seqs, Haps):
raise ValueError(
"Dataset has no genotypes to yield haplotypes/variants from."
)
else:
assert_never(kind)
new_recon = _build_reconstructor(self._seqs, self._tracks, kind)
return replace(self, _seqs_kind=kind, _recon=new_recon)
[docs]
def with_tracks(
self,
tracks: str | list[str] | Literal[False] | None = None,
kind: Literal["tracks", "intervals"] | None = None,
):
"""Modify which tracks to return, returning a new dataset without modifying the old one.
Parameters
----------
tracks
The tracks to return. Can be a (list of) track names or :code:`False` to return no tracks."""
if self._tracks is None:
logger.warning("Dataset has no tracks, so this method has no effect.")
return self
if tracks is None:
tracks = False if self.active_tracks is None else self.active_tracks
if kind == "tracks":
_kind = RaggedTracks
elif kind == "intervals":
_kind = RaggedIntervals
elif kind is None:
_kind = self._tracks.kind
else:
assert_never(kind)
# Compute the new tracks state (active set + kind).
if tracks is False:
# User-deactivate all tracks.
new_tracks = self._tracks.with_tracks(None)
elif isinstance(tracks, str):
new_tracks = self._tracks.with_tracks([tracks]).to_kind(
_kind, # type: ignore[bad-argument-type] # _kind is broader union; runtime branch ensures correct subtype
)
else:
new_tracks = self._tracks.with_tracks(tracks).to_kind(
_kind, # type: ignore[bad-argument-type] # _kind is broader union; runtime branch ensures correct subtype
)
# Validate: at least one of (seqs, tracks) must remain active.
seqs_active = self._seqs_kind is not None and self._seqs is not None
tracks_active = bool(new_tracks.active_tracks)
if not seqs_active and not tracks_active:
raise RuntimeError(
"Dataset is set to only return tracks, so setting tracks to None would"
" result in a Dataset that cannot return anything."
)
new_recon = _build_reconstructor(self._seqs, new_tracks, self._seqs_kind)
return replace(self, _tracks=new_tracks, _recon=new_recon)
[docs]
def with_insertion_fill(
self,
fill: InsertionFill | Mapping[str, InsertionFill],
) -> Self:
"""Configure how track values are filled at insertion sites.
Only meaningful when the dataset returns haplotypes *and* tracks (i.e.
when the reconstructor is :class:`HapsTracks`). Pure-reference and
pure-haplotype datasets have no insertion fill to configure.
Parameters
----------
fill
Either a single :class:`InsertionFill` strategy applied to every
active track, or a dict mapping track name to strategy. Tracks not
in the dict fall back to :class:`Repeat5p`.
"""
if self._tracks is None:
raise ValueError("Dataset has no tracks; cannot configure insertion fill.")
if self._seqs_kind not in ("haplotypes", "annotated", "variants"):
raise ValueError(
"with_insertion_fill is only meaningful for datasets with both "
"haplotypes and tracks (use with_seqs to activate haplotypes first)."
)
if not self._tracks.active_tracks:
raise ValueError(
"with_insertion_fill is only meaningful when tracks are active "
"(use with_tracks to activate tracks first)."
)
new_tracks = self._tracks.with_insertion_fill(fill)
new_recon = _build_reconstructor(self._seqs, new_tracks, self._seqs_kind)
return replace(self, _tracks=new_tracks, _recon=new_recon)
path: Path
"""Path to the dataset."""
output_length: Literal["ragged", "variable"] | int
"""The output length. Can be set to :code:`"ragged"` or :code:`"variable"` to allow for variable length sequences.
If set to an integer, all sequences will be padded or truncated to this length. See the
`online documentation <https://genvarloader.readthedocs.io/en/latest/dataset.html>`_ for more information."""
max_jitter: int
"""Maximum jitter."""
return_indices: bool
"""Whether to return row and sample indices corresponding to the full dataset (no subsetting)."""
contigs: list[str]
"""List of unique contigs."""
jitter: int
"""How much jitter to use."""
deterministic: bool
"""Whether to use randomized or deterministic algorithms. If set to :code:`False`, this will enable random
shifting of longer-than-requested haplotypes and, for unphased variants, enable choosing sets of compatible variants proportional to their CCF;
otherwise the dataset will always apply compatible sets with the highest CCF.
.. note::
This setting is independent of :attr:`~Dataset.jitter`, if you want no :attr:`~Dataset.jitter` you should set it to 0.
"""
rc_neg: bool
"""Whether to reverse-complement the sequences on negative strands."""
_full_bed: pl.DataFrame
_spliced_bed: pl.DataFrame | None
_full_regions: NDArray[np.int32]
"""Unjittered, sorted regions matching order on-disk."""
_idxer: DatasetIndexer
_sp_idxer: SpliceIndexer | None
_seqs: (
Ref | Haps[RaggedSeqs] | Haps[RaggedAnnotatedHaps] | Haps[RaggedVariants] | None
)
_tracks: Tracks[RaggedTracks] | Tracks[RaggedIntervals] | None
_seqs_kind: Literal["haplotypes", "reference", "annotated", "variants"] | None
_recon: (
Ref
| Haps[RaggedSeqs]
| Haps[RaggedAnnotatedHaps]
| Haps[RaggedVariants]
| Tracks
| RefTracks
| HapsTracks[RaggedSeqs, RaggedTracks]
| HapsTracks[RaggedAnnotatedHaps, RaggedTracks]
| HapsTracks[RaggedVariants, RaggedTracks]
| HapsTracks[RaggedSeqs, RaggedIntervals]
| HapsTracks[RaggedAnnotatedHaps, RaggedIntervals]
| HapsTracks[RaggedVariants, RaggedIntervals]
)
_rng: np.random.Generator
@property
def is_subset(self) -> bool:
"""Whether the dataset is a subset."""
return self._idxer.is_subset
@property
def is_spliced(self) -> bool:
"""Whether the dataset is spliced."""
return self._sp_idxer is not None
@property
def has_reference(self) -> bool:
"""Whether the dataset was provided a reference genome."""
return self._seqs is not None
@property
def reference(self) -> Reference | None:
"""The reference genome."""
if self._seqs is None:
return None
return self._seqs.reference
@property
def has_genotypes(self):
"""Whether the dataset has genotypes."""
return isinstance(self._seqs, Haps)
@property
def has_intervals(self) -> bool:
"""Whether the dataset has intervals."""
return self._tracks is not None
@property
def samples(self) -> list[str]:
"""The samples in the dataset."""
return self._idxer.samples
@property
def regions(self) -> pl.DataFrame:
"""The input regions in the dataset as they were provided to :func:`gvl.write() <genvarloader.write()>` i.e. with all BED columns plus any
extra columns that were present."""
if self._idxer.region_subset_idxs is None:
return self._full_bed
return self._full_bed[self._idxer.region_subset_idxs]
@property
def n_regions(self) -> int:
"""The number of (spliced) regions in the dataset."""
return self.shape[0]
@property
def spliced_regions(self) -> pl.DataFrame | None:
"""The spliced regions in the dataset."""
if self._spliced_bed is None or self._sp_idxer is None:
raise ValueError("Dataset does not have splice information.")
if self._sp_idxer.row_subset_idxs is None:
return self._spliced_bed
else:
return self._spliced_bed[self._sp_idxer.row_subset_idxs]
@property
def n_samples(self) -> int:
"""The number of samples in the dataset."""
return self._idxer.n_samples
@property
def ploidy(self) -> int | None:
"""The ploidy of the dataset."""
if isinstance(self._seqs, Haps):
return self._seqs.genotypes.shape[-2]
@property
def shape(self) -> tuple[int, int]:
"""Return the shape of the dataset. :code:`(n_rows, n_samples)`"""
if self._sp_idxer is None:
return self._idxer.shape
else:
return self._sp_idxer.shape
@property
def full_shape(self) -> tuple[int, int]:
"""Return the full shape of the dataset, ignoring any subsetting. :code:`(n_rows, n_samples)`"""
if self._sp_idxer is None:
return self._idxer.full_shape
else:
return self._sp_idxer.full_shape
@property
def available_var_fields(self) -> list[str]:
"""Available variant fields."""
match self._seqs:
case Haps():
return self._seqs.available_var_fields
case _:
return []
@property
def active_var_fields(self) -> list[str]:
"""Active variant fields."""
match self._recon:
case (Haps() as haps) | HapsTracks(haps=haps):
return haps.var_fields
case _:
return []
@property
def available_tracks(self) -> list[str] | None:
"""The available tracks in the dataset."""
if self._tracks is None:
return
return list(self._tracks.intervals)
@property
def active_tracks(self) -> list[str] | None:
"""The active tracks in the dataset."""
if self._tracks is None:
return
return list(self._tracks.active_tracks)
@property
def _available_sequences(self) -> list[str] | None:
"""The available sequences in the dataset."""
match self._seqs:
case None:
return None
case Ref():
return ["reference"]
case Haps():
return ["reference", "haplotypes", "annotated", "variants"]
case s:
assert_never(s)
@property
def sequence_type(
self,
) -> Literal["haplotypes", "reference", "annotated", "variants"] | None:
"""The type of sequences in the dataset."""
return self._seqs_kind
def __len__(self):
return self.n_regions * self.n_samples
def __str__(self) -> str:
splice_status = "Spliced" if self.is_spliced else "Unspliced"
if self._available_sequences is None:
seq_type = None
else:
seqs = self._available_sequences
if self.sequence_type is not None:
seqs[seqs.index(self.sequence_type)] = f"[{self.sequence_type}]"
seq_type = " ".join(seqs)
if self.available_tracks is None:
tracks = None
else:
tracks = f"{', '.join(self.available_tracks[:5])}"
if len(self.available_tracks) > 5:
tracks += f" + {len(self.available_tracks) - 5} more"
if self.active_tracks is None:
act_tracks = None
else:
act_tracks = f"{', '.join(self.active_tracks[:5])}"
if len(self.active_tracks) > 5:
act_tracks += f" + {len(self.active_tracks) - 5} more"
return (
splice_status + f" GVL dataset at {self.path}\n"
f"Is subset: {self.is_subset}\n"
f"# of regions: {self.n_regions}\n"
f"# of samples: {self.n_samples}\n"
f"Output length: {self.output_length}\n"
f"Jitter: {self.jitter} (max: {self.max_jitter})\n"
f"Deterministic: {self.deterministic}\n"
f"Sequence type: {seq_type}\n"
f"Active tracks: {act_tracks}\n"
f"Tracks available: {tracks}\n"
)
def __repr__(self) -> str:
return str(self)
[docs]
def subset_to(
self,
regions: StrIdx | None = None,
samples: StrIdx | None = None,
) -> Self:
"""Subset the dataset to specific regions and/or samples by index or a boolean mask. If regions or samples
are not provided, the corresponding dimension will not be subset.
Parameters
----------
regions
The regions to subset to.
samples
The samples to subset to.
Examples
--------
Subsetting to the first 10 regions:
.. code-block:: python
ds.subset_to(slice(10))
Subsetting to the 2nd and 4th samples:
.. code-block:: python
ds.subset_to(samples=[1, 3])
Subsetting to chromosome 1, assuming it's labeled :code:`"chr1"`:
.. code-block:: python
r_idx = ds.regions["chrom"] == "chr1"
ds.subset_to(regions=r_idx)
Subsetting to regions labeled by a column "split", assuming "split" existed in the input regions:
.. code-block:: python
r_idx = ds.regions["split"] == "train"
ds.subset_to(regions=r_idx)
Subsetting to the intersection with another set of regions:
.. code-block:: python
import seqpro as sp
regions = gvl.read_bedlike("regions.bed")
regions_pr = sp.bed.to_pyr(regions)
ds_regions_pr = sp.bed.to_pyr(ds.regions.with_row_index())
r_idx = ds_regions_pr.overlap(regions_pr).df["index"].to_numpy()
ds.subset_to(regions=r_idx)
"""
if regions is None and samples is None:
return self
if is_str_arr(regions) and self._idxer.r2i_map is None:
raise ValueError(
"Cannot subset to regions by name because no region name was set."
)
if self._sp_idxer is None:
idxer = self._idxer.subset_to(regions=regions, samples=samples)
return replace(self, _idxer=idxer)
else:
sp_idxer, sub_dsi = self._sp_idxer.subset_to(rows=regions, samples=samples)
return replace(self, _idxer=sub_dsi, _sp_idxer=sp_idxer)
[docs]
def to_full_dataset(self) -> Self:
"""Return a full sized dataset, undoing any subsetting."""
if self._sp_idxer is None:
return replace(self, _idxer=self._idxer.to_full_dataset())
else:
return replace(
self,
_idxer=self._idxer.to_full_dataset(),
_sp_idxer=self._sp_idxer.to_full_dataset(),
)
[docs]
def haplotype_lengths(
self,
regions: Idx | None = None,
samples: Idx | str | Sequence[str] | None = None,
) -> NDArray[np.int32] | None:
"""The lengths of jitter-extended haplotypes for specified regions and samples. If the dataset is
not phased or not deterministic, this will return :code:`None` because the haplotypes are not guaranteed to be
a consistent length due to randomness in what variants are used.
.. note::
Currently not implemented for spliced datasets.
Parameters
----------
regions
Regions to compute haplotype lengths for.
samples
Samples to compute haplotype lengths for.
"""
if self._sp_idxer is not None:
raise NotImplementedError(
"Haplotype lengths are not yet implemented for spliced datasets."
)
if not isinstance(self._seqs, Haps) or not self.deterministic:
return None
if regions is None:
regions = slice(None)
if samples is None:
samples = slice(None)
idx = (regions, samples)
ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
r_idx, _ = np.unravel_index(ds_idx, self.full_shape)
# (b)
regions = self._full_regions[r_idx]
regions[:, 1] -= self.jitter
regions[:, 2] += self.jitter
# (b p)
hap_lens = (
regions[:, 2, None] # (b 1)
- regions[:, 1, None] # (b 1)
+ self._seqs._haplotype_ilens(ds_idx, regions, self.deterministic) # (b p)
)
if squeeze:
hap_lens = hap_lens.squeeze(0)
if out_reshape is not None:
hap_lens = hap_lens.reshape(
*out_reshape,
self._seqs.genotypes.shape[-2],
)
return hap_lens
[docs]
def n_variants(
self,
regions: Idx | None = None,
samples: StrIdx | None = None,
) -> NDArray[np.int32]:
"""The number of variants in the dataset for specified regions and samples.
Parameters
----------
regions
Regions to compute the number of variants for.
samples
Samples to compute the number of variants for.
Returns
-------
Array with shape (..., ploidy). The number of variants in the dataset for the specified regions and samples.
If the dataset does not have genotypes, this will return :code:`None`.
"""
if regions is None:
regions = slice(None)
if samples is None:
samples = slice(None)
idx = (regions, samples)
ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
r_idx, s_idx = np.unravel_index(ds_idx, self.full_shape)
if not isinstance(self._seqs, Haps):
n_vars = np.zeros((len(r_idx), len(s_idx), 1), dtype=np.int32)
else:
# ((...), P)
n_vars = self._seqs.n_variants[r_idx, s_idx]
if squeeze:
# (1, P) -> (P)
n_vars = n_vars.squeeze(0)
if out_reshape is not None:
# ((...), P) -> (..., P)
n_vars = n_vars.reshape(*out_reshape, n_vars.shape[-1])
return n_vars
def _output_bytes_per_instance(
self,
regions: Idx | None = None,
samples: Idx | str | Sequence[str] | None = None,
*,
include_offsets: bool = False,
) -> NDArray[np.int64]:
"""Exact bytes one (region, sample) instance materializes to under the
current schema. Shape: (n_instances,) of int64.
Parameters
----------
include_offsets
If ``False`` (default), return the *payload* bytes — the
``numpy.nbytes`` of the materialized output. If ``True``, add the
per-instance share of the int64 offset/lengths arrays that the
shared-memory chunk serialization writes alongside the payload
(see ``_shm_layout.write_chunk``): ``8 * ploidy`` per ragged
output array (outer offsets) and, for ``variants`` ``alt``/``ref``
fields, ``8 * n_variants`` (inner allele offsets). This is the
footprint that must fit in a ``double_buffered`` slot; payload
alone undersizes the slot for ragged outputs. The per-chunk
``+1`` offset terminators and 8-byte alignment padding are not
included here — they are absorbed by the slot's fixed slack.
Raises NotImplementedError for spliced datasets. Raises ValueError for
non-deterministic datasets when with_seqs is in {"haplotypes", "annotated"}.
"""
if self._sp_idxer is not None:
raise NotImplementedError(
"_output_bytes_per_instance is not implemented for spliced datasets."
)
if regions is None:
regions = slice(None)
if samples is None:
samples = slice(None)
idx = (regions, samples)
ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
r_idx, _s_idx = np.unravel_index(ds_idx, self.full_shape)
seq_kind = (
self.sequence_type
) # "reference" | "haplotypes" | "annotated" | "variants" | None
total = np.zeros(len(r_idx), dtype=np.int64)
# Per-instance share of int64 offset/lengths arrays (filled when
# include_offsets); added to `total` just before the final reshape.
offset_total = np.zeros(len(r_idx), dtype=np.int64)
OFF = 8 # int64 offset entry
ploidy = self._seqs.n_variants.shape[-1] if isinstance(self._seqs, Haps) else 1
# These are computed conditionally below; declared here to satisfy the type checker.
hap_len_sum: NDArray[np.int64] = np.empty(0, dtype=np.int64)
region_lens: NDArray[np.int64] = np.empty(0, dtype=np.int64)
# --- seqs payload ---
if seq_kind == "reference":
# region length × 1 byte/nt (S1), no ploidy expansion.
regions_arr = self._full_regions[r_idx].copy()
regions_arr[:, 1] -= self.jitter
regions_arr[:, 2] += self.jitter
region_lens = (regions_arr[:, 2] - regions_arr[:, 1]).astype(np.int64)
total += region_lens
if include_offsets and self.output_length == "ragged":
# ragged reference: 1 outer-offset entry per instance (no ploidy).
offset_total += OFF
elif seq_kind in ("haplotypes", "annotated"):
if not self.deterministic:
raise ValueError(
f"with_seqs={seq_kind!r} requires deterministic=True for "
"_output_bytes_per_instance. Use dataset.with_settings(deterministic=True)."
)
hap_lens = self.haplotype_lengths(regions, samples)
if hap_lens is None:
raise ValueError(
f"with_seqs={seq_kind!r} requires haplotype_lengths() to be available."
)
# hap_lens shape: (..., ploidy). Flatten to (n_inst, ploidy).
hap_lens_flat = hap_lens.reshape(-1, hap_lens.shape[-1]).astype(np.int64)
hap_len_sum = hap_lens_flat.sum(-1) # sum over ploidy
total += hap_len_sum # haps S1: 1 byte/nt
if seq_kind == "annotated":
# annotated: var_idxs and ref_coords are per-position (same
# length as haps), not per-variant. Both are int32 (4 bytes).
total += hap_len_sum * 4 # var_idxs int32
total += hap_len_sum * 4 # ref_coords int32
if include_offsets:
# Each ragged array's outer offsets carry `ploidy` entries per
# instance: 1 array for haplotypes, 3 (haps/var_idxs/ref_coords)
# for annotated.
n_seq_arrays = 1 if seq_kind == "haplotypes" else 3
offset_total += OFF * ploidy * n_seq_arrays
elif seq_kind == "variants":
if not isinstance(self._seqs, Haps):
raise AssertionError("variants mode requires Haps")
haps_obj = self._seqs
var_fields = haps_obj.var_fields
n_vars = self.n_variants(regions, samples) # (n_inst, ploidy)
n_vars_flat = n_vars.reshape(-1, n_vars.shape[-1]).astype(np.int64)
n_vars_total = n_vars_flat.sum(-1) # over ploidy → (n_inst,)
ploidy = n_vars.shape[-1]
for f in var_fields:
if f == "start":
total += n_vars_total * haps_obj.variants.start.dtype.itemsize
elif f == "ilen":
total += n_vars_total * haps_obj.variants.ilen.dtype.itemsize
elif f == "dosage":
if haps_obj.dosages is None:
continue
dosage_dtype = haps_obj.dosages.data.dtype
total += n_vars_total * dosage_dtype.itemsize
elif f in ("alt", "ref"):
# Allele scan: _allele_bytes_sum returns (len(ds_idx) * ploidy,).
per_ploid = haps_obj._allele_bytes_sum(ds_idx, f)
total += per_ploid.reshape(-1, ploidy).sum(-1)
else:
# INFO column: numeric, known dtype from on-disk schema.
info_dtype = haps_obj.variants.info[f].dtype
total += n_vars_total * info_dtype.itemsize
if include_offsets:
# RaggedVariants (kind=2) writes, per field: outer offsets
# (ploidy entries/instance) and, for alt/ref allele fields,
# inner offsets (one entry per variant → n_vars_total/instance).
n_allele_fields = sum(1 for f in var_fields if f in ("alt", "ref"))
offset_total += OFF * ploidy * len(var_fields)
offset_total += OFF * n_vars_total * n_allele_fields
elif seq_kind is None:
pass
else:
raise AssertionError(f"unknown sequence_type {seq_kind!r}")
# --- tracks payload ---
if self.active_tracks:
n_tracks = len(self.active_tracks)
track_itemsize = np.dtype(np.float32).itemsize # tracks are always float32
if seq_kind in ("haplotypes", "annotated"):
# Tracks have shape (b, t, p, ~l): length = haplotype length per ploid.
# hap_len_sum already sums over ploidy → total per instance = hap_len_sum * n_tracks.
total += hap_len_sum * n_tracks * track_itemsize
else:
# reference, variants, or no-seq: tracks have shape (b, t, ~l), length = region length.
# "reference" already computed region_lens above; others need to compute it now.
if seq_kind != "reference":
regions_arr = self._full_regions[r_idx].copy()
regions_arr[:, 1] -= self.jitter
regions_arr[:, 2] += self.jitter
region_lens = (regions_arr[:, 2] - regions_arr[:, 1]).astype(
np.int64
)
total += region_lens * n_tracks * track_itemsize
if include_offsets:
# Each track is its own ragged array. Grouped by (instance ×
# ploidy) for haplotype-shaped outputs, else by instance.
if seq_kind in ("haplotypes", "annotated"):
offset_total += OFF * ploidy * n_tracks
else:
offset_total += OFF * n_tracks
if include_offsets:
total += offset_total
if squeeze:
return total
if out_reshape is not None:
return total.reshape(out_reshape)
return total
[docs]
def n_intervals(
self,
regions: Idx | None = None,
samples: StrIdx | None = None,
) -> NDArray[np.int32]:
"""The number of intervals in the dataset for specified regions and samples.
Parameters
----------
regions
Regions to compute the number of intervals for.
samples
Samples to compute the number of intervals for.
Returns
-------
Array with shape (..., tracks). The number of intervals in the dataset for the specified regions and samples.
If the dataset does not have intervals, this will return :code:`None`.
"""
if regions is None:
regions = slice(None)
if samples is None:
samples = slice(None)
idx = (regions, samples)
ds_idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
r_idx, s_idx = np.unravel_index(ds_idx, self.full_shape)
if self._tracks is None:
n_itvs = np.zeros((len(r_idx), len(s_idx)), dtype=np.int32)
else:
ls = []
for name, kind in self._tracks.active_tracks.items():
if kind is TrackType.SAMPLE:
ls.append(self._tracks.intervals[name].values.lengths[r_idx, s_idx])
elif kind is TrackType.ANNOT:
ls.append(self._tracks.intervals[name].values.lengths[r_idx])
else:
assert_never(kind)
n_itvs = np.stack(ls, axis=-1)
if squeeze:
# (1, P) -> (P)
n_itvs = n_itvs.squeeze(0)
if out_reshape is not None:
# ((...), P) -> (..., P)
n_itvs = n_itvs.reshape(*out_reshape, n_itvs.shape[-1])
return n_itvs
[docs]
def write_annot_tracks(
self, tracks: dict[str, str | Path | pl.DataFrame], overwrite: bool = False
) -> Self:
"""Write annotation tracks to the dataset. Returns a new dataset with the
tracks available. Activate them with :meth:`with_tracks()`.
Parameters
----------
tracks
Paths to the annotation tracks (or literal tables) in BED-like format.
Keys should be the track names and values should be the paths to the BED files
or polars.DataFrames.
.. note::
Only supports BED files for now.
overwrite
Whether to overwrite the existing tracks, by default False
"""
if (
self.available_tracks is not None
and (exists := set(tracks) & set(self.available_tracks))
and not overwrite
):
raise ValueError(f"Some tracks already exists in the dataset: {exists}")
for name, bedlike in tracks.items():
out_dir = self.path / "annot_intervals" / name
out_dir.mkdir(parents=True, exist_ok=True)
if isinstance(bedlike, str) or isinstance(bedlike, Path):
bedlike = sp.bed.read(bedlike)
# ensure the full_bed matches the order on-disk
full_bed = regions_to_bed(self._full_regions, self.contigs)
itvs = _annot_to_intervals(full_bed, bedlike)
out = np.memmap(
out_dir / "intervals.npy",
dtype=INTERVAL_DTYPE,
mode="w+",
shape=itvs.values.data.shape,
)
out["start"] = itvs.starts.data
out["end"] = itvs.ends.data
out["value"] = itvs.values.data
out.flush()
out = np.memmap(
out_dir / "offsets.npy",
dtype=itvs.values.offsets.dtype,
mode="w+",
shape=len(itvs.values.offsets),
)
out[:] = itvs.values.offsets
out.flush()
ds_tracks = Tracks.from_path(self.path, *self.full_shape).with_tracks(None)
# Re-activate the same tracks on the newly loaded ds_tracks object,
# then route through the factory to keep _recon consistent with view-state.
cur_active = self._tracks.active_tracks if self._tracks is not None else {}
new_tracks = (
ds_tracks.with_tracks(cur_active.keys()) if cur_active else ds_tracks
)
recon = _build_reconstructor(self._seqs, new_tracks, self._seqs_kind)
return replace(self, _tracks=ds_tracks, _recon=recon)
[docs]
def to_torch_dataset(
self, return_indices: bool, transform: Callable | None
) -> TorchDataset:
"""Convert the dataset to a PyTorch :class:`Dataset <torch.utils.data.Dataset>`. Requires PyTorch to be installed.
Parameters
----------
return_indices
Whether to append arrays of row and sample indices of the non-subset dataset to each batch.
transform
The transform to apply to each batch of data. The transform should take input matching the output of the dataset and can
return anything that can be converted to a PyTorch tensor. In combination with indices, this allows you to combine arbitrary
row- and sample-specific data with dataset output on-the-fly.
.. note::
Depending on how transforms are implemented, they can easily introduce a dataloading bottleneck. If you find
dataloading is slow, it's often a good idea to try disabling your transform to see if it's impacting throughput.
"""
if self.output_length == "ragged":
logger.warning(
'`output_length` is currently set to "ragged" and ragged output cannot be converted to PyTorch Tensors.'
)
return TorchDataset(self, return_indices, transform)
[docs]
def to_dataloader(
self,
batch_size: int = 1,
shuffle: bool = False,
sampler: td.Sampler | Iterable | None = None,
num_workers: int = 0,
collate_fn: Callable | None = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Callable | None = None,
multiprocessing_context: Callable | None = None,
generator: torch.Generator | None = None,
*,
prefetch_factor: int | None = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
return_indices: bool = False,
transform: Callable | None = None,
mode: str | None = None,
buffer_bytes: int = 2 * 1024**3,
copy: bool = True,
heartbeat_seconds: float = 60.0,
) -> td.DataLoader:
"""Convert the dataset to a PyTorch :class:`DataLoader <torch.utils.data.DataLoader>`. The parameters are the same as a
:class:`DataLoader <torch.utils.data.DataLoader>` with a few omissions e.g. :code:`batch_sampler`.
Requires PyTorch to be installed.
Parameters
----------
batch_size
How many samples per batch to load.
shuffle
Set to True to have the data reshuffled at every epoch.
sampler
Defines the strategy to draw samples from the dataset. Can be any :py:class:`Iterable <typing.Iterable>` with :code:`__len__` implemented. If specified, shuffle must not be specified.
.. important::
Do not provide a :class:`BatchSampler <torch.utils.data.BatchSampler>` here. GVL Datasets use multithreading when indexed with batches of indices to avoid the overhead of multi-processing.
To leverage this, GVL will automatically wrap the :code:`sampler` with a :class:`BatchSampler <torch.utils.data.BatchSampler>`
so that lists of indices are given to the GVL Dataset instead of one index at a time. See `this post <https://discuss.pytorch.org/t/dataloader-sample-by-slices-from-dataset/113005>`_
for more information.
num_workers
How many subprocesses to use for dataloading. :code:`0` means that the data will be loaded in the main process.
.. tip::
For GenVarLoader, it is generally best to set this to 0 or 1 since almost everything in
GVL is multithreaded. However, if you are using a transform that is compute intensive and single threaded, there may
be a benefit to setting this > 1.
collate_fn
Merges a list of samples to form a mini-batch of Tensor(s).
pin_memory
If :code:`True`, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your :code:`collate_fn` returns a batch that is a custom type, see the example below.
drop_last
Set to :code:`True` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If :code:`False` and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
timeout
If positive, the timeout value for collecting a batch from workers. Should always be non-negative.
worker_init_fn
If not :code:`None`, this will be called on each worker subprocess with the worker id (an int in :code:`[0, num_workers - 1]`) as input, after seeding and before data loading.
multiprocessing_context
If :code:`None`, the default multiprocessing context of your operating system will be used.
generator
If not :code:`None`, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate :code:`base_seed` for workers.
prefetch_factor
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
persistent_workers
If :code:`True`, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive.
pin_memory_device
The device to :code:`pin_memory` to if :code:`pin_memory` is :code:`True`.
return_indices
Whether to append arrays of row and sample indices of the non-subset dataset to each batch.
transform
The transform to apply to each batch of data. The transform should take input matching the output of the dataset and can
return anything that can be converted to a PyTorch tensor. In combination with indices, this allows you to combine arbitrary
row- and sample-specific data with dataset output on-the-fly.
.. note::
Depending on how transforms are implemented, they can easily introduce a dataloading bottleneck. If you find
dataloading is slow, it's often a good idea to try disabling your transform to see if it's impacting throughput.
"""
if mode is not None:
# Buffered modes operate directly on the Dataset, not on a TorchDataset wrapper,
# because they need access to _output_bytes_per_instance and raw indexing.
return get_dataloader(
dataset=self,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
generator=generator,
pin_memory_device=pin_memory_device,
mode=mode,
buffer_bytes=buffer_bytes,
copy=copy,
heartbeat_seconds=heartbeat_seconds,
)
return get_dataloader(
dataset=self.to_torch_dataset(return_indices, transform),
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
def __getitem__(
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> (
Ragged[np.bytes_ | np.float32]
| RaggedAnnotatedHaps
| RaggedVariants
| RaggedIntervals
| NDArray[np.bytes_ | np.float32]
| AnnotatedHaps
| tuple[
Ragged[np.bytes_ | np.float32]
| RaggedAnnotatedHaps
| RaggedVariants
| RaggedIntervals
| NDArray[np.bytes_ | np.float32]
| AnnotatedHaps,
...,
]
):
# Thin facade: package state into a QueryView and hand off to the
# query module's free functions, which carry the actual logic.
from ._query import QueryView, getitem
view = QueryView(
idxer=self._idxer,
sp_idxer=self._sp_idxer,
full_regions=self._full_regions,
rng=self._rng,
recon=self._recon,
output_length=self.output_length,
jitter=self.jitter,
deterministic=self.deterministic,
rc_neg=self.rc_neg,
)
return getitem(view, idx)
def _lazy_load_dosages(dataset: Dataset, haps: Haps) -> Haps:
"""Open the dosages memmap for a Haps that didn't request them at open time.
Reuses the same path-resolution logic that ``Haps.from_path`` used. Returns
a new ``Haps`` with ``dosages`` populated (does NOT mutate the input).
"""
import json as _json
from genoray._types import DOSAGE_TYPE
from ._svar_link import _resolve_svar
from ._write import Metadata
path = haps.path
svar_meta_path = path / "genotypes" / "svar_meta.json"
if not svar_meta_path.exists():
raise ValueError(
"Dosage requested but this dataset is not SVAR-backed; no dosages.npy possible."
)
with open(svar_meta_path) as f:
svar_meta = _json.load(f)
shape = tuple(svar_meta["shape"])
dtype = np.dtype(svar_meta["dtype"])
offset_path = path / "genotypes" / "offsets.npy"
# Resolve the SVAR directory the same way Haps.from_path did. Dataset does
# not retain Metadata, so re-read metadata.json from disk.
meta = Metadata.model_validate_json((path / "metadata.json").read_text())
svar_link = meta.svar_link
if svar_link is not None:
svar_path = _resolve_svar(path, svar_link, None)
else:
legacy_link = path / "genotypes" / "link.svar"
svar_path = legacy_link.resolve()
dosage_path = svar_path / "dosages.npy"
if not dosage_path.exists():
raise ValueError(
f"Dosage requested but {dosage_path} does not exist. "
f"Check the SVAR was built with dosages."
)
offsets = np.memmap(offset_path, shape=shape, dtype=dtype, mode="r")
dosages_mm = np.memmap(dosage_path, dtype=DOSAGE_TYPE, mode="r")
rag_shape = (*shape[1:], None)
dosages = Ragged.from_offsets(dosages_mm, rag_shape, offsets.reshape(2, -1))
return replace(haps, dosages=dosages)
def _annot_to_intervals(regions: pl.DataFrame, annot: pl.DataFrame) -> RaggedIntervals:
# normalize contig names
reg_c = regions["chrom"].unique()
annot_c = annot["chrom"].unique()
renamer = (normalize_contig_name(c, reg_c) for c in annot_c)
renamer = {c: new_c for c, new_c in zip(annot_c, renamer) if new_c is not None}
annot = annot.with_columns(chrom=pl.col("chrom").replace(renamer))
# find intersection
intersect = sp.bed.from_pyr(
sp.bed.to_pyr(annot).join(sp.bed.to_pyr(regions.with_row_index()))
).sort("index", "chrom", "chromStart")
# compute offsets, considering regions with no overlaps
i, nonzero_counts = np.unique(intersect["index"], return_counts=True)
counts = np.zeros(regions.height, dtype=np.int32)
counts[i] = nonzero_counts
offsets = lengths_to_offsets(counts)
shape = (len(offsets) - 1, None)
# convert to numpy intervals
itvs = np.empty(intersect.height, dtype=INTERVAL_DTYPE)
starts = Ragged.from_offsets(intersect["chromStart"].to_numpy(), shape, offsets)
ends = Ragged.from_offsets(intersect["chromEnd"].to_numpy(), shape, offsets)
values = Ragged.from_offsets(intersect["score"].to_numpy(), shape, offsets)
itvs = RaggedIntervals(starts, ends, values)
return itvs
SEQ = TypeVar("SEQ", NDArray[np.bytes_], AnnotatedHaps, RaggedVariants)
MaybeSEQ = TypeVar("MaybeSEQ", None, NDArray[np.bytes_], AnnotatedHaps, RaggedVariants)
TRK = TypeVar("TRK", NDArray[np.float32], RaggedIntervals)
MaybeTRK = TypeVar("MaybeTRK", None, NDArray[np.float32], RaggedIntervals)
RSEQ = TypeVar("RSEQ", RaggedSeqs, RaggedAnnotatedHaps, RaggedVariants)
MaybeRSEQ = TypeVar("MaybeRSEQ", None, RaggedSeqs, RaggedAnnotatedHaps, RaggedVariants)
RTRK = TypeVar("RTRK", Ragged[np.float32], RaggedIntervals)
MaybeRTRK = TypeVar("MaybeRTRK", None, Ragged[np.float32], RaggedIntervals)
[docs]
class ArrayDataset(Dataset, Generic[MaybeSEQ, MaybeTRK]):
"""Only for type checking purposes, you should never instantiate this class directly."""
output_length: Literal["variable"] | int
@overload
def with_len(
self: ArrayDataset[NDArray[np.bytes_], None],
output_length: Literal["ragged"],
) -> RaggedDataset[RaggedSeqs, None]: ...
@overload
def with_len(
self: ArrayDataset[AnnotatedHaps, None],
output_length: Literal["ragged"],
) -> RaggedDataset[RaggedAnnotatedHaps, None]: ...
@overload
def with_len(
self: ArrayDataset[None, NDArray[np.float32]],
output_length: Literal["ragged"],
) -> RaggedDataset[None, Ragged[np.float32]]: ...
@overload
def with_len(
self: ArrayDataset[NDArray[np.bytes_], NDArray[np.float32]],
output_length: Literal["ragged"],
) -> RaggedDataset[RaggedSeqs, Ragged[np.float32]]: ...
@overload
def with_len(
self: ArrayDataset[AnnotatedHaps, NDArray[np.float32]],
output_length: Literal["ragged"],
) -> RaggedDataset[RaggedAnnotatedHaps, Ragged[np.float32]]: ...
@overload
def with_len(
self,
output_length: Literal["variable"] | int,
) -> ArrayDataset[NDArray[np.bytes_], MaybeTRK]: ...
def with_len(
self, output_length: Literal["ragged", "variable"] | int
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK] | ArrayDataset[SEQ, MaybeTRK]:
return super().with_len(output_length)
@overload
def with_seqs(self, kind: None) -> ArrayDataset[None, MaybeTRK]: ...
@overload
def with_seqs(
self, kind: Literal["reference", "haplotypes"]
) -> ArrayDataset[NDArray[np.bytes_], MaybeTRK]: ...
@overload
def with_seqs(
self, kind: Literal["annotated"]
) -> ArrayDataset[AnnotatedHaps, MaybeTRK]: ...
@overload
def with_seqs(
self, kind: Literal["variants"]
) -> ArrayDataset[RaggedVariants, MaybeTRK]: ...
def with_seqs(
self, kind: Literal["reference", "haplotypes", "annotated", "variants"] | None
) -> ArrayDataset:
return super().with_seqs(kind)
@overload
def with_tracks(self, tracks: None = None, kind: None = None) -> Self: ...
@overload
def with_tracks(
self, *, tracks: None = None, kind: Literal["tracks"]
) -> ArrayDataset[MaybeSEQ, NDArray[np.float32]]: ...
@overload
def with_tracks(
self, *, tracks: None = None, kind: Literal["intervals"]
) -> ArrayDataset[MaybeSEQ, RaggedIntervals]: ...
@overload
def with_tracks(
self,
tracks: Literal[False],
kind: Literal["tracks", "intervals"] | None = None,
) -> ArrayDataset[MaybeSEQ, None]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: None = None
) -> ArrayDataset[MaybeSEQ, MaybeTRK]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: Literal["tracks"]
) -> ArrayDataset[MaybeSEQ, NDArray[np.float32]]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: Literal["intervals"]
) -> ArrayDataset[MaybeSEQ, RaggedIntervals]: ...
def with_tracks(
self,
tracks: str | list[str] | Literal[False] | None = None,
kind: Literal["tracks", "intervals"] | None = None,
) -> ArrayDataset:
return super().with_tracks(tracks, kind)
@overload
def __getitem__(
self: ArrayDataset[SEQ, None],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> SEQ: ...
@overload
def __getitem__(
self: ArrayDataset[None, NDArray[np.float32]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> NDArray[np.float32]: ...
@overload
def __getitem__(
self: ArrayDataset[SEQ, NDArray[np.float32]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[SEQ, NDArray[np.float32]]: ...
@overload
def __getitem__(
self: ArrayDataset[None, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> RaggedIntervals: ...
@overload
def __getitem__(
self: ArrayDataset[SEQ, RaggedIntervals],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> tuple[SEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: ArrayDataset[SEQ, MaybeTRK],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> SEQ | tuple[SEQ, NDArray[np.float32]]: ...
@overload
def __getitem__(
self: ArrayDataset[MaybeSEQ, NDArray[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> NDArray[np.float32] | tuple[SEQ, NDArray[np.float32]]: ...
@overload
def __getitem__(
self: ArrayDataset[MaybeSEQ, RaggedIntervals],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals | tuple[SEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: ArrayDataset[MaybeSEQ, MaybeTRK],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> SEQ | NDArray[np.float32] | tuple[SEQ, NDArray[np.float32]]: ...
def __getitem__(
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> SEQ | TRK | tuple[SEQ, TRK]:
return super().__getitem__(idx) # type: ignore[bad-return] # base Dataset returns broad union; SEQ/TRK typevars narrow at use sites
[docs]
class RaggedDataset(Dataset, Generic[MaybeRSEQ, MaybeRTRK]):
"""Only for type checking purposes, you should never instantiate this class directly."""
output_length: Literal["ragged"]
@overload
def with_len(
self: RaggedDataset[RaggedSeqs, None],
output_length: Literal["variable"] | int,
) -> ArrayDataset[NDArray[np.bytes_], None]: ...
@overload
def with_len(
self: RaggedDataset[RaggedAnnotatedHaps, None],
output_length: Literal["variable"] | int,
) -> ArrayDataset[AnnotatedHaps, None]: ...
@overload
def with_len(
self: RaggedDataset[None, Ragged[np.float32]],
output_length: Literal["variable"] | int,
) -> ArrayDataset[None, NDArray[np.float32]]: ...
@overload
def with_len(
self: RaggedDataset[None, MaybeRTRK],
output_length: Literal["variable"] | int,
) -> ArrayDataset[None, MaybeTRK]: ...
@overload
def with_len(
self: RaggedDataset[RaggedSeqs, Ragged[np.float32]],
output_length: Literal["variable"] | int,
) -> ArrayDataset[NDArray[np.bytes_], NDArray[np.float32]]: ...
@overload
def with_len(
self: RaggedDataset[RaggedAnnotatedHaps, Ragged[np.float32]],
output_length: Literal["variable"] | int,
) -> ArrayDataset[AnnotatedHaps, NDArray[np.float32]]: ...
@overload
def with_len(
self,
output_length: Literal["ragged"],
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ...
def with_len(
self, output_length: Literal["ragged", "variable"] | int
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK] | ArrayDataset[MaybeSEQ, MaybeTRK]:
return super().with_len(output_length)
@overload
def with_seqs(self, kind: None) -> RaggedDataset[None, MaybeRTRK]: ...
@overload
def with_seqs(
self, kind: Literal["reference", "haplotypes"]
) -> RaggedDataset[RaggedSeqs, MaybeRTRK]: ...
@overload
def with_seqs(
self, kind: Literal["annotated"]
) -> RaggedDataset[RaggedAnnotatedHaps, MaybeRTRK]: ...
@overload
def with_seqs(
self, kind: Literal["variants"]
) -> RaggedDataset[RaggedVariants, MaybeRTRK]: ...
def with_seqs(
self, kind: Literal["reference", "haplotypes", "annotated", "variants"] | None
) -> RaggedDataset:
return super().with_seqs(kind)
@overload
def with_tracks(self, tracks: None = None, kind: None = None) -> Self: ...
@overload
def with_tracks(
self, *, tracks: None = None, kind: Literal["tracks"]
) -> RaggedDataset[MaybeRSEQ, RaggedTracks]: ...
@overload
def with_tracks(
self, *, tracks: None = None, kind: Literal["intervals"]
) -> RaggedDataset[MaybeRSEQ, RaggedIntervals]: ...
@overload
def with_tracks(
self,
tracks: Literal[False],
kind: Literal["tracks", "intervals"] | None = None,
) -> RaggedDataset[MaybeRSEQ, None]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: None = None
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: Literal["tracks"]
) -> RaggedDataset[MaybeRSEQ, RaggedTracks]: ...
@overload
def with_tracks(
self, tracks: str | list[str], kind: Literal["intervals"]
) -> RaggedDataset[MaybeRSEQ, RaggedIntervals]: ...
def with_tracks(
self,
tracks: str | list[str] | Literal[False] | None = None,
kind: Literal["tracks", "intervals"] | None = None,
) -> RaggedDataset:
return super().with_tracks(tracks, kind)
@overload
def __getitem__(
self: RaggedDataset[None, None],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> NoReturn: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, None],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ: ...
@overload
def __getitem__(
self: RaggedDataset[None, Ragged[np.float32]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> Ragged[np.float32]: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, Ragged[np.float32]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[None, RaggedIntervals],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, RaggedIntervals],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> tuple[RSEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: RaggedDataset[RSEQ, MaybeRTRK],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ | tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, Ragged[np.float32]],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, RaggedIntervals],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RaggedIntervals | tuple[RSEQ, RaggedIntervals]: ...
@overload
def __getitem__(
self: RaggedDataset[MaybeRSEQ, MaybeRTRK],
idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx],
) -> RSEQ | Ragged[np.float32] | tuple[RSEQ, Ragged[np.float32]]: ...
def __getitem__(
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> RSEQ | RTRK | tuple[RSEQ, RTRK]:
return super().__getitem__(idx) # type: ignore[bad-return] # base Dataset returns broad union; RSEQ/RTRK typevars narrow at use sites