File size: 34,630 Bytes
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
 
d9dd3a5
e5add15
d9dd3a5
 
 
e5add15
d9dd3a5
 
 
 
 
e5add15
d9dd3a5
e5add15
 
 
 
 
 
 
 
 
d9dd3a5
e5add15
 
 
 
 
d9dd3a5
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
 
 
 
d9dd3a5
 
e5add15
 
 
 
 
 
 
 
 
d9dd3a5
 
e5add15
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
# Composer Replication Framework — User Guide

A zero-to-training walkthrough for the open replication of Cursor Composer 2.5.
Pace: an ML engineer who knows GRPO/DPO at a textbook level but has never
opened this repo. Every step references real code, and every kwarg name
listed below has been imported and verified against
`composer_replication/` source.

---

## 1. What is this framework?

A pure-PyTorch replication of the **3-channel composer loss** that powers
agentic-coding model training. One model, one optimizer, three additive
loss terms — composed every step:

```
                      ┌────────────────────────────────────────────┐
                      │            compose_loss(model, batch)       │
                      └────────────────────────────────────────────┘

        ┌─────────────────────────────────┼─────────────────────────────────┐
        ▼                                 ▼                                 ▼
┌───────────────────┐          ┌──────────────────────┐         ┌──────────────────────┐
│  Channel 1 (RL)   │          │  Channel 2 (SDPO)    │         │  Channel 3 (replay)  │
│  GRPO            │          │  hint-distillation   │         │  multi-teacher DPO   │
│  → lm_ce stub in │          │  generalized JSD     │         │  on (chosen,         │
│  verification    │          │  student vs teacher  │         │  rejected) pairs     │
│  harness         │          │  (hint-conditioned)  │         │  from N teachers     │
└─────────┬─────────┘          └──────────┬───────────┘         └──────────┬───────────┘
          │  weight = 1 (always on)       │ alpha_sdpo            beta_replay │
          └────────────────┬──────────────┴─────────────────┬──────────────────┘
                           ▼                                ▼
                   total = lm_ce + α·sdpo_jsd + β·trace_replay_dpo
                   (channel auto-disables if its weight=0 OR its inputs are missing)
```

Two API surfaces, on purpose:

- **Verification harness**`compose_loss(model, batch, ...)` is a free
  function (channel 1 = LM cross-entropy, the GRPO limit under deterministic
  rewards). Use it for CPU smokes, unit tests, and gradient-flow debugging.
- **Production trainer**`ComposerReplicationTrainer` is a `trl.GRPOTrainer`
  subclass that overrides `_compute_loss` with the same 3 channels on top of
  TRL's real reward + advantage machinery.

The verification harness is what you'll use for sections 2–6; the production
trainer (and its alternates VeRL/PRIME-RL/Monarch) is section 8.

Source of truth: `composer_replication/loss.py` for `compose_loss`,
`composer_replication/trainer/composer_trainer.py` for the trainer subclass.

---

## 2. Install — which extras to pick

Always start with the core install:

```bash
git clone https://huggingface.co/Codeseys/composer-replication-framework
cd composer-replication-framework
pip install -e .
```

That gets you `torch>=2.0` + `transformers>=4.46` and is enough for the
verification harness on CPU (sections 3, 5, 6).

The seven optional extras are declared in `pyproject.toml` `[project.optional-dependencies]`:

```
                              Do you need …

        ┌──────────────────────────┼──────────────────────────┐
        ▼                          ▼                          ▼
   real teacher calls          DiLoCo on                  production
   over OpenRouter?            >1 GPU?                    GRPO training?
        │                          │                          │
        │ yes                      │ yes                      │ yes
        ▼                          ▼                          ▼
   pip install -e ".[replay]"  pip install -e ".[diloco]"  pip install -e ".[train]"
   (httpx)                     (torchft-nightly)           (trl, peft, accelerate, datasets)
        │                          │                          │
        │ + want CPU-side          │ + scaling beyond a        │ + want PRIME-RL
        │ DPO normalization?       │ single host?              │ (Recipe C)?
        ▼                          ▼                          ▼
   pip install -e \".[replaysim]\"  pip install -e \".[serverless]\"  pip install -e \".[prime-rl]\"
   (data-juicer; depends         (fsspec, huggingface_hub)    (prime-rl>=0.5)
    on [replay])
                                                              │ + Monarch actor mesh?

                                                          pip install -e \".[monarch]\"
                                                          (monarch>=0.4.1)
```

