|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """URSA → URSA one-step distillation via Di[M]O-style on-policy training.
|
|
|
| Verified native inference regime (from A/B testing — ground truth):
|
| height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
|
| no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint.
|
| All defaults below align to this verified regime.
|
|
|
| Algorithm (9 stages per iteration)
|
| ------------------------------------
|
| teacher : frozen URSA — provides supervision at pseudo-intermediate x_t.
|
| student : trainable copy — 1-step target.
|
| aux : trainable copy — approximates teacher at x_t; reduces REINFORCE variance.
|
|
|
| Stage 1 : tokenise prompts (cond + uncond when CFG enabled) → txt_ids [B,L]
|
| Stage 2 : sample x_init [B,T,H,W] ~ Uniform(K) (+ optional p_init mixing)
|
| Stage 3 : student 1-step forward on x_init (cond only) → x_hat, logp, H
|
| Stage 4 : pseudo-intermediate x_t = scheduler.add_noise(x_hat, t)
|
| Stage 5 : teacher forward on x_t (CFG=7 dual-branch is the default)
|
| Stage 6 : aux forward → Jeffrey KD
|
| Stage 7 : student forward on x_t → KL KD
|
| Stage 8 : reward = -KL(z_T_cond, z_S_cond) [detached]
|
| Stage 9 : two-backward student update
|
|
|
| Usage:
|
| # Smoke test (verified native regime):
|
| python scripts/train_onestep_ursa_dimo.py \\
|
| --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\
|
| --enable_teacher_cfg --teacher_cfg_scale 7.0 \\
|
| --num_frames 49 --height 320 --width 512 --dry_run
|
|
|
| # Full training:
|
| python scripts/train_onestep_ursa_dimo.py \\
|
| --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\
|
| --enable_teacher_cfg --teacher_cfg_scale 7.0 \\
|
| --num_frames 49 --height 320 --width 512 \\
|
| --batch_size 1 --num_steps 10000 --out_dir ./outputs/dimo_cfg
|
| """
|
|
|
| import argparse
|
| import copy
|
| import json
|
| import math
|
| import os
|
| import sys
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from torch.utils.data import DataLoader
|
|
|
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| if _REPO_ROOT not in sys.path:
|
| sys.path.insert(0, _REPO_ROOT)
|
|
|
| from diffnext.pipelines import URSAPipeline
|
| from src.distill.prompt_dataset import InfiniteDataLoader, PromptDataset, make_collate_fn, CSVSpec
|
| from src.distill.utils_ursa_inputs import (
|
| build_ursa_inputs,
|
| compute_latents_shape,
|
| corrupt_tokens,
|
| extract_visual_logits,
|
| sample_t_curriculum,
|
| )
|
|
|
| def _get_logits(out):
|
| if isinstance(out, (tuple, list)):
|
| return out[0]
|
| if hasattr(out, "sample"):
|
| return out.sample
|
| if hasattr(out, "logits"):
|
| return out.logits
|
| return out
|
|
|
|
|
|
|
|
|
|
|
| def parse_args():
|
| p = argparse.ArgumentParser(description="URSA DiMO one-step distillation")
|
|
|
|
|
| p.add_argument("--teacher_ckpt", required=True)
|
| p.add_argument("--prompt_file", required=True)
|
| p.add_argument("--out_dir", default="./outputs/dimo")
|
|
|
|
|
| p.add_argument("--num_frames", type=int, default=49)
|
| p.add_argument("--height", type=int, default=320)
|
| p.add_argument("--width", type=int, default=512)
|
| p.add_argument("--max_prompt_length", type=int, default=320)
|
|
|
|
|
| p.add_argument("--batch_size", type=int, default=1)
|
| p.add_argument("--num_steps", type=int, default=10_000)
|
| p.add_argument("--lr_student", type=float, default=1e-5)
|
| p.add_argument("--lr_aux", type=float, default=1e-5)
|
| p.add_argument("--weight_decay", type=float, default=0.01)
|
| p.add_argument("--grad_clip", type=float, default=1.0)
|
| p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"])
|
| p.add_argument("--seed", type=int, default=42)
|
| p.add_argument("--log_every", type=int, default=50)
|
| p.add_argument("--save_every", type=int, default=1000)
|
|
|
|
|
| p.add_argument("--lambda_pg", type=float, default=1.0)
|
| p.add_argument("--lambda_kd", type=float, default=0.5)
|
| p.add_argument("--lambda_ent", type=float, default=0.01)
|
| p.add_argument("--tau", type=float, default=1.0, help="Student sampling temperature")
|
| p.add_argument("--tau_kd", type=float, default=1.0, help="KD softmax temperature")
|
|
|
|
|
| p.add_argument("--enable_teacher_cfg", action="store_true", default=False,
|
| help="Enable teacher-side CFG for KD target. "
|
| "False → prior single-branch behavior (fallback).")
|
| p.add_argument("--teacher_cfg_scale", type=float, default=7.0,
|
| help="CFG scale s (verified working value=7)")
|
| p.add_argument("--teacher_cfg_prob", type=float, default=1.0,
|
| help="Max prob of using guided target per sample (after warmup)")
|
| p.add_argument("--teacher_cfg_warmup_steps", type=int, default=2000,
|
| help="Steps to ramp teacher_cfg_prob 0 → teacher_cfg_prob")
|
| p.add_argument("--teacher_cfg_trunc", type=float, default=0.9,
|
| help="t threshold: when t >= trunc, s=1. Set >=1.0 to disable.")
|
| p.add_argument("--lambda_kd_uncond", type=float, default=0.3,
|
| help="Weight for uncond-branch KD / aux loss")
|
| p.add_argument("--reward_use_guided", action="store_true", default=False,
|
| help="[RISKY] Use guided teacher logits for REINFORCE reward.")
|
|
|
| p.add_argument("--eval_cfg_scale", type=float, default=7.0)
|
| p.add_argument("--use_cfg_eval", action="store_true", default=True)
|
|
|
|
|
| p.add_argument("--use_surrogate_grad", action="store_true",
|
| help="DiMO surrogate MSE trick applied to Stage-3 logits")
|
| p.add_argument("--lambda_surr", type=float, default=1.0)
|
| p.add_argument("--fake_rounds", type=int, default=1,
|
| help="Aux updates per generator update (DiMO=2)")
|
|
|
|
|
| p.add_argument("--t_curriculum_steps", type=int, default=10_000)
|
| p.add_argument("--p_mix_corrupt_frac", type=float, default=0.2)
|
| p.add_argument("--p_init_mix_ratio", type=float, default=0.2)
|
| p.add_argument("--collapse_warn_frac", type=float, default=0.2)
|
|
|
|
|
| p.add_argument("--dry_run", action="store_true",
|
| help="Run 1 step + grad-flow check, then exit")
|
| p.add_argument("--debug_dump", type=int, default=0,
|
| help="Dump token histogram + x_hat every N steps (0=off)")
|
|
|
| p.add_argument("--device", type=int, default=0)
|
| return p.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def save_checkpoint(model, path: str, name: str = "student"):
|
| os.makedirs(path, exist_ok=True)
|
| ckpt_path = os.path.join(path, f"{name}.pt")
|
| torch.save(model.state_dict(), ckpt_path)
|
| print(f"[save] {ckpt_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _stable_kl(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor:
|
| """KL(p||q) from raw logits, float32 + log_softmax. → [B] (mean over N tokens).
|
|
|
| p = softmax(z_p/tau), q = softmax(z_q/tau)
|
| KL(p||q) = sum_k p_k * (log p_k - log q_k)
|
|
|
| Both log_p and log_q are computed via log_softmax to avoid
|
| log(softmax(...)) numerical issues.
|
| """
|
| lp = F.log_softmax(z_p.float() / tau, dim=-1)
|
| lq = F.log_softmax(z_q.float() / tau, dim=-1)
|
| return (lp.exp() * (lp - lq)).sum(-1).mean(-1)
|
|
|
|
|
| def _stable_jeffrey(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor:
|
| """Symmetric KL (Jeffrey) from logits, float32 + log_softmax. → [B]."""
|
| return _stable_kl(z_p, z_q, tau) + _stable_kl(z_q, z_p, tau)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _build_dual_inputs(teacher_ref, txt_cond, txt_uncond, x_t, latents_shape, device):
|
| """Concatenate cond+uncond into a single [2B] forward-pass input.
|
|
|
| Returns (ids_dual [2B, L+N+1], rpos_dual [2B, L+N+1, 3], N).
|
| After the forward: chunk(2, dim=0) → (z_cond [B], z_uncond [B]).
|
|
|
| All three models (teacher/aux/student) share the SAME ids_dual / rpos_dual
|
| so the tokens are constructed only once per step.
|
| """
|
| txt_dual = torch.cat([txt_cond, txt_uncond], dim=0)
|
| x_t_dual = torch.cat([x_t, x_t], dim=0)
|
| return build_ursa_inputs(teacher_ref, txt_dual, x_t_dual, latents_shape, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _probe_flex_attn(model, label: str = "") -> object:
|
| """Return the FlexAttentionCausal2D object if present, else None."""
|
| return getattr(model, "flex_attn", None)
|
|
|
|
|
| def _print_flex_attn_state(model, label: str):
|
| fa = _probe_flex_attn(model, label)
|
| if fa is None:
|
| print(f" [flex_attn/{label}] not present on model")
|
| return
|
| print(
|
| f" [flex_attn/{label}] offsets={fa.offsets!r} "
|
| f"block_mask={'set' if fa.block_mask is not None else 'None'} "
|
| f"cu_offsets={'set' if fa.cu_offsets is not None else 'None'}"
|
| )
|
|
|
|
|
| def _reset_flex_attn(model, label: str = "", verbose: bool = False):
|
| """Reset flex_attn to None offsets so standard causal attention is used.
|
|
|
| Our distillation training processes each sample independently (batch dim)
|
| so block-packed attention (offsets != None) is not needed and must be cleared
|
| to avoid cross-sample mask contamination.
|
| """
|
| fa = _probe_flex_attn(model, label)
|
| if fa is None:
|
| return
|
| old_offsets = fa.offsets
|
| fa.offsets = None
|
| fa.block_mask = None
|
| fa.cu_offsets = None
|
| if verbose:
|
| print(f" [flex_attn/{label}] reset: was={old_offsets!r} → None (standard causal)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _compute_cfg_scale(t: torch.Tensor, cfg_scale: float, trunc: float) -> torch.Tensor:
|
| """Per-sample CFG scale [B]: s=cfg_scale when t < trunc, else s=1."""
|
| s = torch.full_like(t, cfg_scale)
|
| if trunc < 1.0:
|
| s = torch.where(t >= trunc, torch.ones_like(t), s)
|
| return s
|
|
|
|
|
| def _cfg_warmup_prob(step: int, cfg_prob: float, warmup_steps: int) -> float:
|
| """Linear warmup: 0 → cfg_prob over warmup_steps steps."""
|
| if warmup_steps <= 0:
|
| return cfg_prob
|
| return cfg_prob * min(1.0, step / warmup_steps)
|
|
|
|
|
| def _build_guided_logits(
|
| z_T_cond: torch.Tensor,
|
| z_T_uncond: torch.Tensor,
|
| t: torch.Tensor,
|
| cfg_scale: float,
|
| trunc: float,
|
| ) -> torch.Tensor:
|
| """z_guided = z_uncond + s*(z_cond - z_uncond), per-sample s [B,1,1]."""
|
| s = _compute_cfg_scale(t, cfg_scale, trunc).view(-1, 1, 1)
|
| return z_T_uncond + s * (z_T_cond - z_T_uncond)
|
|
|
|
|
| def _select_target(
|
| z_guided: torch.Tensor,
|
| z_cond: torch.Tensor,
|
| use_guided: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """Per-sample: z_guided where use_guided[b]=True, else z_cond."""
|
| mask = use_guided.view(-1, 1, 1).expand_as(z_cond)
|
| return torch.where(mask, z_guided, z_cond)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def debug_grad_flow(
|
| teacher, student, aux,
|
| txt_cond, txt_uncond, x_t, latents_shape, device, K, N, tau, tau_kd,
|
| enable_teacher_cfg,
|
| ):
|
| """One fwd+bwd without optimizer.step().
|
|
|
| Asserts:
|
| - teacher: zero grads (frozen)
|
| - aux: non-zero grads after loss_aux.backward()
|
| - student: non-zero grads after loss_student.backward()
|
|
|
| All cond/uncond forwards are batch-concatenated per requirement (1).
|
| """
|
| print("\n" + "=" * 64)
|
| print("[grad_flow] Starting gradient flow debug …")
|
| B = txt_cond.size(0)
|
|
|
|
|
| x_init_dbg = torch.randint(0, K, x_t.shape, device=device, dtype=torch.long)
|
| ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init_dbg, latents_shape, device)
|
| logits_s = student(ids_init, rope_pos=rpos_init).sample
|
| z_s = extract_visual_logits(logits_s.float(), N, K)
|
| p_s = F.softmax(z_s / tau, dim=-1)
|
| x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N)
|
| logp = p_s.clamp(1e-8).log().gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1)
|
| H_mean = -(p_s * p_s.clamp(1e-8).log()).sum(-1).mean()
|
|
|
|
|
| if enable_teacher_cfg and txt_uncond is not None:
|
| ids_dual, rpos_dual, _ = _build_dual_inputs(teacher, txt_cond, txt_uncond, x_t, latents_shape, device)
|
| with torch.no_grad():
|
| logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float()
|
| z_T_dual = extract_visual_logits(logits_T_dual, N, K)
|
| z_T_cond_dbg, z_T_uncond_dbg = z_T_dual.chunk(2, dim=0)
|
| t_dbg = torch.full((B,), 0.5, device=device, dtype=torch.float32)
|
| z_T_guided_dbg = _build_guided_logits(
|
| z_T_cond_dbg.float(), z_T_uncond_dbg.float(), t_dbg, 3.0, 0.9)
|
| z_T_target_dbg = z_T_guided_dbg.detach()
|
| print(f" [grad_flow] z_T_cond shape={z_T_cond_dbg.shape} "
|
| f"min={z_T_cond_dbg.min():.3f} max={z_T_cond_dbg.max():.3f}")
|
| print(f" [grad_flow] z_T_uncond shape={z_T_uncond_dbg.shape} "
|
| f"min={z_T_uncond_dbg.min():.3f} max={z_T_uncond_dbg.max():.3f}")
|
| print(f" [grad_flow] z_T_guided shape={z_T_guided_dbg.shape} "
|
| f"min={z_T_guided_dbg.min():.3f} max={z_T_guided_dbg.max():.3f}")
|
| ids_t_ref = ids_dual[:B]
|
| rpos_t_ref = rpos_dual[:B]
|
| ids_fwd = ids_dual
|
| rpos_fwd = rpos_dual
|
| else:
|
| ids_t_ref, rpos_t_ref, _ = build_ursa_inputs(teacher, txt_cond, x_t, latents_shape, device)
|
| with torch.no_grad():
|
| logits_T = teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float()
|
| z_T_target_dbg = extract_visual_logits(logits_T, N, K).detach()
|
| ids_fwd = ids_t_ref
|
| rpos_fwd = rpos_t_ref
|
|
|
|
|
| with torch.no_grad():
|
| z_T_ref2 = extract_visual_logits(
|
| teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float(), N, K)
|
| z_S_ref2 = extract_visual_logits(
|
| student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample.float(), N, K)
|
| if z_T_ref2.shape != z_S_ref2.shape:
|
| raise RuntimeError(
|
| f"[FATAL] Dual-path shape mismatch: z_T={z_T_ref2.shape} z_S={z_S_ref2.shape}"
|
| )
|
| print(f" [grad_flow] Dual-path check OK: shape={z_T_ref2.shape}")
|
|
|
|
|
| logits_A = aux(ids_fwd.detach(), rope_pos=rpos_fwd.detach()).sample
|
| if enable_teacher_cfg and txt_uncond is not None:
|
| z_A_dual2 = extract_visual_logits(logits_A.float(), N, K)
|
| z_A_cond_dbg, _ = z_A_dual2.chunk(2, dim=0)
|
| else:
|
| z_A_cond_dbg = extract_visual_logits(logits_A.float(), N, K)
|
| loss_aux_sample = _stable_jeffrey(z_T_target_dbg, z_A_cond_dbg, tau_kd)
|
| loss_aux = loss_aux_sample.mean()
|
| loss_aux.backward()
|
|
|
| teacher_grads = [p.grad for p in teacher.parameters() if p.grad is not None]
|
| aux_grads = [p.grad.norm().item() for p in aux.parameters() if p.grad is not None]
|
| print(f" [grad_flow] teacher grads with non-None grad: {len(teacher_grads)} (must be 0)")
|
| if aux_grads:
|
| print(f" [grad_flow] aux grad norm min={min(aux_grads):.3e} "
|
| f"mean={sum(aux_grads)/len(aux_grads):.3e} max={max(aux_grads):.3e}")
|
| else:
|
| print(" [grad_flow] ⚠️ aux has NO grads")
|
| for param in aux.parameters():
|
| param.grad = None
|
|
|
|
|
| logits_S = student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample
|
| z_S_cond = extract_visual_logits(logits_S.float(), N, K)
|
| loss_kd = _stable_kl(z_T_target_dbg, z_S_cond, tau_kd).mean()
|
| adv = (loss_aux_sample.detach() * 0 + 1.0)
|
| assert not adv.requires_grad, "[BUG] adv must be detached"
|
| loss_student = -(adv * logp).mean() + loss_kd - 0.01 * H_mean
|
| loss_student.backward()
|
|
|
| student_grads = [p.grad.norm().item() for p in student.parameters() if p.grad is not None]
|
| if student_grads:
|
| print(f" [grad_flow] student grad norm min={min(student_grads):.3e} "
|
| f"mean={sum(student_grads)/len(student_grads):.3e} "
|
| f"max={max(student_grads):.3e}")
|
| else:
|
| print(" [grad_flow] ⚠️ student has NO grads — diagnosing:")
|
| print(f" logp.requires_grad={logp.requires_grad}")
|
| print(f" z_s.requires_grad={z_s.requires_grad}")
|
|
|
| assert len(teacher_grads) == 0, "teacher has grads — not frozen"
|
| assert len(aux_grads) > 0, "aux has no grads after loss_aux.backward()"
|
| assert len(student_grads) > 0, "student has no grads — grad flow broken"
|
|
|
| for m in (student, aux):
|
| for param in m.parameters():
|
| param.grad = None
|
|
|
| print(" [grad_flow] All gradient assertions PASSED ✓")
|
| print("=" * 64 + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
| device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu")
|
| dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
| compute_dtype = dtype_map[args.mixed_precision]
|
| torch.manual_seed(args.seed)
|
| os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
|
| _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0)
|
| is_native = (
|
| args.height == _NATIVE["height"]
|
| and args.width == _NATIVE["width"]
|
| and args.num_frames == _NATIVE["num_frames"]
|
| )
|
| print(f"[init] verified_native_regime={is_native} "
|
| f"geometry=({args.num_frames}×{args.height}×{args.width}) "
|
| f"teacher_cfg_scale={args.teacher_cfg_scale if args.enable_teacher_cfg else 'OFF'}")
|
| if not is_native:
|
| print(f"[WARN] Current geometry ({args.num_frames}×{args.height}×{args.width}) "
|
| f"is not the verified native URSA regime "
|
| f"({_NATIVE['num_frames']}×{_NATIVE['height']}×{_NATIVE['width']}). "
|
| "Distillation quality may degrade or become invalid.")
|
| if not args.enable_teacher_cfg:
|
| print("[WARN] Teacher CFG is DISABLED. no_cfg is known to produce "
|
| "blank/blurry output for this URSA checkpoint. "
|
| "Distillation without CFG is unlikely to produce useful results.")
|
| elif args.teacher_cfg_scale != _NATIVE["guidance_scale"]:
|
| print(f"[WARN] teacher_cfg_scale={args.teacher_cfg_scale} differs from "
|
| f"the verified working value ({_NATIVE['guidance_scale']}).")
|
|
|
| if args.enable_teacher_cfg and args.reward_use_guided:
|
| print("[WARN] --reward_use_guided is ON — can cause mode collapse, watch tok_entropy.")
|
|
|
|
|
| print(f"[init] Loading from {args.teacher_ckpt} …")
|
| pipe = URSAPipeline.from_pretrained(
|
| args.teacher_ckpt, torch_dtype=compute_dtype, trust_remote_code=True
|
| ).to(device)
|
|
|
| tokenizer = pipe.tokenizer
|
| scheduler = pipe.scheduler
|
| scheduler.to(device=device)
|
|
|
| vae_t_stride = getattr(pipe.vae.config, "temporal_stride", 4)
|
| vae_s_stride = getattr(pipe.vae.config, "spatial_stride", 8)
|
| latents_shape = compute_latents_shape(
|
| args.num_frames, args.height, args.width, vae_t_stride, vae_s_stride
|
| )
|
| T, H, W = latents_shape
|
| N = T * H * W
|
| K = scheduler.codebook_size
|
| print(
|
| f"[init] latents_shape=({T},{H},{W}) N={N} K={K} "
|
| f"CFG={'ON' if args.enable_teacher_cfg else 'OFF'}"
|
| )
|
|
|
|
|
| txt_uncond_base = tokenizer(
|
| [""], max_length=args.max_prompt_length, padding="max_length",
|
| padding_side="left", truncation=True, return_tensors="pt",
|
| ).input_ids.to(device)
|
|
|
|
|
| teacher = pipe.transformer.eval().requires_grad_(False)
|
| student = copy.deepcopy(teacher).train().requires_grad_(True)
|
| aux = copy.deepcopy(teacher).train().requires_grad_(True)
|
|
|
|
|
|
|
|
|
| if args.dry_run:
|
| print("[init] flex_attn state before reset:")
|
| for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")):
|
| _print_flex_attn_state(m, lbl)
|
| for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")):
|
| _reset_flex_attn(m, lbl, verbose=True)
|
| if args.dry_run:
|
| print("[init] flex_attn state after reset:")
|
| for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")):
|
| _print_flex_attn_state(m, lbl)
|
|
|
| opt_student = torch.optim.AdamW(
|
| student.parameters(), lr=args.lr_student, weight_decay=args.weight_decay
|
| )
|
| opt_aux = torch.optim.AdamW(
|
| aux.parameters(), lr=args.lr_aux, weight_decay=args.weight_decay
|
| )
|
|
|
|
|
|
|
| collate = make_collate_fn(tokenizer, args.max_prompt_length, device)
|
|
|
|
|
|
|
|
|
| dataset = PromptDataset(
|
| args.prompt_file,
|
| shuffle_files=True,
|
| shuffle_buffer=50000,
|
| seed=args.seed,
|
| infinite=True,
|
| csv=CSVSpec(caption_field="caption"),
|
| )
|
|
|
| loader = DataLoader(
|
| dataset,
|
| batch_size=args.batch_size,
|
| shuffle=False,
|
| drop_last=True,
|
| num_workers=2,
|
| collate_fn=collate,
|
| pin_memory=True,
|
| )
|
| inf_loader = InfiniteDataLoader(loader)
|
|
|
|
|
| _sanity_check_forward(teacher, scheduler, latents_shape, device, K, args.dry_run)
|
|
|
|
|
| baseline_ema: float = 0.0
|
| x_hat_prev = None
|
| initial_tok_entropy: float = None
|
| dump_dir = os.path.join(args.out_dir, "debug_dumps") if args.debug_dump > 0 else None
|
|
|
| num_steps = 1 if args.dry_run else args.num_steps
|
| print(f"[train] {'DRY RUN' if args.dry_run else f'{num_steps} steps'} "
|
| f"| CFG={args.enable_teacher_cfg}")
|
|
|
| for step in range(1, num_steps + 1):
|
|
|
|
|
|
|
|
|
| txt_cond = next(inf_loader)
|
| txt_cond = txt_cond.to(device, non_blocking=True)
|
| B = txt_cond.size(0)
|
|
|
| txt_uncond = None
|
| if args.enable_teacher_cfg:
|
| txt_uncond = txt_uncond_base.expand(B, -1)
|
|
|
|
|
|
|
|
|
| x_init = _sample_x_init(B, T, H, W, K, device, x_hat_prev, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| with torch.no_grad():
|
| ids_init, rpos_init, _ = build_ursa_inputs(
|
| teacher, txt_cond, x_init, latents_shape, device)
|
| logits_s_init = student(ids_init, rope_pos=rpos_init).sample
|
| z_s = extract_visual_logits(logits_s_init.float(), N, K)
|
| p_s = F.softmax(z_s / args.tau, dim=-1)
|
| x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N)
|
|
|
|
|
|
|
| x_hat_4d = x_hat.view(B, T, H, W)
|
|
|
|
|
|
|
|
|
| t = sample_t_curriculum(B, device, step, warmup_steps=args.t_curriculum_steps)
|
| with torch.no_grad():
|
| x_t = scheduler.add_noise(x_hat_4d, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| with torch.no_grad():
|
| if args.enable_teacher_cfg:
|
|
|
| ids_dual, rpos_dual, _ = _build_dual_inputs(
|
| teacher, txt_cond, txt_uncond, x_t, latents_shape, device)
|
| logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float()
|
| z_T_dual = extract_visual_logits(logits_T_dual, N, K)
|
| z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0)
|
| ids_t = ids_dual[:B]
|
| rpos_t = rpos_dual[:B]
|
| else:
|
| ids_t, rpos_t, _ = build_ursa_inputs(
|
| teacher, txt_cond, x_t, latents_shape, device)
|
| logits_T = teacher(ids_t, rope_pos=rpos_t).sample.float()
|
| z_T_cond = extract_visual_logits(logits_T, N, K)
|
| z_T_uncond = None
|
| ids_dual = ids_t
|
| rpos_dual = rpos_t
|
|
|
|
|
| z_T_guided = None
|
| if args.enable_teacher_cfg:
|
| z_T_cond_f = z_T_cond.float()
|
| z_T_uncond_f = z_T_uncond.float()
|
| z_T_guided = _build_guided_logits(
|
| z_T_cond_f, z_T_uncond_f, t,
|
| args.teacher_cfg_scale, args.teacher_cfg_trunc)
|
|
|
|
|
| p_guided = _cfg_warmup_prob(
|
| step, args.teacher_cfg_prob, args.teacher_cfg_warmup_steps)
|
| use_guided = torch.rand(B, device=device) < p_guided
|
| use_guided_ratio = use_guided.float().mean().item()
|
| z_T_target = _select_target(z_T_guided, z_T_cond_f, use_guided)
|
| else:
|
| use_guided = torch.zeros(B, dtype=torch.bool, device=device)
|
| use_guided_ratio = 0.0
|
| z_T_target = z_T_cond.float()
|
|
|
|
|
| z_T_target = z_T_target.detach()
|
|
|
|
|
|
|
|
|
| loss_aux_cond_v_last = None
|
| loss_aux_uncond_v_last = None
|
| loss_aux_cond_sample_last = None
|
|
|
| for _fr in range(args.fake_rounds):
|
| opt_aux.zero_grad()
|
|
|
| if args.enable_teacher_cfg:
|
|
|
| logits_A_dual = aux(ids_dual.detach(), rope_pos=rpos_dual.detach()).sample
|
| z_A_dual = extract_visual_logits(logits_A_dual.float(), N, K)
|
| z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0)
|
|
|
|
|
| loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd)
|
| loss_aux_cond_v = loss_aux_cond_sample.mean()
|
|
|
|
|
| z_T_uncond_det = z_T_uncond.float().detach()
|
| loss_aux_uncond_sample = _stable_jeffrey(z_T_uncond_det, z_A_uncond, args.tau_kd)
|
| loss_aux_uncond_v = loss_aux_uncond_sample.mean()
|
|
|
| loss_aux_v = loss_aux_cond_v + args.lambda_kd_uncond * loss_aux_uncond_v
|
| else:
|
| logits_A = aux(ids_t.detach(), rope_pos=rpos_t.detach()).sample
|
| z_A_cond = extract_visual_logits(logits_A.float(), N, K)
|
|
|
| loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd)
|
| loss_aux_cond_v = loss_aux_cond_sample.mean()
|
| loss_aux_uncond_v = torch.tensor(0.0, device=device)
|
| loss_aux_v = loss_aux_cond_v
|
|
|
| loss_aux_v.backward()
|
| if args.grad_clip > 0:
|
| torch.nn.utils.clip_grad_norm_(aux.parameters(), args.grad_clip)
|
| opt_aux.step()
|
|
|
| for p in aux.parameters():
|
| p.grad = None
|
|
|
| loss_aux_cond_v_last = loss_aux_cond_v.detach()
|
| loss_aux_uncond_v_last = loss_aux_uncond_v.detach()
|
| loss_aux_cond_sample_last = loss_aux_cond_sample.detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if args.enable_teacher_cfg:
|
| logits_S_dual = _get_logits(student(ids_dual.detach(), rope_pos=rpos_dual.detach())).float()
|
| z_S_dual = extract_visual_logits(logits_S_dual, N, K)
|
| z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0)
|
| else:
|
| logits_S = _get_logits(student(ids_t.detach(), rope_pos=rpos_t.detach())).float()
|
| z_S_cond = extract_visual_logits(logits_S, N, K)
|
| z_S_uncond = None
|
|
|
| if z_T_cond.shape != z_S_cond.shape:
|
| raise RuntimeError(f"[FATAL] Dual-path shape mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}")
|
|
|
| loss_kd_cond = _stable_kl(z_T_target, z_S_cond, args.tau_kd).mean()
|
| loss_kd_uncond_v = torch.tensor(0.0, device=device)
|
| if args.enable_teacher_cfg and (z_S_uncond is not None):
|
| loss_kd_uncond_v = _stable_kl(z_T_uncond.float().detach(), z_S_uncond, args.tau_kd).mean()
|
| loss_kd = loss_kd_cond + args.lambda_kd_uncond * loss_kd_uncond_v
|
|
|
|
|
|
|
|
|
| if args.enable_teacher_cfg and args.reward_use_guided:
|
| z_T_for_rew = z_T_target
|
| else:
|
| z_T_for_rew = z_T_cond.float().detach()
|
|
|
| reward = -_stable_kl(z_T_for_rew.detach(), z_S_cond.detach(), args.tau)
|
| assert not reward.requires_grad
|
|
|
| baseline_ema = 0.99 * baseline_ema + 0.01 * reward.mean().item()
|
| adv = (reward - baseline_ema).detach()
|
| assert not adv.requires_grad
|
|
|
|
|
|
|
|
|
| opt_student.zero_grad(set_to_none=True)
|
|
|
|
|
| (args.lambda_kd * loss_kd).backward()
|
|
|
|
|
| ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init, latents_shape, device)
|
| logits_s_pol = _get_logits(student(ids_init, rope_pos=rpos_init)).float()
|
| z_s_pol = extract_visual_logits(logits_s_pol, N, K)
|
|
|
| logp_tok = F.log_softmax(z_s_pol / args.tau, dim=-1)
|
| p_s_pol = logp_tok.exp()
|
|
|
|
|
| logp_sum = logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1)
|
| logp = logp_sum / N
|
|
|
| H_mean = -(p_s_pol * logp_tok).sum(-1).mean()
|
|
|
| loss_pg = -(adv * logp).mean()
|
|
|
| lambda_ent_eff = args.lambda_ent * (1.0 + 2.0 * use_guided_ratio)
|
| (loss_pg * args.lambda_pg - H_mean * lambda_ent_eff).backward()
|
|
|
|
|
| loss_surr = None
|
| if args.use_surrogate_grad:
|
| with torch.no_grad():
|
| logits_A_ref = _get_logits(aux(ids_t.detach(), rope_pos=rpos_t.detach())).float()
|
| z_A_ref = extract_visual_logits(logits_A_ref, N, K)
|
| p_A_ref = F.softmax(z_A_ref / args.tau_kd, dim=-1).detach()
|
| p_T_ref = F.softmax(z_T_target / args.tau_kd, dim=-1).detach()
|
| grad_surr = (p_A_ref - p_T_ref).detach()
|
| loss_surr = 0.5 * F.mse_loss(z_s_pol, (z_s_pol - grad_surr).detach())
|
| (args.lambda_surr * loss_surr).backward()
|
|
|
| if args.grad_clip > 0:
|
| torch.nn.utils.clip_grad_norm_(student.parameters(), args.grad_clip)
|
| opt_student.step()
|
|
|
|
|
| x_hat_prev = x_hat_4d.detach()
|
|
|
|
|
|
|
|
|
| if step == 1:
|
| _run_assertions(
|
| x_init, ids_init, rpos_init,
|
| z_s, p_s, logp,
|
| z_T_cond, z_S_cond, x_t, K, N, B, T, H, W,
|
| teacher.config.lm_vocab_size,
|
| z_T_uncond=z_T_uncond,
|
| z_T_guided=z_T_guided,
|
| dry_run=args.dry_run,
|
| )
|
|
|
| tok_entropy = _token_histogram_entropy(x_hat, K)
|
| if initial_tok_entropy is None:
|
| initial_tok_entropy = tok_entropy
|
|
|
| if tok_entropy < args.collapse_warn_frac * initial_tok_entropy:
|
| print(
|
| f"[COLLAPSE WARNING] step={step} tok_entropy={tok_entropy:.3f} "
|
| f"initial={initial_tok_entropy:.3f} "
|
| f"ratio={tok_entropy/max(initial_tok_entropy, 1e-8):.2f} < "
|
| f"{args.collapse_warn_frac}. "
|
| "Increase --lambda_ent (try 0.05) or --tau."
|
| )
|
|
|
| if step % args.log_every == 0 or args.dry_run:
|
| surr_str = f" loss_surr={loss_surr.item():.4f}" if loss_surr is not None else ""
|
| print(
|
| f"[step {step:>6d}] "
|
| f"loss_aux_cond={loss_aux_cond_v_last.item():.3e} "
|
| f"loss_aux_uncond={loss_aux_uncond_v_last.item():.3e} "
|
| f"loss_kd_cond={loss_kd_cond.item():.4f} "
|
| f"loss_kd_uncond={loss_kd_uncond_v.item():.4f} "
|
| f"loss_pg={loss_pg.item():.4f}"
|
| f"{surr_str} "
|
| f"H={H_mean.item():.3f} tok_H={tok_entropy:.3f} "
|
| f"guided_ratio={use_guided_ratio:.2f} "
|
| f"baseline={baseline_ema:.4f} "
|
| f"mean_logp_tok={logp.mean().item():.3f}"
|
| )
|
|
|
| if args.debug_dump > 0 and step % args.debug_dump == 0:
|
| _dump_debug(dump_dir, step, x_hat, K)
|
|
|
| if not args.dry_run and step % args.save_every == 0:
|
| ckpt_dir = os.path.join(args.out_dir, f"step_{step:06d}")
|
| save_checkpoint(student, ckpt_dir, "student")
|
| save_checkpoint(aux, ckpt_dir, "aux")
|
|
|
|
|
| if args.dry_run:
|
| print("\n[dry_run] Running gradient flow debug …")
|
| txt_dbg = next(inf_loader)
|
| B_dbg = txt_dbg.size(0)
|
| x_t_dbg = torch.randint(0, K, (B_dbg, T, H, W), device=device, dtype=torch.long)
|
| txt_u_dbg = (txt_uncond_base.expand(B_dbg, -1)
|
| if args.enable_teacher_cfg else None)
|
| debug_grad_flow(
|
| teacher, student, aux,
|
| txt_dbg, txt_u_dbg, x_t_dbg, latents_shape, device, K, N,
|
| args.tau, args.tau_kd, args.enable_teacher_cfg,
|
| )
|
| _dry_run_patches_789(teacher, latents_shape, K, N, device)
|
| print("[dry_run] Done. All checks (1-9) PASSED. Exiting.")
|
| return
|
|
|
|
|
| final_dir = os.path.join(args.out_dir, "final")
|
| save_checkpoint(student, final_dir, "student")
|
| save_checkpoint(aux, final_dir, "aux")
|
| print("[done] Training complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _sample_x_init(B, T, H, W, K, device, x_hat_prev, args):
|
| x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long)
|
| if x_hat_prev is not None and args.p_init_mix_ratio > 0:
|
| n_mix = max(1, int(B * args.p_init_mix_ratio))
|
| x_init[:n_mix] = corrupt_tokens(x_hat_prev[:n_mix], r=args.p_mix_corrupt_frac, K=K)
|
| return x_init
|
|
|
|
|
| def _token_histogram_entropy(x_hat: torch.Tensor, K: int) -> float:
|
| counts = x_hat.flatten().bincount(minlength=K).float()
|
| p = counts / counts.sum()
|
| p = p[p > 0]
|
| return float(-(p * p.log()).sum().item())
|
|
|
|
|
| def _dump_debug(dump_dir: str, step: int, x_hat: torch.Tensor, K: int):
|
| os.makedirs(dump_dir, exist_ok=True)
|
| counts = x_hat.flatten().bincount(minlength=K).tolist()
|
| with open(os.path.join(dump_dir, f"step_{step:06d}_hist.json"), "w") as fh:
|
| json.dump({"step": step, "counts": counts}, fh)
|
| torch.save(x_hat.cpu(), os.path.join(dump_dir, f"step_{step:06d}_xhat.pt"))
|
| print(f"[debug_dump] step={step} saved to {dump_dir}")
|
|
|
|
|
| def _run_assertions(
|
| x_init, ids_init, rpos_init,
|
| z_s, p_s, logp,
|
| z_T_cond, z_S_cond, x_t,
|
| K, N, B, T, H, W, lm_vocab_size,
|
| z_T_uncond=None, z_T_guided=None,
|
| dry_run=False,
|
| ):
|
| """Full shape / value-domain / consistency assertions (run at step=1)."""
|
| print("[assert] Running shape/value assertions …")
|
|
|
| L_plus_N1 = ids_init.size(1)
|
| txt_len = L_plus_N1 - (N + 1)
|
|
|
|
|
| assert x_init.dtype == torch.long, f"x_init dtype={x_init.dtype}"
|
| assert x_init.min() >= 0 and x_init.max() < K, \
|
| f"x_init out of [0,K): [{x_init.min()}, {x_init.max()}]"
|
|
|
|
|
| assert ids_init.shape == (B, L_plus_N1), f"ids_init.shape={ids_init.shape}"
|
| txt_part = ids_init[:, :txt_len]
|
| vis_part = ids_init[:, -N:]
|
| assert (txt_part < lm_vocab_size).all(), \
|
| f"text tokens bleed into visual range (max={txt_part.max()})"
|
| assert (vis_part >= lm_vocab_size).all(), \
|
| f"visual tokens not shifted (min={vis_part.min()}, lm_vocab_size={lm_vocab_size})"
|
| assert (vis_part < lm_vocab_size + K).all(), \
|
| f"visual tokens exceed lm_vocab_size+K (max={vis_part.max()})"
|
|
|
|
|
| assert rpos_init.shape == (B, L_plus_N1, 3), \
|
| f"rope_pos shape={rpos_init.shape} expected ({B},{L_plus_N1},3)"
|
|
|
|
|
| assert z_s.shape == (B, N, K), f"z_s.shape={z_s.shape}"
|
| p_err = (p_s.sum(-1) - 1).abs().max().item()
|
| assert p_err < 1e-3, f"p_s not normalised: max deviation={p_err:.2e}"
|
|
|
|
|
| assert not torch.isnan(logp).any(), "logp contains NaN"
|
| assert not torch.isinf(logp).any(), "logp contains Inf"
|
|
|
|
|
| assert x_t.min() >= 0 and x_t.max() < K, \
|
| f"x_t out of [0,K) after add_noise: [{x_t.min()}, {x_t.max()}]"
|
|
|
|
|
| assert z_T_cond.shape == z_S_cond.shape, \
|
| f"Dual-path mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}"
|
| assert z_T_cond.shape == (B, N, K), f"z_T_cond.shape={z_T_cond.shape}"
|
|
|
|
|
| if dry_run or z_T_uncond is not None:
|
| print(
|
| f"[assert] z_T_cond shape={z_T_cond.shape} "
|
| f"min={z_T_cond.min():.3f} max={z_T_cond.max():.3f} "
|
| f"mean={z_T_cond.mean():.3f}"
|
| )
|
| if z_T_uncond is not None:
|
| assert z_T_uncond.shape == (B, N, K), f"z_T_uncond.shape={z_T_uncond.shape}"
|
| print(
|
| f"[assert] z_T_uncond shape={z_T_uncond.shape} "
|
| f"min={z_T_uncond.min():.3f} max={z_T_uncond.max():.3f} "
|
| f"mean={z_T_uncond.mean():.3f}"
|
| )
|
| if z_T_guided is not None:
|
| assert z_T_guided.shape == (B, N, K), f"z_T_guided.shape={z_T_guided.shape}"
|
| g_min = z_T_guided.min().item()
|
| g_max = z_T_guided.max().item()
|
| g_mean = z_T_guided.mean().item()
|
| print(
|
| f"[assert] z_T_guided shape={z_T_guided.shape} "
|
| f"min={g_min:.3f} max={g_max:.3f} mean={g_mean:.3f}"
|
| )
|
|
|
| assert not torch.isnan(z_T_guided).any(), "z_T_guided contains NaN"
|
| assert not torch.isinf(z_T_guided).any(), "z_T_guided contains Inf"
|
| assert abs(g_min) < 1e4 and abs(g_max) < 1e4, (
|
| f"z_T_guided magnitude too large: min={g_min:.1e} max={g_max:.1e}. "
|
| f"Reduce --teacher_cfg_scale (currently may amplify outlier logits)."
|
| )
|
|
|
| print("[assert] All assertions PASSED ✓")
|
|
|
|
|
| def _sanity_check_forward(teacher, scheduler, latents_shape, device, K, verbose=False):
|
| print("[init] Checking logit dimensions …")
|
| T, H, W = latents_shape
|
| N, B, L = T * H * W, 1, 16
|
| dummy_txt = torch.zeros(B, L, dtype=torch.long, device=device)
|
| dummy_vis = torch.zeros(B, T, H, W, dtype=torch.long, device=device)
|
| with torch.no_grad():
|
| ids, rpos, _ = build_ursa_inputs(teacher, dummy_txt, dummy_vis, latents_shape, device)
|
| logits = teacher(ids, rope_pos=rpos).sample
|
| lm_head_size = teacher.config.lm_head_size
|
| lm_vocab = teacher.config.lm_vocab_size
|
| print(
|
| f"[init] logits={logits.shape} K={K} "
|
| f"lm_head={lm_head_size} lm_vocab={lm_vocab}"
|
| )
|
| assert ids.shape == (B, L + N + 1), f"ids shape {ids.shape}"
|
| assert rpos.shape == (B, L + N + 1, 3), f"rpos shape {rpos.shape}"
|
| z = extract_visual_logits(logits.float(), N, K)
|
| assert z.shape == (B, N, K), f"z shape {z.shape}"
|
| assert lm_head_size >= K, f"lm_head_size={lm_head_size} < K={K}"
|
| if verbose:
|
| print("[init] flex_attn state during sanity check:")
|
| _print_flex_attn_state(teacher, "teacher")
|
| print("[init] Forward check OK ✓")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _dry_run_patches_789(teacher, latents_shape, K, N, device):
|
| """Three deep self-checks executed only during --dry_run.
|
|
|
| Patch 7 — extract_visual_logits end-to-end alignment:
|
| Run a real teacher forward, manually reconstruct z_manual from raw logits
|
| using the latent_shift / codebook_size convention, and assert the result
|
| matches extract_visual_logits(). Handles the common URSA case where
|
| lm_head outputs K logits directly (latent_shift not applied to logit dim).
|
|
|
| Patch 8 — flex_attn semantics sanity:
|
| If the model exposes set_offsets_by_lens, compare visual-logit mean-delta
|
| between offsets=None (standard causal) and a single-block offset. A large
|
| delta is expected and confirms that our training correctly uses offsets=None.
|
| Gracefully skips when flex_attention is unavailable at runtime.
|
|
|
| Patch 9 — logp / token reshape consistency:
|
| With a small (T=3, H=4, W=5) shape, verify x_hat reshape round-trips and
|
| spot-check 10 token positions against manually computed log-probability.
|
| """
|
| T, H, W = latents_shape
|
| L_test, B_test = 16, 1
|
|
|
| print("\n" + "=" * 64)
|
| print("[patch 7/8/9] Running additional dry_run self-checks …")
|
|
|
|
|
|
|
|
|
| dummy_txt = torch.zeros(B_test, L_test, dtype=torch.long, device=device)
|
| dummy_vis = torch.zeros(B_test, T, H, W, dtype=torch.long, device=device)
|
| with torch.no_grad():
|
| ids_test, rpos_test, _ = build_ursa_inputs(
|
| teacher, dummy_txt, dummy_vis, latents_shape, device)
|
| logits_full = teacher(ids_test, rope_pos=rpos_test).sample.float()
|
|
|
| D = logits_full.size(-1)
|
| latent_shift = teacher.config.lm_vocab_size
|
|
|
|
|
|
|
|
|
| print("\n[7] extract_visual_logits end-to-end alignment …")
|
| z_vis = extract_visual_logits(logits_full, N, K)
|
| assert z_vis.shape == (B_test, N, K), f"z_vis.shape={z_vis.shape}"
|
|
|
| if D >= latent_shift + K:
|
|
|
| z_seq = logits_full[:, -(N + 1) : -1]
|
| z_manual = z_seq[..., latent_shift : latent_shift + K]
|
| delta = (z_vis - z_manual).abs().max().item()
|
| print(f" [7] path=full-vocab D={D} latent_shift+K={latent_shift + K}")
|
| print(f" [7] z_vis.shape={z_vis.shape} max|z_vis - z_manual|={delta:.2e}")
|
| assert delta < 1e-5, (
|
| f"extract_visual_logits mismatch (full-vocab path): delta={delta:.2e}. "
|
| "The function should return logits[..., latent_shift:latent_shift+K]."
|
| )
|
| print("[7] extract_visual_logits alignment PASSED ✓")
|
|
|
| else:
|
|
|
|
|
|
|
| z_seq = logits_full[:, -(N + 1) : -1]
|
| if D == K:
|
| delta = (z_vis - z_seq).abs().max().item()
|
| print(
|
| f" [7] SKIP latent_shift formula: D={D} == K={K} "
|
| f"latent_shift={latent_shift}.\n"
|
| f" [7] Explanation: URSA lm_head outputs K visual logits directly.\n"
|
| f" [7] latent_shift={latent_shift} is the input token-ID shift "
|
| f"(raw_code + lm_vocab_size), NOT a logit-dim offset.\n"
|
| f" [7] extract_visual_logits happy-path: z = logits[:, -(N+1):-1] "
|
| f"(no vocab-dim slicing).\n"
|
| f" [7] Fallback check: z_vis == raw causal slice "
|
| f"max_delta={delta:.2e}"
|
| )
|
| assert delta < 1e-5, (
|
| f"z_vis != raw causal slice when D==K: delta={delta:.2e}"
|
| )
|
| else:
|
|
|
| offset = D - K
|
| z_manual = z_seq[..., offset:]
|
| delta = (z_vis - z_manual).abs().max().item()
|
| print(
|
| f" [7] SKIP latent_shift formula: D={D} < latent_shift+K={latent_shift + K}.\n"
|
| f" [7] extract_visual_logits uses offset={offset} (D-K). "
|
| f"max_delta={delta:.2e}"
|
| )
|
| assert delta < 1e-5, (
|
| f"z_vis != z_seq[..., D-K:]: delta={delta:.2e}"
|
| )
|
| print("[7] extract_visual_logits alignment PASSED (fallback path) ✓")
|
|
|
|
|
|
|
|
|
| print("\n[8] flex_attn semantics sanity …")
|
| fa = _probe_flex_attn(teacher)
|
| if fa is None or not hasattr(fa, "set_offsets_by_lens"):
|
| print(" [8] flex_attn.set_offsets_by_lens not available — skip")
|
| print("[8] flex_attn semantics sanity PASSED (skipped — no flex_attn) ✓")
|
| else:
|
| L_total = ids_test.size(1)
|
| txt_block = L_test + (N + 1)
|
| block_lens = [txt_block]
|
|
|
| try:
|
|
|
| _reset_flex_attn(teacher, "teacher")
|
| with torch.no_grad():
|
| logits_A = teacher(ids_test, rope_pos=rpos_test).sample.float()
|
| z_A = extract_visual_logits(logits_A, N, K)
|
|
|
|
|
|
|
|
|
| fa.set_offsets_by_lens(block_lens)
|
| with torch.no_grad():
|
| logits_B = teacher(ids_test, rope_pos=rpos_test).sample.float()
|
| z_B = extract_visual_logits(logits_B, N, K)
|
|
|
| delta_mean = (z_A - z_B).abs().mean().item()
|
| delta_max = (z_A - z_B).abs().max().item()
|
| print(
|
| f" [8] offsets=None vs set_offsets_by_lens({block_lens}):\n"
|
| f" [8] mean_abs_delta={delta_mean:.4e} max_abs_delta={delta_max:.4e}"
|
| )
|
| if delta_mean > 1e-3:
|
| print(
|
| f" [8] WARNING: mean_delta={delta_mean:.2e} > 1e-3.\n"
|
| " [8] Single-block flex_attn uses FULL (bidirectional) attention\n"
|
| " [8] inside the block, whereas offsets=None gives standard CAUSAL\n"
|
| " [8] attention. This difference is EXPECTED — it confirms our\n"
|
| " [8] training correctly uses offsets=None (no packed sequences)."
|
| )
|
| else:
|
| print(f" [8] delta ≤ 1e-3: attention semantics equivalent for this input.")
|
| print("[8] flex_attn semantics sanity PASSED ✓")
|
|
|
| except (NotImplementedError, RuntimeError, Exception) as exc:
|
| print(f" [8] flex_attn runtime not available ({type(exc).__name__}: {exc}) — skip")
|
| print("[8] flex_attn semantics sanity PASSED (runtime skip) ✓")
|
| finally:
|
| _reset_flex_attn(teacher, "teacher")
|
|
|
|
|
|
|
|
|
| print("\n[9] logp/token reshape consistency …")
|
| T9, H9, W9 = 3, 4, 5
|
| N9, B9 = T9 * H9 * W9, 1
|
|
|
| torch.manual_seed(99)
|
| z9 = torch.randn(B9, N9, K)
|
| p9 = F.softmax(z9 / 1.0, dim=-1)
|
|
|
|
|
| x_hat_flat = torch.multinomial(p9.view(-1, K), 1)
|
| x_hat_1d = x_hat_flat.view(B9, N9)
|
| x_hat_4d = x_hat_1d.view(B9, T9, H9, W9)
|
|
|
|
|
| x_hat_back = x_hat_4d.view(B9, N9)
|
| assert torch.equal(x_hat_1d, x_hat_back), (
|
| f"reshape round-trip FAILED: x_hat_1d != x_hat_4d.view(B,N)\n"
|
| f" x_hat_1d.shape={x_hat_1d.shape} x_hat_back.shape={x_hat_back.shape}"
|
| )
|
|
|
|
|
|
|
| logp_all = (
|
| p9.clamp(1e-8).log()
|
| .gather(-1, x_hat_1d.unsqueeze(-1))
|
| .squeeze(-1)
|
| )
|
| logp_sum = logp_all.sum(-1)
|
|
|
|
|
| torch.manual_seed(7)
|
| positions = torch.randperm(N9)[:10].tolist()
|
| for pos in positions:
|
| tok_id = x_hat_1d[0, pos].item()
|
| logp_man = math.log(max(p9[0, pos, tok_id].item(), 1e-8))
|
| logp_gat = logp_all[0, pos].item()
|
| diff = abs(logp_man - logp_gat)
|
| assert diff < 1e-6, (
|
| f"logp mismatch at pos={pos}, tok={tok_id}: "
|
| f"manual={logp_man:.8f} gathered={logp_gat:.8f} diff={diff:.2e}"
|
| )
|
|
|
|
|
| logp_sum_manual = logp_all[0].sum().item()
|
| assert abs(logp_sum.item() - logp_sum_manual) < 1e-5, \
|
| f"logp_sum mismatch: {logp_sum.item():.6f} vs {logp_sum_manual:.6f}"
|
|
|
| print(
|
| f" [9] T={T9},H={H9},W={W9} N={N9} K={K} "
|
| f"x_hat reshape round-trip ✓ "
|
| f"10 logp spot-checks (pos={positions}) ✓ "
|
| f"logp_sum={logp_sum.item():.3f}"
|
| )
|
| print("[9] logp/token reshape consistency PASSED ✓")
|
|
|
| print("\n" + "=" * 64)
|
| print("[patch 7/8/9] All 3 additional dry_run checks PASSED ✓")
|
| print("=" * 64)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|