File size: 8,112 Bytes
3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 aa24c84 2dc1c39 aa24c84 5386687 4993a81 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 5b14326 3fd6ee7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | ---
license: apache-2.0
language:
- en
tags:
- diffusion
- speculative-decoding
- rectified-flow
- dit
- qwen
- math-reasoning
- deltanet
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen3.5-9B
---
# Continuous Latent Speculative Decoding (CLSD)
**Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
**Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation
**Status**: Stage A converged, Stage C alignment in progress
---
## How it was envisaged
A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach.
## What is cool
The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout.
CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses.
## Thesis
Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.
Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design.
The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.
> **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.
---
## Architecture
| Role | Model | Params | Dim | Layers |
|------|-------|--------|-----|--------|
| **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) |
| **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 |
### The Strided Graft
```
Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
D = DeltaNet (linear_attention), A = full_attention
```
### DiT Modifications
1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized
2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C
3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096))
4. Causal masking preserved from original Qwen weights
---
## Training Pipeline
### Pre-Flight: Embedding Extraction
Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought):
- Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
- Chunked into 128-token windows: [64, 128, 4096] safetensors shards
- **146,790 total chunks** across 2,294 files
### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE
The DiT learns the straight-line velocity field v = x1 - x0:
```
x_t = (1-t)*noise + t*target, t in [0,1]
L = ||v_pred - (target - noise)||^2
```
| Parameter | Value |
|-----------|-------|
| Hardware | 1x NVIDIA B200 (183 GB) |
| Steps | 50,000 |
| Batch size | 32 |
| Optimizer | AdamW (lr=1e-4, cosine decay) |
| Wall-clock | 154.8 minutes |
| Final MSE | ~0.013 (converged by step 5K) |
### Stage C: CE Alignment -- IN PROGRESS
Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:
```
noise -> DiT (2 Euler steps) -> draft_embeds
-> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens
```
L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)
Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.
**Smoke test results** (50 steps, batch=1):
- CE dropped 12.8 -> 6.1: verifier starting to read DiT output
- Gradients flow correctly through frozen verifier
**Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb
---
## Step 4: Live Inference (The Parallel Rollout)
1. User submits reasoning prompt
2. 9B Verifier forward pass -> conditioning vector C + KV cache
3. DiT generates **32 candidate 128-token branches** in 2 Euler steps
4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
5. Score by mean log-probability across 128 positions
6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position
7. Qwen samples correct token, new C generated, loop repeats
**Target latency**: <500ms per 128-token block
---
## Step 5: The Shadow Loop (Async RL -- Continuous Improvement)
The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:
```
Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
Shadow Node --[Weight sync every 1000 steps]--> Primary Node
```
### Objective Verification (Reward Signal)
Feed Top-1 decoded tokens through:
- **Lean 4**: formal mathematical proof verification
- **Python sandbox**: code execution for correctness
If verified -> reward the continuous vectors (positive signal)
If failed -> penalize (negative signal)
This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.
### RL Objective
Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:
1. The verifier accepts (high log-prob)
2. Actually solve the problem (Lean4/sandbox verification)
This is an **infinite background process** -- the system improves continuously as long as compute is available.
---
## Repository Contents
```
checkpoints/
dit_stage_a_step_5000.pt # Early training
dit_stage_a_step_10000.pt # Mid training
dit_stage_a_step_30000.pt # Late training
dit_stage_a_final.pt # 50K steps, converged (MSE=0.013)
dit_stage_c_*.pt # CE alignment checkpoints (when available)
embeddings_sample/ # 50 representative embedding shards
batch_*.safetensors # Each: [64, 128, 4096]
```
### Loading a Checkpoint
```python
from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch
qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)
```
---
## Roadmap
- [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
- [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
- [x] Step 2: Stage A rectified flow (50K steps, converged)
- [x] Stage C smoke test (50 steps, pipeline validated)
- [ ] Step 3: Stage C full alignment (2000+ steps on B200)
- [ ] Step 4: Live inference with Causal Guillotine
- [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification
- [ ] Scale to 8x H200 cluster for production training
## Wandb
- Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun)
- Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke)
## License
Apache 2.0
|