[1]:
import sys
from pathlib import Path
from tempfile import TemporaryDirectory

import genvarloader as gvl
import numba as nb
import numpy as np
import polars as pl
import pooch
import seqpro as sp
from einops import rearrange
from loguru import logger
from tqdm.auto import tqdm

Tutorial: Geuvadis#

In this tutorial we’ll see how to use GenVarLoader (GVL) to:

  1. Write a GVL dataset

  2. Add transforms

  3. Lazily subset it (train/test splits)

  4. Get a PyTorch DataLoader

  5. Cache transformed tracks on disk (optional)

  6. Evaluate Basenji2 across genes and individuals (optional)

This tutorial also assumes you have read “What’s a gvl.Dataset?”.

Logging#

A quick note on logging: GenVarLoader uses loguru for logging. We will enable it at the “INFO” level to get some additional information from GVL for this tutorial.

[2]:
logger.remove()
logger.add(sys.stderr, level="INFO")
logger.enable("genvarloader")

Download the data#

The Geuvadis dataset is 451 individuals from the 1000 Genomes Project that have both whole genome sequencing and RNA-seq from blood samples. We’ll see how to use GVL to get a high performance dataloader that yields haplotypes and tracks for training or running inference with sequence models. For the sake of this tutorial, we’ll only work with chromosome 22 so everything can run in a few minutes.

Downloading this data should take ~5-10 minutes and is the slowest step in this notebook.

[3]:
# GRCh38 chromosome 22 sequence
reference = pooch.retrieve(
    url="https://ftp.ensembl.org/pub/release-112/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.22.fa.gz",
    known_hash="sha256:974f97ac8ef7ffae971b63b47608feda327403be40c27e391ee4a1a78b800df5",
    progressbar=True,
)
if not Path(f"{reference[:-3]}.bgz").exists():
    !gzip -dc {reference} | bgzip > {reference[:-3]}.bgz
reference = reference[:-3] + ".bgz"

# PLINK 2 files
variants = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.pgen",
    known_hash="md5:31aba970e35f816701b2b99118dfc2aa",
    progressbar=True,
    fname="1kGP.chr22.pgen",
)
pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.psam",
    known_hash="md5:eefa7aad5acffe62bf41df0a4600129c",
    progressbar=True,
    fname="1kGP.chr22.psam",
)
pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.pvar",
    known_hash="md5:5f922af91c1a2f6822e2f1bb4469d12b",
    progressbar=True,
    fname="1kGP.chr22.pvar",
)

# BigWigs and sample ID mapping
bw_paths = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/bw_chr22.tar.gz",
    known_hash="md5:14bf72e9e9d3e2318d07315c4a2675fb",
    progressbar=True,
    processor=pooch.Untar(),
)
bw_table_path = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/bigwig_table.csv",
    known_hash="md5:7fe7c55b61c7dfa66cfd0a49336f3b08",
    progressbar=True,
)

# BED
bed_path = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/chr22_egenes.bed",
    known_hash="md5:ccb55548e4ddd416d50dbe6638459421",
    progressbar=True,
)

Writing the GVL dataset#

We’ll specify a directory to store the dataset (similar to Zarr stores).

[4]:
tmp_dir = TemporaryDirectory(suffix=".gvl")
ds_path = tmp_dir.name

We’ll also need a table or dictionary specifying the sample names for each BigWig. Tables must have at least have columns sample and path as seen below. The join is added here to update the paths to match the actual download paths.

[5]:
bigwig_table = (
    pl.read_csv(bw_table_path)
    .join(
        pl.Series(bw_paths).to_frame("realpath"),
        left_on="path",
        right_on=pl.col("realpath").str.split("/").list.get(-1),
    )
    .drop("path")
    .rename({"realpath": "path"})
)
bigwig_table.head()
[5]:
shape: (5, 3)
sampleread_countpath
stri64str
"HG00236"34548283"/carter/users/dlaub/.cache/poo…
"HG00259"53041143"/carter/users/dlaub/.cache/poo…
"NA20519"36620358"/carter/users/dlaub/.cache/poo…
"NA20811"24398971"/carter/users/dlaub/.cache/poo…
"NA20768"30019566"/carter/users/dlaub/.cache/poo…

