composer-replication-framework / docs /INTEGRATION_ARCHITECTURE.md
baladithyab
Wave 3: integration architecture + spike-005 trainer skeleton (16 tests pass)
fd77f74

Integration Architecture: 3-Channel Reward Composition Across the Agentic-RL Stack

Status: Architecture spec — verified against framework source code via DeepWiki on 2026-05-25. Companion doc: docs/COMPOSER_RECIPE_MAPPING.md defines the three reward channels (RLVR / Composer-SDPO / N-Teacher-Replay). This document specifies where each one hooks into each framework — the actual function names, decorator surfaces, and DataProto fields you'd touch. Working code skeleton at spikes/005-integrated-trainer-skeleton/.

TL;DR — the unified loss

For any framework choice, the v0.1 trainer computes:

total_loss = grpo_loss
           + α * sdpo_kl_loss        (Composer hint-distill, channel 2)
           + β * trace_replay_loss   (N-teacher novel channel, channel 3)

Where:

  • grpo_loss = standard GRPO+DAPO over RLVR scalar rewards (channel 1, the substrate).
  • sdpo_kl_loss = generalized_jsd_loss(student_logits, teacher_logits, labels=…, beta=0.5, …) — single-model self-distillation, where teacher_logits come from a forward pass on the student model with a hint inserted into the context. Lifted verbatim from siyan-zhao/OPSD::generalized_jsd_loss (verified self-contained static method, MIT licensed).
  • trace_replay_loss = DPO-style preference loss (or PRM-style score regression) over (chosen, rejected) pairs derived from N external teacher disagreements at each step.

The novel architectural claim is that all three channels can run simultaneously in a single trainer step, with the cost split as: (1) one extra forward pass per error site for SDPO, (2) N teacher API calls per replayed step for trace-replay. Spike 001 verified the API economics (✅ $0.98/trace, 5× headroom).

Stack-by-stack integration matrix

Component TRL VeRL TorchForge Monarch OpenEnv
Channel 1 (RLVR/GRPO) GRPOTrainer._compute_loss(model, inputs) — base class behavior, no change core_algos.compute_grpo_outcome_advantage (registered via @register_adv_est("grpo")) forge.controller.GRPO recipe (paused; pattern reference only) Orchestrates rollout/trainer/rewarder ActorMeshes Env exposes RLVR-shaped reward via step()
Channel 2 (SDPO hint-distill) Subclass override of _compute_loss; lift generalized_jsd_loss from OPSD New advantage estimator registered as @register_adv_est("grpo_sdpo"); reads data.batch["sdpo_teacher_logprobs"]; OR keep adv_estimator=grpo and add SDPO term in critic worker's compute_loss Add a new ActorMesh SDPOTeacherActor that re-runs forward with hint-conditioned context; wire into trainer's loss No-op at orchestration layer (just routes hint pairs) Env emits "error site" markers in tool response so trainer knows where to insert hints
Channel 3 (N-teacher trace-replay) Subclass override of _compute_loss; add DPO-pair term using teacher logprobs in inputs["teacher_action_distributions"] Custom adv_estimator; teacher distributions stashed in data.non_tensor_batch["teacher_actions"]; precedent: distillation already attaches teacher_log_probs to rollout DataProto Add a new TeacherReplayActor ActorMesh that holds OpenRouter client; called on a delayed-reward channel (RFC-004) Routes teacher queries via service.spawn(TeacherReplayActor, n=K) for K parallel teacher pools Env's state() API exposes step-level state needed for teacher replay
Multi-turn rollout async Blocking — tool-call stalls GPU AsyncServer + AgentLoop async; tool-call doesn't block GPU ✅ Generator ActorMesh async via vLLM; tool-call waits don't block trainer ✅ ActorMesh + supervision tree; native async Env supports async via WebSocket multiplexed sessions
Weight sync (vLLM ↔ FSDP) Co-located vLLM (no resharding) 3D-HybridEngine (resharding between FSDP↔TP) — most efficient TorchStore RDMA weight broadcast Monarch RDMA data plane N/A (env-side)
Scale ceiling ~32 GPUs / 70B FSDP ✅ 671B+ proven, Megatron-LM Reference patterns only (paused) Thousands of GPUs (mesh) 10K+ concurrent env sessions

