datasysdev commited on
Commit
3fd6ee7
·
verified ·
1 Parent(s): d911ae7

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +205 -0
README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - diffusion
7
+ - speculative-decoding
8
+ - rectified-flow
9
+ - dit
10
+ - qwen
11
+ - math-reasoning
12
+ datasets:
13
+ - AI-MO/NuminaMath-CoT
14
+ base_model:
15
+ - Qwen/Qwen3.5-9B
16
+ ---
17
+
18
+ # Continuous Latent Speculative Decoding (CLSD)
19
+
20
+ **Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
21
+ **Target**: SOTA mathematical reasoning via continuous latent speculative decoding
22
+ **Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer
23
+
24
+ ---
25
+
26
+ ## Thesis
27
+
28
+ Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a
29
+ hybrid causal Diffusion Transformer (DiT) — a strided 12-layer slice of Qwen3.5-9B —
30
+ operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier.
31
+ Both models share the exact same 4096-dimensional manifold, the same tokenizer,
32
+ and the same attention geometry. No projection bridges, no dimensional translation loss.
33
+
34
+ Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8
35
+ standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern.
36
+ The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear
37
+ recurrence is strictly causal by design and cannot be flipped to bidirectional.
38
+
39
+ The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps.
40
+ The verifier evaluates them in a single batched forward pass. The DiT is aligned via
41
+ Cross-Entropy backpropagation through the frozen verifier.
42
+
43
+ > **Why causal diffusion works**: The conditioning vector C is injected via adaLN into
44
+ > every position simultaneously, providing global context regardless of attention mask.
45
+ > Token 1 does not need to see token 128 -- C already carries the full prompt context.
46
+ > The causal constraint actually forces the DiT to learn autoregressive-like internal
47
+ > logic, which mirrors the frozen verifier expectations.
48
+
49
+ ---
50
+
51
+ ## Architecture
52
+
53
+ ### Models
54
+
55
+ | Role | Model | Params | Dim | Layers | Attn Heads | KV Heads |
56
+ |------|-------|--------|-----|--------|-----------|----------|
57
+ | **Generator (DiT)** | Qwen3.5-9B -> strided 12-layer slice | ~4.0B | 4096 | 12 | 16 | 4 |
58
+ | **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 | 16 | 4 |
59
+
60
+ ### The Strided Graft
61
+
62
+ ```
63
+ Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
64
+ Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
65
+ DiT indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
66
+
67
+ D = DeltaNet (linear_attention), A = full_attention
68
+ Result: 9 DeltaNet + 3 full_attention layers
69
+ ```
70
+
71
+ ### Modifications to Grafted Layers
72
+
73
+ 1. **Strip the LM head** -- the DiT outputs continuous embeddings, not logits
74
+ 2. **Keep causal masking** -- preserves 100% of pre-trained weight integrity
75
+ 3. **Inject adaLN-Zero modulators** -- one per block, nn.Linear(4096, 24576)
76
+ 4. **Zero-initialize** -- at step 0 the network acts as identity
77
+ 5. **Timestep conditioning** -- sinusoidal embedding + conditioning vector C
78
+ 6. **Learned local positional embedding** -- nn.Parameter(zeros(1, 128, 4096))
79
+
80
+ ---
81
+
82
+ ## Training Pipeline
83
+
84
+ ### Pre-Flight: Embedding Extraction
85
+
86
+ Target embeddings pre-computed from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought reasoning):
87
+ - Tokenize reasoning paths with Qwen tokenizer
88
+ - Lookup embeddings via Qwen3.5-9B frozen embedding matrix E (248320 x 4096)
89
+ - Chunk into fixed 128-token windows
90
+ - Save as [64, 128, 4096] safetensors shards
91
+
92
+ **Result**: 2,294 shard files x 64 chunks = **146,790 total chunks** (~144 GB)
93
+
94
+ ### Stage A: Rectified Flow (Velocity Regression)
95
+
96
+ Teach the DiT the straight-line velocity field from noise to embeddings using Rectified Flow:
97
+
98
+ x_t = (1 - t) * x_0 + t * x_1, t in [0, 1]
99
+
100
+ L_RF = ||v_theta(x_t, t, C) - (x_1 - x_0)||^2
101
+
102
+ | Property | DDPM + LCM (old) | Rectified Flow (this work) |
103
+ |----------|-------------------|---------------------------|
104
+ | Training objective | Noise prediction | Velocity prediction (v) |
105
+ | Trajectory shape | Curved (needs 1000 steps) | **Straight line** |
106
+ | Distillation required? | Yes | **No** |
107
+ | Native inference steps | 2 (after distillation) | **1-2 Euler steps natively** |
108
+
109
+ **This release**: Stage A trained on 1x NVIDIA B200 for 50,000 steps:
110
+
111
+ | Parameter | Value |
112
+ |-----------|-------|
113
+ | Optimizer | AdamW (lr=1e-4, warmup 100 steps, cosine decay) |
114
+ | Batch size | 32 |
115
+ | Steps | 50,000 |
116
+ | Wall-clock | 154.8 minutes |
117
+ | Final MSE loss | ~0.013 (converged by step 5K) |
118
+ | Checkpoints included | 5K, 10K, 20K, 30K, 40K, final |
119
+
120
+ ### Stage C: CE Alignment (Next)
121
+
122
+ Shift the DiT from outputs that look like embeddings to outputs that make
123
+ the 9B verifier produce correct tokens:
124
+
125
+ ```
126
+ z ~ N(0,I) -> DiT(z, C) -> [2 Euler steps] -> X (128x4096)
127
+ -> Qwen_frozen(X, past_kv) -> logits (128x248320)
128
+ ```
129
+
130
+ L_total = alpha * CE(logits, targets) + beta * MSE(X, E(targets))
131
+
132
+ - alpha = 1.0 (CE drives alignment)
133
+ - beta = 0.1 -> 0 over training (MSE regularizer anneals)
134
+
135
+ ---
136
+
137
+ ## Live Inference (Target)
138
+
139
+ 1. User submits a reasoning prompt
140
+ 2. 9B Verifier runs forward pass -> extracts C (4096-d) + KV cache
141
+ 3. DiT samples 32 noise vectors, generates 32 candidate 128-token branches in **2 Euler steps**
142
+ 4. 9B Verifier evaluates all 32 branches in one batched forward pass
143
+ 5. **Causal Guillotine**: Scan Top-1 draft left-to-right, truncate at first position where log-prob drops below threshold
144
+ 6. Qwen samples the correct token, new C generated, loop repeats
145
+
146
+ **Target latency**: <500ms per 128-token block
147
+
148
+ ---
149
+
150
+ ## Repository Contents
151
+
152
+ ```
153
+ embeddings/ # Pre-computed NuminaMath-CoT embeddings (146K chunks)
154
+ batch_0000.safetensors # Each: [64, 128, 4096]
155
+ ...
156
+ checkpoints/
157
+ dit_stage_a_step_5000.pt
158
+ dit_stage_a_step_10000.pt
159
+ dit_stage_a_step_20000.pt
160
+ dit_stage_a_step_30000.pt
161
+ dit_stage_a_step_40000.pt
162
+ dit_stage_a_final.pt # 50K steps, converged
163
+ ```
164
+
165
+ ### Loading a Checkpoint
166
+
167
+ ```python
168
+ from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
169
+ from transformers import AutoModelForCausalLM
170
+ import torch
171
+
172
+ # Build the DiT architecture
173
+ qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
174
+ dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
175
+
176
+ # Load trained weights
177
+ state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
178
+ dit.load_state_dict(state_dict)
179
+ ```
180
+
181
+ ---
182
+
183
+ ## Key Architectural Decisions
184
+
185
+ 1. **Shared 4096-d space**: Generator and verifier operate in the same embedding geometry natively. No projection layers, no information bottlenecks.
186
+ 2. **Strided layer slice**: DiT inherits geometric knowledge from early, middle, and late layers of the 9B.
187
+ 3. **Rectified Flow over DDPM**: Linear trajectories -> no distillation stage -> native 2-step generation.
188
+ 4. **Instruct/Instruct architecture**: Both drafter and verifier sliced from the same model. Zero distributional gap at initialization.
189
+ 5. **Monte Carlo parallel search**: 32 branches x 128 tokens = 4,096 candidate tokens per inference step.
190
+
191
+ ---
192
+
193
+ ## Citation
194
+
195
+ ```bibtex
196
+ @misc{clsd2026,
197
+ title={Continuous Latent Speculative Decoding: A Hybrid Causal DiT for Parallel Reasoning},
198
+ year={2026},
199
+ url={https://huggingface.co/datasysdev/clsd}
200
+ }
201
+ ```
202
+
203
+ ## License
204
+
205
+ Apache 2.0