Finally, we’ll need a BED file specifying what regions to include in the dataset. We can either specify a path or a polars DataFrame. We’ll use gvl.read_bedlike to conveniently read the BED file into memory and subset it to just the top 20 eGenes for this tutorial. The BED file provided corresponds to transcription start sites of eGenes, sorted in descending order by their absolute sum of coefficients.

[6]:
bed = gvl.read_bedlike(bed_path)[:20]
bed.head()
[6]:
shape: (5, 6)
chromchromStartchromEndnamescorestrand
stri64i64strf64str
"chr22"4169949941699499"ENSG00000167077"null"+"
"chr22"4283541242835412"ENSG00000100266"null"-"
"chr22"2085898320858983"ENSG00000099940"null"+"
"chr22"2070769120707691"ENSG00000241973"null"-"
"chr22"4991816749918167"ENSG00000184164"null"+"

Now we’re ready to write the dataset.

The bed above specifies the transcription start site for each gene so chromStart == chromEnd, so we’ll expand those regions to \(2^{17}\) (131,072) bp using gvl.with_length which corresponds to the input length for Basenji2.

We’ll also instantiate a gvl.BigWigs from the above table (we could also use a dictionary). We’ll name this track “read-depth” so we can manage different transformations of the track data or provide multiple tracks for the same samples. Later, we’ll add a transformed track for \(\log_2(\text{CPM}+1)\) to see this in action.

Finally, we’ll pass max_jitter as 128 bp. This will allow random jittering of the sequences and tracks up to 128 bp in either direction. When we open the dataset later it will use the maximum amount of jitter by default.

[7]:
gvl.write(
    path=ds_path,
    bed=gvl.with_length(bed, 2**17),  # change region length to 131,072 bp
    variants=variants,
    tracks=gvl.BigWigs.from_table(name="read-depth", table=bigwig_table),
    max_jitter=128,  # allow up to 128 bp jitter
    overwrite=True,
)
2025-03-19 20:27:36.453 | INFO     | genvarloader._dataset._write:write:99 - Writing dataset to /tmp/tmpmcduh61m.gvl
2025-03-19 20:27:36.454 | INFO     | genvarloader._dataset._write:write:104 - Found existing GVL store, overwriting.
2025-03-19 20:27:36.541 | INFO     | genvarloader._dataset._write:write:172 - Using 451 samples.
2025-03-19 20:27:36.541 | INFO     | genvarloader._dataset._write:write:178 - Writing genotypes.
2025-03-19 20:27:42.068 | INFO     | genvarloader._dataset._write:write:197 - Writing BigWig intervals.
2025-03-19 20:27:46.202 | INFO     | genvarloader._dataset._write:write:204 - Finished writing.

Note that gvl.write will also automatically use the intersection of samples from source files. In this case, they are perfectly matched to each other. But, if we had used PLINK files for the full 3,202 samples from the 1000 Genomes Project then it would have identified and used the 451 intersecting samples.

Dataloader#

Now that the dataset is written, we can add a transform, split it, and get a PyTorch dataloader in ~10 lines of code.

[8]:
def transform(haps, tracks):
    haps = rearrange(
        sp.DNA.ohe(haps), "... length alphabet -> ... alphabet length"
    ).astype(np.float32)
    return haps, tracks


ds = (
    gvl.Dataset.open(ds_path, reference=reference)
    .with_seqs("haplotypes")
    .with_tracks("read-depth")
    .with_len(2**17)
    .with_transform(transform)
)
n_train = round(ds.n_samples * 0.8)
gene1_train_ds = ds.subset_to(samples=slice(0, n_train))
dl = gene1_train_ds.to_dataloader(batch_size=16, num_workers=0, shuffle=True)
2025-03-19 20:27:46.208 | INFO     | genvarloader._dataset._impl:open:227 - Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance.
2025-03-19 20:27:46.243 | INFO     | genvarloader._dataset._impl:open:269 - Opened dataset:
GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: ragged
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: read-depth
Tracks available: read-depth