Reading the matrix: rows are "what each reward channel touches in each framework." Columns are framework choices. The matrix shows the v0.1 framework choice is non-trivial:

  • TRL = simplest extension story (one subclass override) but doesn't async-decouple tool calls and caps at ~70B.
  • VeRL = most flexible at scale (custom adv_estimator + DataProto extension is well-trodden) and has async agent loop, but Ray-heavy and steeper curve.
  • TorchForge + Monarch = cleanest abstraction but Forge is "development paused" — use as reference, not foundation.
  • OpenEnv = orthogonal substrate — works with all of the above; not a choice, a default.

Architecture diagrams (mechanism-level, all three channels)

1. Composer SDPO hint-distill flow (single model, hint-conditioned self-teacher)

                                    ┌─────────────────────┐
                                    │  Hint Generator     │
                                    │  - templates v0.1   │
                                    │  - LLM-driven v0.2  │
                                    └──────────┬──────────┘
                                               │ generates hint text
                                               ▼ at error sites
       Trace, mid-rollout:                ┌────────────────┐
       …turn_4 (OK)                       │ Build paired   │
       turn_5 (ERROR: tool not found) ────│ contexts:      │
       …turn_6 (OK)                       │   ctx_student  │
                                          │   ctx_teacher  │
                                          │  (= ctx_student│
                                          │   + hint at    │
                                          │   turn_5)      │
                                          └───────┬────────┘
                                                  │
                                ┌─────────────────┴──────────────────┐
                                │                                    │
                                ▼                                    ▼
                      ┌──────────────────┐              ┌────────────────────┐
                      │ Student forward  │              │ Teacher forward    │
                      │ on ctx_student   │              │ (SAME MODEL on     │
                      │   → student_logits│             │  ctx_teacher)      │
                      │                  │              │   → teacher_logits │
                      └──────────┬───────┘              └────────┬───────────┘
                                 │                               │
                                 └──────────┬────────────────────┘
                                            │ feed both into
                                            ▼
                          ┌─────────────────────────────────────────┐
                          │ generalized_jsd_loss(                  │
                          │   student_logits=…,                    │
                          │   teacher_logits=…,                    │
                          │   labels=… (mask non-error turns),     │
                          │   beta=0.5,    # JSD                   │
                          │   temperature=1.0,                     │
                          │   token_clip=…)                         │
                          │                                         │
                          │ → sdpo_kl_loss (a scalar)              │
                          └──────────────┬──────────────────────────┘
                                         │
                                         ▼
                              add to total_loss with α weight