Quick decision table:

| Goal                                                  | Install                                  |
|-------------------------------------------------------|------------------------------------------|
| CPU smoke / verification (sections 3, 5, 6)           | `pip install -e .`                       |
| Section 4 (replaysim DJNormalizer)                    | `pip install -e ".[replaysim]"`          |
| Section 7 dev loop (LocalProcessExecutor + file://)   | `pip install -e ".[serverless]"`         |
| Real DiLoCo outer-loop                                | `pip install -e ".[diloco,serverless]"`  |
| Section 8 Recipe A (TRL GRPO)                         | `pip install -e ".[train]"`              |
| Section 8 Recipe C (PRIME-RL)                         | `pip install -e ".[prime-rl]"`           |
| Section 8 Recipe C+D (PRIME-RL + Monarch)             | `pip install -e ".[prime-rl,monarch]"`   |
| Everything for development                            | `pip install -e ".[dev]"`                |

---

## 3. Quickstart: `examples/qwen_05b_quickstart` end-to-end on CPU

The fastest way to convince yourself the framework works on a real HF model.
~3–5 min wall-clock on CPU, ~1 GB disk for Qwen2.5-0.5B weights.

```bash
pip install -e .
python examples/qwen_05b_quickstart/run.py
```

What the script does (read the source at
`examples/qwen_05b_quickstart/run.py`):

1. Pin RNG (`random.seed(42)`, `torch.manual_seed(42)`) so the per-step
   numbers below are reproducible.
2. Load `Qwen/Qwen2.5-0.5B-Instruct` on CPU in fp32, set `model.train()`.
3. `batch = build_batch(tokenizer, device="cpu")` — a real chat-template-formatted
   batch with all keys the 3-channel composer might consume.
4. Five backward steps with `compose_loss(model, batch, alpha_sdpo=0.1,
   beta_replay=0.05)`; an `AdamW(lr=1e-5)` optimizer; finite-grad check
   after each step.

Expected output (transcribed from `examples/qwen_05b_quickstart/run.log`):

```
[quickstart] loading Qwen/Qwen2.5-0.5B-Instruct (CPU, fp32) ...
[quickstart] loaded — 0.494B params
[quickstart] building real chat-template batch ...
[quickstart] running 5 backward steps ...
  step 0: total=0.7390  lm_ce=0.7358  sdpo=0.0000  dpo=0.0639  finite=True
  step 1: total=0.0379  lm_ce=0.0351  sdpo=0.0000  dpo=0.0563  finite=True
  step 2: total=0.0122  lm_ce=0.0110  sdpo=0.0000  dpo=0.0240  finite=True
  step 3: total=0.0060  lm_ce=0.0055  sdpo=0.0000  dpo=0.0098  finite=True
  step 4: total=0.0031  lm_ce=0.0029  sdpo=0.0000  dpo=0.0044  finite=True
========================================================
  Initial loss: 0.7390  →  Final loss: 0.0031  →  Reduction: 99.6%
  Verdict: PASS
========================================================
```

How to read this:

- **`total` collapses by ~99%.** The model successfully memorizes the
  single batch — exactly what you expect from an SGD pass on a 0.5B model
  with one fixed input. This is a wiring check, not a generalization claim.
- **`lm_ce` carries almost all the magnitude.** Channel 1 (the GRPO stub)
  is doing the work — the response tokens are short and have low entropy
  under the trained model.
- **`sdpo=0.0000` on every step.** Channel 2 has auto-disabled because the
  default `build_batch` does not include `ctx_teacher_input_ids`. Compare
  the conditional in `compose_loss`:
  ```python
  if (alpha_sdpo > 0.0
      and "ctx_teacher_input_ids" in inputs
      and inputs["ctx_teacher_input_ids"].numel() > 0):
  ```
  — channel auto-off if either the weight or the inputs are missing.
- **`dpo > 0` and trending down.** The batch *does* include
  `dpo_chosen_input_ids`, `dpo_chosen_response_mask`,
  `dpo_chosen_ref_logprobs` (and the rejected counterparts), so channel 3
  is live.
- **`finite=True`** — every step's `p.grad` was finite for every parameter.
  This is the wiring contract; if it ever flips to `False` the smoke fails.

If you see `Verdict: PASS`, the framework is correctly installed and
gradients flow through all live channels. You are ready for section 4.

---

## 4. Adding the trace-replay channel

The quickstart batch *had* DPO inputs, but they were synthetic — the
`build_batch` helper bakes them in. To get **real** DPO pairs from
multi-teacher disagreement, use the replaysim package.

### 4a. Spin up `replay_trace`

```python
import asyncio
from composer_replication import (
    DEFAULT_TEACHERS, replay_trace, extract_dpo_pairs,
)

# Trace must be a list[TraceState]; see composer_replication/teacher_replay.py
# for the exact TypedDict shape. Each state holds a chat-messages prefix +
# the student's actual action at that step.
states = [...]   # your frozen agentic trace; see spike 001 for a 50-step example

teacher_actions = asyncio.run(
    replay_trace(
        states=states,
        teachers=DEFAULT_TEACHERS,    # claude-opus-4.7 + gpt-5 + deepseek-v4-pro
        max_total_usd=10.0,           # hard ceiling (spike 001 measured $0.98/trace mean)
    )
)
```

The 3 teachers are queried in parallel via OpenRouter
(`OPENROUTER_API_KEY` in env or `~/.hermes/.env`), latencies recorded,
costs tracked.

### 4b. Get `DPOPair`s from disagreement

```python
pairs = extract_dpo_pairs(
    states=states,
    teacher_actions=teacher_actions,
    agreement_threshold=2,    # at least 2/3 teachers must agree on the chosen action
)
```

Each pair is a `DPOPair` TypedDict with the exact shape the
`DJNormalizer` and downstream training expects:

```python
class DPOPair(TypedDict):
    state_id:           str
    state_messages:     list[dict]    # conversation context
    chosen:             str           # teacher-consensus action
    rejected:           str           # student action
    n_teachers_agreeing: int
```

(verified in `composer_replication/teacher_replay.py:99–105`).

### 4c. Run `DJNormalizer` with `default.yaml`

```python
from composer_replication.replaysim import DJNormalizer

normalizer = DJNormalizer()        # uses recipes/replaysim/default.yaml
normalized = normalizer.normalize(pairs)
# → list[NormalizedDPOPair]
```

`DJNormalizer` shells out to data-juicer's `DefaultExecutor` under the hood
(file-in / file-out contract). The default recipe at
`composer_replication/recipes/replaysim/default.yaml` runs four CPU-only ops
in order:

1. `text_length_filter` (8 ≤ chars ≤ 32000) on `chosen` and `rejected`
2. `words_num_filter` (2 ≤ words ≤ 4096) on both
3. `special_characters_filter` (≤50% non-alpha) on both
4. `document_deduplicator` (per-batch hashing, lowercase, ignore non-character) on `chosen`

Records carry **two parallel shapes** for `chosen`/`rejected`:
- flat strings (`chosen`, `rejected`) → consumed by data-juicer's text_key-based filters
- chat-messages lists (`chosen_messages`, `rejected_messages`) → preserved for chat-aware ops + round-trip

This dual-shape design (verified in the test
`test_dpo_pair_to_dj_record_shape`,
`composer_replication/replaysim/tests/test_replaysim.py:44`) is what
unblocked the data-juicer integration in Wave 14.

### 4d. The 3-record fixture from spike 001

The fixture lives at
`spikes/001-teacher-replay-cost/states.jsonl` (50 states) and
`spikes/001-teacher-replay-cost/results.jsonl` (the teacher responses, all
priced and timed). The first 3 states are:

```jsonl
{"id": "state-000", "task": "Fix the failing test in tests/test_auth.py::test_login_with_email", ...}
{"id": "state-001", "task": "Add rate-limiting middleware to the Flask app", ...}
{"id": "state-002", "task": "Refactor the parse_config function — it's 200 lines and has 3 responsibilities", ...}
```

For each, all 3 teachers answered (claude-opus-4.7, gpt-5, deepseek-v4-pro);
agreement on the `(c)` choice for state-000 and state-001 (read more
files / check schema first) drives a clean DPO pair where the student's
action becomes the rejected. For state-002, all 3 agreed on `(c)` (write
characterization tests first) → another clean pair. These three records
pass through the `DJNormalizer` default recipe unchanged (length, words,
special-char ratios all in bounds; no duplicates).

The full 50-state trace cost **$0.98 mean** end-to-end across all three
teachers (spike 001 verdict). The framework's cost ceiling
(`max_total_usd`) and VOI gating drop this to ~$0.30/trace projected.

### 4e. End-to-end one-liner

```python
from composer_replication.replaysim import replay_and_normalize_trace

teacher_actions, normalized_pairs = await replay_and_normalize_trace(
    states=states,
    teachers=DEFAULT_TEACHERS,
    agreement_threshold=2,
    max_total_usd=10.0,
)
```

(`async def`; for sync callers use the sibling `replay_and_normalize_trace_sync`
in `composer_replication.replaysim.normalize`.)

---

## 5. Switching DPO → SimPO: one kwarg

```python
components = compose_loss(
    model, batch,
    alpha_sdpo=0.1,
    beta_replay=0.05,
    dpo_variant="simpo",      # ← the only line that changes
    simpo_beta=2.0,           # paper default
    simpo_gamma=1.0,          # paper default
)
```

The kwarg is verified in `compose_loss`'s signature
(`composer_replication/loss.py:81`):

```python
dpo_variant: Literal["dpo", "simpo"] = "dpo",
```

### What changes in the loss curve

- **Channel 3 input requirements drop.** `compose_loss` no longer reads
  `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`. Reference-model
  VRAM cost goes to zero. (Source: `composer_replication/loss.py:111–113`
  and `composer_replication/distillation/simpo.py:23–27`.)
- **Loss scale shifts.** Standard DPO is
  `-logsigmoid(β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))])`.
  SimPO is `-logsigmoid(β·[avg_logπ(c) - avg_logπ(r)] - γ)` — average
  per-token log-prob (length-normalized) and a constant target margin γ.
- **Loss is ≤ DPO loss when chosen/rejected separation is large.** The
  unit test `test_simpo_loss_lower_for_better_separation`
  (`composer_replication/distillation/tests/test_distillation_losses.py:35`)
  verifies that a wider chosen-vs-rejected gap drives lower SimPO loss —
  meaning, in practice, SimPO curves are *steeper* than DPO when the
  preference signal is strong, and *flatter* when it's weak.
- **No KL-against-reference regularization.** This is both the upside (no
  ref-model serving) and the risk (more tendency to drift). Watch for
  reward-hacking-style degeneracies if your preference data has noise.

### When to use SimPO

- **GPU-poor.** You can't afford to keep a frozen reference policy resident
  alongside the trainer.
- **Cold-start preference data.** Length-normalization (avg_logπ vs sum)
  helps when chosen/rejected lengths are wildly imbalanced — common in
  agentic traces where the student's failed attempt is short and the
  teacher's correction is long.
- **You don't have ref logprobs precomputed.** SimPO needs nothing from
  the reference policy.

When to **stay on DPO**: when you need the explicit KL anchor against
a known-good reference (e.g., when training over a long horizon and you
want to bound the drift), or when your preference data is very noisy and
the reference acts as a regularizer.

---

## 6. Adding TAID / Entropy-Aware OPD wrappers

Channel 2 (SDPO/OPSD) can be replaced by **TAID** (Sakana AI,
arXiv:2501.16937) for capacity-gap distillation, or by
**Entropy-Aware OPD** (ICLR 2026 Spotlight) for per-token forward/reverse-KL
gating. Both are wired through `compose_loss`:

```python
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,         # current TAID interpolation coeff
entropy_opd_h_max: float | None = None,
```

(verified at `composer_replication/loss.py:82–93`.)

### TAID (upstream-faithful port)

> **Wave 15 rewrite, breaking change.** The previous in-tree TAID was
> algorithmically different from the paper (it mixed in probability space
> against a frozen step-0 student snapshot and wrapped a symmetric JSD
> criterion). It has been replaced with an upstream-faithful port:
> logit-space mix, current-student-detached anchor, forward-KL criterion.
> Old kwargs `taid_schedule_step`, `taid_total_steps`, `taid_schedule`,
> `taid_alpha_min`, `taid_alpha_max`, plus `inputs["student_init_logits"]` /
> `inputs["student_init_input_ids"]` are **gone**. They have no upstream
> analogue. Use `taid_t` (and optionally `TAIDScheduler`) instead.

The TAID criterion is forward-KL against a logit-space-interpolated target:

```
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L   = - mean_token  Σ_v  p_t(v) · log_softmax(student_logits)(v)
```

where `t ∈ [0, 1]` is the interpolation coefficient. At `t=0` the target
is the (detached) student itself — the loss is the entropy of that
distribution and contributes no gradient to the student. At `t=1` it
reduces to standard forward-KL distillation against the teacher.

The schedule that produces `t` is the **trainer's** responsibility. The
package ships an optional `TAIDScheduler` that mirrors the paper's
adaptive momentum scheme:

```python
from composer_replication.distillation import TAIDScheduler

sched = TAIDScheduler(num_train_steps=10_000)   # paper defaults
for step in range(num_train_steps):
    components = compose_loss(
        model, batch,
        sdpo_wrapper="taid",
        taid_t=sched.t,
    )
    components.total.backward(); optimizer.step()
    sched.update_t(components.sdpo_jsd.detach(), global_step=step)
```

`TAIDScheduler` defaults match upstream: `t_start=0.4`, `t_end=1.0`,
`alpha=5e-4`, `beta=0.99`. Pass `disable_adaptive=True` to fall back to
the deterministic linear schedule
`t = t_start + progress · (t_end - t_start)`.

If you want a simple fixed schedule (no scheduler), just compute `t`
yourself and pass it in — `compose_loss` validates `taid_t ∈ [0, 1]`.

### Upstream-parity test

`composer_replication/distillation/tests/test_taid_parity.py` skip-imports
the upstream reference at `/tmp/taid-clone` (clone with
`git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone`)
and asserts our `taid_loss(student, teacher, mask, t)` matches upstream
`TAID.compute_loss(...)` within `atol=rtol=1e-5` across `t ∈ {0.0, 0.1, 0.4,
0.5, 0.9, 1.0}`. This is the load-bearing parity guarantee.

### Entropy-Aware OPD

Drop-in for channel 2 — gates between forward KL (mode-covering) and
reverse KL (mode-seeking) per token, weighted by the teacher's entropy:

```
L = Σ_t  w(t) · KL_fwd_t  +  (1 - w(t)) · KL_rev_t
w(t) = clamp(H_teacher(t) / h_max, 0, 1)
```

`entropy_opd_h_max=None` (the default) auto-sets `h_max = log(V)` (the
maximum-entropy bound for a vocab-V softmax).

### Boundary-condition unit test (proof of correctness)

The test `test_taid_loss_t_zero_target_matches_detached_student`
(`composer_replication/distillation/tests/test_distillation_losses.py`)
pins TAID's `t=0` invariant — the teacher is *completely* hidden from the
gradient because the target collapses to `softmax(student.detach())`:

```python
def test_taid_loss_t_zero_target_matches_detached_student():
    s1 = torch.randn(1, 2, 4, requires_grad=True)
    teacher_a = torch.zeros(1, 2, 4); teacher_a[..., 0] = 10.0
    teacher_b = torch.zeros(1, 2, 4); teacher_b[..., 3] = 10.0
    mask = torch.ones(1, 2)
    loss_a = taid_loss(s1, teacher_a, mask, t=0.0)
    loss_b = taid_loss(s1, teacher_b, mask, t=0.0)
    # Two completely different teachers must give the same loss at t=0.
    assert abs(float(loss_a) - float(loss_b)) < 1e-6
```

This is the load-bearing test for TAID: if the `t=0` endpoint ever leaks
teacher signal into the gradient, this test fires and the contract is
broken. The companion test `test_taid_loss_t_one_is_pure_forward_kl`
pins the `t=1` endpoint by hand-computing `-Σ p_teacher · log_q` and
asserting equality.
---

## 7. Going multi-replica with serverless DiLoCo

DiLoCo is the outer-loop optimizer that lets you run N replicas in
parallel, sync them every H inner steps, and tolerate slow links — see
`docs/adrs/ADR-005-serverless-diloco.md` for the design. The framework
gives you three increasingly-distant deployments:

### Step 1 — `LocalProcessExecutor` for development

```python
from composer_replication.diloco.serverless import (
    LocalProcessExecutor, ObjectStoreAllReduce,
)
import tempfile

with tempfile.TemporaryDirectory() as td:
    rendezvous = ObjectStoreAllReduce(td, rank=0, world_size=4)
    executor = LocalProcessExecutor()
    handles = executor.launch_replicas(
        n_replicas=4,
        entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
        entrypoint_args={"rendezvous_uri": td, "rank_env": "REPLICA_RANK"},
    )
    results = executor.collect(handles, timeout=600)
```

`LocalProcessExecutor` (`composer_replication/diloco/serverless/executor.py:160`)
spawns N child processes via `multiprocessing.get_context("spawn")` and
sets `REPLICA_RANK={0..N-1}` in each child's env. It satisfies the
`ServerlessExecutor` Protocol (line 35) — the same Protocol the cloud
adapters implement. So the dev-loop code is byte-identical to the cloud
deploy: only the executor instance changes.

### Step 2 — `ObjectStoreAllReduce` as the rendezvous

```python
# Local file:// for tests
rendezvous = ObjectStoreAllReduce("/tmp/diloco-runs/run42/", rank=0, world_size=4)

# Real S3 (after `pip install -e .[serverless]`)
rendezvous = ObjectStoreAllReduce(
    "s3://my-bucket/diloco-runs/run42/",
    rank=0, world_size=4,
    timeout_s=1800.0,
)
```

The communication pattern is `S3 PutObject + N GetObjects` once per
inner H steps (matches DiLoCo's actual sync cadence,
arXiv:2311.08105 §3.2). For 1B-param bf16, that's ~2 GB / 30 minutes
per replica — well within S3 free-tier. On the inner side the framework
exposes a `MockManager` that drops into the `torchft.Manager` slot, so
you can validate the rendezvous logic before plugging in real torchft
(verified by `test_serverless_diloco_integration.py`).

### Step 3 — point at `ModalExecutor` / `HFJobsExecutor`

```python
# Modal (skeleton at composer_replication/diloco/serverless/modal.py)
from composer_replication.diloco.serverless.modal import ModalExecutor
executor = ModalExecutor(image="modal:python3.11", gpu="A100")

# HuggingFace Jobs (skeleton at composer_replication/diloco/serverless/hf_jobs.py)
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
executor = HFJobsExecutor(hardware="a10g-large")

# Same Protocol — same launch_replicas / poll / collect calls as Local
handles = executor.launch_replicas(n_replicas=4, ...)
```

Both adapters check their cloud SDK at `__init__` time (not at module
import) so they don't break the package if you don't have `modal` or
`huggingface_hub` installed. Production maturity: dev-ready for cloud
trial; per ADR-005, full HA-cluster fan-out lives in v0.2+.

---

## 8. Picking an RL backend

Four canonical recipes, each tied to an upstream framework. Source:
`docs/INTEGRATION_ARCHITECTURE.md` Recipes A–D plus
`docs/adrs/ADR-006-rl-frameworks.md`.

### Recipe A — TRL `GRPOTrainer` subclass

`ComposerReplicationTrainer` is a `trl.GRPOTrainer` subclass that
overrides `_compute_loss(model, inputs)` to compose the same 3 channels
on top of TRL's real reward + advantage machinery. Install:
`pip install -e ".[train]"`. Then:

```python
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(model=..., reward_funcs=[...], ...)
trainer.train()
```

**When to use it:** This is the v0.0/v0.1 recommended path. You want
real GRPO with rewards, you have HF model + dataset + (mostly) standard
GRPO infrastructure, and you don't need >100B-param scale. TRL is
mature, the trainer is a small subclass, and the same `compose_loss`
math runs in both the verification harness and in production with no
re-coding.

→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe A: TRL `GRPOTrainer`
subclass" (line 205).

### Recipe B — VeRL custom `adv_estimator` + DataProto extension

VeRL replaces TRL's reward+advantage machinery with a Ray-driven actor
graph that's specifically optimized for distributed RL training of
large LMs. Composition with the framework: extend `DataProto` with the
hint-conditioned columns + DPO pair fields, register a custom
`adv_estimator` that calls the same `compose_loss` body.

**When to use it:** You're past 7B-param, you have multi-host setup
(Ray cluster), and TRL's single-process trainer is the bottleneck. VeRL
is the recommended scale path for v0.2+. Trade-off: the extension surface
is larger and the docs are sparser than TRL's.

→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe B: VeRL custom
`adv_estimator`" (line 289).

### Recipe C — PRIME-RL with DPPO-clip details

`composer_replication/recipes/prime_rl/composer_loss.py` ships a
`loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2,
dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3)` adapter that maps
PRIME-RL's `LossInputs` struct (1-D per-sample tensors:
`trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
`advantages`, `loss_mask`) into our 3-channel composition.

The DPPO+KL bit is what makes PRIME-RL distinctive — and we mirror
PRIME-RL's upstream `default_loss_fn` exactly (verified against
`prime_rl/trainer/rl/loss.py` lines 116-165):

```python
log_ir       = trainer_logprobs - inference_logprobs
ir           = exp(log_ir)                                  # importance ratio
probs_diff   = exp(trainer_logprobs) - exp(inference_logprobs)
invalid_high = probs_diff >  dppo_mask_high                 # for positive-advantage tokens
invalid_low  = probs_diff < -dppo_mask_low                  # for negative-advantage tokens
invalid      = where(advantages > 0, invalid_high, invalid_low)
keep         = loss_mask & ~invalid
pg_loss      = keep      * (adv_tau * advantages) * ir
kl_loss      = loss_mask * log_ir**2
loss         = (-pg_loss + kl_tau * kl_loss).sum()
```

Three things to remember: (1) the mask gate is on **probability-space**
`exp(trainer_lp) - exp(inference_lp)`, not on the log-ratio; (2) the
policy-gradient term is multiplied by the importance ratio
`exp(trainer_lp - inference_lp)`, not by `trainer_lp` directly (proper
IS-corrected gradient, not REINFORCE); (3) the mask is **conditioned on
the sign of the advantage** — positive-advantage tokens are dropped on
the upper bound, negative-advantage tokens on the lower. Defaults
`dppo_mask_high=dppo_mask_low=0.2` and `adv_tau=1.0, kl_tau=1e-3`
match PRIME-RL's `DefaultLossConfig` (all fields `Field(..., ge=0)`).
SDPO (channel 2) is gated `NotImplementedError` in v0 because PRIME-RL
exposes log-probs, not full vocab logits; channel 3 (trace-replay DPO)
emits a warning if `beta_dpo != 0`.

**When to use it:** You're already in the PRIME-Intellect / decentralized
training universe, you want INTELLECT-style scaling on a long-horizon
training run, and DPPO masking is part of your existing reward+advantage
recipe. Install: `pip install -e ".[prime-rl]"`.

→ See `composer_replication/recipes/prime_rl/prime_rl_recipe.md` and
`docs/INTEGRATION_ARCHITECTURE.md` § "Recipe C: TorchForge + Monarch"
(line 356).

### Recipe C+D — Monarch as actor mesh

Monarch (the actor framework underpinning TorchForge) hosts the
trainer/generator/manager actors in a topology-aware mesh. The framework
ships *skeleton* actor definitions at
`composer_replication/recipes/monarch/actors.py` (TrainerActor,
GeneratorActor) and a layout doc at `monarch_actor_layout.md`. v0
intentionally *fails fast* if you try to instantiate them
(`raise NotImplementedError("v0 skeleton; deferred to v0.2 per ADR-006")`)
because the upstream Monarch API is still moving.

**When to use it:** Reference-pattern reading only in v0. Decision point
is v0.2 once the upstream actor API stabilizes. Treat the skeleton as
shape-of-the-final-answer documentation, not as a production target.
Install: `pip install -e ".[prime-rl,monarch]"` for the full surface.

→ See `composer_replication/recipes/monarch/monarch_actor_layout.md`
and `docs/adrs/ADR-006-rl-frameworks.md`.

---

## Common pitfalls + what tests catch them

The framework's 115-test suite (post-Wave-15) is structured so each pitfall has a
specific test-file home. If you hit one of these in production, the
corresponding test is your fastest reproducer.

| Pitfall                                                                                       | Symptom                                              | Test file (catches it)                                                                                              |
|-----------------------------------------------------------------------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|
| Forgetting `taid_schedule_step` when `sdpo_wrapper="taid"`                                    | `ValueError` at first step                           | `composer_replication/tests/test_compose_loss_integration.py` (kwarg validation)                                    |
| TAID α=0 endpoint leaks teacher signal                                                        | Teacher swap changes the loss when α should be 0     | `test_taid_loss_alpha_zero_ignores_teacher` in `composer_replication/distillation/tests/test_distillation_losses.py:153` |
| TAID α=1 endpoint differs from plain SDPO                                                     | Bit-difference vs reference SDPO at the schedule end | `test_taid_blended_logits_endpoints` in `composer_replication/distillation/tests/test_distillation_losses.py:115`   |
| SimPO loss not differentiable through the loss-of-sigmoid path                                | `chosen.grad is None` after backward                 | `test_simpo_loss_differentiable` in `composer_replication/distillation/tests/test_distillation_losses.py:50`        |
| SimPO shape-mismatch slips through silently                                                   | Broadcasting bug, NaN downstream                     | `test_simpo_loss_shape_mismatch_raises` in `composer_replication/distillation/tests/test_distillation_losses.py:61` |
| Entropy-OPD failing to zero out when distributions match                                      | Loss > 0 when student≡teacher                        | `test_entropy_aware_opd_zero_when_distributions_match` in `composer_replication/distillation/tests/test_distillation_losses.py:217` |
| Entropy of one-hot ≠ 0 / entropy of uniform ≠ log(V)                                          | Wrong gating weights w(t)                            | `test_teacher_entropy_one_hot_is_zero` and `test_teacher_entropy_uniform_is_log_v` in `composer_replication/distillation/tests/test_distillation_losses.py:175,183` |
| `DJNormalizer` records missing the chat-messages shape                                        | Filters silently no-op or crash                      | `test_dpo_pair_to_dj_record_shape` in `composer_replication/replaysim/tests/test_replaysim.py:44`                   |
| `DJNormalizer` round-trip drops `state_messages` / metadata                                   | Lost provenance                                      | `test_dj_record_to_normalized_roundtrip` and `test_dj_record_to_normalized_preserves_state_messages` in `composer_replication/replaysim/tests/test_replaysim.py` |
| `ObjectStoreAllReduce` accepts an out-of-bounds rank                                          | Silent corruption of the all-reduce average          | `test_object_store_allreduce_init_validates_rank` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:31` |
| `ObjectStoreAllReduce(world_size=1)` doesn't passthrough cleanly                              | False all-reduce on single replica                   | `test_object_store_allreduce_world_size_1_passthrough` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:46` |
| `LocalProcessExecutor` doesn't propagate child failures to `collect()`                        | Silent test pass when a replica crashed              | `test_serverless_diloco_integration.py` in `composer_replication/diloco/serverless/tests/`                          |
| PRIME-RL adapter accidentally uses `(B, T)` shape instead of per-sample `(seq,)`              | Shape mismatch / wrong reduction                     | `composer_replication/recipes/prime_rl/tests/test_composer_loss.py` (10 tests covering shape and DPPO mask edges)   |
| Channel 2/3 fails to auto-disable when its inputs are absent                                  | Crash on missing key, not graceful skip              | `composer_replication/tests/test_compose_loss_integration.py` (`(a) defaults reproduce existing compose_loss output bit-exact`) |

Run the full suite with `pytest` from the repo root.

---

**File path:** `/mnt/e/CS/HF/composer-replication-framework/docs/USER_GUIDE.md`