GVL uses numba JIT compiled functions extensively, so the first call to gvl.write, first batch from a dataloader, etc. will often take much longer than subsequent calls due to compilation. This allows GVL to be multithreaded almost everywhere that it can be, so using num_workers=0 or 1 is usually the best choice for dataloader throughput.

[9]:
haps, tracks = next(iter(dl))
print(haps.shape, tracks.shape)
torch.Size([16, 2, 4, 131072]) torch.Size([16, 1, 2, 131072])

After one-hot encoding, the haplotypes have shape (batch, ploidy, alphabet, length) and the tracks have shape (batch, tracks, ploidy, length).

Pre-computing transformed tracks (optional)#

Suppose we would like to normalize the read depth across the dataset to account for library size. We could compute this on-the-fly, but GVL also offers a way to write this data back to disk to cache this computation and potentially improve performance. Note that this is the most technical part of this tutorial, so feel free to skip this and come back later.

[10]:
sample_library_sizes = (
    pl.Series(ds.samples)
    .to_frame("sample")
    .join(bigwig_table, on="sample", how="left")["read_count"]
    .to_numpy()
)
sample_library_sizes[:5]
[10]:
array([27256165, 43941108, 39687917, 22341838, 23258231])

For this step, we’ll use Dataset.write_transformed_track which expects a transform function to be given. From the docs:

The arguments given to the transform will be the dataset indices, region indices, and sample indices as numpy arrays and the tracks themselves as a Ragged array with shape (regions, samples). The tracks must be a Ragged array since regions may be different lengths to accomodate indels. This function should then return the transformed tracks as a Ragged array with the same shape and lengths.

Below, you can see an example of a transform of ragged data that uses Numba to accelerate the computation. Note that working with Ragged arrays is often not necessary with on-the-fly transformations, since data for deep learning is readily processed to be uniform length before any transformations.

[11]:
@nb.njit(parallel=True, nogil=True, fastmath=True)
def inner_transform(s_idx, data, offsets):
    log_cpm = np.empty_like(data)
    for i in nb.prange(len(offsets) - 1):
        start = offsets[i]
        end = offsets[i + 1]
        sample = s_idx[i]
        log_cpm[start:end] = np.log1p(
            data[start:end] / sample_library_sizes[sample] * 1e6
        )
    return log_cpm


def log_cpm(r_idx, s_idx, tracks: gvl.Ragged[np.float32]):
    data = inner_transform(s_idx, tracks.data, tracks.offsets)
    return gvl.Ragged.from_offsets(data, tracks.shape, tracks.offsets)


ds = ds.write_transformed_track(
    "lcpb", "read-depth", log_cpm, overwrite=True, max_mem=1 * 2**30
)
ds
[11]:
GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: 131072
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: read-depth
Tracks available: lcpb, read-depth

If the above cell crashes the kernel, it may have ran out of RAM which reducing ``max_mem`` can fix.

After writing the transformed track to disk, we can see the dataset now has the "lcpb" track available (note the list of available tracks is always sorted).

Evaluating Basenji2 on personalized expression (optional)#

Note: this section requires PyTorch and basenji2-pytorch to be installed.

Here, we’ll show a (very) quick and dirty demo of some of the results found by Huang et al. Nat Gen 2023 with Basenji2. We also recommend running this with a GPU since inference with Basenji2 may take quite a while otherwise.

[12]:
human_targets = pl.read_csv(
    "https://github.com/calico/basenji/blob/master/manuscripts/cross2020/targets_human.txt?raw=true",
    separator="\t",
)
target = human_targets.filter(
    pl.col("description").str.contains(r"(?i)cage.*gm12878")
).item(0, "index")
human_targets.filter(pl.col("description").str.contains(r"(?i)cage.*gm12878"))
[12]:
shape: (1, 8)
indexgenomeidentifierfileclipscalesum_statdescription
i64i64strstri64i64strstr
51100"CNhs12333""/home/drk/tillage/datasets/hum…3841"sum""CAGE:B lymphoblastoid cell lin…

If the above cell is taking more than a few seconds, try restarting its execution – sometimes GitHub fails to respond so the file doesn’t download. Likewise for below, recount3 can get stuck.

