--- license: apache-2.0 language: - en tags: - diffusion - speculative-decoding - rectified-flow - dit - qwen - math-reasoning - deltanet datasets: - AI-MO/NuminaMath-CoT base_model: - Qwen/Qwen3.5-9B --- # Continuous Latent Speculative Decoding (CLSD) **Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier **Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation **Status**: Stage A converged, Stage C alignment in progress --- ## How it was envisaged A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach. ## What is cool The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout. CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses. ## Thesis Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss. Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design. The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass. > **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations. --- ## Architecture | Role | Model | Params | Dim | Layers | |------|-------|--------|-----|--------| | **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) | | **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 | ### The Strided Graft ``` Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31] Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ] D = DeltaNet (linear_attention), A = full_attention ``` ### DiT Modifications 1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized 2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C 3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096)) 4. Causal masking preserved from original Qwen weights --- ## Training Pipeline ### Pre-Flight: Embedding Extraction Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought): - Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix - Chunked into 128-token windows: [64, 128, 4096] safetensors shards - **146,790 total chunks** across 2,294 files ### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE The DiT learns the straight-line velocity field v = x1 - x0: ``` x_t = (1-t)*noise + t*target, t in [0,1] L = ||v_pred - (target - noise)||^2 ``` | Parameter | Value | |-----------|-------| | Hardware | 1x NVIDIA B200 (183 GB) | | Steps | 50,000 | | Batch size | 32 | | Optimizer | AdamW (lr=1e-4, cosine decay) | | Wall-clock | 154.8 minutes | | Final MSE | ~0.013 (converged by step 5K) | ### Stage C: CE Alignment -- IN PROGRESS Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness: ``` noise -> DiT (2 Euler steps) -> draft_embeds -> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens ``` L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings) Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment. **Smoke test results** (50 steps, batch=1): - CE dropped 12.8 -> 6.1: verifier starting to read DiT output - Gradients flow correctly through frozen verifier **Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb --- ## Step 4: Live Inference (The Parallel Rollout) 1. User submits reasoning prompt 2. 9B Verifier forward pass -> conditioning vector C + KV cache 3. DiT generates **32 candidate 128-token branches** in 2 Euler steps 4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention) 5. Score by mean log-probability across 128 positions 6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position 7. Qwen samples correct token, new C generated, loop repeats **Target latency**: <500ms per 128-token block --- ## Step 5: The Shadow Loop (Async RL -- Continuous Improvement) The Primary Node never stops drafting. A Shadow Node continuously improves the DiT: ``` Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node Shadow Node --[Weight sync every 1000 steps]--> Primary Node ``` ### Objective Verification (Reward Signal) Feed Top-1 decoded tokens through: - **Lean 4**: formal mathematical proof verification - **Python sandbox**: code execution for correctness If verified -> reward the continuous vectors (positive signal) If failed -> penalize (negative signal) This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively. ### RL Objective Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that: 1. The verifier accepts (high log-prob) 2. Actually solve the problem (Lean4/sandbox verification) This is an **infinite background process** -- the system improves continuously as long as compute is available. --- ## Repository Contents ``` checkpoints/ dit_stage_a_step_5000.pt # Early training dit_stage_a_step_10000.pt # Mid training dit_stage_a_step_30000.pt # Late training dit_stage_a_final.pt # 50K steps, converged (MSE=0.013) dit_stage_c_*.pt # CE alignment checkpoints (when available) embeddings_sample/ # 50 representative embedding shards batch_*.safetensors # Each: [64, 128, 4096] ``` ### Loading a Checkpoint ```python from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES from transformers import AutoModelForCausalLM import torch qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16) dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES) state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True) dit.load_state_dict(state_dict) ``` --- ## Roadmap - [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT) - [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B) - [x] Step 2: Stage A rectified flow (50K steps, converged) - [x] Stage C smoke test (50 steps, pipeline validated) - [ ] Step 3: Stage C full alignment (2000+ steps on B200) - [ ] Step 4: Live inference with Causal Guillotine - [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification - [ ] Scale to 8x H200 cluster for production training ## Wandb - Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun) - Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke) ## License Apache 2.0