Key implementation note: Per the DeepWiki audit, OPSD's SelfDistillationDataCollator builds two prompts per example:

  • ctx_student = problem only (or problem + rollout up to error turn).
  • ctx_teacher = problem + privileged info (in OPSD's case, the verified solution; in our case, the hint).

For Composer-style hint-distill, we adapt this: ctx_teacher = ctx_student + injected_hint at the specific turn boundary, with labels masked to keep loss only at the post-hint tokens of that turn.

2. N-Teacher trace-replay flow (N external teachers, novel)

       Trace, frozen post-rollout:
       turn_1 (state_1, action_1_student, reward=…)
       turn_2 (state_2, action_2_student, reward=…)
       …
       turn_50 (state_50, action_50_student, reward=…)
                                │
                                │ for each turn t in trace:
                                ▼
                        ┌───────────────────────────┐
                        │ teacher pool (frozen)     │
                        │  ┌──────────────────────┐ │
                        │  │ Opus 4.7 (anthro)    │ │
                        │  │ GPT-5 (openai)       │ │
                        │  │ DeepSeek V4 Pro      │ │
                        │  └──────────────────────┘ │
                        │  parallel API calls       │
                        └───────────┬───────────────┘
                                    │ teacher_t = [a_t^Opus, a_t^GPT, a_t^DS]
                                    ▼
                        ┌───────────────────────────────────┐
                        │ disagreement scorer:              │
                        │  if 2+ teachers agree on X        │
                        │     and student picked Y ≠ X:     │
                        │       chosen=X, rejected=Y        │
                        │       (DPO pair)                  │
                        │  else if all 3 disagree:          │
                        │       skip (no signal)            │
                        │  else if all agree with student:  │
                        │       skip (no signal)            │
                        └──────────────┬────────────────────┘
                                       │ DPO pairs[]
                                       ▼
                        ┌───────────────────────────────────┐
                        │ DPO loss term:                    │
                        │  L = -log σ(β·(logπ(chosen|s)     │
                        │           − logπ_ref(chosen|s)    │
                        │           − logπ(rejected|s)      │
                        │           + logπ_ref(rejected|s)))│
                        │                                   │
                        │ → trace_replay_loss (a scalar)    │
                        └──────────────┬────────────────────┘
                                       │
                                       ▼
                          add to total_loss with β weight

Key implementation note: unlike SDPO, this happens post-rollout, not during. The trace is frozen, teacher calls are batched, DPO pairs are extracted offline, and the loss is computed in a follow-up training step. This decouples teacher-API-call latency from the trainer's GPU loop entirely. Spike 001 verified ~20s p95 step latency for parallel 3-teacher calls — acceptable at offline-batch cadence.

3. The combined trainer step (all three channels)

            ┌──────────────────────────────────────────────────────────┐
            │              ROLLOUT PHASE (per episode)                 │
            │  Generator (vLLM) → Env (OpenEnv) → trace JSONL          │
            │  → emits (state_t, action_t, reward_t, error_marker_t)   │
            └────────────────────────┬─────────────────────────────────┘
                                     │
                  ┌──────────────────┼──────────────────────────┐
                  │                  │                          │
        ┌─────────▼─────────┐ ┌──────▼─────────┐    ┌───────────▼─────────┐
        │ RLVR scoring      │ │ Hint detection │    │ Teacher replay      │
        │ (test pass etc.)  │ │ at error_marker│    │ (post-rollout, async│
        │                   │ │   → hint_text  │    │  via OpenRouter API)│
        │ → reward_outcome  │ │ → ctx_teacher  │    │ → teacher_actions[] │
        └─────────┬─────────┘ └──────┬─────────┘    └───────────┬─────────┘
                  │                  │                          │
                  │                  │            ┌─────────────┘
                  │                  │            │ disagreement→DPO pairs
                  │                  │            │
                  └──────────────────┼────────────┘
                                     ▼
            ┌──────────────────────────────────────────────────────────┐
            │              TRAINING PHASE (per gradient step)          │
            │                                                          │
            │  forward(student, ctx_rollout) → student_logits          │
            │  forward(student, ctx_teacher) → teacher_logits ← SDPO   │
            │                                                          │
            │  grpo_loss        = compute_grpo_loss(reward_outcome)    │
            │  sdpo_kl_loss     = generalized_jsd_loss(s_logits,       │
            │                       t_logits, labels=error_mask)        │
            │  trace_replay_loss= dpo_loss(student_logprobs,           │
            │                              ref_logprobs, dpo_pairs)    │
            │                                                          │
            │  total_loss = grpo_loss + α*sdpo_kl_loss + β*replay_loss │
            │                                                          │
            │  total_loss.backward()                                   │
            │  optimizer.step()                                        │
            └──────────────────────────────────────────────────────────┘

Cost composition per training step (v0.0/v0.1 estimate):

Operation Cost
Rollout forward (vLLM, async) k tokens × inference TFLOPs
Teacher forward (training-mode FSDP, hint-conditioned) ~1 extra FW pass per error site (sparse — maybe 5% of tokens)
RLVR reward eval ~test execution overhead, env-bound, async
Teacher API replay (post-rollout, batched) ~$0.02/step × parallel 3-teacher = ~$1/trace at 50 steps (verified by spike 001)
GRPO + SDPO + DPO loss compute Negligible vs forward passes
Backward + optimizer step Standard FSDP step

The SDPO channel is forward-pass-bound (one extra FW per error site). The trace-replay channel is API-call-bound (offline, post-rollout, ~$0.30/trace with VOI gating in v0.1). They don't compete for the same resource.

Per-framework integration recipes

Recipe A: TRL GRPOTrainer subclass (recommended for v0.0/v0.1)

Why this is the right v0.1 choice: simplest extension; OPSD code lifts cleanly; Qwen3-7B fits comfortably in TRL's scale ceiling; first-class OpenEnv integration via environment_factory.

from trl import GRPOTrainer
from opsd_trainer import generalized_jsd_loss  # lifted from siyan-zhao/OPSD


class ComposerReplicationTrainer(GRPOTrainer):
    """v0.1 trainer: GRPO + SDPO hint-distill + N-teacher trace-replay-DPO."""

    def __init__(self, *args, alpha_sdpo=0.1, beta_replay=0.05, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha_sdpo = alpha_sdpo
        self.beta_replay = beta_replay

    def _compute_loss(self, model, inputs):
        # Channel 1: standard GRPO loss
        grpo_loss = super()._compute_loss(model, inputs)

        # Channel 2: SDPO hint-distill at error sites
        sdpo_kl = self._compute_sdpo_loss(model, inputs)

        # Channel 3: trace-replay DPO from teacher disagreement
        replay_dpo = self._compute_trace_replay_loss(model, inputs)

        # Compose
        total_loss = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo

        # Log all three components for ablation
        if self.state.global_step % self.args.logging_steps == 0:
            self.log({
                "loss/grpo": grpo_loss.detach().item(),
                "loss/sdpo_kl": sdpo_kl.detach().item(),
                "loss/trace_replay_dpo": replay_dpo.detach().item(),
                "loss/total": total_loss.detach().item(),
            })

        return total_loss

    def _compute_sdpo_loss(self, model, inputs):
        if "ctx_teacher_input_ids" not in inputs or inputs["ctx_teacher_input_ids"].numel() == 0:
            # No error sites in this batch — SDPO is a no-op.
            return torch.tensor(0.0, device=model.device)

        student_logits = model(input_ids=inputs["input_ids"]).logits
        with torch.no_grad():
            # Teacher = same model, hint-injected context. NO grad.
            teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits

        return generalized_jsd_loss(
            student_logits=student_logits,
            teacher_logits=teacher_logits,
            labels=inputs["sdpo_loss_mask"],  # only error-turn tokens
            beta=0.5,
            temperature=1.0,
            token_clip=10.0,
        )

    def _compute_trace_replay_loss(self, model, inputs):
        if "dpo_chosen_input_ids" not in inputs:
            return torch.tensor(0.0, device=model.device)

        # Standard DPO loss using teacher-disagreement-derived pairs
        chosen_logprobs = self._get_logprobs(model, inputs["dpo_chosen_input_ids"])
        rejected_logprobs = self._get_logprobs(model, inputs["dpo_rejected_input_ids"])
        ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]  # precomputed
        ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]

        beta_dpo = 0.1
        logits = beta_dpo * (chosen_logprobs - ref_chosen_logprobs
                             - rejected_logprobs + ref_rejected_logprobs)
        return -F.logsigmoid(logits).mean()

