Continuous Latent Speculative Decoding (CLSD)

Architecture: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier Key Innovation: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation Status: Stage A converged, Stage C alignment in progress


How it was envisaged

A machine for constructing candidate inner-machines that bias the verifier toward solution paths its default rollout would miss; learning a search policy that explores verifier-legible continuations and unlocks competence that ordinary autoregressive rollout fails to reach.

What is cool

The base verifier has fixed weights, but its inference process is not exhausted by ordinary left-to-right decoding. A learned continuous proposer can search for hidden-state trajectories and token paths that the verifier can recognize as correct, even if the verifier would rarely or never reach them under standard autoregressive rollout. CLSD is a supervised-trained latent block proposer whose diffusion structure makes parallel search cheap enough to expose verifier-accessible solutions that AR decoding misses.

Thesis

Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.

Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps causal masking -- DeltaNet linear recurrence is strictly causal by design.

The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.

Why causal diffusion works: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.


Architecture

Role Model Params Dim Layers
Generator (DiT) Qwen3.5-9B strided slice ~4.0B 4096 12 (9 DeltaNet + 3 FullAttn)
Verifier (frozen) Qwen3.5-9B (text tower) 9B 4096 32

The Strided Graft

Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types:   [D, A, D, D, D,  A,  D,  D,  D,  D,  D,  A ]

D = DeltaNet (linear_attention), A = full_attention

DiT Modifications

  1. adaLN-Zero modulators per block: nn.Linear(4096, 24576), zero-initialized
  2. Timestep conditioning: sinusoidal embedding + conditioning vector C
  3. Learned local positional embedding: nn.Parameter(zeros(1, 128, 4096))
  4. Causal masking preserved from original Qwen weights

Training Pipeline

Pre-Flight: Embedding Extraction

Target embeddings from AI-MO/NuminaMath-CoT (mathematical chain-of-thought):

  • Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
  • Chunked into 128-token windows: [64, 128, 4096] safetensors shards
  • 146,790 total chunks across 2,294 files

Stage A: Rectified Flow (Velocity Regression) -- COMPLETE

The DiT learns the straight-line velocity field v = x1 - x0:

x_t = (1-t)*noise + t*target,  t in [0,1]
L = ||v_pred - (target - noise)||^2
Parameter Value
Hardware 1x NVIDIA B200 (183 GB)
Steps 50,000
Batch size 32
Optimizer AdamW (lr=1e-4, cosine decay)
Wall-clock 154.8 minutes
Final MSE ~0.013 (converged by step 5K)

Stage C: CE Alignment -- IN PROGRESS

Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:

noise -> DiT (2 Euler steps) -> draft_embeds
  -> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens

L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)

Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.

Smoke test results (50 steps, batch=1):

  • CE dropped 12.8 -> 6.1: verifier starting to read DiT output
  • Gradients flow correctly through frozen verifier

Current run: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb


Step 4: Live Inference (The Parallel Rollout)

  1. User submits reasoning prompt
  2. 9B Verifier forward pass -> conditioning vector C + KV cache
  3. DiT generates 32 candidate 128-token branches in 2 Euler steps
  4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
  5. Score by mean log-probability across 128 positions
  6. Causal Guillotine: scan Top-1 left-to-right, truncate at first low-confidence position
  7. Qwen samples correct token, new C generated, loop repeats

Target latency: <500ms per 128-token block


Step 5: The Shadow Loop (Async RL -- Continuous Improvement)

The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:

Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
Shadow Node  --[Weight sync every 1000 steps]--> Primary Node

Objective Verification (Reward Signal)

Feed Top-1 decoded tokens through:

  • Lean 4: formal mathematical proof verification
  • Python sandbox: code execution for correctness

If verified -> reward the continuous vectors (positive signal) If failed -> penalize (negative signal)

This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.

RL Objective

Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:

  1. The verifier accepts (high log-prob)
  2. Actually solve the problem (Lean4/sandbox verification)

This is an infinite background process -- the system improves continuously as long as compute is available.


Repository Contents

checkpoints/
  dit_stage_a_step_5000.pt      # Early training
  dit_stage_a_step_10000.pt     # Mid training
  dit_stage_a_step_30000.pt     # Late training
  dit_stage_a_final.pt          # 50K steps, converged (MSE=0.013)
  dit_stage_c_*.pt              # CE alignment checkpoints (when available)
embeddings_sample/              # 50 representative embedding shards
  batch_*.safetensors           # Each: [64, 128, 4096]

Loading a Checkpoint

from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch

qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)

Roadmap

  • Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
  • Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
  • Step 2: Stage A rectified flow (50K steps, converged)
  • Stage C smoke test (50 steps, pipeline validated)
  • Step 3: Stage C full alignment (2000+ steps on B200)
  • Step 4: Live inference with Causal Guillotine
  • Step 5: Shadow Loop async RL with Lean4/sandbox verification
  • Scale to 8x H200 cluster for production training

Wandb

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for datasysdev/clsd

Finetuned
Qwen/Qwen3.5-9B
Finetuned
(159)
this model

Dataset used to train datasysdev/clsd