| | """Train a self-driving lab planner with TRL GRPO and OpenEnv rewards."""
|
| |
|
| | from __future__ import annotations
|
| |
|
| | import argparse
|
| | import json
|
| | import random
|
| | import re
|
| | from numbers import Real
|
| | from pathlib import Path
|
| | from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| |
|
| | from client import BioExperimentEnv
|
| | from models import (
|
| | ActionType,
|
| | ExperimentAction,
|
| | ExperimentObservation,
|
| | build_agent_observation_context,
|
| | build_agent_system_prompt,
|
| | )
|
| | from server.hackathon_environment import BioExperimentEnvironment
|
| | from server.tasks.scenarios import SCENARIO_LIBRARY
|
| |
|
| | DEFAULT_MODEL_ID = "Qwen/Qwen3.5-4B"
|
| | DEFAULT_OUTPUT_DIR = "training/grpo-output"
|
| | DEFAULT_BASE_URL = "http://localhost:8000"
|
| | DEFAULT_COMPLETION_TOKEN_BUDGET = 160
|
| | INVALID_ACTION_PENALTY = -2.0
|
| | ENVIRONMENT_ERROR_PENALTY = -4.0
|
| |
|
| | SYSTEM_PROMPT = build_agent_system_prompt()
|
| |
|
| | ACTION_TYPES = [action.value for action in ActionType]
|
| | ACTION_TYPE_ALIASES = {
|
| | "collect_samples": ActionType.COLLECT_SAMPLE.value,
|
| | "collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| | "collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| | "prepare_sc_library": ActionType.PREPARE_LIBRARY.value,
|
| | "sequence_single_cells": ActionType.SEQUENCE_CELLS.value,
|
| | "qc": ActionType.RUN_QC.value,
|
| | "run_quality_control": ActionType.RUN_QC.value,
|
| | "cluster": ActionType.CLUSTER_CELLS.value,
|
| | "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| | "differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| | "trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value,
|
| | "infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value,
|
| | "network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value,
|
| | "select_markers": ActionType.MARKER_SELECTION.value,
|
| | "final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value,
|
| | }
|
| |
|
| | HEURISTIC_SEQUENCE = [
|
| | ActionType.COLLECT_SAMPLE,
|
| | ActionType.SELECT_COHORT,
|
| | ActionType.PREPARE_LIBRARY,
|
| | ActionType.SEQUENCE_CELLS,
|
| | ActionType.RUN_QC,
|
| | ActionType.FILTER_DATA,
|
| | ActionType.NORMALIZE_DATA,
|
| | ActionType.INTEGRATE_BATCHES,
|
| | ActionType.CLUSTER_CELLS,
|
| | ActionType.DIFFERENTIAL_EXPRESSION,
|
| | ActionType.PATHWAY_ENRICHMENT,
|
| | ActionType.MARKER_SELECTION,
|
| | ActionType.TRAJECTORY_ANALYSIS,
|
| | ActionType.REGULATORY_NETWORK_INFERENCE,
|
| | ActionType.SYNTHESIZE_CONCLUSION,
|
| | ]
|
| |
|
| | VALID_ACTION_TYPES = set(ACTION_TYPES)
|
| |
|
| |
|
| | def compact_preview(value: Any, max_chars: int = 160) -> str:
|
| | try:
|
| | text = json.dumps(value, ensure_ascii=True, sort_keys=True)
|
| | except TypeError:
|
| | text = str(value)
|
| | text = re.sub(r"\s+", " ", text).strip()
|
| | if len(text) <= max_chars:
|
| | return text
|
| | return text[: max_chars - 3] + "..."
|
| |
|
| |
|
| | def _edit_distance(a: str, b: str) -> int:
|
| | if len(a) < len(b):
|
| | return _edit_distance(b, a)
|
| | if not b:
|
| | return len(a)
|
| | prev = list(range(len(b) + 1))
|
| | for i, ca in enumerate(a):
|
| | curr = [i + 1]
|
| | for j, cb in enumerate(b):
|
| | curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
|
| | prev = curr
|
| | return prev[-1]
|
| |
|
| |
|
| | def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
|
| | for name in names:
|
| | if name in payload:
|
| | return payload[name]
|
| |
|
| | lowered = {
|
| | str(key).lower(): value
|
| | for key, value in payload.items()
|
| | }
|
| | for name in names:
|
| | if name.lower() in lowered:
|
| | return lowered[name.lower()]
|
| |
|
| | for key, value in lowered.items():
|
| | for name in names:
|
| | threshold = max(2, len(name) // 3)
|
| | if _edit_distance(key, name.lower()) <= threshold:
|
| | return value
|
| | return None
|
| |
|
| |
|
| | def build_argument_parser() -> argparse.ArgumentParser:
|
| | parser = argparse.ArgumentParser(
|
| | description="Train a GRPO policy against the OpenEnv bio experiment environment."
|
| | )
|
| | parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
|
| | parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| | parser.add_argument("--dataset-episodes", type=int, default=8)
|
| | parser.add_argument("--rollout-steps", type=int, default=6)
|
| | parser.add_argument(
|
| | "--collection-policy",
|
| | choices=["random", "heuristic"],
|
| | default="heuristic",
|
| | help="Policy used to build prompt states for GRPO training.",
|
| | )
|
| | parser.add_argument(
|
| | "--reward-backend",
|
| | choices=["local", "remote"],
|
| | default="local",
|
| | help="Use local in-process scoring or a live OpenEnv server.",
|
| | )
|
| | parser.add_argument(
|
| | "--base-url",
|
| | default=DEFAULT_BASE_URL,
|
| | help="Base URL for the OpenEnv server when reward-backend=remote.",
|
| | )
|
| | parser.add_argument(
|
| | "--scenario-name",
|
| | action="append",
|
| | default=None,
|
| | help="Repeatable scenario selector. Defaults to all curated scenarios.",
|
| | )
|
| | parser.add_argument(
|
| | "--domain-randomise",
|
| | action="store_true",
|
| | help="Enable domain randomisation while building prompts and local rewards.",
|
| | )
|
| | parser.add_argument("--num-generations", type=int, default=2)
|
| | parser.add_argument(
|
| | "--max-completion-length",
|
| | type=int,
|
| | default=DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| | )
|
| | parser.add_argument("--max-prompt-length", type=int, default=768)
|
| | parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| | parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
| | parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| | parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
| | parser.add_argument("--logging-steps", type=int, default=1)
|
| | parser.add_argument("--save-steps", type=int, default=50)
|
| | parser.add_argument(
|
| | "--plot-metric-key",
|
| | default=None,
|
| | help="Optional extra metric key from trainer log history to plot.",
|
| | )
|
| | parser.add_argument("--seed", type=int, default=0)
|
| | parser.add_argument(
|
| | "--load-model-only",
|
| | action="store_true",
|
| | help="Download and load the selected model and tokenizer, then exit.",
|
| | )
|
| | parser.add_argument(
|
| | "--trust-remote-code",
|
| | action="store_true",
|
| | help="Pass trust_remote_code=True to model/tokenizer loading.",
|
| | )
|
| | parser.add_argument(
|
| | "--dry-run",
|
| | action="store_true",
|
| | help="Build the prompt dataset and smoke-test the reward function without training.",
|
| | )
|
| | parser.add_argument(
|
| | "--push-to-hub",
|
| | type=str,
|
| | default=None,
|
| | help="HuggingFace Hub repo id to push the trained model to (e.g. 'myuser/my-model').",
|
| | )
|
| | return parser
|
| |
|
| |
|
| | def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| | return build_argument_parser().parse_args(argv)
|
| |
|
| |
|
| | def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| | """Build an argparse-style namespace for notebooks and scripts."""
|
| | parser = build_argument_parser()
|
| | defaults = vars(parser.parse_args([]))
|
| | unknown = sorted(set(overrides) - set(defaults))
|
| | if unknown:
|
| | raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| | defaults.update(overrides)
|
| | return argparse.Namespace(**defaults)
|
| |
|
| |
|
| | def format_observation(obs: ExperimentObservation) -> str:
|
| | parts = [
|
| | f"TASK: {obs.task.problem_statement}",
|
| | f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
|
| | f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
|
| | (
|
| | "Step: "
|
| | f"{obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} "
|
| | f"| Time: {obs.resource_usage.time_remaining_days:.0f}d"
|
| | ),
|
| | ]
|
| | context = build_agent_observation_context(obs, max_tools=5, max_assays=2)
|
| | if context:
|
| | parts.append(context)
|
| | if obs.pipeline_history:
|
| | last5 = obs.pipeline_history[-5:]
|
| | parts.append("Recent history:")
|
| | for step in last5:
|
| | tag = "OK" if step.success else "FAIL"
|
| | line = f" [{tag}] {step.action_type.value}"
|
| | if step.method:
|
| | line += f" ({step.method})"
|
| | line += f": {step.output_summary[:80]}"
|
| | parts.append(line)
|
| | completed = {
|
| | step.action_type for step in obs.pipeline_history if step.success
|
| | }
|
| | if completed:
|
| | parts.append(
|
| | "Completed steps (do NOT repeat): "
|
| | + ", ".join(sorted(action.value for action in completed))
|
| | )
|
| | remaining = [
|
| | action.value for action in HEURISTIC_SEQUENCE if action not in completed
|
| | ]
|
| | if remaining:
|
| | parts.append(f"Remaining steps (choose one): {', '.join(remaining)}")
|
| | if obs.latest_output and obs.latest_output.data:
|
| | parts.append(
|
| | f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
|
| | )
|
| | if obs.rule_violations:
|
| | parts.append(f"VIOLATIONS: {obs.rule_violations}")
|
| | if obs.discovered_markers:
|
| | parts.append(f"Markers found so far: {obs.discovered_markers[:5]}")
|
| | if obs.candidate_mechanisms:
|
| | parts.append(f"Candidate mechanisms: {obs.candidate_mechanisms[:5]}")
|
| | parts.append(
|
| | 'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n'
|
| | '{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}'
|
| | )
|
| | return "\n".join(parts)
|
| |
|
| |
|
| | def build_training_prompt(obs: ExperimentObservation) -> str:
|
| | return f"{SYSTEM_PROMPT}\n\n{format_observation(obs)}"
|
| |
|
| |
|
| | def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
|
| | seen = set(history)
|
| | for action in HEURISTIC_SEQUENCE:
|
| | if action not in seen:
|
| | return action
|
| | if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
|
| | return ActionType.VALIDATE_MARKER
|
| | return ActionType.SYNTHESIZE_CONCLUSION
|
| |
|
| |
|
| | def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
|
| | if policy == "random":
|
| | return random.choice(list(ActionType))
|
| | return heuristic_next_action(history, step_index)
|
| |
|
| |
|
| | def default_comparison_name(conditions: Sequence[str]) -> str:
|
| | normalized = {condition.lower() for condition in conditions}
|
| | if {"healthy", "ipf"} <= normalized:
|
| | return "IPF_vs_healthy"
|
| | if any("treated" in condition for condition in normalized) and any(
|
| | "untreated" in condition for condition in normalized
|
| | ):
|
| | return "treated_vs_untreated"
|
| | return "disease_vs_healthy"
|
| |
|
| |
|
| | def build_experiment_action(
|
| | action_type: ActionType,
|
| | discovered_markers: Sequence[str],
|
| | candidate_mechanisms: Sequence[str],
|
| | conditions: Sequence[str],
|
| | ) -> ExperimentAction:
|
| | method = None
|
| | parameters: Dict[str, object] = {}
|
| | justification = f"Advance the experiment with {action_type.value}."
|
| |
|
| | if action_type == ActionType.COLLECT_SAMPLE:
|
| | parameters = {"n_samples": 6}
|
| | justification = "Collect enough samples to start the experiment."
|
| | elif action_type == ActionType.SELECT_COHORT:
|
| | parameters = {
|
| | "comparison": default_comparison_name(conditions),
|
| | "conditions": list(conditions[:2]) or ["disease", "healthy"],
|
| | }
|
| | justification = "Define the cohort split before committing to downstream analysis."
|
| | elif action_type == ActionType.PREPARE_LIBRARY:
|
| | method = "10x_chromium"
|
| | justification = "Prepare a single-cell library for sequencing."
|
| | elif action_type == ActionType.CULTURE_CELLS:
|
| | method = "organoid_culture"
|
| | parameters = {"duration_days": 7}
|
| | justification = "Expand viable cells before a perturbation or profiling step."
|
| | elif action_type == ActionType.PERTURB_GENE:
|
| | method = "CRISPRi"
|
| | parameters = {"target_gene": candidate_mechanisms[0] if candidate_mechanisms else "STAT3"}
|
| | justification = "Test whether a candidate regulator causally shifts cell state."
|
| | elif action_type == ActionType.PERTURB_COMPOUND:
|
| | method = "small_molecule_screen"
|
| | parameters = {"compound": candidate_mechanisms[0] if candidate_mechanisms else "TGFb_inhibitor"}
|
| | justification = "Probe the pathway hypothesis with a targeted compound perturbation."
|
| | elif action_type == ActionType.SEQUENCE_CELLS:
|
| | method = "NovaSeq"
|
| | justification = "Generate reads for downstream single-cell analysis."
|
| | elif action_type == ActionType.RUN_QC:
|
| | method = "scanpy.pp.calculate_qc_metrics"
|
| | justification = "Measure technical quality before filtering."
|
| | elif action_type == ActionType.FILTER_DATA:
|
| | method = "scanpy.pp.filter_cells"
|
| | justification = "Remove low-quality cells and technical artifacts."
|
| | elif action_type == ActionType.NORMALIZE_DATA:
|
| | method = "scanpy.pp.normalize_total"
|
| | justification = "Normalize counts for comparable expression profiles."
|
| | elif action_type == ActionType.INTEGRATE_BATCHES:
|
| | method = "scanorama.integrate"
|
| | justification = "Correct batch effects before comparing cellular programs."
|
| | elif action_type == ActionType.CLUSTER_CELLS:
|
| | method = "scanpy.tl.leiden"
|
| | justification = "Resolve cell states before interpretation."
|
| | elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| | method = "scanpy.tl.rank_genes_groups"
|
| | parameters = {"comparison": default_comparison_name(conditions)}
|
| | justification = "Identify genes associated with the phenotype of interest."
|
| | elif action_type == ActionType.TRAJECTORY_ANALYSIS:
|
| | method = "scanpy.tl.dpt"
|
| | justification = "Recover pseudotime and lineage structure."
|
| | elif action_type == ActionType.PATHWAY_ENRICHMENT:
|
| | method = "gseapy.prerank"
|
| | justification = "Translate gene-level changes into pathway programs."
|
| | elif action_type == ActionType.MARKER_SELECTION:
|
| | method = "scanpy.tl.rank_genes_groups"
|
| | justification = "Nominate marker genes for validation."
|
| | elif action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| | method = "pySCENIC"
|
| | justification = "Infer upstream regulators behind the observed state changes."
|
| | elif action_type == ActionType.VALIDATE_MARKER:
|
| | method = "qPCR"
|
| | parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
|
| | justification = "Validate the strongest discovered marker."
|
| | elif action_type == ActionType.DESIGN_FOLLOWUP:
|
| | method = "followup_plan"
|
| | parameters = {"priority_hypothesis": candidate_mechanisms[0] if candidate_mechanisms else "fibrotic_activation"}
|
| | justification = "Propose the next experiment to disambiguate remaining uncertainty."
|
| | elif action_type == ActionType.REQUEST_SUBAGENT_REVIEW:
|
| | method = "peer_review"
|
| | parameters = {"focus": "experimental_design"}
|
| | justification = "Request a review of the current self-driving lab plan."
|
| | elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| | top = list(discovered_markers[:5]) if discovered_markers else []
|
| | parameters = {
|
| | "claims": [{
|
| | "top_markers": top,
|
| | "causal_mechanisms": list(candidate_mechanisms[:5]),
|
| | "predicted_pathways": {
|
| | mechanism: 0.6
|
| | for mechanism in list(candidate_mechanisms[:3])
|
| | },
|
| | "confidence": 0.6,
|
| | "claim_type": "causal" if candidate_mechanisms else "correlational",
|
| | "claim": f"Synthesis for {default_comparison_name(conditions)}.",
|
| | }],
|
| | }
|
| | justification = "Summarize the current evidence into a conclusion."
|
| |
|
| | return ExperimentAction(
|
| | action_type=action_type,
|
| | method=method,
|
| | parameters=parameters,
|
| | justification=justification,
|
| | confidence=0.75,
|
| | )
|
| |
|
| |
|
| | def selected_scenarios(requested: Optional[Sequence[str]]) -> List[str]:
|
| | from server.tasks.procedural_generator import generate_procedural_scenarios
|
| | all_scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
|
| | available = [scenario.name for scenario in all_scenarios]
|
| | if not requested:
|
| | return available
|
| | unknown = sorted(set(requested) - set(available))
|
| | if unknown:
|
| | raise ValueError(f"Unknown scenarios requested: {', '.join(unknown)}")
|
| | return list(requested)
|
| |
|
| |
|
| | def action_completion_json(action: ExperimentAction) -> str:
|
| | payload = {
|
| | "action_type": action.action_type.value,
|
| | "method": action.method,
|
| | "parameters": action.parameters,
|
| | "justification": action.justification,
|
| | "confidence": action.confidence,
|
| | }
|
| | return json.dumps(payload)
|
| |
|
| |
|
| | def build_prompt_examples(
|
| | *,
|
| | dataset_episodes: int,
|
| | rollout_steps: int,
|
| | collection_policy: str,
|
| | scenario_names: Sequence[str],
|
| | seed: int,
|
| | domain_randomise: bool,
|
| | ) -> List[Dict[str, str]]:
|
| | rng = random.Random(seed)
|
| | examples: List[Dict[str, str]] = []
|
| | scenario_cycle = list(scenario_names)
|
| | rng.shuffle(scenario_cycle)
|
| |
|
| | for episode_idx in range(dataset_episodes):
|
| | scenario_name = scenario_cycle[episode_idx % len(scenario_cycle)]
|
| | env = BioExperimentEnvironment(
|
| | scenario_name=scenario_name,
|
| | domain_randomise=domain_randomise,
|
| | )
|
| | obs = env.reset()
|
| | history_actions: List[ExperimentAction] = []
|
| |
|
| | for step_idx in range(rollout_steps):
|
| | if obs.done:
|
| | break
|
| |
|
| | next_action = build_experiment_action(
|
| | action_type=pick_action(
|
| | collection_policy,
|
| | step_idx,
|
| | [action.action_type for action in history_actions],
|
| | ),
|
| | discovered_markers=obs.discovered_markers,
|
| | candidate_mechanisms=obs.candidate_mechanisms,
|
| | conditions=obs.task.conditions,
|
| | )
|
| | examples.append({
|
| | "prompt": build_training_prompt(obs),
|
| | "scenario_name": scenario_name,
|
| | "history_actions": json.dumps(
|
| | [action.model_dump() for action in history_actions]
|
| | ),
|
| | "rng_seed": str(env._latent.rng_seed),
|
| | "reference_action": action_completion_json(next_action),
|
| | "problem_statement": obs.task.problem_statement,
|
| | })
|
| |
|
| | history_actions.append(next_action)
|
| | obs = env.step(next_action)
|
| |
|
| | return examples
|
| |
|
| |
|
| | def completion_to_text(completion: Any) -> str:
|
| | if isinstance(completion, str):
|
| | return completion.strip()
|
| | if isinstance(completion, dict):
|
| | return content_to_text(completion.get("content", ""))
|
| | if isinstance(completion, list):
|
| | for item in reversed(completion):
|
| | if isinstance(item, dict) and "content" in item:
|
| | text = content_to_text(item["content"])
|
| | if text:
|
| | return text
|
| | if isinstance(item, str) and item.strip():
|
| | return item.strip()
|
| | return str(completion).strip()
|
| |
|
| |
|
| | def content_to_text(content: Any) -> str:
|
| | if isinstance(content, str):
|
| | return content.strip()
|
| | if isinstance(content, list):
|
| | parts: List[str] = []
|
| | for part in content:
|
| | if isinstance(part, str):
|
| | parts.append(part)
|
| | elif isinstance(part, dict):
|
| | if isinstance(part.get("text"), str):
|
| | parts.append(part["text"])
|
| | elif isinstance(part.get("content"), str):
|
| | parts.append(part["content"])
|
| | return "".join(parts).strip()
|
| | return str(content).strip()
|
| |
|
| |
|
| | def _repair_truncated_json(text: str) -> Optional[str]:
|
| | """Try to repair JSON truncated mid-value (common with small LLMs)."""
|
| | s = text.strip()
|
| | if not s.startswith("{"):
|
| | return None
|
| |
|
| | s = re.sub(r',\s*"[^"\n]*$', '', s)
|
| | s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
|
| |
|
| | in_string = False
|
| | escape = False
|
| | for ch in s:
|
| | if escape:
|
| | escape = False
|
| | continue
|
| | if ch == "\\":
|
| | escape = True
|
| | continue
|
| | if ch == '"':
|
| | in_string = not in_string
|
| |
|
| | if in_string:
|
| | s += '"'
|
| |
|
| | open_braces = s.count("{") - s.count("}")
|
| | open_brackets = s.count("[") - s.count("]")
|
| | s += "]" * max(0, open_brackets)
|
| | s += "}" * max(0, open_braces)
|
| |
|
| | try:
|
| | obj = json.loads(s)
|
| | if isinstance(obj, dict):
|
| | return s
|
| | except json.JSONDecodeError:
|
| | pass
|
| |
|
| | s = re.sub(r',\s*([}\]])', r'\1', s)
|
| | try:
|
| | obj = json.loads(s)
|
| | if isinstance(obj, dict):
|
| | return s
|
| | except json.JSONDecodeError:
|
| | pass
|
| | return None
|
| |
|
| |
|
| | def _normalize_jsonish_text(text: str) -> str:
|
| | """Normalize common near-JSON artifacts emitted by small local models."""
|
| | text = _strip_js_comments(text)
|
| | text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
|
| | text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
|
| | text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
|
| | text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
|
| | return text
|
| |
|
| |
|
| | def _strip_js_comments(text: str) -> str:
|
| | """Remove // and /* */ comments that small LLMs inject into JSON."""
|
| | text = re.sub(r'//[^\n]*', '', text)
|
| | text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
| | return text
|
| |
|
| |
|
| | def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
| | stripped = _normalize_jsonish_text(text).strip()
|
| | fence_prefix = "```"
|
| | if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
|
| | lines = stripped.splitlines()
|
| | if len(lines) >= 3:
|
| | stripped = "\n".join(lines[1:-1]).strip()
|
| |
|
| | candidates: List[str] = [stripped]
|
| | start = stripped.find("{")
|
| | while start != -1:
|
| | depth = 0
|
| | for idx in range(start, len(stripped)):
|
| | char = stripped[idx]
|
| | if char == "{":
|
| | depth += 1
|
| | elif char == "}":
|
| | depth -= 1
|
| | if depth == 0:
|
| | candidates.append(stripped[start:idx + 1])
|
| | break
|
| | start = stripped.find("{", start + 1)
|
| |
|
| | first_brace = stripped.find("{")
|
| | if first_brace != -1:
|
| | repaired = _repair_truncated_json(stripped[first_brace:])
|
| | if repaired is not None:
|
| | candidates.append(repaired)
|
| |
|
| | candidates.sort(key=len, reverse=True)
|
| |
|
| | for candidate in candidates:
|
| | try:
|
| | parsed = json.loads(candidate)
|
| | except json.JSONDecodeError:
|
| | continue
|
| | if isinstance(parsed, dict):
|
| | return parsed
|
| |
|
| | return None
|
| |
|
| |
|
| | def normalize_optional_string(value: Any) -> Optional[str]:
|
| | if value is None or isinstance(value, bool):
|
| | return None
|
| | if isinstance(value, str):
|
| | value = value.strip()
|
| | return value or None
|
| | if isinstance(value, (int, float)):
|
| | return str(value)
|
| | return compact_preview(value, 80)
|
| |
|
| |
|
| | def normalize_action_type(raw_action_type: Any) -> Optional[str]:
|
| | if not isinstance(raw_action_type, str):
|
| | return None
|
| |
|
| | candidate = raw_action_type.strip().lower()
|
| | if candidate in ACTION_TYPES:
|
| | return candidate
|
| | if candidate in ACTION_TYPE_ALIASES:
|
| | return ACTION_TYPE_ALIASES[candidate]
|
| |
|
| | candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
|
| | if candidate in ACTION_TYPES:
|
| | return candidate
|
| | if candidate in ACTION_TYPE_ALIASES:
|
| | return ACTION_TYPE_ALIASES[candidate]
|
| |
|
| | heuristics = [
|
| | (("collect", "sample"), ActionType.COLLECT_SAMPLE.value),
|
| | (("cohort",), ActionType.SELECT_COHORT.value),
|
| | (("library",), ActionType.PREPARE_LIBRARY.value),
|
| | (("culture",), ActionType.CULTURE_CELLS.value),
|
| | (("perturb", "gene"), ActionType.PERTURB_GENE.value),
|
| | (("perturb", "compound"), ActionType.PERTURB_COMPOUND.value),
|
| | (("sequence",), ActionType.SEQUENCE_CELLS.value),
|
| | (("qc",), ActionType.RUN_QC.value),
|
| | (("quality", "control"), ActionType.RUN_QC.value),
|
| | (("filter",), ActionType.FILTER_DATA.value),
|
| | (("normal",), ActionType.NORMALIZE_DATA.value),
|
| | (("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value),
|
| | (("cluster",), ActionType.CLUSTER_CELLS.value),
|
| | (("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value),
|
| | (("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
|
| | (("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value),
|
| | (("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value),
|
| | (("marker",), ActionType.MARKER_SELECTION.value),
|
| | (("validat", "marker"), ActionType.VALIDATE_MARKER.value),
|
| | (("followup",), ActionType.DESIGN_FOLLOWUP.value),
|
| | (("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value),
|
| | (("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value),
|
| | ]
|
| | for fragments, normalized in heuristics:
|
| | if all(fragment in candidate for fragment in fragments):
|
| | return normalized
|
| | return None
|
| |
|
| |
|
| | def _unique_nonempty(items: Sequence[Any], limit: int = 5) -> List[str]:
|
| | seen: set[str] = set()
|
| | result: List[str] = []
|
| | for raw in items:
|
| | value = normalize_optional_string(raw)
|
| | if not value:
|
| | continue
|
| | key = value.upper()
|
| | if key in seen:
|
| | continue
|
| | seen.add(key)
|
| | result.append(value)
|
| | if len(result) >= limit:
|
| | break
|
| | return result
|
| |
|
| |
|
| | def _infer_conclusion_evidence(
|
| | obs: ExperimentObservation,
|
| | ) -> Tuple[List[str], List[str], Dict[str, float]]:
|
| | top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5)
|
| | causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5)
|
| | predicted_pathways: Dict[str, float] = {}
|
| |
|
| | for output in reversed(obs.all_outputs):
|
| | if not output.success:
|
| | continue
|
| | data = output.data or {}
|
| |
|
| | if not top_markers:
|
| | markers = data.get("markers", [])
|
| | if isinstance(markers, list):
|
| | top_markers = _unique_nonempty(markers, limit=5)
|
| | if not causal_mechanisms:
|
| | regulators = data.get("top_regulators", [])
|
| | if isinstance(regulators, list):
|
| | causal_mechanisms = _unique_nonempty(regulators, limit=5)
|
| | if not predicted_pathways:
|
| | for item in data.get("top_pathways", []):
|
| | if not isinstance(item, dict):
|
| | continue
|
| | pathway = normalize_optional_string(item.get("pathway"))
|
| | score = item.get("score")
|
| | if pathway and isinstance(score, (int, float)):
|
| | predicted_pathways[pathway] = float(score)
|
| | if len(predicted_pathways) >= 5:
|
| | break
|
| | if top_markers and causal_mechanisms and predicted_pathways:
|
| | break
|
| |
|
| | return top_markers, causal_mechanisms, predicted_pathways
|
| |
|
| |
|
| | def ensure_conclusion_claims(
|
| | obs: ExperimentObservation,
|
| | action: ExperimentAction,
|
| | ) -> ExperimentAction:
|
| | if action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
|
| | return action
|
| |
|
| | parameters = dict(action.parameters or {})
|
| | raw_claims = parameters.get("claims")
|
| | if isinstance(raw_claims, list):
|
| | normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)]
|
| | if normalized_claims:
|
| | parameters["claims"] = normalized_claims
|
| | if parameters != action.parameters:
|
| | return action.model_copy(update={"parameters": parameters})
|
| | return action
|
| |
|
| | top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs)
|
| | claim_type = "causal" if causal_mechanisms else "correlational"
|
| | conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions"
|
| | claim = action.justification or f"Final synthesis for {conditions}."
|
| |
|
| | parameters["claims"] = [{
|
| | "top_markers": top_markers,
|
| | "causal_mechanisms": causal_mechanisms,
|
| | "predicted_pathways": predicted_pathways,
|
| | "confidence": action.confidence,
|
| | "claim_type": claim_type,
|
| | "claim": claim,
|
| | }]
|
| | if not action.justification:
|
| | action = action.model_copy(update={"justification": claim})
|
| | return action.model_copy(update={"parameters": parameters})
|
| |
|
| |
|
| | def parse_action_completion(text: str) -> Optional[ExperimentAction]:
|
| | payload = extract_json_object(text)
|
| | if payload is not None:
|
| | action_type = normalize_action_type(get_payload_value(payload, "action_type"))
|
| | if action_type is None:
|
| | return None
|
| |
|
| | parameters = get_payload_value(payload, "parameters", "params") or {}
|
| | if not isinstance(parameters, dict):
|
| | parameters = {}
|
| |
|
| | confidence = get_payload_value(payload, "confidence")
|
| | if confidence is None:
|
| | confidence = 0.5
|
| | try:
|
| | confidence = float(confidence)
|
| | except (TypeError, ValueError):
|
| | confidence = 0.5
|
| |
|
| | justification = get_payload_value(
|
| | payload, "justification", "reasoning", "rationale", "reason"
|
| | )
|
| | if justification is not None and not isinstance(justification, str):
|
| | justification = compact_preview(justification, 200)
|
| |
|
| | return ExperimentAction(
|
| | action_type=ActionType(action_type),
|
| | method=normalize_optional_string(get_payload_value(payload, "method")),
|
| | parameters=parameters,
|
| | justification=justification,
|
| | confidence=min(1.0, max(0.0, confidence)),
|
| | )
|
| |
|
| | action_match = re.search(
|
| | r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
|
| | text,
|
| | re.IGNORECASE,
|
| | )
|
| | if not action_match:
|
| | return None
|
| |
|
| | action_type = normalize_action_type(action_match.group(1))
|
| | if action_type is None:
|
| | return None
|
| |
|
| | method_match = re.search(
|
| | r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)',
|
| | text,
|
| | re.IGNORECASE,
|
| | )
|
| | confidence_match = re.search(
|
| | r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
|
| | text,
|
| | re.IGNORECASE,
|
| | )
|
| | justification_match = re.search(
|
| | r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
|
| | text,
|
| | re.DOTALL | re.IGNORECASE,
|
| | )
|
| |
|
| | confidence = 0.5
|
| | if confidence_match:
|
| | try:
|
| | confidence = float(confidence_match.group(1))
|
| | except ValueError:
|
| | confidence = 0.5
|
| |
|
| | justification = None
|
| | if justification_match:
|
| | try:
|
| | justification = json.loads(f'"{justification_match.group(1)}"')
|
| | except json.JSONDecodeError:
|
| | justification = justification_match.group(1)
|
| |
|
| | method = None
|
| | if method_match:
|
| | raw_method = method_match.group(1)
|
| | if raw_method.startswith('"') and raw_method.endswith('"'):
|
| | try:
|
| | method = json.loads(raw_method)
|
| | except json.JSONDecodeError:
|
| | method = raw_method.strip('"')
|
| | elif raw_method.lower() not in {"null", "none", "true", "false"}:
|
| | method = raw_method
|
| |
|
| | return ExperimentAction(
|
| | action_type=ActionType(action_type),
|
| | method=normalize_optional_string(method),
|
| | parameters={},
|
| | justification=justification,
|
| | confidence=min(1.0, max(0.0, confidence)),
|
| | )
|
| |
|
| |
|
| | def decode_history_actions(history_actions: Optional[str]) -> List[ExperimentAction]:
|
| | if not history_actions:
|
| | return []
|
| | raw_actions = json.loads(history_actions)
|
| | return [
|
| | ExperimentAction(**action_payload)
|
| | for action_payload in raw_actions
|
| | if isinstance(action_payload, dict)
|
| | ]
|
| |
|
| |
|
| | def normalise_column(values: Any, length: int) -> List[Any]:
|
| | if values is None:
|
| | return [None] * length
|
| | if isinstance(values, list):
|
| | if len(values) == length:
|
| | return values
|
| | if len(values) == 1:
|
| | return values * length
|
| | return values[:length] + [None] * max(0, length - len(values))
|
| | return [values] * length
|
| |
|
| |
|
| | class OpenEnvReward:
|
| | """Reward function compatible with TRL GRPOTrainer."""
|
| |
|
| | def __init__(
|
| | self,
|
| | *,
|
| | reward_backend: str,
|
| | base_url: str,
|
| | invalid_action_penalty: float = INVALID_ACTION_PENALTY,
|
| | environment_error_penalty: float = ENVIRONMENT_ERROR_PENALTY,
|
| | domain_randomise: bool = False,
|
| | ) -> None:
|
| | self.__name__ = "openenv_reward"
|
| | self.reward_backend = reward_backend
|
| | self.base_url = base_url
|
| | self.invalid_action_penalty = invalid_action_penalty
|
| | self.environment_error_penalty = environment_error_penalty
|
| | self.domain_randomise = domain_randomise
|
| |
|
| | def __call__(
|
| | self,
|
| | completions: List[Any],
|
| | scenario_name: Optional[List[str]] = None,
|
| | history_actions: Optional[List[str]] = None,
|
| | rng_seed: Optional[List[str]] = None,
|
| | **_: Any,
|
| | ) -> List[float]:
|
| | scenario_names = normalise_column(scenario_name, len(completions))
|
| | history_columns = normalise_column(history_actions, len(completions))
|
| | seed_columns = normalise_column(rng_seed, len(completions))
|
| | rewards: List[float] = []
|
| |
|
| | for completion, current_scenario, current_history, current_seed in zip(
|
| | completions,
|
| | scenario_names,
|
| | history_columns,
|
| | seed_columns,
|
| | ):
|
| | action = parse_action_completion(completion_to_text(completion))
|
| | if action is None:
|
| | rewards.append(self.invalid_action_penalty)
|
| | continue
|
| |
|
| | try:
|
| | if self.reward_backend == "remote":
|
| | reward = self._score_remote(action, current_scenario, current_history)
|
| | else:
|
| | reward = self._score_local(action, current_scenario, current_history, current_seed)
|
| | except Exception:
|
| | reward = self.environment_error_penalty
|
| |
|
| | rewards.append(float(reward))
|
| |
|
| | return rewards
|
| |
|
| | def _score_local(
|
| | self,
|
| | action: ExperimentAction,
|
| | scenario_name: Optional[str],
|
| | history_actions: Optional[str],
|
| | rng_seed: Optional[str] = None,
|
| | ) -> float:
|
| | env = BioExperimentEnvironment(
|
| | scenario_name=scenario_name,
|
| | domain_randomise=self.domain_randomise,
|
| | )
|
| | seed = int(rng_seed) if rng_seed else None
|
| | obs = env.reset(seed=seed)
|
| | for previous_action in decode_history_actions(history_actions):
|
| | obs = env.step(previous_action)
|
| | if obs.done:
|
| | return float(obs.reward)
|
| | action = ensure_conclusion_claims(obs, action)
|
| | obs = env.step(action)
|
| | return float(obs.reward)
|
| |
|
| | def _score_remote(
|
| | self,
|
| | action: ExperimentAction,
|
| | scenario_name: Optional[str],
|
| | history_actions: Optional[str],
|
| | ) -> float:
|
| | with BioExperimentEnv(base_url=self.base_url) as env:
|
| |
|
| |
|
| |
|
| | result = env.reset()
|
| | for previous_action in decode_history_actions(history_actions):
|
| | result = env.step(previous_action)
|
| | if result.done:
|
| | return float(result.reward or 0.0)
|
| | result = env.step(action)
|
| | if result.reward is not None:
|
| | return float(result.reward)
|
| | return float(result.observation.reward)
|
| |
|
| |
|
| | def is_numeric_log_value(value: Any) -> bool:
|
| | return isinstance(value, Real) and not isinstance(value, bool)
|
| |
|
| |
|
| | def available_numeric_log_keys(log_history: Sequence[Dict[str, Any]]) -> List[str]:
|
| | keys = {
|
| | key
|
| | for entry in log_history
|
| | if isinstance(entry, dict)
|
| | for key, value in entry.items()
|
| | if key != "step" and is_numeric_log_value(value)
|
| | }
|
| | return sorted(keys)
|
| |
|
| |
|
| | def extract_log_series(
|
| | log_history: Sequence[Dict[str, Any]],
|
| | key: Optional[str],
|
| | ) -> List[Tuple[float, float]]:
|
| | if not key:
|
| | return []
|
| |
|
| | series: List[Tuple[float, float]] = []
|
| | synthetic_step = 0
|
| | for entry in log_history:
|
| | if not isinstance(entry, dict) or key not in entry:
|
| | continue
|
| | value = entry.get(key)
|
| | if not is_numeric_log_value(value):
|
| | continue
|
| |
|
| | raw_step = entry.get("step")
|
| | if is_numeric_log_value(raw_step):
|
| | step = float(raw_step)
|
| | else:
|
| | synthetic_step += 1
|
| | step = float(synthetic_step)
|
| |
|
| | series.append((step, float(value)))
|
| |
|
| | return series
|
| |
|
| |
|
| | def select_reward_key(log_history: Sequence[Dict[str, Any]]) -> Optional[str]:
|
| | numeric_keys = available_numeric_log_keys(log_history)
|
| | reward_keys = [key for key in numeric_keys if "reward" in key.lower()]
|
| | if not reward_keys:
|
| | return None
|
| |
|
| | preferred = [
|
| | "reward",
|
| | "mean_reward",
|
| | "reward_mean",
|
| | "rewards/open_env_reward",
|
| | ]
|
| | lowered = {key.lower(): key for key in reward_keys}
|
| | for key in preferred:
|
| | if key in lowered:
|
| | return lowered[key]
|
| |
|
| | reward_keys.sort(key=lambda key: ("/" in key, len(key), key))
|
| | return reward_keys[0]
|
| |
|
| |
|
| | def select_metric_key(
|
| | log_history: Sequence[Dict[str, Any]],
|
| | *,
|
| | reward_key: Optional[str],
|
| | requested_key: Optional[str] = None,
|
| | ) -> Optional[str]:
|
| | numeric_keys = available_numeric_log_keys(log_history)
|
| | if requested_key:
|
| | if requested_key not in numeric_keys:
|
| | available = ", ".join(numeric_keys) or "none"
|
| | raise ValueError(
|
| | f"Requested plot metric '{requested_key}' was not logged. "
|
| | f"Available numeric keys: {available}"
|
| | )
|
| | return requested_key
|
| |
|
| | excluded = {
|
| | "epoch",
|
| | "loss",
|
| | "learning_rate",
|
| | "step",
|
| | "total_flos",
|
| | "train_loss",
|
| | "train_runtime",
|
| | "train_samples_per_second",
|
| | "train_steps_per_second",
|
| | }
|
| | if reward_key:
|
| | excluded.add(reward_key)
|
| |
|
| | preferred = [
|
| | "kl",
|
| | "objective/kl",
|
| | "completion_length",
|
| | "mean_completion_length",
|
| | "grad_norm",
|
| | "entropy",
|
| | "accuracy",
|
| | "learning_rate",
|
| | "epoch",
|
| | ]
|
| | numeric_set = set(numeric_keys)
|
| | for key in preferred:
|
| | if key in numeric_set and key not in excluded:
|
| | return key
|
| |
|
| | candidates = [
|
| | key for key in numeric_keys
|
| | if key not in excluded and "reward" not in key.lower()
|
| | ]
|
| | if candidates:
|
| | return candidates[0]
|
| |
|
| | for fallback in ("learning_rate", "epoch"):
|
| | if fallback in numeric_set:
|
| | return fallback
|
| |
|
| | return None
|
| |
|
| |
|
| | def save_plot(
|
| | path: Path,
|
| | *,
|
| | series: Sequence[Tuple[float, float]],
|
| | title: str,
|
| | ylabel: str,
|
| | ) -> None:
|
| | import matplotlib
|
| |
|
| | matplotlib.use("Agg")
|
| | import matplotlib.pyplot as plt
|
| |
|
| | fig, ax = plt.subplots(figsize=(8, 4.5))
|
| | if series:
|
| | x_values, y_values = zip(*series)
|
| | ax.plot(x_values, y_values, marker="o", linewidth=1.8)
|
| | else:
|
| | ax.text(
|
| | 0.5,
|
| | 0.5,
|
| | "No logged data available",
|
| | ha="center",
|
| | va="center",
|
| | transform=ax.transAxes,
|
| | )
|
| | ax.set_title(title)
|
| | ax.set_xlabel("Step")
|
| | ax.set_ylabel(ylabel)
|
| | ax.grid(True, alpha=0.3)
|
| | fig.tight_layout()
|
| | fig.savefig(path, dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def save_training_plots(
|
| | log_history: Sequence[Dict[str, Any]],
|
| | output_dir: str | Path,
|
| | metric_key: Optional[str] = None,
|
| | ) -> Dict[str, str]:
|
| | import matplotlib
|
| |
|
| | matplotlib.use("Agg")
|
| | import matplotlib.pyplot as plt
|
| |
|
| | output_path = Path(output_dir)
|
| | output_path.mkdir(parents=True, exist_ok=True)
|
| |
|
| | reward_key = select_reward_key(log_history)
|
| | selected_metric_key = select_metric_key(
|
| | log_history,
|
| | reward_key=reward_key,
|
| | requested_key=metric_key,
|
| | )
|
| |
|
| | loss_series = extract_log_series(log_history, "loss")
|
| | reward_series = extract_log_series(log_history, reward_key)
|
| | metric_series = extract_log_series(log_history, selected_metric_key)
|
| |
|
| | loss_path = output_path / "training_loss.png"
|
| | reward_path = output_path / "training_reward.png"
|
| | metric_path = output_path / "training_metric.png"
|
| | dashboard_path = output_path / "training_dashboard.png"
|
| | manifest_path = output_path / "training_plot_manifest.json"
|
| |
|
| | save_plot(loss_path, series=loss_series, title="Training Loss", ylabel="Loss")
|
| | save_plot(
|
| | reward_path,
|
| | series=reward_series,
|
| | title=f"Training Reward ({reward_key or 'not logged'})",
|
| | ylabel="Reward",
|
| | )
|
| | save_plot(
|
| | metric_path,
|
| | series=metric_series,
|
| | title=f"Training Metric ({selected_metric_key or 'not logged'})",
|
| | ylabel=selected_metric_key or "Metric",
|
| | )
|
| |
|
| | fig, axes = plt.subplots(3, 1, figsize=(10, 12))
|
| | plot_specs = [
|
| | (axes[0], loss_series, "Training Loss", "Loss"),
|
| | (axes[1], reward_series, f"Training Reward ({reward_key or 'not logged'})", "Reward"),
|
| | (
|
| | axes[2],
|
| | metric_series,
|
| | f"Training Metric ({selected_metric_key or 'not logged'})",
|
| | selected_metric_key or "Metric",
|
| | ),
|
| | ]
|
| | for axis, series, title, ylabel in plot_specs:
|
| | if series:
|
| | x_values, y_values = zip(*series)
|
| | axis.plot(x_values, y_values, marker="o", linewidth=1.8)
|
| | else:
|
| | axis.text(
|
| | 0.5,
|
| | 0.5,
|
| | "No logged data available",
|
| | ha="center",
|
| | va="center",
|
| | transform=axis.transAxes,
|
| | )
|
| | axis.set_title(title)
|
| | axis.set_xlabel("Step")
|
| | axis.set_ylabel(ylabel)
|
| | axis.grid(True, alpha=0.3)
|
| | fig.tight_layout()
|
| | fig.savefig(dashboard_path, dpi=150)
|
| | plt.close(fig)
|
| |
|
| | manifest = {
|
| | "available_numeric_keys": available_numeric_log_keys(log_history),
|
| | "reward_key": reward_key,
|
| | "metric_key": selected_metric_key,
|
| | "plots": {
|
| | "loss": str(loss_path),
|
| | "reward": str(reward_path),
|
| | "metric": str(metric_path),
|
| | "dashboard": str(dashboard_path),
|
| | },
|
| | }
|
| | manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
|
| | return manifest["plots"]
|
| |
|
| |
|
| | def run_dry_run_preview(
|
| | examples: Sequence[Dict[str, str]],
|
| | reward_fn: OpenEnvReward,
|
| | output_dir: str,
|
| | ) -> None:
|
| | if not examples:
|
| | raise ValueError("No training prompts were generated for the dry run.")
|
| |
|
| | sample = examples[0]
|
| | sample_reward = reward_fn(
|
| | completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
|
| | scenario_name=[sample["scenario_name"]],
|
| | history_actions=[sample["history_actions"]],
|
| | )[0]
|
| |
|
| | print(f"Built {len(examples)} prompt states.")
|
| | print(f"Output directory: {Path(output_dir)}")
|
| | print(f"Sample scenario: {sample['scenario_name']}")
|
| | print(f"Sample reward for reference action: {sample_reward:+.3f}")
|
| | print("\nSample prompt:\n")
|
| | print(sample["prompt"])
|
| |
|
| |
|
| | def resolve_torch_runtime() -> Dict[str, Any]:
|
| | import torch
|
| |
|
| | use_cuda = torch.cuda.is_available()
|
| | bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
|
| | dtype = torch.bfloat16 if bf16 else (
|
| | torch.float16 if use_cuda else torch.float32
|
| | )
|
| | return {
|
| | "use_cuda": use_cuda,
|
| | "device": "cuda:0" if use_cuda else "cpu",
|
| | "dtype": dtype,
|
| | "bf16": bf16,
|
| | "fp16": use_cuda and not bf16,
|
| | "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
|
| | }
|
| |
|
| |
|
| | def _guard_invalid_torchao_version() -> None:
|
| | """Treat malformed torchao installs as unavailable for HF imports."""
|
| | import functools
|
| | import importlib.metadata as importlib_metadata
|
| | import sys
|
| | from packaging.version import InvalidVersion, Version
|
| |
|
| | if getattr(importlib_metadata, "_openenv_torchao_guard_installed", False):
|
| | metadata_guard_installed = True
|
| | else:
|
| | original_version = importlib_metadata.version
|
| |
|
| | def guarded_version(distribution_name: str) -> str:
|
| | version = original_version(distribution_name)
|
| | if distribution_name.lower() == "torchao":
|
| | try:
|
| | Version(version)
|
| | except InvalidVersion as exc:
|
| | raise importlib_metadata.PackageNotFoundError(
|
| | f"Malformed torchao version metadata: {version!r}"
|
| | ) from exc
|
| | return version
|
| |
|
| | importlib_metadata.version = guarded_version
|
| | importlib_metadata._openenv_torchao_guard_installed = True
|
| | metadata_guard_installed = False
|
| |
|
| | import_utils = sys.modules.get("transformers.utils.import_utils")
|
| | if import_utils is not None and not getattr(import_utils, "_openenv_torchao_guard_installed", False):
|
| | original_is_package_available = import_utils._is_package_available
|
| |
|
| | def guarded_is_package_available(
|
| | pkg_name: str,
|
| | return_version: bool = False,
|
| | ):
|
| | if pkg_name != "torchao":
|
| | return original_is_package_available(pkg_name, return_version=return_version)
|
| | is_available, package_version = original_is_package_available(
|
| | pkg_name,
|
| | return_version=True,
|
| | )
|
| | if not is_available:
|
| | return (False, package_version) if return_version else (False, None)
|
| | try:
|
| | Version(package_version)
|
| | except InvalidVersion:
|
| | return (False, "0") if return_version else (False, None)
|
| | return (True, package_version) if return_version else (True, None)
|
| |
|
| | min_version = getattr(import_utils, "TORCHAO_MIN_VERSION", "0")
|
| |
|
| | @functools.lru_cache
|
| | def guarded_is_torchao_available(min_version_override: str = min_version) -> bool:
|
| | is_available, package_version = guarded_is_package_available(
|
| | "torchao",
|
| | return_version=True,
|
| | )
|
| | if not is_available:
|
| | return False
|
| | try:
|
| | return Version(package_version) >= Version(min_version_override)
|
| | except InvalidVersion:
|
| | return False
|
| |
|
| | if hasattr(import_utils.is_torchao_available, "cache_clear"):
|
| | import_utils.is_torchao_available.cache_clear()
|
| | import_utils._is_package_available = guarded_is_package_available
|
| | import_utils.is_torchao_available = guarded_is_torchao_available
|
| | import_utils._openenv_torchao_guard_installed = True
|
| |
|
| | transformers_utils = sys.modules.get("transformers.utils")
|
| | if transformers_utils is not None:
|
| | transformers_utils.is_torchao_available = guarded_is_torchao_available
|
| |
|
| | if metadata_guard_installed and import_utils is None:
|
| | return
|
| |
|
| |
|
| | def _guard_partial_vllm_install() -> None:
|
| | """Treat partial vLLM installs as unavailable for TRL imports."""
|
| | import functools
|
| | import importlib
|
| |
|
| | try:
|
| | import trl.import_utils as trl_import_utils
|
| | except Exception:
|
| | return
|
| |
|
| | if getattr(trl_import_utils, "_openenv_vllm_guard_installed", False):
|
| | return
|
| |
|
| | def _has_usable_vllm() -> bool:
|
| | try:
|
| | importlib.import_module("vllm")
|
| | importlib.import_module("vllm.distributed.device_communicators.pynccl")
|
| | importlib.import_module("vllm.distributed.utils")
|
| | except Exception:
|
| | return False
|
| | return True
|
| |
|
| | @functools.lru_cache
|
| | def guarded_is_vllm_available(*args: Any, **kwargs: Any) -> bool:
|
| | return _has_usable_vllm()
|
| |
|
| | if hasattr(trl_import_utils.is_vllm_available, "cache_clear"):
|
| | trl_import_utils.is_vllm_available.cache_clear()
|
| | trl_import_utils.is_vllm_available = guarded_is_vllm_available
|
| | trl_import_utils._openenv_vllm_guard_installed = True
|
| |
|
| |
|
| | def load_model_artifacts(
|
| | model_id: str,
|
| | *,
|
| | trust_remote_code: bool,
|
| | ):
|
| | _guard_invalid_torchao_version()
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| | runtime = resolve_torch_runtime()
|
| | print(f"Loading tokenizer for {model_id} ...")
|
| | tokenizer = AutoTokenizer.from_pretrained(
|
| | model_id,
|
| | trust_remote_code=trust_remote_code,
|
| | )
|
| | if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| | print(f"Loading model for {model_id} ...")
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | model_id,
|
| | trust_remote_code=trust_remote_code,
|
| | torch_dtype=runtime["dtype"],
|
| | )
|
| | if runtime["use_cuda"]:
|
| | model = model.to(runtime["device"])
|
| | else:
|
| | model = model.to("cpu")
|
| | return tokenizer, model
|
| |
|
| |
|
| | def build_openenv_reward(args: argparse.Namespace) -> OpenEnvReward:
|
| | """Return the OpenEnv-compatible reward callable used by GRPO."""
|
| | return OpenEnvReward(
|
| | reward_backend=args.reward_backend,
|
| | base_url=args.base_url,
|
| | domain_randomise=args.domain_randomise,
|
| | )
|
| |
|
| |
|
| | def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
|
| | """Build the OpenEnv rollout states that seed GRPO prompts."""
|
| | scenario_names = selected_scenarios(args.scenario_name)
|
| | examples = build_prompt_examples(
|
| | dataset_episodes=args.dataset_episodes,
|
| | rollout_steps=args.rollout_steps,
|
| | collection_policy=args.collection_policy,
|
| | scenario_names=scenario_names,
|
| | seed=args.seed,
|
| | domain_randomise=args.domain_randomise,
|
| | )
|
| | return {
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | }
|
| |
|
| |
|
| | def build_grpo_config(
|
| | args: argparse.Namespace,
|
| | runtime: Dict[str, Any],
|
| | ):
|
| | import inspect
|
| |
|
| | _guard_invalid_torchao_version()
|
| | _guard_partial_vllm_install()
|
| | from trl import GRPOConfig
|
| |
|
| | config_kwargs = {
|
| | "output_dir": args.output_dir,
|
| | "learning_rate": args.learning_rate,
|
| | "per_device_train_batch_size": args.per_device_train_batch_size,
|
| | "gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| | "num_generations": args.num_generations,
|
| | "max_completion_length": args.max_completion_length,
|
| | "max_prompt_length": args.max_prompt_length,
|
| | "num_train_epochs": args.num_train_epochs,
|
| | "logging_steps": args.logging_steps,
|
| | "save_steps": args.save_steps,
|
| | "bf16": runtime["bf16"],
|
| | "fp16": runtime["fp16"],
|
| | "report_to": "none",
|
| | "remove_unused_columns": False,
|
| | }
|
| | supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
|
| |
|
| |
|
| |
|
| | if (
|
| | "max_length" in supported_params
|
| | and "max_prompt_length" not in supported_params
|
| | and "max_completion_length" not in supported_params
|
| | ):
|
| | config_kwargs["max_length"] = (
|
| | args.max_prompt_length + args.max_completion_length
|
| | )
|
| |
|
| | filtered_kwargs = {
|
| | key: value
|
| | for key, value in config_kwargs.items()
|
| | if key in supported_params
|
| | }
|
| | skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
|
| | if skipped:
|
| | print(
|
| | "GRPOConfig compatibility: skipping unsupported fields "
|
| | f"{', '.join(skipped)}"
|
| | )
|
| |
|
| | return GRPOConfig(**filtered_kwargs)
|
| |
|
| |
|
| | def build_grpo_trainer(
|
| | *,
|
| | model: Any,
|
| | tokenizer: Any,
|
| | reward_func: Any,
|
| | train_dataset: Any,
|
| | args: argparse.Namespace,
|
| | runtime: Dict[str, Any],
|
| | ):
|
| | _guard_invalid_torchao_version()
|
| | _guard_partial_vllm_install()
|
| | from trl import GRPOTrainer
|
| |
|
| | config = build_grpo_config(args, runtime)
|
| | return GRPOTrainer(
|
| | model=model,
|
| | reward_funcs=reward_func,
|
| | args=config,
|
| | train_dataset=train_dataset,
|
| | processing_class=tokenizer,
|
| | )
|
| |
|
| |
|
| | def generate_action_with_model(
|
| | model: Any,
|
| | tokenizer: Any,
|
| | prompt_or_observation: str | ExperimentObservation,
|
| | *,
|
| | max_new_tokens: int = DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| | temperature: float = 0.2,
|
| | top_p: float = 0.9,
|
| | do_sample: bool = True,
|
| | ) -> Dict[str, Any]:
|
| | import torch
|
| |
|
| | if isinstance(prompt_or_observation, ExperimentObservation):
|
| | prompt = build_training_prompt(prompt_or_observation)
|
| | else:
|
| | prompt = str(prompt_or_observation)
|
| |
|
| | model_device = getattr(model, "device", None)
|
| | if model_device is None:
|
| | model_device = resolve_torch_runtime()["device"]
|
| |
|
| | inputs = tokenizer(prompt, return_tensors="pt")
|
| | inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| | prompt_tokens = inputs["input_ids"].shape[1]
|
| |
|
| | generation_kwargs = {
|
| | "max_new_tokens": max_new_tokens,
|
| | "do_sample": do_sample,
|
| | "temperature": temperature,
|
| | "top_p": top_p,
|
| | "pad_token_id": tokenizer.pad_token_id,
|
| | }
|
| | with torch.no_grad():
|
| | output_ids = model.generate(**inputs, **generation_kwargs)
|
| |
|
| | new_tokens = output_ids[0][prompt_tokens:]
|
| | response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| | action = parse_action_completion(response_text)
|
| | if action is not None and isinstance(prompt_or_observation, ExperimentObservation):
|
| | action = ensure_conclusion_claims(prompt_or_observation, action)
|
| | return {
|
| | "prompt": prompt,
|
| | "response_text": response_text,
|
| | "action": action,
|
| | }
|
| |
|
| |
|
| | def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| | random.seed(args.seed)
|
| | runtime = resolve_torch_runtime()
|
| |
|
| | if args.load_model_only:
|
| | tokenizer, model = load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | )
|
| | device = getattr(model, "device", "unknown")
|
| | print(f"Model ready: {args.model_id}")
|
| | print(f"Tokenizer vocab size: {len(tokenizer)}")
|
| | print(f"Model device: {device}")
|
| | print(f"Runtime device name: {runtime['device_name']}")
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | }
|
| |
|
| | prompt_data = prepare_prompt_examples(args)
|
| | scenario_names = prompt_data["scenario_names"]
|
| | examples = prompt_data["examples"]
|
| | reward_fn = build_openenv_reward(args)
|
| |
|
| | if args.dry_run:
|
| | run_dry_run_preview(examples, reward_fn, args.output_dir)
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | "reward_fn": reward_fn,
|
| | }
|
| |
|
| | from datasets import Dataset
|
| | train_dataset = Dataset.from_list(examples)
|
| | tokenizer, model = load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | )
|
| |
|
| | print(
|
| | f"Training runtime: device={runtime['device']} "
|
| | f"name={runtime['device_name']} "
|
| | f"dtype={runtime['dtype']}"
|
| | )
|
| | print(
|
| | "OpenEnv reward: "
|
| | f"backend={args.reward_backend} scenarios={len(scenario_names)} "
|
| | f"examples={len(examples)}"
|
| | )
|
| |
|
| | trainer = build_grpo_trainer(
|
| | model=model,
|
| | train_dataset=train_dataset,
|
| | tokenizer=tokenizer,
|
| | reward_func=reward_fn,
|
| | args=args,
|
| | runtime=runtime,
|
| | )
|
| | trainer.train()
|
| | trainer.save_model(args.output_dir)
|
| | tokenizer.save_pretrained(args.output_dir)
|
| | if args.push_to_hub:
|
| | from huggingface_hub import HfApi
|
| | api = HfApi()
|
| | api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
|
| | print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| | api.upload_folder(
|
| | folder_path=args.output_dir,
|
| | repo_id=args.push_to_hub,
|
| | repo_type="model",
|
| | create_pr=False,
|
| | )
|
| | print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
|
| | plot_paths = save_training_plots(
|
| | trainer.state.log_history,
|
| | args.output_dir,
|
| | metric_key=args.plot_metric_key,
|
| | )
|
| | print("Saved training plots:")
|
| | for plot_name, plot_path in plot_paths.items():
|
| | print(f" - {plot_name}: {plot_path}")
|
| |
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | "reward_fn": reward_fn,
|
| | "train_dataset": train_dataset,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | "trainer": trainer,
|
| | "plot_paths": plot_paths,
|
| | }
|
| |
|
| |
|
| | def main() -> None:
|
| | run_training(parse_args())
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|