The data collator (a sibling to OPSD's SelfDistillationDataCollator) is responsible for assembling the extra fields:

  • ctx_teacher_input_ids — the hint-augmented context, when error markers fire
  • sdpo_loss_mask — which token positions are post-hint and should contribute to KL
  • dpo_chosen_input_ids / dpo_rejected_input_ids — pairs from spike-003-style extraction
  • dpo_*_ref_logprobs — precomputed under the reference (student-init) policy

OpenEnv plumbing stays untouched — the environment_factory=… kwarg of GRPOTrainer already handles the SWE-bench-lite env.

Recipe B: VeRL custom adv_estimator + DataProto extension (recommended for v0.2 scale)

Why this is the right v0.2 choice: VeRL has the only proven 70B+/671B RL story; HybridFlow's 3D-HybridEngine is the production reference for FSDP↔vLLM resharding; VeRL has precedent for exactly this pattern (teacher_log_probs already used for distillation per the DeepWiki audit).

# verl_extensions/composer_adv.py
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import register_adv_est


@register_adv_est("grpo_composer")
def compute_grpo_composer_advantage(token_level_rewards, eos_mask, index, **kwargs):
    """GRPO advantage with SDPO + N-teacher trace-replay shaping.

    Reads from kwargs (passed via DataProto.batch / non_tensor_batch):
      - sdpo_teacher_logprobs: per-token logprobs from hint-conditioned forward
      - teacher_actions:       list of N teacher action distributions per step
      - alpha_sdpo, beta_replay: weights
    """
    # Standard GRPO advantage (same as built-in)
    base_adv = core_algos.compute_grpo_outcome_advantage(
        token_level_rewards, eos_mask, index
    )

    # SDPO shaping: at error-site tokens, add an extra advantage term
    # proportional to (teacher_logprob - student_logprob) — this nudges
    # the policy gradient toward the hint-conditioned distribution.
    sdpo_teacher_lp = kwargs.get("sdpo_teacher_logprobs")
    if sdpo_teacher_lp is not None:
        student_lp = kwargs["old_log_prob"]
        sdpo_term = kwargs["alpha_sdpo"] * (sdpo_teacher_lp - student_lp)
        # Only apply at error-mask positions
        sdpo_term = sdpo_term * kwargs["sdpo_error_mask"]
        base_adv = base_adv + sdpo_term

    # Trace-replay shaping: per-step PRM signal from teacher consensus
    teacher_actions = kwargs.get("teacher_actions")
    if teacher_actions is not None:
        prm_signal = compute_teacher_consensus_prm(teacher_actions, kwargs["student_actions"])
        base_adv = base_adv + kwargs["beta_replay"] * prm_signal

    return base_adv

In the run config:

# ppo_trainer.yaml
algorithm:
  adv_estimator: grpo_composer
  alpha_sdpo: 0.1
  beta_replay: 0.05

In the rollout worker, attach the extra fields to DataProto:

# verl_extensions/composer_rollout.py
def attach_composer_fields(data: DataProto, sdpo_teacher_lp, teacher_actions):
    data.batch["sdpo_teacher_logprobs"] = sdpo_teacher_lp
    data.batch["sdpo_error_mask"]       = build_error_mask(...)
    data.non_tensor_batch["teacher_actions"] = teacher_actions
    return data

This pattern is identical to how VeRL already handles distillation rollouts (per the DeepWiki audit: "teacher log-probabilities are stashed on the rollout output and later concatenated into the per-batch DataProto for the student training step").

Recipe C: TorchForge + Monarch (reference patterns only, not a production target)

Forge is "development paused per the upstream banner; lift patterns, don't depend on it. The relevant patterns are:

  • SDPOTeacherActor ActorMesh — runs the hint-conditioned forward pass on a separate compute group, returns logits via TorchStore RDMA back to the trainer. Useful when SDPO forward is expensive enough to warrant offload.
  • TeacherReplayActor ActorMesh — pool of K parallel actors, each holding an OpenRouter HTTP client. Trainer calls service.spawn(TeacherReplayActor).query(state, n=3) and gets back N teacher distributions.
  • Delayed-reward channel (OpenEnv RFC-004) — for teacher replay where the signal arrives post-rollout, not at step(). Map to a separate reward stream that the trainer subscribes to.

If/when Monarch's K8s story matures and we move to v0.2 multi-cluster decentralized scale, lift these patterns into the VeRL stack rather than building on Forge directly.

Recipe D: OpenEnv (substrate, not a choice)

OpenEnv is orthogonal — it works with TRL, VeRL, TorchForge, and any custom trainer. The contract:

  • Env exposes reset(...), step(action), state(), close().
  • Env optionally exposes tools via MCP (RFC-003).
  • Env optionally emits delayed rewards (RFC-004).
  • Container deploys via Docker; trainer connects via WebSocket multiplexed sessions.

For our framework, the env contract needs two lightweight extensions (both backward-compatible):

  1. Error-site markers in tool responses. When a tool call fails (404, type error, runtime exception), the env's step() response includes meta["error_kind"] and meta["hint_template_key"] — pre-defined keys the trainer's hint generator dispatches on. This lets the trainer decide where in the trace to insert hints without re-running the env.
  2. State-replay endpoint. For trace-replay, the env supports state(t) returning the exact same observation the agent saw at step t — needed so external teachers see identical context. This is purely additive; existing OpenEnv envs without this can fall back to "feed teacher the conversation history" mode.

We'll publish both extensions as proposed RFCs against meta-pytorch/OpenEnv once the v0.0 spike validates the full framework.

Why all three channels can run simultaneously (the architectural argument)

These three channels do not compete for any shared resource:

Resource Channel 1 (RLVR) Channel 2 (SDPO) Channel 3 (replay)
GPU forward pass rollout (vLLM, async) extra FW per error (training, FSDP) none — uses precomputed logprobs
GPU backward pass yes yes (added to total_loss) yes (added to total_loss)
External API budget none none $0.30–1/trace (verified, spike 001)
Latency-critical path yes — gates next rollout minor — extra FW <5% of tokens no — async, post-rollout
Storage rollout JSONL extra ctx + mask in collator DPO pairs JSONL (separate dataset repo)

Furthermore the gradients are additive by design — the three loss terms each have their own α/β weights, so we can ablate any subset by setting the weight to 0. The v0.1 ablation matrix:

Run α (SDPO) β (replay) Tests
Baseline 0 0 pure GRPO+RLVR
+SDPO only 0.1 0 Composer recipe replication
+Replay only 0 0.05 the v0.0 novel claim, scaled to 32B
Full 0.1 0.05 combined channel test (v0.1 winner candidate)

This 4-arm A/B at 32B is the v0.1 terminal experiment. Total cost ~$1200 (4 runs × 3 seeds × ~$100 each). Roadmap.

Open questions / followups (for v0.1 design phase, not v0.0)

  1. Hint generator architecture (open since the recipe-mapping doc). Templates first; LLM-driven generator if templates plateau on style/communication errors.
  2. SDPO weight α schedule. OPSD paper used constant; SDPO paper uses constant; Cursor never says. Likely warmup-from-0 then constant; ablate.
  3. DPO pair extraction threshold. Spike 003 will determine: do we want only "2-of-3 teachers agree" pairs (high signal, fewer pairs), or also "1-of-3 differs from student" (more pairs, noisier)?
  4. Teacher pool composition. Spike 001 used Opus 4.7 + GPT-5 + DeepSeek V4 Pro. Question for v0.1: should we add a fourth teacher (Qwen3-Max-MoE? Kimi K2.5?) as a same-family voice to balance Anthropic/OpenAI? Cost adds linearly.
  5. Reward hacking monitoring. Cursor mentioned (without specifics) "agentic monitoring tools." Our v0.1 environment needs sandbox hardening: disable find, unzip, bytecode tools, and Python type-cache reads, so the model can't reverse-engineer deleted features the way Composer 2.5's model did.

Citations

Primary sources verified for this document:

  • TRL GRPOTrainer._compute_loss — verified via DeepWiki query against huggingface/trl repo on 2026-05-25. environment_factory kwarg confirmed for OpenEnv plumbing.
  • VeRL @register_adv_est + DataProto — verified via DeepWiki query against volcengine/verl repo on 2026-05-25. Distillation precedent (teacher_log_probs already attached to rollout DataProto) confirms the pattern.
  • OPSD generalized_jsd_loss — verified via DeepWiki query against siyan-zhao/OPSD repo on 2026-05-25. Static method, self-contained, MIT licensed, FlashAttention-2 compatible. Function signature reproduced verbatim above.
  • Cursor blogIntroducing Composer 2.5, read directly via tavily_extract advanced mode. Footnote 1 cites the three self-distillation papers.
  • SDPO paper — Hübotter et al., arXiv:2601.20802, ICLR 2026 Scaling Post-training Workshop.
  • OPSD paper — Zhao et al., arXiv:2601.18734, code at github.com/siyan-zhao/OPSD (MIT).
  • Existing research notesresearch/03-monarch-torchforge-openenv.md (Monarch/Forge/OpenEnv) and research/04-verl-trl.md (VeRL/TRL) for framework-level context. Audit notes on those files apply: trust extension-point claims here over framework-level claims there when in conflict.

This document is the bridge between the conceptual 3-channel composition (in COMPOSER_RECIPE_MAPPING.md) and the executable trainer skeleton (in spikes/005-integrated-trainer-skeleton/). Anyone implementing v0.1 starts here, then opens the skeleton.