mRNABERT

Weights and tokenizer for mRNABERT (Xiong et al., Nature Communications 2025), loaded with the bug-fixed model code from Taykhoom/MosaicBERT-updated.

mRNABERT is a language model pre-trained on 18 million mRNA sequences incorporating contrastive learning to integrate semantic features of amino acids.

This repo contains only weights and tokenizer files. The model code is loaded automatically from Taykhoom/MosaicBERT-updated via trust_remote_code=True. See that repo for the full list of bugs fixed relative to the original MosaicBERT implementation.

Architecture

mRNABERT uses the MosaicBERT architecture with an mRNA-specific vocabulary.

Parameter Value
Layers 12
Attention heads 12
Embedding dimension 768
Vocabulary size 74 (5 special + 5 single-nt + 64 codons)
Positional encoding ALiBi (no position embeddings)
Attention Flash Attention (packed QKV)
FFN Gated Linear Units (GeGLU)
Padding Unpadding (tokens concatenated, no padding overhead)
Max sequence length 1024 tokens
Parameters ~114M

Vocabulary

The tokenizer uses BertTokenizer with a hybrid vocabulary. Sequences are encoded in the DNA alphabet (T, not U) even though the model is trained on mRNA.

Range Tokens Use
0-4 [PAD] [UNK] [CLS] [SEP] [MASK] Special tokens
5-9 A T C G N Single nucleotides (UTR regions)
10-73 AAA ... GGG All 64 codons (CDS regions)

Pretraining

  • Objective: Masked Language Modeling + contrastive learning (amino-acid semantic features)
  • Data: 18 million curated mRNA sequences
  • Source checkpoint: pytorch_model.bin from YYLY66/mRNABERT

Parity Verification

Hidden states verified max abs diff < 2.4e-05 at all 13 representation levels (embedding + 12 transformer layers) relative to the original implementation. Both models use flash_attn_varlen_qkvpacked_func; the small numerical differences are flash attention rounding, not a correctness issue. SDPA vs eager max diff = 1.81e-05. Verified on GPU with PyTorch 2.7 / CUDA 12.9.

Usage

mRNABERT requires CDS-aware preprocessing: UTR regions must be single-nucleotide space-separated and CDS regions must be codon space-separated. The tokenizer handles this automatically via batch_encode_with_cds() when a CDS track is available, or you can pass pre-formatted strings directly for simple use cases.

Sequences use T (not U).

Embedding generation with CDS tracks (recommended)

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model.eval()

# Raw sequences (T not U) + per-nucleotide CDS track
# cds[i] != 0 marks the start of a codon at position i
sequences = ["ATCGATGTTTCCC", "AATGCCC"]
cds_tracks = [
    np.array([0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0]),  # CDS starts at pos 3
    np.array([0, 1, 0, 0, 1, 0, 0]),                      # CDS starts at pos 1
]

enc, chunk_counts = tokenizer.batch_encode_with_cds(
    sequences, cds_tracks, return_tensors="pt", padding=True
)

with torch.no_grad():
    out = model(**enc)

mask = enc["attention_mask"].unsqueeze(-1).float()
mean_emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1)  # (batch, 768)

Embedding generation without CDS tracks

Pass pre-formatted space-separated strings directly when no CDS annotation is available:

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model.eval()

# Space-separated: single nt for UTRs, codons for CDS; use T not U
sequences = [
    "A T C G G A GGG CCC TTT AAA",   # mixed UTR + CDS
    "ATG TTT CCC GAC TAA",            # CDS only
]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

mask = enc["attention_mask"].unsqueeze(-1).float()
mean_emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1)  # (batch, 768)

MLM logits

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
model.eval()

enc = tokenizer(["A T C G [MASK] CCC TTT"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits  # (1, seq_len, 74)

Attention implementation

# SDPA (default on PyTorch >= 2.0)
model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True,
                                   attn_implementation="sdpa")

# Flash Attention 2 (requires: pip install flash-attn --no-build-isolation)
model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True,
                                   attn_implementation="flash_attention_2")

Fine-tuning

import torch.nn as nn
from transformers import AutoModel

class mRNABERTClassifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True)
        self.head = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1)
        return self.head(pooled)

Citation

@article{xiong2025_mrnabert,
  title   = {{mRNABERT}: advancing {mRNA} sequence design with a universal language model and comprehensive dataset},
  author  = {Xiong, Ying and Wang, Aowen, and Kang, Yu and Shen, Chao and Hsieh, Chang-Yu and Hou, Tingjun},
  journal = {Nature Communications},
  volume  = {16},
  number  = {1},
  pages   = {10371},
  year    = {2025},
  doi     = {10.1038/s41467-025-65340-8}
}

Credits

Original mRNABERT model and weights by Xiong et al. Source: GitHub. Bug-fixed model code by Taykhoom/MosaicBERT-updated, authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

Apache 2.0, following the original repository.

Downloads last month
98
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support