File size: 8,112 Bytes
3fd6ee7
 
 
 
 
 
 
 
 
 
 
5b14326
3fd6ee7
 
 
 
 
 
 
 
 
5b14326
 
3fd6ee7
 
 
aa24c84
 
2dc1c39
aa24c84
5386687
 
 
4993a81
3fd6ee7
 
5b14326
3fd6ee7
5b14326
3fd6ee7
5b14326
3fd6ee7
5b14326
3fd6ee7
 
 
 
 
5b14326
 
 
 
3fd6ee7
 
 
 
 
 
 
 
 
 
5b14326
 
 
 
 
3fd6ee7
 
 
 
 
 
 
5b14326
 
 
 
3fd6ee7
5b14326
3fd6ee7
5b14326
3fd6ee7
5b14326
 
 
 
3fd6ee7
 
 
5b14326
3fd6ee7
5b14326
 
3fd6ee7
5b14326
3fd6ee7
5b14326
3fd6ee7
5b14326
3fd6ee7
 
5b14326
 
3fd6ee7
 
5b14326
 
 
 
 
 
 
3fd6ee7
5b14326
3fd6ee7
 
 
5b14326
3fd6ee7
5b14326
 
 
 
 
 
 
3fd6ee7
 
 
 
 
5b14326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fd6ee7
 
 
 
5b14326
 
 
 
 
 
 
3fd6ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b14326
3fd6ee7
5b14326
 
 
 
 
 
 
 
3fd6ee7
5b14326
3fd6ee7
5b14326
 
3fd6ee7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
---
license: apache-2.0
language:
- en
tags:
- diffusion
- speculative-decoding
- rectified-flow
- dit
- qwen
- math-reasoning
- deltanet
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen3.5-9B
---

# Continuous Latent Speculative Decoding (CLSD)

**Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
**Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation
**Status**: Stage A converged, Stage C alignment in progress

---

## How it was envisaged

A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach.

## What is cool

The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout.
CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses.
## Thesis

Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.

Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design.

The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.

> **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.

---

## Architecture

| Role | Model | Params | Dim | Layers |
|------|-------|--------|-----|--------|
| **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) |
| **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 |

### The Strided Graft

```
Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types:   [D, A, D, D, D,  A,  D,  D,  D,  D,  D,  A ]

D = DeltaNet (linear_attention), A = full_attention
```

### DiT Modifications
1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized
2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C
3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096))
4. Causal masking preserved from original Qwen weights

---

## Training Pipeline

### Pre-Flight: Embedding Extraction

Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought):
- Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
- Chunked into 128-token windows: [64, 128, 4096] safetensors shards
- **146,790 total chunks** across 2,294 files

### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE

The DiT learns the straight-line velocity field v = x1 - x0:

```
x_t = (1-t)*noise + t*target,  t in [0,1]
L = ||v_pred - (target - noise)||^2
```

| Parameter | Value |
|-----------|-------|
| Hardware | 1x NVIDIA B200 (183 GB) |
| Steps | 50,000 |
| Batch size | 32 |
| Optimizer | AdamW (lr=1e-4, cosine decay) |
| Wall-clock | 154.8 minutes |
| Final MSE | ~0.013 (converged by step 5K) |

### Stage C: CE Alignment -- IN PROGRESS

Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:

```
noise -> DiT (2 Euler steps) -> draft_embeds
  -> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens
```

L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)

Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.

**Smoke test results** (50 steps, batch=1):
- CE dropped 12.8 -> 6.1: verifier starting to read DiT output
- Gradients flow correctly through frozen verifier

**Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb

---

## Step 4: Live Inference (The Parallel Rollout)

1. User submits reasoning prompt
2. 9B Verifier forward pass -> conditioning vector C + KV cache
3. DiT generates **32 candidate 128-token branches** in 2 Euler steps
4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
5. Score by mean log-probability across 128 positions
6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position
7. Qwen samples correct token, new C generated, loop repeats

**Target latency**: <500ms per 128-token block

---

## Step 5: The Shadow Loop (Async RL -- Continuous Improvement)

The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:

```
Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
Shadow Node  --[Weight sync every 1000 steps]--> Primary Node
```

### Objective Verification (Reward Signal)

Feed Top-1 decoded tokens through:
- **Lean 4**: formal mathematical proof verification
- **Python sandbox**: code execution for correctness

If verified -> reward the continuous vectors (positive signal)
If failed -> penalize (negative signal)

This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.

### RL Objective

Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:
1. The verifier accepts (high log-prob)
2. Actually solve the problem (Lean4/sandbox verification)

This is an **infinite background process** -- the system improves continuously as long as compute is available.

---

## Repository Contents

```
checkpoints/
  dit_stage_a_step_5000.pt      # Early training
  dit_stage_a_step_10000.pt     # Mid training
  dit_stage_a_step_30000.pt     # Late training
  dit_stage_a_final.pt          # 50K steps, converged (MSE=0.013)
  dit_stage_c_*.pt              # CE alignment checkpoints (when available)
embeddings_sample/              # 50 representative embedding shards
  batch_*.safetensors           # Each: [64, 128, 4096]
```

### Loading a Checkpoint

```python
from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch

qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)
```

---

## Roadmap

- [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
- [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
- [x] Step 2: Stage A rectified flow (50K steps, converged)
- [x] Stage C smoke test (50 steps, pipeline validated)
- [ ] Step 3: Stage C full alignment (2000+ steps on B200)
- [ ] Step 4: Live inference with Causal Guillotine
- [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification
- [ ] Scale to 8x H200 cluster for production training

## Wandb

- Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun)
- Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke)

## License

Apache 2.0