Evaluate Basenji2 for personal gene expression#
Here we will reproduce the results from Huang et al. 2023. All the steps here exactly match the steps in the original paper except for the 1 bp jittering and forward and reverse complementing of the haplotypes, which make no apparent difference to the results and will save us quite a bit of time.
Even with with this time-saving simplification, processing all 3,259 eGenes from Lappalainen et al. 2013 may take hours to days to run depending on your GPU. Thus, we’ll just use the first 100 genes for this tutorial which should run in a few minutes. If you do not have a GPU, we recommend skipping this tutorial as Basenji2 is prohibitively slow without one.
[1]:
from math import log
import genvarloader as gvl
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import polars.selectors as cs
import scipy.stats as st
import seaborn as sns
import seqpro as sp
import torch
from basenji2_pytorch import Basenji2, basenji2_params, basenji2_weights
from einops import rearrange
from genoray import PGEN
from genoray.exprs import is_biallelic, is_snp
from tqdm.auto import tqdm
/carter/users/dlaub/projects/GenVarLoader/.pixi/envs/docs-gpu/lib/python3.12/site-packages/sorted_nearest/__init__.py:1: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
Here, we’ll use GenVarLoader’s data registry to fetch all the relevant files for the Geuvadis dataset:
A list of eGenes and their TSS coordinates from Lappalainen et al. 2013
Gene expression from Geuvadis
A PGEN file of phased and unphased 1,000 Genomes genotypes for 465 individuals
A table of sample IDs that are phased
A table annotating the tracks (targets) that Basenji2 outputs
[2]:
ds_path = input("Dataset path (should end in .gvl): ")
pred_path = input("Prediction path (should end in .npy): ")
basenji2_in_len = 2**17
basenji2_out_len = 896
paths = gvl.data_registry.fetch("geuvadis_ebi")
samples = pl.read_csv(paths["samples"], separator="\t").to_series().to_list()
[3]:
genes = pl.read_csv(
paths["genes"],
has_header=False,
new_columns=["gene_id", "chrom", "chromStart", "hgnc", "strand"],
)
genes.head()
[3]:
| gene_id | chrom | chromStart | hgnc | strand |
|---|---|---|---|---|
| str | i64 | i64 | str | str |
| "ENSG00000000457" | 1 | 169863408 | "SCYL3" | "-" |
| "ENSG00000001630" | 7 | 91772266 | "CYP51A1" | "-" |
| "ENSG00000002549" | 4 | 17578815 | "LAP3" | "+" |
| "ENSG00000002745" | 7 | 120965421 | "WNT16" | "+" |
| "ENSG00000003056" | 12 | 9102551 | "M6PR" | "-" |
[43]:
expr = (
pl.read_csv(paths["expr"], separator="\t")
.rename({"TargetID": "gene_id"})
.drop("Gene_Symbol", "Chr", "Coord")
.with_columns(
pl.col("gene_id").str.split(".").list.get(0),
cs.numeric().log1p() / log(2), # log2(RPKM + 1)
)
.join(genes, "gene_id")
.sort("gene_id")
.select("gene_id", *samples)
)
expr.head()
[43]:
| gene_id | HG00097 | HG00099 | HG00100 | HG00101 | HG00102 | HG00103 | HG00104 | HG00106 | HG00108 | HG00109 | HG00110 | HG00111 | HG00112 | HG00114 | HG00116 | HG00117 | HG00118 | HG00119 | HG00120 | HG00121 | HG00122 | HG00123 | HG00124 | HG00125 | HG00126 | HG00127 | HG00128 | HG00129 | HG00130 | HG00131 | HG00133 | HG00134 | HG00135 | HG00136 | HG00137 | HG00138 | … | NA20770 | NA20771 | NA20772 | NA20773 | NA20774 | NA20778 | NA20783 | NA20785 | NA20786 | NA20787 | NA20790 | NA20792 | NA20795 | NA20796 | NA20797 | NA20798 | NA20799 | NA20800 | NA20801 | NA20802 | NA20803 | NA20804 | NA20805 | NA20806 | NA20807 | NA20808 | NA20809 | NA20810 | NA20811 | NA20812 | NA20813 | NA20814 | NA20815 | NA20816 | NA20819 | NA20826 | NA20828 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| str | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | … | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| "ENSG00000000457" | 2.883367 | 1.834309 | 2.64608 | 2.585935 | 2.34584 | 2.925417 | 2.65781 | 2.789464 | 2.444504 | 2.75724 | 2.925114 | 2.680441 | 2.913031 | 2.689902 | 2.998056 | 2.697751 | 2.766225 | 3.042996 | 2.592133 | 2.867706 | 2.805805 | 2.765552 | 3.013281 | 2.484207 | 2.903942 | 2.904265 | 2.641283 | 2.685263 | 2.665948 | 3.092614 | 2.709707 | 2.602347 | 2.754889 | 3.092115 | 2.83426 | 2.444628 | … | 2.794595 | 2.784521 | 2.976292 | 2.700462 | 2.405552 | 2.506383 | 3.117941 | 3.252319 | 2.96167 | 2.72472 | 2.819782 | 3.059721 | 2.820945 | 2.876582 | 2.569356 | 3.077712 | 2.897556 | 2.628595 | 2.436556 | 2.437954 | 2.667586 | 2.787217 | 2.868959 | 2.551098 | 2.860837 | 2.700339 | 2.69241 | 2.745681 | 2.660285 | 2.554025 | 2.998785 | 2.825764 | 2.686573 | 2.872107 | 2.979091 | 2.894534 | 2.933071 |
| "ENSG00000001630" | 4.969557 | 4.224181 | 4.976058 | 5.218726 | 5.182351 | 4.767849 | 5.413671 | 4.758416 | 5.097579 | 5.25007 | 5.123745 | 5.073436 | 5.249919 | 5.124941 | 5.129572 | 5.13768 | 5.046579 | 5.088643 | 5.319505 | 4.862542 | 5.124061 | 4.70492 | 4.88879 | 4.696655 | 4.808061 | 4.801197 | 4.939249 | 5.109147 | 5.146362 | 5.185362 | 5.034132 | 5.057867 | 5.156479 | 5.437459 | 5.094595 | 5.130883 | … | 5.312725 | 5.104296 | 5.040871 | 4.911511 | 5.040483 | 5.20505 | 4.880228 | 4.601109 | 4.928203 | 5.244682 | 4.964349 | 5.099042 | 5.40614 | 5.257256 | 4.523433 | 5.279652 | 4.843644 | 5.203806 | 5.06995 | 5.086101 | 5.17526 | 4.878176 | 5.143902 | 5.335128 | 5.132463 | 4.878686 | 5.056312 | 5.12613 | 5.195856 | 5.252959 | 5.390342 | 4.933299 | 4.928922 | 5.257504 | 5.356906 | 5.191717 | 5.075509 |
| "ENSG00000002549" | 6.579422 | 5.805355 | 6.826235 | 6.520504 | 6.818759 | 6.527228 | 6.858863 | 6.43286 | 6.822276 | 6.527388 | 6.735795 | 6.33409 | 6.855515 | 6.859368 | 5.567008 | 6.420919 | 6.488021 | 6.804768 | 6.539446 | 7.0783 | 6.773451 | 6.681058 | 7.04007 | 6.570902 | 6.607061 | 6.269434 | 6.653516 | 6.510893 | 6.644596 | 6.452672 | 6.897555 | 6.780814 | 6.569071 | 6.54277 | 6.882278 | 7.068692 | … | 6.493295 | 6.865757 | 6.602131 | 6.882519 | 6.302862 | 6.556039 | 6.584511 | 6.675576 | 6.38825 | 6.653772 | 6.680885 | 6.700284 | 6.742045 | 6.667749 | 6.524001 | 6.427709 | 6.583828 | 6.415492 | 6.575268 | 6.444256 | 6.485439 | 6.3251 | 6.396262 | 6.784344 | 6.644984 | 6.518831 | 6.666653 | 6.223493 | 6.621568 | 6.901742 | 6.665184 | 6.65501 | 6.765441 | 6.371931 | 6.612746 | 6.47355 | 6.548355 |
| "ENSG00000002745" | 0.192883 | 0.029851 | 0.195664 | 0.126503 | 0.297257 | 0.129629 | 0.021852 | 0.275263 | 0.071207 | 0.263726 | 0.131601 | 0.230152 | 0.156213 | 0.204568 | 0.040416 | 0.225189 | 0.111335 | 0.164453 | 0.20008 | 0.145553 | 0.200483 | 0.209593 | 0.359152 | -0.049574 | 0.191282 | 0.021413 | 0.227732 | 0.276992 | 0.143957 | 0.091331 | 0.589119 | -0.037316 | 0.233788 | 0.993584 | 0.187406 | 0.291823 | … | 0.131191 | 0.100634 | 0.241065 | 0.097617 | 0.034683 | 0.077639 | 0.123341 | 0.08147 | 0.170642 | 0.085312 | -0.027111 | 0.084868 | 0.537418 | 0.064881 | 0.057349 | 0.127728 | 0.128894 | 0.836652 | 0.034484 | 0.134283 | 0.151141 | 0.313646 | 0.01313 | 0.01391 | 0.463116 | 0.252047 | 0.06253 | 0.177781 | 0.392925 | 0.252537 | 0.07996 | 0.08126 | 0.14503 | 0.131564 | -0.038256 | 0.073507 | -0.062844 |
| "ENSG00000003056" | 6.664455 | 5.774003 | 6.621526 | 6.600959 | 6.868067 | 6.685243 | 6.656745 | 6.68273 | 6.585196 | 6.805682 | 6.768248 | 6.805896 | 7.043169 | 6.600454 | 6.652283 | 6.648638 | 6.907678 | 6.852528 | 6.939126 | 6.661728 | 6.610243 | 6.842894 | 6.705818 | 6.049611 | 7.000083 | 6.702603 | 6.49256 | 6.901974 | 6.533254 | 6.632719 | 6.696538 | 6.826387 | 6.803077 | 6.591067 | 6.601555 | 6.695223 | … | 6.829099 | 6.897351 | 6.992783 | 6.643068 | 6.48466 | 6.348297 | 6.941982 | 6.788952 | 6.922804 | 6.946785 | 6.239979 | 6.820033 | 6.74771 | 6.558058 | 6.517936 | 6.674047 | 6.696774 | 6.479763 | 6.708801 | 6.477054 | 6.400093 | 7.030301 | 6.608396 | 6.698758 | 6.190337 | 6.861345 | 6.815563 | 6.796029 | 7.081206 | 6.910327 | 6.565901 | 6.666187 | 6.484757 | 6.8308 | 6.826336 | 6.574483 | 6.77982 |
The genotypes from Geuvadis include some structural and multiallelic variants which we will filter out. To match Huang et al.’s analysis, we will also filter out indels despite the fact that GenVarLoader supports them.
[ ]:
pgen = PGEN(
paths["pgen"], filter=is_snp & is_biallelic & ~pl.col("ID").str.starts_with("sv")
)
bed = gvl.with_length(
genes.with_columns(chromEnd=pl.col("chromStart")), basenji2_in_len
).with_columns(pl.col("chrom").cast(pl.Utf8))
bed.head()
2025-06-05 10:28:24.258 | INFO | genoray._pgen:_read_index:1164 - Loading genoray index.
| gene_id | chrom | chromStart | hgnc | strand | chromEnd |
|---|---|---|---|---|---|
| str | str | i64 | str | str | i64 |
| "ENSG00000000457" | "1" | 169797872 | "SCYL3" | "-" | 169928944 |
| "ENSG00000001630" | "7" | 91706730 | "CYP51A1" | "-" | 91837802 |
| "ENSG00000002549" | "4" | 17513279 | "LAP3" | "+" | 17644351 |
| "ENSG00000002745" | "7" | 120899885 | "WNT16" | "+" | 121030957 |
| "ENSG00000003056" | "12" | 9037015 | "M6PR" | "-" | 9168087 |
[9]:
n_genes = 100
assert n_genes > 0
assert n_genes <= genes.height
gvl.write(ds_path, bed[:n_genes], variants=pgen, samples=samples, overwrite=True)
2025-06-05 10:29:40.346 | INFO | genvarloader._dataset._write:write:75 - Writing dataset to /cellar/users/dlaub/projects/GenVarLoader/data/geuvadis/ds.gvl
2025-06-05 10:29:40.352 | INFO | genvarloader._dataset._write:write:147 - Using 420 samples.
2025-06-05 10:29:40.352 | INFO | genvarloader._dataset._write:write:153 - Writing genotypes.
2025-06-05 10:29:47.391 | WARNING | genvarloader._dataset._write:_write_from_pgen:397 - A region has no variants for any sample. This could be expected depending on the region lengths and source of variants. However, this can also be caused by a mismatch between the reference genome used for the BED file coordinates and the one used for the variants.
2025-06-05 10:29:50.967 | INFO | genvarloader._dataset._write:write:177 - Finished writing.
[10]:
ref = "/carter/shared/genomes/homo_sapiens/grch37.primary.fa.bgz"
ds = gvl.Dataset.open(ds_path, ref).with_len(basenji2_in_len)
2025-06-05 10:29:51.034 | INFO | genvarloader._dataset._impl:open:191 - Loading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance.
2025-06-05 10:29:51.982 | INFO | genvarloader._dataset._reconstruct:from_path:183 - Loading variant data.
2025-06-05 10:29:52.666 | INFO | genvarloader._dataset._impl:open:276 - Opened dataset:
GVL store at /cellar/users/dlaub/projects/GenVarLoader/data/geuvadis/ds.gvl
Is subset: False
# of regions: 100
# of samples: 420
Output length: ragged
Jitter: 0 (max: 0)
Deterministic: True
Sequence type: reference [haplotypes] annotated variants
Active tracks: None
Tracks available: None
This will find the index in Basenji2’s output that corresponds to lymphoblastoid cells.
[11]:
target_info = pl.read_csv(paths["basenji2_targets"], separator="\t")
targets = target_info.filter(pl.col("description").str.contains(r"lymphoblastoid"))[
"index"
].to_list()
targets
[11]:
[5110]
This creates and memory maps an array on-disk to store the predictions.
[12]:
ploidy = 2
# (r s p t l)
pred_expr = np.memmap(
pred_path,
dtype=np.float32,
mode="w+",
shape=(ds.n_regions, ds.n_samples, ploidy, len(targets), basenji2_out_len),
)
[13]:
basenji2 = Basenji2(basenji2_params["model"])
basenji2.load_state_dict(torch.load(basenji2_weights()))
device = "cuda" if torch.cuda.is_available() else "cpu"
basenji2 = basenji2.to(device).eval()
basenji2 = torch.compile(basenji2)
Inference for 100 genes took ~9 minutes on a single A30 GPU, including compilation time.
[14]:
def transform(haps):
haps = sp.DNA.ohe(haps).swapaxes(-2, -1)
return haps
batch_size = 32
with torch.no_grad(), torch.autocast(device, torch.bfloat16):
for batch_idx, batch in enumerate(
tqdm(ds.to_dataloader(batch_size, transform=transform))
):
bsize = len(batch)
s = batch_idx * batch_size
e = s + bsize
ds_idx = np.arange(s, e)
r_idx, s_idx = np.unravel_index(ds_idx, ds.shape)
batch = rearrange(batch, "b p a l -> (b p) a l").to(device, torch.float32)
pred = basenji2(batch)[..., targets]
pred = rearrange(pred, "(b p) l t -> b p t l", b=bsize)
pred_expr[r_idx, s_idx] = pred.numpy(force=True)
W0605 10:30:10.567000 459199 site-packages/torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode
[15]:
ploidy = 2
# (r s p t l)
pred_expr = np.memmap(
pred_path,
dtype=np.float32,
mode="r",
shape=(ds.n_regions, ds.n_samples, ploidy, len(targets), basenji2_out_len),
)
[16]:
# (g s p t l) -> (g s)
huang_expr = pred_expr[..., basenji2_out_len // 2 - 5 : basenji2_out_len // 2 + 5].mean(
axis=(2, 3, 4)
)
[19]:
gene_rho = np.diag(
st.spearmanr(expr[:n_genes, 1:].to_numpy(), huang_expr, axis=0).statistic,
ds.n_samples,
)
indiv_rho = np.diag(
st.spearmanr(expr[:n_genes, 1:].to_numpy(), huang_expr, axis=1).statistic,
ds.n_regions,
)
Since we’re only using the first 100 genes, the results do not match Huang et al.’s exactly but are still consistent with their results. Using personalized sequences as input, Basenji2 is able to predict differences in expression between genes but not between individuals.
[42]:
fig, ax = plt.subplots()
sns.ecdfplot(gene_rho.ravel(), label=r"$\rho$ across genes", ax=ax)
sns.ecdfplot(indiv_rho.ravel(), label=r"$\rho$ across individuals", ax=ax)
ax.axvline(np.nanmean(gene_rho), color="C0", linestyle="--")
ax.axvline(np.nanmean(indiv_rho), color="C1", linestyle="--")
ax.legend(loc="best")
ax.set(xlabel=r"Spearman $\rho$");