代码解读(关键设计决策) utils_ursa_inputs.py build_ursa_inputs(transformer, txt_ids, visual_tokens, latents_shape, device) 严格复刻 URSAPipeline.__call__ 的 token 拼接逻辑: img_ids = pad(latents_flat + lm_vocab_size, (1,0), value=bov_token_id)input_ids = cat([txt_ids, img_ids], dim=1)blk_pos = flex_rope.get_pos(latents_shape, L)rope_pos = cat([txt_pos, blk_pos[0]]).unsqueeze(0).expand(B,-1,-1) extract_visual_logits(logits, N, K) 坑 1 防护:z = logits[:, -(N+1):-1](causal slice),然后根据最后一维是否等于 K 决定是否再切 slice。 sample_t_curriculum — 前 10k 步用 t = 1-(1-u)^2 偏大,之后恢复均匀采样。 train_onestep_ursa_dimo.py 训练循环 每一步的 9 个 stage 对应 DiMO 论文的完整流程: Stage 操作 梯度 1-2 tokenize + 采样 x_init (80% uniform / 20% corrupt) 无 3 student 在 x_init 上 1-step forward → x_hat, logp, H ✅ student 4 add_noise(x_hat, t) → x_t 无(离散采样截断) 5 teacher 在 x_t → p_T 无 (no_grad) 6 aux 在 x_t → Jeffrey(p_T, p_A) → backward → aux update ✅ aux only 7 student 在 x_t → KL(p_T ‖ p_S_t) ✅ student 8 REINFORCE: r=-loss_aux, adv=r-EMA, loss_pg=-(adv·logp) ✅ student (via logp) 9 L_s = λ_pg·loss_pg + λ_kd·loss_kd - λ_ent·H → student update ✅ student 运行命令示例 端到端冒烟测试(单卡,17帧256×256,2000步): python scripts/train_onestep_ursa_dimo.py \ --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ --num_frames 17 --height 256 --width 256 \ --batch_size 1 --num_steps 2000 \ --log_every 50 --save_every 500 \ --out_dir ./outputs/dimo_test 评估(1-step student vs 25-step teacher): python scripts/eval_onestep_ursa.py \ --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ --student_ckpt ./outputs/dimo_test/final/student.pt \ --num_frames 17 --height 256 --width 256 \ --teacher_steps 25 \ --out_dir ./outputs/eval 扩展到完整分辨率(49帧 320×512): python scripts/train_onestep_ursa_dimo.py \ --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ --num_frames 49 --height 320 --width 512 \ --batch_size 2 --num_steps 50000 \ --lambda_ent 0.01 --t_curriculum_steps 10000 \ --mixed_precision bf16 --out_dir ./outputs/dimo_full 三大稳定性机制(缺一不可) t curriculum — 前 10k 步 t 偏大,teacher 分布更尖锐,KD 信号更强,避免早期 student 随机游走 p_init mixing — 20% batch 用 corrupt(x_hat_prev, r=0.2),让 student 学会"一步修复" 熵正则 λ_ent — 初始 0.01,若检测到 tok_entropy 下降就升到 0.05 8 卡启动命令 accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config=./configs/distill_dimo.yaml experiment.output_dir=./experiments/distill_dimo distill.teacher_ckpt=/gfs/space/private/fengzl/World_Model/URSA-1.7B distill.prompt_source=/gfs/space/private/fengzl/World_Model/Koala-36M-v1 distill.batch_size_per_gpu=1 Smoke Test(50 步,保存 checkpoint) accelerate launch --num_processes 8 --mixed_precision bf16 \ scripts/train_distill_dimo.py \ config="./configs/distill_dimo.yaml" \ experiment.output_dir="./experiments/smoke" \ distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" \ training.max_train_steps=50 \ experiment.save_every=50 加载 student.pt 做 1-step 推理 from diffnext.pipelines import URSAPipelineimport torchpipe = URSAPipeline.from_pretrained( "/path/to/URSA-1.7B-IBQ1024", torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")# 替换 transformer 权重为 studentstate = torch.load("experiments/distill_dimo/checkpoints/final/student.pt", map_location="cuda")pipe.transformer.load_state_dict(state, strict=True)# 1-step 生成(num_inference_steps=1)frames = pipe( prompt="a dog running on a beach", height=256, width=256, num_frames=17, num_inference_steps=1, guidance_scale=3.0,).frames 最新 修改分辨率和cfg后 accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ --machine_rank 0 --num_machines 1 --num_processes 8 \ scripts/train_distill_dimo.py \ config="./configs/distill_dimo.yaml" \ experiment.output_dir="./experiments/distill_dimo" \ distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"