scGPT + ESM2 prior β€” Replogle K562/Jurkat/HepG2 β†’ RPE1

Produced as part of the sc-interp single-cell model comparison repo.

Provenance

Base model

scGPT whole-human pretrained (Cui et al. 2024), augmented with a frozen ESM2-15B per-gene prior table from arcinstitute/SE-600M. The prior is projected through a learnable linear (5120 β†’ 512) and added to the existing learnable gene-token embedding before the transformer body. Architecturally identical to scGPT otherwise: 12 transformer blocks, 8 heads, d_model=512, max_seq_len=1536.

Training

Source dataset: arcinstitute/State-Replogle-Filtered β€” CRISPRi essential-genome screens from Replogle et al. 2022 and Nadig et al. 2025. Training: 362,327 cells from K562 + Jurkat + HepG2 with 1,383 perturbations and 8,569 val pairs (held-out K562 perturbations). Evaluation: 109,207 RPE1 cells perturbed by the 1,047 genes overlapping the K562 training perturbation set, plus 10,691 real RPE1 controls.

Fine-tuned the scGPT whole-human pretrained checkpoint with the ESM2 gene prior injected at the gene embedding layer. Used --stop-metric pearson_delta (per-perturbation Pearson on Ξ”-expression) for early-stopping and best-checkpoint selection β€” this metric directly measures perturbation-effect prediction quality, whereas full-expression pearson is dominated by the unchanged-genes baseline.

Budget and stopping

Hardware NVIDIA H100 PCIe (80 GB)
Train batch size 192
Eval batch size 192
Max epochs 30
Early-stop patience 10
Stop metric pearson_delta
Epochs trained 19 (early-stopped)
Best epoch 9
Best val pearson_delta 0.1944
Training cells seen 3,420,399
Wall clock 254.6 min (~4.25 h)
AMP fp16
Optimizer Adam, lr=1e-4, StepLR Ξ³=0.9

Test set metrics (cell-eval)

metric mean median max
pearson_delta 0.5080 0.5057 0.6985
pr_auc 0.5304 0.5239 0.9234
roc_auc 0.3878 0.3868 0.4872
overlap_at_N 0.5533 0.5596 0.9260
de_sig_genes_recall 0.6401 0.6540 0.9552
de_direction_match 0.6174 0.6315 0.7602
discrimination_score_l1 0.5037 0.5014 1.0000
mae_delta 0.1650 0.1624 0.2221

These cell-eval full-profile numbers are not directly comparable to the scGPT or State paper headline numbers β€” those report different splits (within-cell-line on Norman, or 4-line leave-one-out via Replogle-Nadig). Companion runs in progress: (a) base scGPT without the ESM prior on the same K562/Jurkat/HepG2 β†’ RPE1 split; (b) crosscoder-style model-diffing on the two finetuned models. The headline scientific question β€” does the ESM prior improve cross-cell-line transfer in scGPT β€” needs (a) before this artifact alone can answer it.

Known limitations

  • Single seed; no variance estimates.
  • Test set restricted to RPE1 perturbations that overlap K562 training perturbations (1,047 / 1,499 RPE1 perts); generalization to truly unseen perturbed genes is not assessed by this eval.
  • DE-aware metrics (pr_auc, roc_auc, de_sig_genes_recall, de_direction_match) are computed against K562-derived DE rankings; ideal for an RPE1 test set would be RPE1-derived DE rankings on the truth side.
  • Base scGPT (no ESM) companion run is in progress on separate hardware; the does-ESM-help claim is not yet established by this artifact alone.

Files

  • best_model.pt β€” fine-tuned weights (~210 MB)
  • args.json β€” scGPT architecture config (inherited from scGPT_human)
  • vocab.json β€” scGPT gene β†’ token id mapping (inherited)
  • scgpt_esm_prior.safetensors β€” frozen ESM2-15B per-gene prior aligned to scGPT's 60,697-token vocab β€” required at load time, the model expects ESM-augmented embeddings
  • training_stats.json β€” epoch count, best metric, wall clock, wandb URL
  • predictions/scgpt_replogle_test.h5ad β€” .X = predicted expression, .layers['truth'] = ground truth; includes 10,691 real RPE1 controls; 119,898 cells Γ— 6,546 genes
  • eval/agg_results.csv β€” cell-eval full-profile aggregated stats across 1,047 RPE1 test perturbations
  • eval/results.csv β€” cell-eval full-profile per-perturbation metrics

Usage

from pathlib import Path
import torch, json
from huggingface_hub import snapshot_download
from scgpt.model import TransformerGenerator
from scgpt.model.gene_priors import GenePriorEncoder
from scgpt.tokenizer.gene_tokenizer import GeneVocab

ckpt = Path(snapshot_download('matthewshu/scgpt-replogle-esm-ft'))
vocab = GeneVocab.from_file(str(ckpt / 'vocab.json'))
for tok in ('<pad>', '<cls>', '<eoc>'):
    if tok not in vocab: vocab.append_token(tok)
margs = json.load(open(ckpt / 'args.json'))
gene_prior = GenePriorEncoder.from_safetensors(
    ckpt / 'scgpt_esm_prior.safetensors', d_model=margs['embsize']
)
model = TransformerGenerator(
    ntoken=len(vocab), d_model=margs['embsize'], nhead=margs['nheads'],
    d_hid=margs['d_hid'], nlayers=margs['nlayers'], nlayers_cls=3,
    n_cls=1, vocab=vocab, dropout=margs.get('dropout', 0.0),
    pad_token=margs.get('pad_token', '<pad>'),
    pad_value=margs.get('pad_value', 0),
    pert_pad_id=margs.get('pert_pad_id', 2),
    use_fast_transformer=False, gene_prior=gene_prior,
)
state = torch.load(ckpt / 'best_model.pt', map_location='cpu')
model.load_state_dict(state)
model.eval()

Citation

If you use this model, please cite:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train matthewshu/scgpt-replogle-esm-ft