from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import Generic, overload
import numba as nb
import numpy as np
import pandera.polars as pa
import polars as pl
import seqpro as sp
from genoray import VCF
from genoray._utils import ContigNormalizer
from numpy.typing import NDArray
from .._dataset._impl import SEQ, ArrayDataset, MaybeTRK
from .._dataset._indexing import DatasetIndexer
from .._types import AnnotatedHaps, Idx
from ._records import RaggedAlleles
[docs]
def sites_vcf_to_table(
vcf: str | Path | VCF,
attributes: list[str] | None = None,
info_fields: list[str] | None = None,
) -> pl.DataFrame:
"""Extract a table of variant site info from a VCF. All sites must be bi-allelic.
Parameters
----------
vcf
Path to a VCF or a :class:`genoray.VCF` instance. Note that :class:`genoray.VCF` can accept a filter function.
attributes
A list of attributes to include in the output table. Note that "CHROM", "POS", "REF", and "ALT" are always included
even if not in this list.
info_fields
A list of INFO fields to include in the output table.
"""
if not isinstance(vcf, VCF):
vcf = VCF(vcf)
min_attrs = ["CHROM", "POS", "REF", "ALT"]
if attributes is None:
attrs = min_attrs
else:
attrs = min_attrs + [attr for attr in attributes if attr not in min_attrs]
df = vcf.get_record_info(fields=attrs, info=info_fields)
if df.select((pl.col("ALT").list.len() > 1).any()).item():
raise ValueError("All sites must be bi-allelic.")
df = df.with_columns(pl.col("ALT").list.first())
return df
[docs]
class SitesSchema(pa.DataFrameModel):
"""Schema to validate a table of variant sites."""
CHROM: str
POS: int
REF: str
ALT: str
def _sites_table_to_bedlike(sites: pl.DataFrame) -> pl.DataFrame:
sites = sites.pipe(SitesSchema.validate)
return (
sites.with_columns(
chromStart=pl.col("POS") - 1,
chromEnd=pl.col("POS") + pl.col("REF").str.len_bytes() - 1,
)
.drop("POS")
.rename({"CHROM": "chrom"})
)
[docs]
class DatasetWithSites(Generic[MaybeTRK]):
dataset: ArrayDataset[AnnotatedHaps, MaybeTRK]
"""Dataset of haplotypes and potentially tracks."""
sites: pl.DataFrame
"""Table of variant site information."""
rows: pl.DataFrame
"""Rows of this object, where each row is a combination of a dataset region and a site."""
_row_map: NDArray[np.uint32]
"""Map from row index to dataset row index and site row index."""
_idxer: DatasetIndexer
@property
def n_rows(self) -> int:
return self._idxer.n_regions
@property
def n_samples(self) -> int:
return self._idxer.n_samples
@property
def shape(self) -> tuple[int, int]:
return self._idxer.shape
def __len__(self) -> int:
return self.n_rows * self.n_samples
[docs]
def __init__(
self,
dataset: ArrayDataset[SEQ, MaybeTRK],
sites: pl.DataFrame,
max_variants_per_region: int = 1,
):
"""Dataset with variant sites, used to apply site-only variants e.g. from ClinVar to a Dataset of haplotypes.
Currently only supports bi-allelic SNPs. Takes the intersection of the dataset regions and the sites, and
applies the site-only variants to the Dataset's haplotypes.
Accessed just like a Dataset, but where the rows are combinations of dataset regions and sites. Will return
two :class:`AnnotatedHaps` with variants applied and flags indicating whether the variant was applied, deleted, or existed.
The flags are 0 for applied, 1 for deleted, and 2 for existed. If the dataset has tracks, they will be
returned as well and reflect any site-only variants. The first :class:`AnnotatedHaps` is the wildtype haplotypes
and the second is the mutated haplotypes. The mutant haplotypes will also have their variant indices and reference
coordinates updated to reflect the applied variants. Locations where a site-only variant was applied will have a
variant index of -2.
Parameters
----------
dataset
Dataset of haplotypes and potentially tracks.
sites
Table of variant site information.
max_variants_per_region
Maximum number of variants per region. Currently only 1 is supported.
Examples
--------
.. code-block:: python
import genvarloader as gvl
sites = gvl.sites_vcf_to_table("path/to/variants.vcf")
ds = gvl.Dataset.open("path/to/dataset.gvl", "path/to/reference.fasta")
ds_sites = gvl.DatasetWithSites(ds, sites)
wt_haps, mut_haps, flags = ds_sites[0, 0]
# flags is a np.uint8 (or an array of np.uint8 when accessing multiple rows/samples)
ds_sites.dataset = ds_sites.dataset.with_tracks("read-depth")
wt_haps, mut_haps, flags, tracks = ds_sites[0, 0]
"""
if max_variants_per_region > 1:
raise NotImplementedError("max_variants_per_region > 1 not yet supported")
if not isinstance(dataset, ArrayDataset):
raise ValueError(
'Dataset output_length must either be "variable" or a fixed length integer.'
)
sites = _sites_table_to_bedlike(sites)
if sites.select(
(
(pl.col("REF").str.len_bytes() != 1) | pl.col("ALT").str.len_bytes()
!= 1
).any()
).item():
raise ValueError(
"All sites must be SNPs. Consider filtering the VCF as either a preprocessing step or via the sites_vcf_to_table function."
)
c_norm = ContigNormalizer(dataset.contigs)
chroms: list[str] = sites["chrom"].to_list()
norm_chroms = c_norm.norm(chroms)
norm_chroms = [
c if norm is None else norm for c, norm in zip(chroms, norm_chroms)
]
sites = sites.with_columns(chrom=pl.Series(norm_chroms))
ds_bed = dataset.regions.with_row_index("region_idx")
if isinstance(dataset.output_length, int):
ds_bed = ds_bed.with_columns(
chromEnd=pl.col("chromStart") + dataset.output_length
)
ds_pyr = sp.bed.to_pyr(ds_bed)
sites_pyr = sp.bed.to_pyr(sites.with_row_index("site_idx"))
rows = pl.from_pandas(ds_pyr.join(sites_pyr, suffix="_site").df)
if rows.height == 0:
raise RuntimeError("No overlap between dataset regions and sites.")
rows = rows.rename(
{
"Chromosome": "chrom",
"Start": "chromStart",
"End": "chromEnd",
"Strand": "strand",
"Start_site": "POS0",
},
strict=False,
).drop("End_site")
_dataset = dataset.with_seqs("annotated").with_settings(
deterministic=True, jitter=0
)
self.sites = sites
self.dataset = _dataset
self.rows = rows
self._row_map = rows.select("region_idx", "site_idx").to_numpy()
self._idxer = DatasetIndexer.from_region_and_sample_idxs(
np.arange(self.rows.height), np.arange(dataset.n_samples), dataset.samples
)
@overload
def __getitem__(
self: DatasetWithSites[None],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> tuple[AnnotatedHaps, AnnotatedHaps, NDArray[np.uint8]]: ...
@overload
def __getitem__(
self: DatasetWithSites[NDArray[np.float32]],
idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]],
) -> tuple[
AnnotatedHaps, AnnotatedHaps, NDArray[np.uint8], NDArray[np.float32]
]: ...
def __getitem__(
self, idx: Idx | tuple[Idx] | tuple[Idx, Idx | str | Sequence[str]]
) -> (
tuple[AnnotatedHaps, AnnotatedHaps, NDArray[np.uint8]]
| tuple[AnnotatedHaps, AnnotatedHaps, NDArray[np.uint8], NDArray[np.float32]]
):
idx, squeeze, out_reshape = self._idxer.parse_idx(idx)
row_idx, s_idx = np.unravel_index(idx, self.shape)
ds_rows = self._row_map[row_idx, 0]
out = self.dataset[ds_rows, s_idx]
if isinstance(out, tuple):
wt_haps, tracks = out
else:
wt_haps = out
ploidy = wt_haps.shape[-2]
length = wt_haps.shape[-1]
sites = self.rows[row_idx]
starts = sites["POS0"].to_numpy() # 0-based
alts = RaggedAlleles.from_polars(sites["ALT"])
# (b p ~l)
wt_haps = wt_haps.reshape((-1, ploidy, length))
# flags: (b p)
mut_haps, v_idxs, ref_coords, flags = apply_site_only_variants(
haps=wt_haps.haps.view(np.uint8).copy(), # (b p ~l)
v_idxs=wt_haps.var_idxs.copy(), # (b p ~l)
ref_coords=wt_haps.ref_coords, # (b p ~l)
site_starts=starts,
alt_alleles=alts.data.view(np.uint8),
alt_offsets=alts.offsets,
)
mut_haps = AnnotatedHaps(
haps=mut_haps.view("S1"), var_idxs=v_idxs, ref_coords=ref_coords
)
if squeeze:
wt_haps = wt_haps.squeeze(0)
mut_haps = mut_haps.squeeze(0)
flags = flags.squeeze(0)
if out_reshape is not None:
wt_haps = wt_haps.reshape((*out_reshape, ploidy, length))
mut_haps = mut_haps.reshape((*out_reshape, ploidy, length))
flags = flags.reshape(*out_reshape, ploidy)
if isinstance(out, tuple):
return (
wt_haps,
mut_haps,
flags,
tracks, # type: ignore[unbound-name] # tracks is bound when isinstance(out, tuple) branch is taken
)
else:
return wt_haps, mut_haps, flags
APPLIED = np.uint8(0)
DELETED = np.uint8(1)
EXISTED = np.uint8(2)
# * fixed length, SNPs only
@nb.njit(parallel=True, nogil=True, cache=True)
def apply_site_only_variants(
haps: NDArray[np.uint8], # (b p ~l)
v_idxs: NDArray[np.int32], # (b p ~l)
ref_coords: NDArray[np.int32], # (b p ~l)
site_starts: NDArray[np.int32], # (b)
alt_alleles: NDArray[np.uint8], # (b ~a)
alt_offsets: NDArray[np.int64], # (b+1)
) -> tuple[NDArray[np.uint8], NDArray[np.int32], NDArray[np.int32], NDArray[np.uint8]]:
batch_size, ploidy, _ = haps.shape
flags = np.empty((batch_size, ploidy), dtype=np.uint8)
for b in nb.prange(batch_size):
for p in nb.prange(ploidy):
bp_hap = haps[b, p]
bp_idx = v_idxs[b, p]
bp_ref_coord = ref_coords[b, p]
pos = site_starts[b]
alt = alt_alleles[alt_offsets[b] : alt_offsets[b + 1]]
rel_start = np.searchsorted(bp_ref_coord, pos)
rel_end = rel_start + len(alt)
if bp_ref_coord[rel_start] != pos:
flags[b, p] = DELETED
continue
if np.all(bp_hap[rel_start:rel_end] == alt):
flags[b, p] = EXISTED
continue
flags[b, p] = APPLIED
bp_hap[rel_start:rel_end] = alt
bp_idx[rel_start:rel_end] = -2
return haps, v_idxs, ref_coords, flags