[13]:
count_df = pl.read_csv(
    "https://duffel.rail.bio/recount3/human/data_sources/sra/gene_sums/42/ERP001942/sra.gene_sums.ERP001942.G029.gz",
    separator="\t",
    comment_prefix="#",
)
accessions = bigwig_table.with_columns(
    accession=pl.col("path").str.extract(r"(ERR\d+)")
)["accession"]
counts = (
    bed.join(
        count_df.select("gene_id", *accessions),
        left_on="name",
        right_on=pl.col("gene_id").str.split(".").list.get(0),
        maintain_order="left",
    )
    .select(*accessions)
    .to_numpy()
)
counts.shape
[13]:
(20, 451)
[14]:
import torch
from basenji2_pytorch import Basenji2, basenji2_params, basenji2_weights

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_float32_matmul_precision("medium")

basenji2 = Basenji2(basenji2_params["model"]).to(device)
basenji2.load_state_dict(torch.load(basenji2_weights(), weights_only=True))
basenji2.eval();
[15]:
def transform(haps, *args):
    haps = rearrange(
        sp.DNA.ohe(haps), "... length alphabet -> ... alphabet length"
    ).astype(np.float32)
    return haps, *args


ds = (
    gvl.Dataset.open(ds_path, reference)
    .with_len(2**17)
    .with_indices(True)
    .with_tracks(None)
    .with_transform(transform)
)
2025-03-19 20:27:59.608 | INFO     | genvarloader._dataset._impl:open:227 - Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance.
2025-03-19 20:27:59.639 | INFO     | genvarloader._dataset._impl:open:269 - Opened dataset:
GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: ragged
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: lcpb, read-depth
Tracks available: lcpb, read-depth

If you’re using a GPU, you may need to use a smaller batch size depending on how much GPU RAM you have.

[16]:
batch_size = 48
# number of output bins for Basenji2, each corresponds to 128 bp of sequence
n_bins = 896

Compute predictions for all genes using reference sequences:

[17]:
ref_preds = np.empty((ds.n_regions, n_bins), dtype=np.float32)
with torch.no_grad():
    for ref, r_idx, _ in tqdm(
        ds.subset_to(samples=0)
        .with_seqs("reference")
        .to_dataloader(batch_size=batch_size, num_workers=0)
    ):
        ref_preds[r_idx] = basenji2(ref.to("cuda"))[..., target].numpy(force=True)

Next we’ll compute the Pearson correlation between predicted transformed CAGE-seq read-depth and mean expression. We’ll use a 5 bin (640 bp) window that is 9 bins (1152 bp) upstream of the TSS since this yielded the highest correlation with a little testing. This is somewhat expected since CAGE-seq reads should fall in the 5’ UTR region, and we haven’t thoroughly confirmed that the TSS coordinates we’re using are exactly the same as what Basenji2 trained on.

[18]:
ref_x_gene = np.corrcoef(
    ref_preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts.mean(-1), rowvar=False
)[0, 1]
ref_x_gene
[18]:
np.float64(0.47156382152055587)

We’d expect this to be the highest possible correlation Basenji2 can achieve on these genes. Let’s see how it does across individuals and across genes with haplotypes.

[19]:
preds = np.empty(ds.full_shape + (n_bins,), dtype=np.float64)
with torch.no_grad():
    for haps, r_idx, s_idx in tqdm(
        ds.to_dataloader(batch_size=batch_size, num_workers=0)
    ):
        preds[r_idx, s_idx] = basenji2(haps[:, 0].to(device))[..., target].numpy(
            force=True
        )
[20]:
ave_pearson_x_idv = np.diag(
    np.corrcoef(preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts), 20
).mean()
ave_pearson_x_gene = np.diag(
    np.corrcoef(preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts, rowvar=False),
    451,
).mean()
ave_pearson_x_idv, ave_pearson_x_gene
[20]:
(np.float64(-0.005341827841689839), np.float64(0.4334261598507151))

The average correlation across genes with haplotypes is only slightly less than with reference sequences, but just as Huang et al. and others have found, the correlation across individuals is 0 on average.