"""Tabular interval track source for :func:`gvl.write()`.
Mirrors the :class:`BigWigs` reader API surface so that
:func:`genvarloader._dataset._write._write_track` can dispatch to either.
"""
from __future__ import annotations
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import polars as pl
from ._utils import normalize_contig_name
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
from ._ragged import RaggedIntervals
CANONICAL_COLS = ("sample_id", "chrom", "start", "end", "value")
#: ``gvl.Table`` is temporarily disabled: its backend, ``polars-bio``,
#: intermittently segfaults the interpreter during overlap queries (a
#: non-deterministic native-runtime issue observed on CPython 3.12 and 3.13).
#: ``polars-bio`` has been removed from genvarloader's direct dependencies (it
#: remains an indirect dependency of ``genoray``). Re-enable by reverting this
#: commit once the upstream issue is resolved.
#: Upstream: https://github.com/biodatageeks/polars-bio/issues/395
_TABLE_DISABLED_MSG = (
"gvl.Table is temporarily disabled because its polars-bio backend "
"intermittently segfaults during overlap queries (upstream "
"https://github.com/biodatageeks/polars-bio/issues/395). Use gvl.BigWigs "
"for track input in the meantime."
)
[docs]
class Table:
"""Long-form interval track keyed by ``(sample_id, chrom, start, end, value)``.
.. warning::
Temporarily disabled. Constructing a :class:`Table` raises
:class:`NotImplementedError` while the ``polars-bio`` segfault
(`#395 <https://github.com/biodatageeks/polars-bio/issues/395>`_) is
unresolved. Use :class:`BigWigs` for track input in the meantime.
"""
name: str
samples: list[str]
contigs: Mapping[str, int]
[docs]
def __init__(
self,
name: str,
data: pl.DataFrame | Mapping[str, pl.DataFrame],
column_map: Mapping[str, str] | None = None,
) -> None:
raise NotImplementedError(_TABLE_DISABLED_MSG)
self.name = name
df = self._normalize_input(data, column_map)
df = df.cast(
{
"sample_id": pl.Utf8,
"chrom": pl.Utf8,
"start": pl.Int64,
"end": pl.Int64,
"value": pl.Float32,
}
).sort("chrom", "sample_id", "start")
self._df = df
self.samples = sorted(df["sample_id"].unique().to_list())
self.contigs = {
row["chrom"]: int(row["max_end"])
for row in df.group_by("chrom")
.agg(pl.col("end").max().alias("max_end"))
.iter_rows(named=True)
}
@classmethod
def from_path(
cls,
name: str,
path: str | Path | Mapping[str, str | Path],
column_map: Mapping[str, str] | None = None,
) -> Table:
if isinstance(path, Mapping):
data: dict[str, pl.DataFrame] = {
sid: cls._read_path(Path(p)) for sid, p in path.items()
}
return cls(name, data, column_map)
return cls(name, cls._read_path(Path(path)), column_map)
@staticmethod
def _read_path(p: Path) -> pl.DataFrame:
suf = p.suffix.lower()
if suf == ".csv":
return pl.read_csv(p)
if suf in (".tsv", ".txt"):
return pl.read_csv(p, separator="\t")
if suf == ".parquet":
return pl.read_parquet(p)
if suf in (".arrow", ".ipc"):
return pl.read_ipc(p)
raise ValueError(
f"Unsupported file extension {suf!r}. "
"Expected one of .csv, .tsv, .txt, .parquet, .arrow, .ipc."
)
@staticmethod
def _normalize_input(
data: pl.DataFrame | Mapping[str, pl.DataFrame],
column_map: Mapping[str, str] | None,
) -> pl.DataFrame:
if isinstance(data, pl.DataFrame):
df = Table._apply_column_map(data, column_map, expect_sample_id=True)
else:
# dict[sample_id, df] without sample_id col
frames: list[pl.DataFrame] = []
for sid, sub in data.items():
renamed = Table._apply_column_map(
sub, column_map, expect_sample_id=False
)
frames.append(renamed.with_columns(sample_id=pl.lit(sid)))
if not frames:
raise ValueError("Empty mapping passed to Table.")
df = pl.concat(frames, how="vertical_relaxed")
missing = [c for c in CANONICAL_COLS if c not in df.columns]
if missing:
raise ValueError(
f"Missing required column(s) {missing}. "
f"Use `column_map` to rename if your columns differ from {CANONICAL_COLS}."
)
return df.select(*CANONICAL_COLS)
@staticmethod
def _apply_column_map(
df: pl.DataFrame,
column_map: Mapping[str, str] | None,
expect_sample_id: bool,
) -> pl.DataFrame:
if not column_map:
return df
# column_map is canonical -> actual; invert to actual -> canonical for rename
rename = {
actual: canonical
for canonical, actual in column_map.items()
if actual in df.columns
}
if not expect_sample_id:
rename.pop("sample_id", None)
return df.rename(rename)
def count_intervals(
self,
contig: str,
starts: ArrayLike,
ends: ArrayLike,
sample: str | list[str] | None = None,
**kwargs,
) -> NDArray[np.int32]:
import polars_bio as pb
# pb.set_option is idempotent; called per-method to avoid relying on import order.
pb.set_option("datafusion.bio.coordinate_system_check", "false")
pb.set_option("datafusion.bio.coordinate_system_zero_based", True)
samples = self._resolve_samples(sample)
sample_to_si = {s: i for i, s in enumerate(samples)}
starts_arr = np.atleast_1d(np.asarray(starts, dtype=np.int64))
ends_arr = np.atleast_1d(np.asarray(ends, dtype=np.int64))
n_regions = len(starts_arr)
n_samples = len(samples)
_contig = normalize_contig_name(contig, self.contigs)
if _contig is None:
return np.zeros((n_regions, n_samples), dtype=np.int32)
contig = _contig
contig_subset = self._df.filter(
(pl.col("chrom") == contig) & pl.col("sample_id").is_in(samples)
)
if contig_subset.height == 0:
return np.zeros((n_regions, n_samples), dtype=np.int32)
queries = pl.DataFrame(
{
"chrom": [contig] * n_regions,
"start": starts_arr,
"end": ends_arr,
"_q": np.arange(n_regions, dtype=np.int64),
}
)
result = pb.overlap(
queries,
contig_subset.select("chrom", "start", "end", "sample_id"),
cols1=["chrom", "start", "end"],
cols2=["chrom", "start", "end"],
output_type="polars.DataFrame",
)
out = np.zeros((n_regions, n_samples), dtype=np.int32)
if result.height == 0:
return out
q_idx = result["_q_1"].to_numpy()
si_idx = (
result.select(
pl.col("sample_id_2").replace_strict(
sample_to_si, return_dtype=pl.Int64
)
)
.to_series()
.to_numpy()
)
np.add.at(out, (q_idx, si_idx), 1)
return out
def _intervals_from_offsets(
self,
contig: str,
starts: ArrayLike,
ends: ArrayLike,
offsets: NDArray[np.int64],
sample: str | list[str] | None = None,
**kwargs,
) -> RaggedIntervals:
import polars_bio as pb
from seqpro.rag import Ragged
from ._ragged import RaggedIntervals
# pb.set_option is idempotent; called per-method to avoid relying on import order.
pb.set_option("datafusion.bio.coordinate_system_check", "false")
pb.set_option("datafusion.bio.coordinate_system_zero_based", True)
samples = self._resolve_samples(sample)
sample_to_si = {s: i for i, s in enumerate(samples)}
starts_arr = np.atleast_1d(np.asarray(starts, dtype=np.int64))
ends_arr = np.atleast_1d(np.asarray(ends, dtype=np.int64))
n_regions = len(starts_arr)
n_samples = len(samples)
shape = (n_regions, n_samples, None)
total = int(offsets[-1])
flat_starts = np.empty(total, dtype=np.int32)
flat_ends = np.empty(total, dtype=np.int32)
flat_values = np.empty(total, dtype=np.float32)
_contig = normalize_contig_name(contig, self.contigs)
if _contig is not None and total > 0:
contig = _contig
contig_subset = self._df.filter(
(pl.col("chrom") == contig) & pl.col("sample_id").is_in(samples)
)
if contig_subset.height > 0:
queries = pl.DataFrame(
{
"chrom": [contig] * n_regions,
"start": starts_arr,
"end": ends_arr,
"_q": np.arange(n_regions, dtype=np.int64),
}
)
joined = pb.overlap(
queries,
contig_subset.select("chrom", "start", "end", "sample_id", "value"),
cols1=["chrom", "start", "end"],
cols2=["chrom", "start", "end"],
output_type="polars.DataFrame",
)
if joined.height > 0:
# Sort by query index, sample index, then table start (matches BigWigs order).
si_idx = (
joined.select(
pl.col("sample_id_2").replace_strict(
sample_to_si, return_dtype=pl.Int64
)
)
.to_series()
.to_numpy()
)
q_idx = joined["_q_1"].to_numpy()
j_starts_raw = joined["start_2"].to_numpy()
order = np.lexsort(
(
j_starts_raw,
si_idx,
q_idx,
)
) # last key = primary
q_idx = q_idx[order]
si_idx = si_idx[order]
j_starts = j_starts_raw[order].astype(np.int32, copy=False)
j_ends = (
joined["end_2"].to_numpy()[order].astype(np.int32, copy=False)
)
j_values = (
joined["value_2"]
.to_numpy()[order]
.astype(np.float32, copy=False)
)
cell_idx = q_idx * n_samples + si_idx
boundaries = np.concatenate(
(
[0],
np.where(np.diff(cell_idx) != 0)[0] + 1,
)
)
counts_per_cell = np.diff(
np.concatenate((boundaries, [len(cell_idx)]))
)
intra = np.arange(len(cell_idx)) - np.repeat(
boundaries, counts_per_cell
)
write_pos = offsets[cell_idx] + intra
flat_starts[write_pos] = j_starts
flat_ends[write_pos] = j_ends
flat_values[write_pos] = j_values
return RaggedIntervals(
Ragged.from_offsets(flat_starts, shape, offsets),
Ragged.from_offsets(flat_ends, shape, offsets),
Ragged.from_offsets(flat_values, shape, offsets),
)
def _resolve_samples(self, sample: str | list[str] | None) -> list[str]:
if sample is None:
return list(self.samples)
if isinstance(sample, str):
samples = [sample]
else:
samples = list(sample)
if missing := set(samples) - set(self.samples):
raise ValueError(f"Sample(s) {missing} not found in Table.")
return samples