ERNIE-RNA

ERNIE-RNA is an RNA-specific large language model that incorporates RNA base-pairing potential as a recurrent 2D structural bias into each attention layer, enabling the model to capture secondary structure information during pretraining.

Architecture

Parameter Value
Layers 12
Attention heads 12
Embedding dimension 768
FFN dimension 3072
Vocabulary size 25
Positional encoding Sinusoidal (fairseq-style)
Architecture Pre-LN Transformer with recurrent 2D RNA pairing bias
Max sequence length 1024

Vocabulary

Token ID Notes
<cls> 0 Prepended to every sequence
<pad> 1 Padding token
<eos> 2 Appended to every sequence
<unk> 3 Unknown token
G 4
A 5
U 6 T is silently mapped to U during tokenization
C 7
N 8 Ambiguous nucleotide
Y-I 9-20 IUPAC ambiguity codes
madeupword0-2 21-23 Padding tokens from original vocab
<mask> 24 MLM mask token

2D RNA Pairing Bias

ERNIE-RNA computes a pairwise RNA base-pairing potential matrix from the input sequence at the start of each forward pass. This matrix (shape [B, T, T, 1]) is projected to [B, H, T, T] via a 2-layer MLP (1 -> 6 -> H, with GELU) and added to the attention logits in the first layer. The pre-softmax attention scores then become the updated 2D bias for the next layer, creating a recurrent structural information pathway across all 12 transformer layers.

Base-pairing scores: A-U = 2.0, G-C = 3.0, G-U wobble = 0.8.

Pretraining

  • Objective: Masked language modeling (MLM) on RNA sequences
  • Data: RNAcentral (non-redundant RNA sequences)
  • Source checkpoint: ERNIE-RNA_pretrain.pt

Checkpoint selection

Single pretrained checkpoint from the original repository. Used as-is; no fine-tuned variants are included in this release.

Parity Verification

Hidden-state representations verified identical (max abs diff = 6.48e-05) to the original implementation at all 13 representation levels (embedding + 12 transformer layers). Verified on GPU with PyTorch 2.7 / CUDA 12.

SDPA implementation additionally verified to match eager output exactly (max diff = 0.00e+00) when no padding is present.

Related Models

See the full ERNIE-RNA collection.

Model Notes
Taykhoom/ERNIE-RNA Pretrained model (this model)

Usage

Embedding generation

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model.eval()

sequences = ["AUGCAUGCAUGC", "GGGGCCCCGGGG"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

cls_emb   = out.last_hidden_state[:, 0, :]   # (batch, 768) -- CLS token
token_emb = out.last_hidden_state             # (batch, seq_len, 768)

# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6]         # (batch, seq_len, 768)

MLM logits

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model.eval()

enc = tokenizer(["AUG<mask>AUG"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits   # (1, seq_len, 25)

Fine-tuning

Use the CLS token embedding (last_hidden_state[:, 0, :]) as input to a prediction head for sequence-level tasks. For token-level tasks, use last_hidden_state directly.

Faster inference with SDPA

model = AutoModel.from_pretrained(
    "Taykhoom/ERNIE-RNA",
    trust_remote_code=True,
    attn_implementation="sdpa",
)

Note: the ERNIE-RNA 2D bias mechanism requires access to pre-softmax attention scores. The SDPA implementation computes these manually (QK logits) and passes them to F.scaled_dot_product_attention for the output, preserving exact numerical equivalence with the eager implementation.

Implementation Notes

The original ERNIE-RNA codebase uses fairseq and standard scaled dot-product attention (eager). This HF port adds:

  • attn_implementation="sdpa" support via a hybrid approach: pre-softmax logits are computed explicitly to maintain the recurrent 2D bias update, while F.scaled_dot_product_attention is used for the output tensor.
  • attn_implementation="flash_attention_2" support: falls back to eager, because flash attention does not expose pre-softmax scores needed for the recurrent 2D bias update.

The twod_proj MLP is always run in float32 (matching the original) regardless of the model's compute dtype.

Citation

@article{yin2025_ernierna,
  title   = {{ERNIE-RNA}: an {RNA} language model with structure-enhanced representations},
  author  = {Yin, Weijie and Zhang, Zhaoyu and He, Liang and Jiang, Rui and Zhang, Shuo and Liu, Gan and Zeng, Xuezhi and Zhao, Wen and Gao, Xiaowo},
  journal = {Nature Communications},
  volume  = {16},
  number  = {1},
  pages   = {8407},
  year    = {2025},
  doi     = {10.1038/s41467-025-64972-0}
}

Credits

Original model and code by Yin et al. Source: GitHub. The HF conversion code was authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

Apache 2.0, following the original repository.

Downloads last month
30
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Taykhoom/ERNIE-RNA