Skip to content

feature(xjy): Refine PriorZero Implementation#441

Open
xiongjyu wants to merge 187 commits into
opendilab:dev-multitask-balance-clean-rftfrom
xiongjyu:dev-multitask-balance-clean-rft
Open

feature(xjy): Refine PriorZero Implementation#441
xiongjyu wants to merge 187 commits into
opendilab:dev-multitask-balance-clean-rftfrom
xiongjyu:dev-multitask-balance-clean-rft

Conversation

@xiongjyu

@xiongjyu xiongjyu commented Nov 20, 2025

Copy link
Copy Markdown
Collaborator

这个 PR 主要完善了 PriorZero的实现与开发流程,修复了若干影响训练正确性和稳定性的关键问题,并对训练逻辑、损失计算、数据采集进行了系统性的增强。

本 PR 已完成的工作
• 修复了 PriorZero 训练流程中的多个关键 bug,包括 game segment 构建、loss 计算、log-prob 对齐以及 action 处理中的错误。
• 完善了 REINFORCE / RFT 风格的策略优化实现,在 buffer 中正确存储并使用 old_logprob,保证策略更新的正确性。
• 补充并规范了训练过程中的统计指标,包括 KL divergence、policy entropy 等,用于更好地监控训练状态。
• 优化了 Collector 与 Replay Buffer 的数据流转逻辑,提升数据一致性与采样稳定性,减少隐式错误。
• 引入并验证了单卡场景下的 vLLM 权重同步机制。
• 多 GPU / 多节点场景下的 vLLM 权重同步与稳定性验证

Comment thread zoo/jericho/priorzero/priorzero_policy.py Outdated
Comment thread zoo/jericho/priorzero/priorzero_policy.py Outdated
@xiongjyu xiongjyu deleted the branch opendilab:dev-multitask-balance-clean-rft November 24, 2025 14:28
@xiongjyu xiongjyu closed this Nov 24, 2025
@xiongjyu xiongjyu deleted the dev-multitask-balance-clean-rft branch November 24, 2025 14:28
@xiongjyu xiongjyu reopened this Nov 24, 2025
@puyuan1996 puyuan1996 added the research Research work in progress label Nov 28, 2025
for i in range(num_engines):
bundle_indices = None
if tensor_parallel_size > 1:
bundle_indices = get_bundle_indices(shared_pg, i, tensor_parallel_size)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是参考的ray官方改进吗

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个vllm_engine基本和openrlhf这部分是一样的;不过目前只使用一个vllm,并且tensor_parallel_size =1;因为显存够

Comment thread zoo/jericho/priorzero/vllm_utils/vllm_engine_ray.py Outdated
…ple for world-model training; train LLM only on latest trajectories
xiongjyu and others added 30 commits April 4, 2026 17:52
Docs: Expand PriorZero (Jericho) README with detailed configuration and usage
…rft' into dev-multitask-balance-clean-rft-vl
…evels, per-level eval logging

  - Switch from single-task to multi-task training/eval across all 40 BabyAI levels
  - Add per-level TensorBoard logging in evaluator (WM+LLMPrior and LLMPrior modes)
  - Run initial evaluation before training loop starts
  - Align hyperparams: max_steps=20, prompt_max_len=512, model=qwen2.5-7b
  - Increase eval intervals (wm: 2000, llm: 200) for 40-level multi-task

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…Inter-RL levels

- Add 1D→2D unsqueeze guard in tokenizer and HFLanguageRepresentationNetwork
  to prevent BERT ValueError when batch dim is missing during evaluation
- Correct task list from 40 to 18 levels based on HF AgentGym-RL-Data-ID dataset
  (levels: 1-11, 19-21, 30-31, 33, 36)
