clsd / README.md
datasysdev's picture
Update README.md
4993a81 verified
---
license: apache-2.0
language:
- en
tags:
- diffusion
- speculative-decoding
- rectified-flow
- dit
- qwen
- math-reasoning
- deltanet
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen3.5-9B
---
# Continuous Latent Speculative Decoding (CLSD)
**Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
**Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation
**Status**: Stage A converged, Stage C alignment in progress
---
## How it was envisaged
A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach.
## What is cool
The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout.
CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses.
## Thesis
Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.
Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design.
The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.
> **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.
---
## Architecture
| Role | Model | Params | Dim | Layers |
|------|-------|--------|-----|--------|
| **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) |
| **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 |
### The Strided Graft
```
Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
D = DeltaNet (linear_attention), A = full_attention
```
### DiT Modifications
1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized
2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C
3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096))
4. Causal masking preserved from original Qwen weights
---
## Training Pipeline
### Pre-Flight: Embedding Extraction
Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought):
- Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
- Chunked into 128-token windows: [64, 128, 4096] safetensors shards
- **146,790 total chunks** across 2,294 files
### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE
The DiT learns the straight-line velocity field v = x1 - x0:
```
x_t = (1-t)*noise + t*target, t in [0,1]
L = ||v_pred - (target - noise)||^2
```
| Parameter | Value |
|-----------|-------|
| Hardware | 1x NVIDIA B200 (183 GB) |
| Steps | 50,000 |
| Batch size | 32 |
| Optimizer | AdamW (lr=1e-4, cosine decay) |
| Wall-clock | 154.8 minutes |
| Final MSE | ~0.013 (converged by step 5K) |
### Stage C: CE Alignment -- IN PROGRESS
Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:
```
noise -> DiT (2 Euler steps) -> draft_embeds
-> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens
```
L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)
Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.
**Smoke test results** (50 steps, batch=1):
- CE dropped 12.8 -> 6.1: verifier starting to read DiT output
- Gradients flow correctly through frozen verifier
**Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb
---
## Step 4: Live Inference (The Parallel Rollout)
1. User submits reasoning prompt
2. 9B Verifier forward pass -> conditioning vector C + KV cache
3. DiT generates **32 candidate 128-token branches** in 2 Euler steps
4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
5. Score by mean log-probability across 128 positions
6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position
7. Qwen samples correct token, new C generated, loop repeats
**Target latency**: <500ms per 128-token block
---
## Step 5: The Shadow Loop (Async RL -- Continuous Improvement)
The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:
```
Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
Shadow Node --[Weight sync every 1000 steps]--> Primary Node
```
### Objective Verification (Reward Signal)
Feed Top-1 decoded tokens through:
- **Lean 4**: formal mathematical proof verification
- **Python sandbox**: code execution for correctness
If verified -> reward the continuous vectors (positive signal)
If failed -> penalize (negative signal)
This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.
### RL Objective
Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:
1. The verifier accepts (high log-prob)
2. Actually solve the problem (Lean4/sandbox verification)
This is an **infinite background process** -- the system improves continuously as long as compute is available.
---
## Repository Contents
```
checkpoints/
dit_stage_a_step_5000.pt # Early training
dit_stage_a_step_10000.pt # Mid training
dit_stage_a_step_30000.pt # Late training
dit_stage_a_final.pt # 50K steps, converged (MSE=0.013)
dit_stage_c_*.pt # CE alignment checkpoints (when available)
embeddings_sample/ # 50 representative embedding shards
batch_*.safetensors # Each: [64, 128, 4096]
```
### Loading a Checkpoint
```python
from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch
qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)
```
---
## Roadmap
- [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
- [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
- [x] Step 2: Stage A rectified flow (50K steps, converged)
- [x] Stage C smoke test (50 steps, pipeline validated)
- [ ] Step 3: Stage C full alignment (2000+ steps on B200)
- [ ] Step 4: Live inference with Causal Guillotine
- [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification
- [ ] Scale to 8x H200 cluster for production training
## Wandb
- Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun)
- Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke)
## License
Apache 2.0