from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, cast
import awkward as ak
import numba as nb
import numpy as np
import seqpro as sp
from awkward.contents import (
Content,
ListArray,
ListOffsetArray,
NumpyArray,
RegularArray,
)
from genoray._types import DOSAGE_TYPE, POS_TYPE, V_IDX_TYPE
from numpy.typing import NDArray
from seqpro.rag import OFFSET_TYPE, Ragged, is_rag_dtype, lengths_to_offsets
from typing_extensions import Self
from .._ragged import reverse_complement_masked
from .._torch import TORCH_AVAILABLE, requires_torch
if TORCH_AVAILABLE or TYPE_CHECKING:
import torch
from torch.nested import nested_tensor_from_jagged as nt_jag
from torch.nested._internal.nested_tensor import NestedTensor
class RaggedVariant(ak.Record):
pass
[docs]
class RaggedVariants(ak.Array):
"""An awkward record array with shape :code:`(batch, ploidy, ~variants, [~length])`.
Guaranteed to at least have the field :code:`"alt"` and :code:`"start"` and one of :code:`"ref"` or :code:`"ilen"`.
"""
def __init__(
self,
alt: ak.Array,
start: Ragged[POS_TYPE],
ref: ak.Array | None = None,
ilen: Ragged[np.int32] | None = None,
dosage: Ragged[DOSAGE_TYPE] | None = None,
**kwargs: Ragged[np.number],
):
if ref is None and ilen is None:
raise ValueError("Must provide one of refs or ilens.")
to_zip = {"alt": alt, "start": start}
if ref is not None:
to_zip["ref"] = ref
if ilen is not None:
to_zip["ilen"] = ilen
if dosage is not None:
to_zip["dosage"] = dosage
arr = ak.zip(
to_zip | kwargs, 1, parameters={"__record__": RaggedVariants.__name__}
)
super().__init__(arr)
[docs]
@classmethod
def from_ak(cls, arr: ak.Array) -> RaggedVariants:
"""Create a RaggedVariants object from an awkward array.
Parameters
----------
arr
The awkward array to create a RaggedVariants object from.
"""
fields = set(arr.fields)
if missing := {"alt", "start"} - fields:
raise ValueError(f"Missing required fields: {missing}")
if {"ref", "ilen"}.isdisjoint(fields):
raise ValueError("Must have one of ref or ilen.")
def find_and_convert_to_ragged(content: Content, depth_context: dict, **kwargs):
if isinstance(content, (ListArray, ListOffsetArray)):
depth_context["n_varlen"] += 1
if (
# is a varlen leaf
isinstance(content, (ListArray, ListOffsetArray))
and isinstance(content.content, NumpyArray)
# is the only varlen leaf in this branch
and depth_context["n_varlen"] == 1
# has no parameters that might conflict with Ragged
and len(content.parameters) == 0
):
return ak.with_parameter(content, "__list__", "Ragged", highlevel=False)
arr = ak.transform( # type: ignore[bad-assignment] # ak.transform stub returns Array|tuple|None; we know it's Array here
find_and_convert_to_ragged, arr, depth_context={"n_varlen": 0}
)
return ak.with_parameter(arr, "__record__", RaggedVariants.__name__)
@property
def alt(self) -> ak.Array:
"""Alternative alleles."""
return cast(ak.Array, super().__getitem__("alt"))
@property
def start(self) -> Ragged[POS_TYPE]:
"""0-based start positions."""
return cast(Ragged[POS_TYPE], super().__getitem__("start"))
@property
def ilen(self) -> Ragged[np.int32]:
"""Indel lengths. Infallible."""
if "ilen" not in self.fields:
ilen = ak.str.length(self.alt) - ak.str.length(self.ref) # type: ignore[missing-attribute] # ak.str submodule isn't exposed in awkward's top-level type stubs
ilen = Ragged(ilen)
return ilen
return cast(Ragged[np.int32], super().__getitem__("ilen"))
@property
def shape(self) -> tuple[int | None, ...]:
return self.start.shape
@property
def end(self) -> Ragged[POS_TYPE]:
"""0-based, exclusive end positions."""
if hasattr(self, "ref"):
ref = cast(Ragged[np.bytes_], self.ref)
return self.start + ak.num(ref, -1)
else:
ilen = cast(Ragged[np.int32], self.ilen)
return self.start - np.clip(ilen, None, 0) + 1
[docs]
def reshape(self, shape: tuple[int | None, ...]) -> Self:
"""Reshape leading, regular axes. Assumes no trailing regular axes."""
reshaped = {}
for field in self.fields:
arr = cast(Ragged | ak.Array, self[field])
if isinstance(arr, Ragged):
arr = arr.reshape(shape)
else:
# strip regular axes
node = arr.layout
while isinstance(node, RegularArray):
node = node.content
# create new regular axes
for len_ in reversed(shape[1:]):
if len_ is None:
continue
node = RegularArray(node, len_)
arr = ak.Array(node)
reshaped[field] = arr
return type(self)(**reshaped)
[docs]
def squeeze(self, axis: int | None = None, **kwargs) -> Self:
"""Squeeze first axis."""
return self[0]
[docs]
def infer_germline_ccfs_(
self, ccf_field: str = "dosages", max_ccf: float = 1.0
) -> Self:
"""Infer germline CCFs in-place.
Germline variants are identified by having missing CCFs i.e. they have a variant
index but missing CCFs. Missing CCFs are inferred to be :code:`max_ccf` - sum(overlapping CCFs).
Parameters
----------
max_ccf
Maximum CCF value.
"""
if not hasattr(self, ccf_field):
raise ValueError(f"Cannot infer germline CCFs without {ccf_field}.")
ccfs = self[ccf_field]
if not isinstance(ccfs, Ragged) or not is_rag_dtype(ccfs, DOSAGE_TYPE):
raise ValueError(f"{ccf_field} must be a Ragged array of {DOSAGE_TYPE}.")
_infer_germline_ccfs(
ccfs.data,
self.start.offsets,
self.start.data,
self.ilen.data,
max_ccf=max_ccf,
)
return self
[docs]
def to_packed(self) -> Self:
"""Pack all fields into contiguous, zero-based arrays.
Replaces the previous :func:`ak.to_packed` call with field-wise packing:
seqpro :meth:`~seqpro.rag.Ragged.to_packed` for numeric :class:`~seqpro.rag.Ragged`
fields, and an allele-level seqpro pack + group-offset rebase +
:func:`~._haps._build_allele_layout` rebuild for the doubly-nested ``alt``/``ref``
fields.
"""
from seqpro.rag import Ragged
# local import to avoid circular dependency (_haps imports RaggedVariants)
from ._haps import _alt_layout_parts, _build_allele_layout
packed: dict = {}
for field in self.fields:
arr = self[field]
if field in ("alt", "ref"):
leaf, allele_off, group_off, ploidy = _alt_layout_parts(arr)
# _alt_layout_parts returns the FULL (un-sliced) leaf and allele_off even
# for a sliced view — only group_off carries the slice's offset. We must
# use group_off[0] to locate where this view's allele groups begin in the
# full allele_off, then slice and zero-base both allele_off and leaf to
# match so that _build_allele_layout sees a clean, contiguous layout.
g0 = int(group_off[0])
rebased_group = np.asarray(group_off, np.int64) - g0
# slice allele_off to only the alleles in this view and zero-base
a0 = int(allele_off[g0])
sliced_allele_off = np.asarray(allele_off[g0:], np.int64) - a0
sliced_leaf = leaf[a0:]
# pack the allele (byte) level: contiguates bytes
allele_lvl = Ragged.from_offsets(
sliced_leaf.view("S1"),
(sliced_allele_off.size - 1, None),
sliced_allele_off,
).to_packed()
packed[field] = _build_allele_layout(
np.asarray(allele_lvl.data).view(np.uint8),
np.asarray(allele_lvl.offsets),
rebased_group,
ploidy,
)
else:
packed[field] = (
arr.to_packed()
if isinstance(arr, Ragged)
else Ragged(arr).to_packed()
)
return type(self)(**packed)
[docs]
def rc_(self, to_rc: NDArray[np.bool_] | None = None) -> Self:
"""Reverse complement the alleles. This is an in-place operation.
Parameters
----------
to_rc
A boolean mask of the same shape as the variant dimension. If :code:`True`, the alternative allele will be reverse complemented.
If :code:`None`, will reverse complement all alternative alleles.
Returns
-------
The RaggedVariants object with the alleles reverse complemented.
"""
if to_rc is None:
to_rc = np.ones(self.shape[0], np.bool_) # type: ignore[no-matching-overload] # ak.Array shape may contain None; np.ones overload expects int|Sequence[int]
elif not to_rc.any():
return self
# local import: _haps imports RaggedVariants (avoid circular import)
from ._haps import _alt_layout_parts
for field in ("alt", "ref"):
if field not in self.fields:
continue
arr = self[field]
leaf, allele_off, group_off, ploidy = _alt_layout_parts(arr)
# per-allele mask: to_rc is per-batch; broadcast across ploidy then variants
per_bp = np.repeat(np.ascontiguousarray(to_rc, np.bool_), ploidy)
per_allele = np.repeat(per_bp, np.diff(group_off))
view = Ragged.from_offsets(
leaf.view("S1"), (per_allele.size, None), allele_off
)
# in-place: mutates `leaf`, which shares memory with `arr`'s buffer
reverse_complement_masked(view, per_allele)
return self
[docs]
@requires_torch
def to_nested_tensor_batch(
self,
device: str | torch.device = "cpu",
tokenizer: Literal["seqpro"]
| Callable[[NDArray[np.bytes_]], NDArray[np.integer]]
| None = None,
) -> dict[str, NestedTensor | int]:
"""Convert a RaggedVariants object to a dictionary of nested tensors. Will flatten across
the ploidy dimension for attributes ILEN, starts, and dosages such that their shapes are (batch * ploidy, ~variants).
For the alternative alleles, will flatten across both the ploidy and variant dimensions such that the
shape is (batch * ploidy * ~variants, ~alt_len).
.. important::
This function assumes all variant data is packed (see :func:`ak.to_packed`).
Parameters
----------
device
The device to move the tensors to.
tokenizer
The tokenizer to use for the alternative alleles.
- If :code:`"seqpro"`, will use :func:`seqpro.tokenize` to convert :code:`ACGTN -> 0 1 2 3 4`.
- If :code:`None`, will use the integer ASCII value of each character i.e. :code:`ACGTN -> 65 67 71 84 78`.
- Otherwise, will use the provided callable to convert the alternative alleles to a tensor of integers.
Returns
-------
Dictionary of `nested tensors <https://docs.pytorch.org/docs/stable/nested.html>`_ and integers with the following keys:
- :code:`"alts"` with shape :code:`(batch * ploidy * ~variants, ~alt_len)`
- :code:`"ilens"` with shape :code:`(batch * ploidy, ~variants)`
- :code:`"starts"` with shape :code:`(batch * ploidy, ~variants)`
- :code:`"dosages"` with shape :code:`(batch * ploidy, ~variants)`
- :code:`"max_n_vars"`: int, maximum number of variants
- :code:`"max_alt_len"`: int, maximum length of an alternative allele
- :code:`"max_ref_len"`: int, maximum length of a reference allele
"""
batch = {}
variant_offsets = None
for field in self.fields:
arr = cast(Ragged | ak.Array, self[field])
if isinstance(arr, Ragged):
data = torch.from_numpy(arr.data).to(device)
if variant_offsets is None:
variant_offsets = torch.from_numpy(arr.offsets.astype(np.int32)).to(
device
)
batch["max_n_vars"] = int(np.diff(arr.offsets).max())
batch[field] = nt_jag(data, variant_offsets)
elif field in {"ref", "alt"}:
data, offsets, max_alen = _alleles_to_nested_tensor(arr, tokenizer)
data = data.to(device)
batch[f"max_{field}_len"] = max_alen
batch[field] = nt_jag(data, offsets)
return batch
[docs]
def pad(
self,
allele: str | bytes = b"N",
ilen: int = 0,
start: int = -1,
dosage: float = 0.0,
**pad_values: Any,
) -> Self:
"""Append a pad variant so that every group is guaranteed to have at least 1 variant. If the group has variants,
no variant is appended.
Parameters
----------
allele
The allele to use for ALTs and REFs
ilen
start
The start position to use for the pad variant
dosage
The dosage to use for the pad variant
**pad_values
Additional values to use for each field. Raises a ValueError if any field does not have a pad value.
Returns
-------
The RaggedVariants object with the pad variant appended to each group that has no variants.
"""
if isinstance(allele, str):
allele = allele.encode()
pad_values |= {
"alt": allele,
"ref": allele,
"ilen": ilen,
"start": start,
"dosage": dosage,
}
if missing_fields := set(self.fields) - set(pad_values.keys()):
raise ValueError(f"Missing pad values for fields: {missing_fields}")
arr = ak.pad_none(self, 1, -1)
for field in self.fields:
value = pad_values[field]
arr = ak.with_field(arr, ak.fill_none(arr[field], value, -1), field)
return arr
def _alleles_to_nested_tensor(
alleles: ak.Array,
tokenizer: Literal["seqpro"]
| Callable[[NDArray[np.bytes_]], NDArray[np.integer]]
| None = None,
) -> tuple[torch.Tensor, torch.Tensor, int]:
_alleles = cast(Content, alleles.layout)
while not isinstance(_alleles, NumpyArray):
if isinstance(_alleles, (ListArray, ListOffsetArray)):
offsets = _alleles
_alleles = cast(Content, _alleles.content)
_alleles = cast(NDArray[np.bytes_], _alleles.data)
if tokenizer == "seqpro":
_alleles = sp.tokenize(_alleles, dict(zip(sp.DNA.alphabet, range(4))), 4)
elif tokenizer is not None:
_alleles = tokenizer(_alleles)
else:
_alleles = _alleles.view(np.uint8)
_alleles = torch.from_numpy(_alleles)
offsets = cast(ListArray | ListOffsetArray, offsets) # type: ignore[redundant-cast] # cast is documentation here; pyrefly narrows but readers benefit
# (N ~V ~L) -> (N ~V) -> (N*~V)
if isinstance(offsets, ListArray):
lengths = cast(NDArray, offsets.stops.data - offsets.starts.data)
offsets = lengths_to_offsets(lengths, np.int32)
else:
offsets = offsets.offsets.data.astype(np.int32) # type: ignore[missing-attribute] # awkward Index.data typed as ArrayLike; numpy ndarray method missing on stub
lengths = np.diff(offsets)
if len(lengths) == 0:
max_alen = 0
else:
max_alen = lengths.max().item()
offsets = torch.from_numpy(offsets)
return _alleles, offsets, max_alen
ak.behavior["*", RaggedVariants.__name__] = RaggedVariants
@nb.njit(parallel=True, nogil=True, cache=True)
def _infer_germline_ccfs(
ccfs: NDArray[DOSAGE_TYPE],
v_offsets: NDArray[OFFSET_TYPE],
v_starts: NDArray[POS_TYPE],
ilens: NDArray[np.int32],
max_ccf: float = 1.0,
):
"""Infer germline CCFs from the variant indices and variant starts. Updates CCFs in-place.
Germline variants are identified by having missing CCFs.
i.e. they have a variant index but missing CCFs. Germline CCFs are inferred
to be 1 - sum(overlapping somatic CCFs).
Parameters
----------
ccfs
Shape: (alts) raveled view of ragged cancer cell fractions.
v_offsets
Shape: (alts + 1) offsets into :code:`ccfs`.
v_starts
Shape: (alts) 0-based start positions.
ilens
Shape: (alts) indel lengths.
max_ccf
Maximum cancer cell fraction.
"""
n_sp = len(v_offsets) - 1
for o_idx in nb.prange(n_sp):
o_s, o_e = v_offsets[o_idx], v_offsets[o_idx + 1]
n_variants: int = o_e - o_s
if n_variants == 0:
continue
ccf = ccfs[o_s:o_e]
if not np.isnan(ccf).any():
continue
v_start = v_starts[o_s:o_e]
ilen = ilens[o_s:o_e]
v_end = (
v_start - np.minimum(0, ilen) + 1
) # +1 for atomic variants, +shared_len for non-atomic
v_end_sorter = np.argsort(v_end)
v_end = v_end[v_end_sorter]
# sorted merge by starts then ends
# ends are marked by being negative
starts_ends = np.empty(n_variants * 2, POS_TYPE)
se_local_idx = np.empty(n_variants * 2, V_IDX_TYPE)
start_idx = 0
end_idx = 0
for i in range(n_variants * 2):
end = v_end[end_idx]
if start_idx < n_variants and v_start[start_idx] < end:
starts_ends[i] = v_start[start_idx]
se_local_idx[i] = start_idx
start_idx += 1
else:
starts_ends[i] = -end
se_local_idx[i] = v_end_sorter[end_idx]
end_idx += 1
running_ccf = DOSAGE_TYPE(0)
# use -1 to mark that we are not currently within a germline variant
g_idx = V_IDX_TYPE(-1)
# set g_end to maximum possible value
g_end = np.iinfo(POS_TYPE).max
for i in range(n_variants * 2):
pos: POS_TYPE = starts_ends[i]
local_idx: V_IDX_TYPE = se_local_idx[i]
pos_ccf: DOSAGE_TYPE = ccf[local_idx]
is_germ = np.isnan(pos_ccf)
# end of variant overlaps with end of current germline variant
#! without this we will decrement the running CCF before setting the germline CCF
# this is because tied ends are sorted by start, but the ends are 0-based exclusive
# so we need to set the germline CCF before we start any decrementing
if -pos >= g_end:
ccf[g_idx] = max_ccf - running_ccf
g_idx = -1
g_end = np.iinfo(POS_TYPE).max
# start of a germline variant
if is_germ and pos > 0:
# for now: check for overlapping variants and set to zero
# to correspond to behavior of haplotype reconstruction
# which only keeps first variant out of an overlapping set
# TODO: handle overlapping germline vars without excessive memory
# iterate over all g_ends, matching running ccf for each?
if g_idx != -1 and np.isnan(ccf[g_idx]):
ccf[local_idx] = 0
continue
g_idx = local_idx
# have to recompute the end because we sorted them above so the local idx points
# to the wrong place
g_end = pos - min(0, ilen[local_idx]) + 1
else:
# sign of pos, with 0 being positive
running_ccf += (2 * (pos >= 0) - 1) * pos_ccf
np.nan_to_num(ccf, copy=False, nan=max_ccf)