| --- |
| license: apache-2.0 |
| language: |
| - en |
| tags: |
| - diffusion |
| - speculative-decoding |
| - rectified-flow |
| - dit |
| - qwen |
| - math-reasoning |
| - deltanet |
| datasets: |
| - AI-MO/NuminaMath-CoT |
| base_model: |
| - Qwen/Qwen3.5-9B |
| --- |
| |
| # Continuous Latent Speculative Decoding (CLSD) |
|
|
| **Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier |
| **Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation |
| **Status**: Stage A converged, Stage C alignment in progress |
|
|
| --- |
|
|
| ## How it was envisaged |
|
|
| A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach. |
|
|
| ## What is cool |
|
|
| The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout. |
| CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses. |
| ## Thesis |
|
|
| Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss. |
|
|
| Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design. |
|
|
| The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass. |
|
|
| > **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations. |
|
|
| --- |
|
|
| ## Architecture |
|
|
| | Role | Model | Params | Dim | Layers | |
| |------|-------|--------|-----|--------| |
| | **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) | |
| | **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 | |
|
|
| ### The Strided Graft |
|
|
| ``` |
| Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31] |
| Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ] |
| |
| D = DeltaNet (linear_attention), A = full_attention |
| ``` |
|
|
| ### DiT Modifications |
| 1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized |
| 2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C |
| 3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096)) |
| 4. Causal masking preserved from original Qwen weights |
|
|
| --- |
|
|
| ## Training Pipeline |
|
|
| ### Pre-Flight: Embedding Extraction |
|
|
| Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought): |
| - Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix |
| - Chunked into 128-token windows: [64, 128, 4096] safetensors shards |
| - **146,790 total chunks** across 2,294 files |
|
|
| ### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE |
|
|
| The DiT learns the straight-line velocity field v = x1 - x0: |
|
|
| ``` |
| x_t = (1-t)*noise + t*target, t in [0,1] |
| L = ||v_pred - (target - noise)||^2 |
| ``` |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Hardware | 1x NVIDIA B200 (183 GB) | |
| | Steps | 50,000 | |
| | Batch size | 32 | |
| | Optimizer | AdamW (lr=1e-4, cosine decay) | |
| | Wall-clock | 154.8 minutes | |
| | Final MSE | ~0.013 (converged by step 5K) | |
|
|
| ### Stage C: CE Alignment -- IN PROGRESS |
|
|
| Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness: |
|
|
| ``` |
| noise -> DiT (2 Euler steps) -> draft_embeds |
| -> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens |
| ``` |
|
|
| L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings) |
|
|
| Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment. |
|
|
| **Smoke test results** (50 steps, batch=1): |
| - CE dropped 12.8 -> 6.1: verifier starting to read DiT output |
| - Gradients flow correctly through frozen verifier |
|
|
| **Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb |
| |
| --- |
| |
| ## Step 4: Live Inference (The Parallel Rollout) |
| |
| 1. User submits reasoning prompt |
| 2. 9B Verifier forward pass -> conditioning vector C + KV cache |
| 3. DiT generates **32 candidate 128-token branches** in 2 Euler steps |
| 4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention) |
| 5. Score by mean log-probability across 128 positions |
| 6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position |
| 7. Qwen samples correct token, new C generated, loop repeats |
| |
| **Target latency**: <500ms per 128-token block |
| |
| --- |
| |
| ## Step 5: The Shadow Loop (Async RL -- Continuous Improvement) |
| |
| The Primary Node never stops drafting. A Shadow Node continuously improves the DiT: |
| |
| ``` |
| Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node |
| Shadow Node --[Weight sync every 1000 steps]--> Primary Node |
| ``` |
| |
| ### Objective Verification (Reward Signal) |
| |
| Feed Top-1 decoded tokens through: |
| - **Lean 4**: formal mathematical proof verification |
| - **Python sandbox**: code execution for correctness |
| |
| If verified -> reward the continuous vectors (positive signal) |
| If failed -> penalize (negative signal) |
| |
| This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively. |
| |
| ### RL Objective |
| |
| Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that: |
| 1. The verifier accepts (high log-prob) |
| 2. Actually solve the problem (Lean4/sandbox verification) |
| |
| This is an **infinite background process** -- the system improves continuously as long as compute is available. |
| |
| --- |
| |
| ## Repository Contents |
| |
| ``` |
| checkpoints/ |
| dit_stage_a_step_5000.pt # Early training |
| dit_stage_a_step_10000.pt # Mid training |
| dit_stage_a_step_30000.pt # Late training |
| dit_stage_a_final.pt # 50K steps, converged (MSE=0.013) |
| dit_stage_c_*.pt # CE alignment checkpoints (when available) |
| embeddings_sample/ # 50 representative embedding shards |
| batch_*.safetensors # Each: [64, 128, 4096] |
| ``` |
| |
| ### Loading a Checkpoint |
| |
| ```python |
| from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES |
| from transformers import AutoModelForCausalLM |
| import torch |
|
|
| qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16) |
| dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES) |
| state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True) |
| dit.load_state_dict(state_dict) |
| ``` |
| |
| --- |
| |
| ## Roadmap |
| |
| - [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT) |
| - [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B) |
| - [x] Step 2: Stage A rectified flow (50K steps, converged) |
| - [x] Stage C smoke test (50 steps, pipeline validated) |
| - [ ] Step 3: Stage C full alignment (2000+ steps on B200) |
| - [ ] Step 4: Live inference with Causal Guillotine |
| - [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification |
| - [ ] Scale to 8x H200 cluster for production training |
| |
| ## Wandb |
| |
| - Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun) |
| - Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke) |
| |
| ## License |
| |
| Apache 2.0 |
| |