- Increase evaluator_env_num from 2 to 8 for ~4x faster multi-task eval

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ajectory dump

  - Add _should_continue_collect + drain_vllm_iter to PriorZeroCollector to
    fix the vLLM-TP deadlock that hung _sync_prompts_for_tp when DDP ranks
    finished episodes at different step counts (mirrors the existing eval fix)
  - Add _sync_prompts_for_tp / drain_vllm_iter in DataProcessor: all_gather
    prompts over the TP subgroup so partners submit matched vllm.generate
    calls; required for any vllm_tensor_parallel_size > 1 with DDP > 1
  - Introduce setup_priorzero_logging with priorzero.main/train/eval loggers
    (file + rank-0 console, NullHandler elsewhere); replace scattered
    loguru/print calls in entry, trainer, datafactory, vllm worker
  - Evaluator: dist.barrier between WM and WM_LLMPrior eval, save per-episode
    eval trajectories as JSON, extend per-level TB stats with mean/std/min/max
  - MuZero evaluator: add total_finishes hard counter to prevent eval hang
    when episodes are unevenly distributed across envs
  - Guards: empty-batch return in UniZeroPolicy._forward_eval, zero-length
    input return in HFLanguageRepresentationNetwork
  - BabyAI: bump prompt_max_len 512→4096 to fit obs budget, align
    evaluator_env_num with eval level count

  Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…size floor, Fix dual CUDA OOM in train→vLLM handoff and large-batch forward pass

  - Remove `max(micro_train_batch_size, 32)` floor in PolicyModel.forward and
    ReferenceModel.forward; use micro_train_batch_size directly (2 vs 32),
    reducing per-chunk logits from ~37 GiB to ~2.3 GiB for Qwen2.5-7B
  - Remove redundant batch-level .to(device) before chunking loop to avoid
    duplicating full batch on GPU alongside per-chunk slices
  - Reduce default micro_train_batch_size from 4 to 2 for 7B model headroom

  - Reorder offload_states() before _broadcast_to_vllm() in train_batch so
    DeepSpeed optimizer states are freed to CPU before vLLM wake_up() reclaims
    ~43 GiB for weights+KV cache (was OOM at cumem_allocator.cpp:62)
  - Defer logits.to(float32) to training path only (return_entropy=True),
    avoiding a 42 GiB fp32 allocation during old_action_log_probs forward;
    bf16 path in log_probs_from_logits already handles numerical stability
  - Reduce micro_train_batch_size 4→2 in BabyAI config

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…aseline config

  - Fix eval_only_llm_prior episode refill: with env_num=4 and n_episode=18,
    finished envs were never re-added, so only 4 of 18 levels were evaluated.
    Add refill logic mirroring eval_with_llm_prior to cover all levels.
  - Reduce format_weight 0.5→0.1 to prevent constant-positive advantage bias
    from always-correct format rewards causing response length collapse (88→10)
  - Increase rft_kl_coef 0.001→0.01 and entropy_loss_coef 0→0.01 to anchor
    policy against drift and entropy collapse
  - Accept **kwargs in GameSegment.append() so env-specific fields like
    raw_obs_text pass through without TypeError
  - Add BabyAI UniZero baseline config (no LLM) with WM hyperparameters
    aligned to PriorZero for fair ablation comparison

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…or, harden RFT training

  - Refactor PriorZero evaluator TB logging: unified x-axis (env_step) for
    all three eval modes (wm_only, wm_llm, llm_only) with new hierarchical
    tags eval/{mode}/agg/* and eval/{mode}/per_level/*; old tags preserved
    under deprecated/ prefix for transition
  - Add env_step-based eval frequency (wm_eval_freq_envsteps,
    llm_eval_freq_envsteps) to both BabyAI and Jericho configs, with
    iter-based fallback when set to 0
  - Remove phase guard on eval_wm_only and eval_llm_only so all three
    modes run in every phase, enabling consistent cross-phase comparison
  - Add eval_wm_only() method to PriorZero evaluator with per-level
    tracking (previously delegated to super().eval() without level info)
  - Create MuZeroPerLevelEvaluator for UniZero baseline with dual x-axis
    TB logging (env_step primary, train_iter secondary) and per-level
    reward breakdown matching PriorZero tag structure
  - Add KL early stopping in BatchPPOTrainer: skip remaining gradient
    updates when ref_kl exceeds kl_early_stop_threshold (default 0 =
    disabled)
  - Replace hard assert with warning + filter in DataProcessor for
    tokenizer round-trip mismatches to avoid training crashes
  - Tune BabyAI RFT hyperparams: format_weight 0.1→0.3, rft_kl_coef
    0.01→0.1, entropy_loss_coef 0.01→0.001, replay_buffer 300k→500k,
    collect_num_simulations 50→25

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…-babyai-textcraft

feature(pu): add babyai and textcraft env and configs for priorzero
feature(pu): add image-based/vlm version of priorzero
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

research Research work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants