Eagle3 Long-Context Draft Head for Nemotron-Cascade-2-30B-A3B (Sliding-Window 4k)
This is an Eagle3 speculative-decoding
draft head trained against
nvidia/Nemotron-Cascade-2-30B-A3B
as the verifier. To our knowledge it is the first Eagle3 head trained
against a hybrid Mamba-Transformer MoE verifier, and the first to
explore sliding-window attention on the draft as a long-context
optimization.
TL;DR
- Verifier:
nvidia/Nemotron-Cascade-2-30B-A3B(30B-param hybrid Mamba-Transformer MoE, 52 layers = 23 Mamba + 23 MLP/MoE + 6 GQA attention, per the verifier'shybrid_override_patternconfig). - Draft architecture: 1-layer Llama transformer block with
hidden_size=2688(matches verifier residual stream),intermediate_size=8064(~3x hidden),head_dim=128,num_attention_heads=32,num_key_value_heads=2(GQA),sliding_window=4096(band causal),max_position_embeddings=262144(matches verifier so vLLM does not clamp serving max_model_len),vocab_size=131072,draft_vocab_size=32000. - Aux hidden state layers captured from verifier: layers 2 / 26 / 48 (~4% / 51% / 94% depth). See "Layer indexing" section below for the exact mapping and the engine-portability caveat. Layout follows NVIDIA's gpt-oss-120b long-context Eagle3 reference.
Layer indexing -- which verifier layers feed the Eagle3 draft
Eagle3 draft heads consume hidden states from THREE specific layers of the verifier (early / middle / late) and learn to map them into the draft's prediction. The choice of which 3 layers matters: too early and the draft has no semantics; too late and it's just predicting from the verifier's own next-token distribution.
This draft is trained against verifier layers 2 / 26 / 48 out of the verifier's 52 hybrid layers. The indexing convention is critical and engine-dependent.
Verifier layer pattern (canonical, from HF config.json)
NemotronH publishes its hybrid layer arrangement as
hybrid_override_pattern in the verifier's config.json. For
nvidia/Nemotron-Cascade-2-30B-A3B:
hybrid_override_pattern = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME"
where each character corresponds to one transformer block in the 52-layer stack (0-indexed):
M= Mamba (state-space) layerE= MLP / MoE feed-forward layer*= GQA attention layer
Counts: 23 M + 23 E + 6 * = 52.
What layers 2 / 26 / 48 actually are
Index counting all 52 transformer blocks starting from 0
(matches HF transformers model.backbone.layers[i] ordering):
| Index | Char | Type | Depth | Notes |
|---|---|---|---|---|
| 2 | M |
Mamba | ~4% | Early Mamba (3rd block overall, 2nd Mamba block) |
| 26 | * |
Attention | ~51% | Middle attention -- the 4th of 6 attention layers, central position |
| 48 | M |
Mamba | ~94% | Late Mamba (3rd-from-last block) |
The 6 attention layers are at canonical indices [5, 12, 19, 26, 33, 42].
Layer 26 is the central one.
Engine-portability warning
Different inference engines may number layers differently because of how they fuse, group, or skip blocks. Before serving this draft on a new engine, verify that the engine's "layer i" corresponds to the same character in the hybrid pattern as HF transformers does.
vLLM (0.19+): uses the full 52-layer ModuleList in its NemotronH model implementation, 0-indexed contiguous, matches HF transformers. The
target_hidden_state_indices: [2, 26, 48]in this draft'sconfig.jsonis read directly by vLLM's Eagle3 speculative decoder and passed to the verifier as raw indices intomodel.backbone.layers. Tested working on vLLM 0.19 with the in-flight L=32k iteration.SGLang: SGLang's NemotronH implementation also uses 0-indexed contiguous layer counting that matches HF, but as of late 2025 it does NOT yet support sliding-window attention on Eagle3 drafts (this draft has
sliding_window: 4096in its config). Once SGLang ships sw support, the layer indices should map cleanly. Until then, prefer vLLM for serving this draft.HF transformers (raw, no engine): the verifier exposes its layers as
model.backbone.layers[0..51], wherelayers[i]is the i-th character ofhybrid_override_pattern. This is the canonical reference. The training pipeline in this fork (SpecForge with the NemotronH backend patch) reads from these indices via theaux_hidden_state_layersconfig field on the draft.TensorRT-LLM / others: NOT tested. If you port to one of these, cross-check by running a single forward pass on a fixed input and comparing the hidden state dump at indices 2/26/48 against vLLM's dump for the same input. The hidden states should be bit-equivalent (or near-equivalent up to dtype) at the same indices if the numbering convention matches.
Why these specific indices
These three positions follow NVIDIA's gpt-oss-120b long-context Eagle3 recipe (early/mid/late triad). The choice of 2 / 26 / 48 as opposed to e.g. 0 / 26 / 51 was made to:
- Avoid layer 0 (the embedding-adjacent layer), where hidden states are still mostly raw token embeddings without much contextualization.
- Skip the very last layer (51), where hidden states are essentially the verifier's own pre-lm_head representation -- the draft would learn a near-identity from there to the verifier's output and not generalize.
- Pick a Mamba+Attention+Mamba triad rather than three of the same type, so the draft sees a mix of state-space (long-range, Mamba) and attention (short-range, GQA) representations of the same input.
- Trainable parameters: ~196 M (excludes the frozen embedding layer loaded from the verifier).
- Checkpoint:
epoch_2_step_12000(epoch 2, step 12000).
Training schedule: 2-stage (SFT bootstrap + on-policy specialization)
This draft head is trained in two stages to first build a general-language foundation and then specialize to the on-policy reasoning trace distribution that the verifier will actually emit at inference time.
| Stage | Source | Rows | LR | Epochs | Bucketing | Notes |
|---|---|---|---|---|---|---|
| Stage 1 (SFT bootstrap) | cascade2_sft_train.jsonl |
20,000 | 1e-4 | 1 | on (length-bucketed sampler, stage 1 only) | Broad SFT distribution; long-tail length variance, so length-bucketed sampling kills the straggler-rank bottleneck on DP=2 |
| Stage 2 (on-policy fine-tune) | c2_traces_train.jsonl |
9,658 | 5e-5 | 2-3 | off | Loaded from stage 1 checkpoint via --ckpt-dir (weight-only, fresh optimizer). Bucketing intentionally disabled to maximize gradient diversity per step on the narrower on-policy distribution we're specializing to. |
The CoT trace set (c2_traces_cot_train.jsonl, 0.9k rows) is excluded
from this iteration because it has significantly higher mean tokens-per-row
(50k) than the plain traces (21k), which would slow stage 2 wall-clock for
marginal benefit on the first baseline. It will be folded back in for a
follow-up stage 2 fine-tune if the deployed draft underperforms on
CoT-heavy reasoning.
Key training details
| Training framework | SpecForge fork at https://github.com/hav4ik/SpecForge-Nemotron3, branch nemotron-cascade-2-experiments |
| Git commit | f517199ed2db |
| Training date | 2026-04-09 UTC |
| Max sequence length | 32768 tokens (this iteration is the L=32768 fast-iteration baseline; an L=65536 follow-up is planned once this baseline ships -- see "Iteration plan" section below) |
| Truncation loss at this max_length | Stage 1: ~26% of tokens lost to L=32768 truncation (avg tok/row 18904 → 13898). Stage 2: ~15% of tokens lost (avg 21468 → 18179). See "Training data accounting" section below for the per-stage table. |
| TTT length | 6 (Eagle3 test-time-training unroll depth) |
| Sliding window | 4096 tokens on the draft attention |
| Optimizer | BF16Optimizer, lr=5e-05, warmup_ratio=0.02 |
| Epochs (this stage) | 3 |
| Per-rank batch size | 1 (no sequence packing) |
| Tensor parallel | TP=1, DP=2 |
| Global effective batch per step | 2 conversations |
Training data accounting (the full token-count picture)
Three different but equally valid token counts exist for this dataset, and they don't agree because they measure different things. Documented here to prevent future confusion:
| Metric | Stage 1 (cascade2_sft_train) | Stage 2 (c2_traces_train) | Source |
|---|---|---|---|
| Rows / conversations | 20,000 | 9,658 | dataset card + jsonl line count |
| Untruncated total tokens (1) | 378,073,468 | 207,336,404 | dataset card (apply_chat_template(messages, tokenize=True)) |
| Truncated total tokens at L=32768 (2) | 277,963,707 | 175,570,455 | this fork's tokenizer pipeline w/ max_length=32768 |
| Truncation loss vs (1) | 100 M / 26.5% | 32 M / 15.3% | (1) - (2) |
| Loss-masked tokens (assistant turns only, what loss is computed on) (3) | 132,941,808 | 158,793,164 | this fork, loss_mask == 1 positions in the tokenized sequence |
| Loss-mask fraction of (2) | 47.8% | 90.4% | (3) / (2) |
| Avg untruncated tok/row | 18,904 | 21,468 | dataset card |
| Avg truncated tok/row | 13,898 | 18,179 | (2) / rows |
| Avg loss-masked tok/row | 6,647 | 16,442 | (3) / rows |
| jsonl size on disk | 1521 MiB | 660 MiB | filesystem |
The three counts measure:
- Untruncated total -- every token in the conversation including system prompt + user turns + assistant turns + special tokens, with no max-length cap. This is what the dataset card publishes.
- Truncated total -- same thing but capped at the training
max_length. Any sample longer than the cap loses its tail. Stage 1 loses ~26% of total tokens at L=32768 because most SFT conversations are >18k tokens and many exceed 32k. Stage 2 loses ~15%. This is the main reason an L=65536 follow-up is planned. - Loss-masked -- the subset of (2) where
loss_mask == 1, i.e. only the assistant-turn positions. These are the only positions Eagle3's distillation loss is computed on. Stage 2's 90% loss-mask fraction reflects that reasoning traces are mostly the model's output (short user prompt, very long assistant trace). Stage 1's 48% reflects more balanced SFT dialogue.
Eagle3 vocabulary pruning (the union d2t / t2d mapping)
Eagle3's draft head predicts logits over a smaller draft vocabulary
than the verifier's full vocab to keep the lm_head shape manageable.
Concretely: target_vocab_size=131072 → draft_vocab_size=32000. The
mapping (d2t and t2d buffers on the draft model) is the top-32000
most frequent target tokens in the training set, computed from
loss-masked positions.
Multi-stage gotcha (fixed in this fork): the standard SpecForge
pipeline auto-generates the d2t / t2d from whatever training set the
current run is processing, and unconditionally calls
draft_model.load_vocab_mapping(...) AFTER loading the checkpoint passed
via --ckpt-dir. For 2-stage training this would mean stage 2 OVERWRITES
the d2t / t2d buffers loaded from the stage 1 checkpoint with a NEW
mapping derived from stage 2's train set. The lm_head weights loaded from
stage 1 would still be aligned to stage 1's mapping → silent index
permutation mismatch in every gradient step of stage 2.
Fix: a union vocab mapping
built from the combined token frequencies of stage 1 + stage 2,
saved once before training, and passed to BOTH stage launchers via
--vocab-mapping-path. This way the lm_head is index-aligned end to end.
Per-stage coverage with the union top-32000:
| Stage | Total loss-masked tokens | Unique token ids | Coverage by union top-32k | Tokens lost (out of vocab) | Unique-in-vocab |
|---|---|---|---|---|---|
| Stage 1 | 132,941,808 | 81,439 | 99.5328% | 621,131 (0.47%) | 31,978 / 81,439 (39.27%) |
| Stage 2 | 158,793,164 | 36,510 | 99.9621% | 60,180 (0.04%) | 26,017 / 36,510 (71.26%) |
| Union | 291,734,972 | 83,209 | 99.77% | 671,311 (0.23%) | 32,000 / 83,209 (38.46%) |
Stage 2 has better coverage than stage 1 even though the union top-K was built jointly -- because stage 2's reasoning-trace distribution is much narrower (~half the unique tokens of stage 1 despite ~20% more total tokens). Stage 2 contributes 1,770 unique tokens not present in stage 1's loss-masked positions (mostly latex / reasoning markers); these would have been silently dropped from the draft vocab if the mapping had been built from stage 1 only.
Both stages comfortably exceed 99% frequency coverage with draft_vocab_size=32000,
so the chosen draft vocab is big enough -- no need to grow it.
Iteration plan: this is the L=32768 fast-iteration baseline
This checkpoint is the first Eagle3 head trained against Nemotron-Cascade-2 in this fork's iteration sequence:
| Iteration | max_length | Status | Goal |
|---|---|---|---|
| L=32768 (this) | 32768 | in flight / shipped | Fast-iteration baseline. Get a working draft with all the engineering knobs validated end-to-end. Pays a ~26% truncation loss on stage 1 and ~15% on stage 2. |
| L=65536 (next) | 65536 | planned | Long-context follow-up. Re-train both stages at the higher cap to recover the truncated tail. Will need --draft-mlp-grad-checkpoint re-enabled to fit on a 96 GB GPU per rank. |
The per-length eval data (16k / 32k / 64k validation sets, see below) lets us measure how well the L=32k draft generalizes to longer contexts without training at long context. If the L=32k → L=64k generalization gap on the eval set is small, the L=65k follow-up may not be worth re-training. If the gap is large, the follow-up is justified.
Note on batch size + length-bucketed sampling
This draft is trained with per-rank batch_size=1 and no sequence
packing. With dp_size=2 data-parallel ranks the global
effective batch is 2 conversations per optimizer step.
This is the standard configuration for long-context Eagle3 training
(NVIDIA's gpt-oss-120b long-context Eagle3 uses the same per-rank
batch=1 layout) -- packing into fixed-token-budget batches gives
diminishing returns at long context because each conversation already
saturates the GPU activation budget.
Per-step token counts vary by ~20x because conversation lengths in the training distribution range from ~30 to ~30000+ tokens. Without length-aware sampling, every optimizer step is bottlenecked by the slowest rank: a (2k, 32k) DP-pair wastes half the GPU while rank 1 sits idle waiting for rank 0.
Stage 1 fixes this with length-bucketed sampling
(--with-data-bucketing, see
specforge/data/utils.py:LengthBucketDistributedSampler):
the sampler sorts the dataset by per-sample length, partitions the
sorted index into contiguous "global batches" of size num_replicas * batch_size, shuffles only the order of those global batches per
epoch with a deterministic seed, and assigns each rank its slice of
every global batch. Effect: all ranks within a single optimizer step
see samples of similar length, eliminating the straggler-rank waste.
Bias mitigation: pure length-sorted iteration would bias gradient
updates by difficulty across an epoch (long samples first or last). The
per-epoch global-batch-order shuffle mitigates this; the intra-batch
order is intentionally NOT shuffled so each step still sees bucketed
lengths. Same tradeoff as HF Trainer's group_by_length=True.
Stage 2 leaves bucketing OFF intentionally. The on-policy reasoning trace distribution is narrow enough that gradient diversity per step matters more than throughput, and the per-step difficulty drift would work against the "specialize precisely" goal of stage 2.
| Wandb run | https://wandb.ai/hav4ik/nemotron-cascade-2-eagle3/runs/yz5vd3qw |
⚠️ CRITICAL: Mamba SSM state precision
This draft was trained with the Mamba SSM scan boundary state in
float32, matching vLLM 0.19's default behavior for NemotronH
(--mamba_ssm_cache_dtype float32).
The Nemotron team has explicitly confirmed (HF discussion) that downcasting the Mamba SSM state to bf16 causes a ~10% absolute regression on AIME-class math benchmarks (88.3% vs 99.17% on AIME 2025 was the SGLang-vs-vLLM gap they measured from this single issue alone).
Both training and inference must use fp32 SSM state to avoid this precision regression. Concretely:
- Training: SpecForge fork includes a monkey-patch
(
specforge/_mamba_fp32_patch.py) that forces upstreammamba_ssm's_state_passing_fwdto useout_dtype=torch.float32. The patch is applied at the very top ofscripts/train_eagle3.pybefore any other module-level imports. It is bit-equivalent to vLLM 0.19's NemotronH default behavior. - Inference (vLLM 0.19+): works automatically for NemotronH-Cascade-2
because the verifier's
config.jsonships withmamba_ssm_cache_dtype: "float32", which triggers vLLM'sNemotronHForCausalLMConfig.verify_and_update_configto set the SSM cache to fp32. No explicit flag needed when serving NemotronH-Cascade-2 with vLLM 0.19+. - Inference (SGLang): pass
--mamba-ssm-dtype float32(the SGLang flag is not on by default, per the Nemotron team's recommendation). - Inference (other frameworks / older vLLM): explicitly pass
--mamba_ssm_cache_dtype float32(vLLM) or equivalent.
For the full audit story (call chain, the exact line numbers in
upstream mamba_ssm 2.3.1 where the bf16 downcast happens, vLLM's
4-layer override chain, etc.), see the
HANDOFF.md "Mamba SSM precision -- the full intricacies" section.
Inference: serving with vLLM
vllm serve nvidia/Nemotron-Cascade-2-30B-A3B \
--speculative-config '{
"model": "<path-to-this-checkpoint-or-hf-repo>",
"method": "eagle3",
"num_speculative_tokens": 5
}' \
--trust-remote-code \
--max-model-len 262144 \
--tensor-parallel-size 2
vLLM 0.19+ will automatically:
- Load the verifier with fp32 Mamba SSM state (auto-read from the verifier's config.json -- no explicit flag needed)
- Apply sliding-window attention to the draft using the
sliding_window: 4096andlayer_types: ["sliding_attention"]fields in this checkpoint's config.json. The Eagle3-awaretarget_layer_countoffset is auto-computed from the verifier (no need to set it manually). - Bound the draft KV cache to the 4096-token window via
vLLM's
SlidingWindowSpecin the hybrid KV cache manager.
Inference: serving with SGLang
SGLang does not yet support sliding-window attention on Eagle3 drafts (as of late 2025 -- check the SGLang changelog for status). When SGLang adds support, the equivalent invocation will be:
python -m sglang.launch_server \
--model nvidia/Nemotron-Cascade-2-30B-A3B \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path <path-to-this-checkpoint-or-hf-repo> \
--speculative-num-steps 6 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 16 \
--mamba-ssm-dtype float32 \
--trust-remote-code \
--tp 2
Until SGLang ships sliding-window support for Eagle3 drafts, prefer serving this draft via vLLM 0.19+.
Sliding-window attention -- the experimental hypothesis
The draft is a 1-layer Llama transformer; without windowing its self-attention is O(L^2) per query position. The verifier (Nemotron-Cascade-2) only has 6 attention layers and 31 Mamba layers, so the verifier's compute scales gracefully with context length, but the Eagle3 draft becomes the long-context bottleneck.
Hypothesis: most reasoning-trace tokens are local-pattern continuations -- the draft only needs to predict locally-coherent next tokens; the verifier rejects any draft proposal that diverges globally. So the draft should be able to do well with a 4096-token sliding window, collapsing its attention cost from O(L^2) to O(L*window) and freeing us to push training context length up.
This checkpoint is the first published Eagle3 head built around this hypothesis. Acceptance rates / downstream speedup numbers (when available) will be added to a follow-up section here.
Memory engineering required to train at long context
Training a 1-layer Llama-style draft against a 30B-param frozen
Mamba-Transformer verifier at long context (L=32k or L=65k) with
ttt_length=6 does not fit on a 96 GB GPU per rank without
significant engineering. The SpecForge fork that produced this
checkpoint adds the following training-time optimizations:
- Chunked MLP forward (
--draft-mlp-chunk-size 4096, active for this checkpoint): Liger-style per-chunk MLP forward over the seq dim, drops the per-step transient peak from ~4 GiB to ~800 MiB per MLP forward. Cheap defensive memory saving with negligible compute overhead. - Chunked fused linear + soft-target CE
(
--fused-linear-loss --fused-linear-loss-chunk-size 4096, active for this checkpoint): chunks the lm_head + KL-distillation loss over the seq dim with per-chunk grad checkpointing, never materializes the full[B, T, V]logits tensor. Saves ~14 GiB at L=32k and ~28 GiB at L=65k. Mathematically equivalent to the unchunked path (numerical equivalence verified in the SpecForge fork'sspecforge/core/loss.py__main__block). - Sliding-window attention on the draft (
sliding_window=4096, active for this checkpoint): collapses draft self-attention from O(L²) to O(L * window). See "Sliding-window attention" section. - Mamba SSM fp32 monkey-patch (always on): see Mamba SSM section.
- Patch to free verifier full-vocab logits before draft TTT
unrolling so the ~16 GiB
[L, V]tensor doesn't stay alive across the backward pass. - In-place padding for the per-step shift operations on big logit tensors.
Disabled at L=32k, re-enabled at L=65k:
- Grad-checkpoint on the draft MLP (
--draft-mlp-grad-checkpoint, OFF for this L=32k checkpoint, will be ON for the L=65k follow-up): recomputesgate_proj/silu/up_proj/(gate*up)intermediates during backward, drops them from the saved-for-backward set across TTT unrolls. Required at L=65k to fit on 96 GB but costs ~25-30% on the draft forward, so we leave it off at L=32k where we have memory headroom and want maximum step throughput.
See the experiments/nemotron-cascade-2/README.md for the complete memory math reference and the per-commit engineering tour.
License and attribution
- Verifier model (
nvidia/Nemotron-Cascade-2-30B-A3B): refer to the verifier's HF page for its license. This draft head is derivative in the sense that it was trained against the verifier's outputs. - This draft head: Apache 2.0 (matching the SpecForge license).
- Training framework: SpecForge fork at
https://github.com/hav4ik/SpecForge-Nemotron3, derived from
https://github.com/sgl-project/SpecForge with additional engineering
patches for hybrid Mamba-Transformer verifiers (see branch
nemotron-cascade-2-experiments).
- Downloads last month
- 1,217