license: apache-2.0
language:
- en
tags:
- diffusion
- speculative-decoding
- rectified-flow
- dit
- qwen
- math-reasoning
- deltanet
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen3.5-9B
Continuous Latent Speculative Decoding (CLSD)
Architecture: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier Key Innovation: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation Status: Stage A converged, Stage C alignment in progress
How it was envisaged
A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach.
What is cool
The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout. CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses.
Thesis
Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.
Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps causal masking -- DeltaNet linear recurrence is strictly causal by design.
The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.
Why causal diffusion works: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.
Architecture
| Role | Model | Params | Dim | Layers |
|---|---|---|---|---|
| Generator (DiT) | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) |
| Verifier (frozen) | Qwen3.5-9B (text tower) | 9B | 4096 | 32 |
The Strided Graft
Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
D = DeltaNet (linear_attention), A = full_attention
DiT Modifications
- adaLN-Zero modulators per block: nn.Linear(4096, 24576), zero-initialized
- Timestep conditioning: sinusoidal embedding + conditioning vector C
- Learned local positional embedding: nn.Parameter(zeros(1, 128, 4096))
- Causal masking preserved from original Qwen weights
Training Pipeline
Pre-Flight: Embedding Extraction
Target embeddings from AI-MO/NuminaMath-CoT (mathematical chain-of-thought):
- Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
- Chunked into 128-token windows: [64, 128, 4096] safetensors shards
- 146,790 total chunks across 2,294 files
Stage A: Rectified Flow (Velocity Regression) -- COMPLETE
The DiT learns the straight-line velocity field v = x1 - x0:
x_t = (1-t)*noise + t*target, t in [0,1]
L = ||v_pred - (target - noise)||^2
| Parameter | Value |
|---|---|
| Hardware | 1x NVIDIA B200 (183 GB) |
| Steps | 50,000 |
| Batch size | 32 |
| Optimizer | AdamW (lr=1e-4, cosine decay) |
| Wall-clock | 154.8 minutes |
| Final MSE | ~0.013 (converged by step 5K) |
Stage C: CE Alignment -- IN PROGRESS
Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:
noise -> DiT (2 Euler steps) -> draft_embeds
-> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens
L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)
Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.
Smoke test results (50 steps, batch=1):
- CE dropped 12.8 -> 6.1: verifier starting to read DiT output
- Gradients flow correctly through frozen verifier
Current run: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb
Step 4: Live Inference (The Parallel Rollout)
- User submits reasoning prompt
- 9B Verifier forward pass -> conditioning vector C + KV cache
- DiT generates 32 candidate 128-token branches in 2 Euler steps
- 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
- Score by mean log-probability across 128 positions
- Causal Guillotine: scan Top-1 left-to-right, truncate at first low-confidence position
- Qwen samples correct token, new C generated, loop repeats
Target latency: <500ms per 128-token block
Step 5: The Shadow Loop (Async RL -- Continuous Improvement)
The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:
Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
Shadow Node --[Weight sync every 1000 steps]--> Primary Node
Objective Verification (Reward Signal)
Feed Top-1 decoded tokens through:
- Lean 4: formal mathematical proof verification
- Python sandbox: code execution for correctness
If verified -> reward the continuous vectors (positive signal) If failed -> penalize (negative signal)
This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.
RL Objective
Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:
- The verifier accepts (high log-prob)
- Actually solve the problem (Lean4/sandbox verification)
This is an infinite background process -- the system improves continuously as long as compute is available.
Repository Contents
checkpoints/
dit_stage_a_step_5000.pt # Early training
dit_stage_a_step_10000.pt # Mid training
dit_stage_a_step_30000.pt # Late training
dit_stage_a_final.pt # 50K steps, converged (MSE=0.013)
dit_stage_c_*.pt # CE alignment checkpoints (when available)
embeddings_sample/ # 50 representative embedding shards
batch_*.safetensors # Each: [64, 128, 4096]
Loading a Checkpoint
from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch
qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)
Roadmap
- Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
- Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
- Step 2: Stage A rectified flow (50K steps, converged)
- Stage C smoke test (50 steps, pipeline validated)
- Step 3: Stage C full alignment (2000+ steps on B200)
- Step 4: Live inference with Causal Guillotine
- Step 5: Shadow Loop async RL with Lean4/sandbox verification
- Scale to 8x H200 cluster for production training
Wandb
- Stage A: clsd-speedrun
- Stage C smoke: clsd-speedrun-smoke
License
Apache 2.0