diff --git a/docs/design/sep_code.md b/docs/design/sep_code.md new file mode 100644 index 0000000000..5a462f70e8 --- /dev/null +++ b/docs/design/sep_code.md @@ -0,0 +1,492 @@ +# 共卡 / 非共卡生产代码拆分设计 + +## 1. 背景 + +当前 `AgentLoopManager` 同时承担两套生产模式: + +- 共卡训练:一次 `produce_batch()` 内完成 rollout 生产、pending 收尾、从 replay buffer 取训练 batch。 +- 非共卡训练:后台 **Background Producer** 持续写 replay buffer,前台 **Training Consumer** 通过 `get_batch()` 消费,并在 **Expired Produce Batch**、权重同步、评测、checkpoint 之间切换状态。 + +这两套模式共享同一个 `AgentLoopManager`、同一个 `ProduceProgress`、同一个 `AsyncProduceStrategy` 实现。结果是: + +- 共卡路径需要理解 `_status / _update_event / _model_step / _produce_progress` 等非共卡状态。 +- 非共卡路径修改容易改变共卡 `produce_batch()` 的同步行为。 +- `AsyncProduceStrategy` 的 pending task 既被当作单次调用局部状态,又被当作非共卡跨调用后台状态。 + +本设计目标是拆开生产侧代码,让共卡生产和非共卡生产各自有独立 **Module**、独立 **Interface** 和独立状态;同时保留 `AsyncProduceStrategyConfig` 在共卡训练中的异步生产能力,并用 `DisaggAsyncProduceStrategyConfig` 显式表达非共卡后台 producer。 + +## 2. 目标 + +1. 共卡生产修改不影响非共卡生产。 +2. 非共卡 **Background Producer** / **Training Consumer** 状态机修改不影响共卡 `produce_batch()`。 +3. 共卡和非共卡使用不同的 strategy config:`AsyncProduceStrategyConfig` 构建共卡 `AsyncProduceStrategy`,`DisaggAsyncProduceStrategyConfig` 构建非共卡 `DisaggAsyncProduceStrategy`。 +4. 共卡 async 生产保持简单:pending task 只属于单次 `AgentLoopManager.produce_batch()`,不跨 manager 调用保存。 +5. 非共卡 async 生产保留后台 pending、pause/continue、Expired Produce Batch、checkpoint/resume 等能力。 + +## 3. 非目标 + +- 不兼容把 `AsyncProduceStrategyConfig(...)` 同时用于共卡和非共卡训练;非共卡训练配置必须显式使用 `DisaggAsyncProduceStrategyConfig(...)`。 +- 不改变 replay buffer 的领域语义。 +- 不在共卡路径引入非共卡状态机。 +- 不把所有共享 helper 都拆成公开接口;共享逻辑可以作为私有 Implementation 留在 manager 包内部。 + +## 4. 总体方案 + +把现在一个宽 `AgentLoopManager` 拆成两个 manager **Module**: + +- `AgentLoopManager` +- `DisaggAgentLoopManager` + +把现在一个 `ProduceProgress` 拆成两个进度 **Module**: + +- `ProduceProgress` +- `DisaggProduceProgress` + +把现在一个完整 `AsyncProduceStrategy` 拆成两个具体 strategy **Adapter**: + +- `AsyncProduceStrategy` +- `DisaggAsyncProduceStrategy` + +共卡和非共卡使用不同 config 构建具体 Adapter,不在 strategy config 的 `build(...)` 里传 `mode`: + +```python +AsyncProduceStrategyConfig(...).build(...) +# -> AsyncProduceStrategy + +DisaggAsyncProduceStrategyConfig(...).build(...) +# -> DisaggAsyncProduceStrategy +``` + +也就是说,拆分的是执行模式,不是删除共卡 async。 + +设计约束: + +- `AsyncProduceStrategy` 和 `DisaggAsyncProduceStrategy` 不继承公共父类。两者各自显式持有配置字段,少量共享算法用 module-level helper 函数表达。 +- 共卡和非共卡的 strategy **Interface** 分开:`ProduceStrategy` 定义共卡 `produce_batch(ctx)`、`pause_produce(ctx)`;`DisaggProduceStrategy` 定义非共卡 `produce_batch(ctx)`、`pause_produce(ctx)`、`pending_task_count()`。共卡 `pause_produce(ctx)` 只收尾本次 manager 调用的 pending,不承载非共卡 checkpoint / update-event 语义。 +- `AgentLoopManager` 和 `DisaggAgentLoopManager` 不继承公共父类。task batch 分配、staleness refresh、take batch、result 聚合等共享逻辑用 module-level helper 函数表达。 +- `pause_produce` 的关键顺序和 pending drain 协议必须复用当前生产代码语义,核心 drain 协议抽成 `pause_pending_tasks(...)`,而不是藏在某个 manager 父类或 async strategy 父类里。 +- 所有 Config 只暴露一个 `build(...)`;`AgentLoopManagerConfig.mode` 只选择 manager 类型,strategy 类型由 `ProduceStrategyConfig` / `DisaggProduceStrategyConfig` 的具体 config 类型决定。 +- 不新增 single-task / multi-task manager,也不在本次改造里新增 single/multi 私有分支;继续复用当前 task batch allocation 和结果聚合逻辑。 + +## 5. Module 职责 + +| Module | Interface | Implementation | +| --- | --- | --- | +| `AgentLoopManagerConfig` | `build(...)` | 根据 `mode` 构建 task runner、sampler、agent loop 和 manager,并校验 strategy config 类型 | +| `AgentLoopManager` | `produce_batch(batch_size, train_step, model_step)` | 共卡单次生产、局部收尾、取训练 batch | +| `DisaggAgentLoopManager` | `produce_loop`, `get_batch`, `pause_produce`, `continue_produce`, `shutdown` | 非共卡后台生产和消费状态机 | +| `ProduceProgress` | `build`, `add_raw_rewards`, `add_produced`, `add_produce_time` | 单次共卡生产窗口,不进 checkpoint | +| `DisaggProduceProgress` | `ensure_target_upto`, `begin_consume`, `mark_consumed`, `state_dict` | 非共卡绝对累计 target/consumed 和 resume 状态 | +| `ProduceStrategy` | `produce_batch(ctx)`, `pause_produce(ctx)` | 共卡 strategy 抽象接口父类,只接收 `ProduceContext` | +| `DisaggProduceStrategy` | `produce_batch(ctx)`, `pause_produce(ctx)`, `pending_task_count()` | 非共卡 strategy 抽象接口父类,只接收 `DisaggProduceContext` | +| `AsyncProduceStrategy` | `ProduceStrategy` | 持有本次 manager 调用的局部 pending set;`produce_batch` 只生产,`pause_produce` 才 drain | +| `DisaggAsyncProduceStrategy` | `DisaggProduceStrategy` | `_PendingTasks` 跨调用保存,处理 update event 和 model expired | + +建议的共享 helper: + +| Helper | 用途 | +| --- | --- | +| `allocate_task_batch_sizes(...)` | 复用当前按 task weight 分配 batch 的逻辑 | +| `validate_task_batch_sizes(...)` | 复用 batch size 校验 | +| `refresh_for_all_tasks(...)` | 复用 completed / aborted staleness refresh | +| `take_train_batch(...)` | 复用 replay buffer take、consumed 记账、leftover 统计、result 聚合 | +| `pause_pending_tasks(...)` | 复用 pending task pause / drain / cancel 协议 | + +这些 helper 是 Implementation 复用,不是新的业务 **Interface**。调用方仍只看到 mode-specific manager 和 strategy。 + +## 6. Config 构建规则 + +`TaskSpecConfig.produce_strategy_config` 接受两类 config: + +- 共卡:`ProduceStrategyConfig`,例如 `SyncProduceStrategyConfig` / `AsyncProduceStrategyConfig`。 +- 非共卡:`DisaggProduceStrategyConfig`,例如 `DisaggAsyncProduceStrategyConfig`。 + +所有 Config 都只保留一个 `build(...)`,不提供 `build_colocate(...)` / `build_disaggregated(...)` 这类 mode-specific wrapper,也不在 strategy config 的 `build(...)` 里传 `mode`: + +```python +class AgentLoopManagerConfig: + mode: Literal["colocate", "disaggregated"] = "colocate" + + def build(...): + if self.mode == "colocate": + assert isinstance(task_cfg.produce_strategy_config, ProduceStrategyConfig) + if self.mode == "disaggregated": + assert isinstance(task_cfg.produce_strategy_config, DisaggProduceStrategyConfig) + + strategy = task_cfg.produce_strategy_config.build( + sync_weights_interval=sync_weights_interval, + rollout_controller=rollout_controller, + ) + if self.mode == "colocate": + return AgentLoopManager(task_runners, replay_buffer, logger) + if self.mode == "disaggregated": + return DisaggAgentLoopManager(task_runners, replay_buffer, logger) +``` + +`SyncProduceStrategyConfig` 只构建普通 `SyncProduceStrategy`。非共卡训练 producer 只有非共卡 async config,因此 `AgentLoopManagerConfig(mode="disaggregated").build(...)` 下遇到 sync / 共卡 async config 应 fail fast: + +```python +class SyncProduceStrategyConfig: + def build(self, *, sync_weights_interval, rollout_controller): + return SyncProduceStrategy(...) +``` + +`AsyncProduceStrategyConfig` 只构建共卡 async strategy: + +```python +class AsyncProduceStrategyConfig: + def build(self, *, sync_weights_interval, rollout_controller): + return AsyncProduceStrategy(...) +``` + +`DisaggAsyncProduceStrategyConfig` 只构建非共卡后台 async strategy: + +```python +class DisaggAsyncProduceStrategyConfig: + def build(self, *, sync_weights_interval, rollout_controller): + return DisaggAsyncProduceStrategy(...) +``` + +非共卡评估不是后台 producer,不构建 `DisaggProduceStrategy`: + +```python +eval_agent_loop_manager_cfg = cfg.eval_agent_loop_manager_cfg.model_copy(update={"mode": "colocate"}) +self.eval_agent_loop_manager = eval_agent_loop_manager_cfg.build(...) +# eval task 可以继续使用 SyncProduceStrategyConfig -> SyncProduceStrategy +``` + +这个构建边界的价值是:共卡 manager 拿到 `ProduceStrategy`,非共卡 manager 拿到 `DisaggProduceStrategy`,配置类型本身表达执行环境。两个 strategy **Interface** 的名字和方法签名不同,因此非共卡 pending / checkpoint 语义不会泄漏到共卡 strategy,也不会藏在 strategy config 的 `mode` 分支里。 + +## 6.1 Strategy Context + +strategy 方法只接收 mode-specific context,不把 `Progress` 作为第二个散装参数: + +```python +class BaseProduceContext: + ... + +class ProduceStrategy: + async def produce_batch(self, ctx: ProduceContext): ... + async def pause_produce(self, ctx: ProduceContext): ... + +class DisaggProduceStrategy: + async def produce_batch(self, ctx: DisaggProduceContext): ... + async def pause_produce(self, ctx: DisaggProduceContext): ... +``` + +原因是 `Progress` 仍按当前内部字段结构由 context 持有,但不应该变成 strategy 方法签名里的通用第二参数: + +- `BaseProduceContext` 保留当前 `ProduceContext` 的内部字段结构,例如 `task_batch_size`、`progress`、`stale_threshold`,以及 `sample_group()` / `generate_group()` / `put_generated_group()` 行为。 +- `ProduceContext` 是共卡简化版,只去掉非共卡需要的 `update_event`、绝对 consumed/target 访问和 checkpoint 语义,不把 raw rewards / produced samples / produce time 重构成一个 `metrics` 字段。 +- `DisaggProduceContext` 继承 `BaseProduceContext`,额外暴露 `update_event`、`available_count()`、`target_abs` 和 `DisaggProduceProgress`。 +- 这样可以保留原来 `SyncProduceStrategy.produce_batch(ctx)` 的简单形状;不是改成 `produce_batch(ctx, progress)`。 + +## 7. 共卡生产流程 + +共卡路径只允许一个 public 入口: + +```python +await manager.produce_batch(batch_size, train_step, model_step=model_step) +``` + +流程: + +1. 根据 `train_step` 计算 task batch sizes。 +2. 创建 `ProduceProgress`。 +3. `continue_generation()`,切到 rollout 阶段。 +4. 各 task 并发调用对应 strategy 的 `produce_batch(ctx)`,只生产到 replay buffer。 +5. 等所有 active task 的 `produce_batch(ctx)` 都返回后,manager 再逐个调用 `pause_produce(ctx)`,由 strategy 内部复用 `pause_pending_tasks(...)` 收尾本次 pending。 +6. 从 replay buffer 取 completed rollout groups。 +7. `pause_generation()`,切回静止态。 +8. 返回非空 `ProduceBatchResult`。 + +关键不变量:共卡 multi-task 下,先达到 target 的 task 不能在自己的 `produce_batch(ctx)` 结束时立刻调用 `pause_pending_tasks(...)`。否则它会提前向 rollout worker 发送 pause,阻塞其他还在生产的 task。pending 收尾必须由 `AgentLoopManager.produce_batch()` 在所有 task 生产结束后统一编排。 + +异常语义:`asyncio.gather(...)` 只有在所有 task 的 `produce_batch(ctx)` 正常返回时才会进入后续 pause/drain 和 take batch。任一 task 抛异常时,manager 不捕获、不转换成 `ProduceBatchStatus`、不做 best-effort cleanup,让原始异常直接向 trainer 传播并中断训练,避免 `finally` 里的二次异常覆盖真正的失败栈。 + +业务约束:同一个共卡 `AgentLoopManager` 实例不支持并发调用 `produce_batch()`。`AsyncProduceStrategy` 持有的是本次 manager 调用的局部 pending set,这个约束由 trainer 调用模型保证,不在 manager 里增加复杂防御。 + +共卡路径不出现: + +- `_status` +- `_update_event` +- `_finish_event` +- `DisaggProduceProgress` +- `_PendingTasks` +- `produce_loop` +- `get_batch` +- `continue_produce` + +## 8. 非共卡生产流程 + +非共卡路径由两个 public 入口协作: + +```python +producer_task = create_task(manager.produce_loop(batch_size)) +get_batch_task = create_task(manager.get_batch(batch_size, train_step=train_step)) +done, _ = await wait({producer_task, get_batch_task}, return_when=FIRST_COMPLETED) +if producer_task in done: + producer_task.result() +produce_result = get_batch_task.result() +``` + +`DisaggAgentLoopManager` 独占以下状态: + +- `status` +- `update_event` +- `finish_event` +- `model_step` +- `pause_time_s` +- `DisaggProduceProgress` + +核心不变量: + +- **Background Producer** 只在 `NORMAL` 状态下推进 `producer_future_step`。 +- **Training Consumer** 成功取出非空 batch 后推进 `consumed_samples` 和 `next_consumer_step`。 +- **Expired Produce Batch** 只有在训练侧已有更新 **Model Step** 时,才允许返回空 batch 跳过训练。 +- 权重同步前必须 `pause_produce()`,同步/评测后必须 `continue_produce(model_step=...)`。 +- **Background Producer** 异常是终止性训练失败,不转换成 manager status;trainer 在等待 `get_batch()` 时必须同时观察 `producer_task`,用 `producer_task.result()` 暴露原始异常栈并中断训练。 +- 非共卡异常路径也不做 best-effort cleanup;正常训练结束时才显式 `shutdown()` 并等待后台 producer 退出。 + +## 9. Async 策略拆分 + +旧 `AsyncProduceStrategy` 的完整实现拆成两个具体 Adapter,并由两个 config 分别构建。 + +### 9.1 `AsyncProduceStrategy` + +职责: + +- 本次 manager `produce_batch()` 期间持有局部 `pending_tasks = set()`。 +- 按 `over_sample_threshold`、tail batch、partial rollout 规则调度 rollout group。 +- 保留当前 async producer 的生产预算语义:normal 模式的 oversample 预算按 `ceil(over_sample_threshold * task_batch_size)` 计算;tail-batch 模式从 expired / aborted pool 采样,且不再扩大 oversample 窗口。 +- 收到完成结果后过滤、写 replay buffer、更新本次统计字段。 +- 达到本次 batch target 后返回;不在 `produce_batch(ctx)` 内暂停 agent loop。 +- `pause_produce(ctx)` 复用 `pause_pending_tasks(...)` drain 本次 pending;只能由 manager 在所有 task 的 `produce_batch(ctx)` 都返回后调用。 + +它不负责: + +- 跨调用保存 pending。 +- 观察 `update_event`。 +- 返回 `UPDATE_WEIGHT_AND_ABORT`。 +- 维护 `model_step` 状态机。 +- checkpoint pending task。 +- 继承公共 async 父类。 + +### 9.2 `DisaggAsyncProduceStrategy` + +职责: + +- 持有 `_PendingTasks`,允许 pending task 跨多次 `produce_batch()` 调用存在。 +- 观察 `ctx.should_abort()`。 +- 根据 `model_step / producer_future_step` 判断 **Expired Produce Batch**。 +- 保留当前 async producer 的生产预算语义:normal 模式的 oversample 预算按 `ceil(over_sample_threshold * task_batch_size)` 计算;tail-batch 模式从 expired / aborted pool 采样,且不再扩大 oversample 窗口。 +- `pause_produce()` drain 或 cancel pending。 +- 为 checkpoint 提供 `pending_task_count()`。 + +它不负责: + +- 从 replay buffer 取训练 batch。 +- 推进 `DisaggProduceProgress` 的 consumer step。 +- 触发权重同步。 +- 继承公共 async 父类。 + +### 9.3 pause pending helper + +当前最新 `pause_produce` 有两个层次: + +1. manager 层:先设置暂停信号,切换 manager 状态,再暂停 rollout controller。 +2. strategy 层:如果还有 pending task,周期性发送 agent loop pause,claim 已完成任务并入库,超过 timeout 后 cancel 剩余 pending。 + +拆分后保留这个顺序,但把 strategy 层 pending drain 抽成全局 helper。共卡路径的 manager 层不设置非共卡 update-event/status,只负责“所有 task produce 完成后再统一收尾”的顺序;非共卡路径的 manager 层仍负责设置暂停信号和状态: + +```python +async def pause_pending_tasks( + *, + pending_tasks: set[asyncio.Task] | _PendingTasks, + ctx, + put_claimed_task, +) -> float: + if isinstance(pending_tasks, set): + pending = _LocalPendingTasks(pending_tasks) + else: + pending = pending_tasks + + if pending.count() == 0: + return 0.0 + + pending_pause_tasks = {create_task(request_agent_loop_pause(ctx))} + deadline = now() + PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S + next_periodic_pause = now() + PERIODIC_ABORT_INTERVAL_S + + while pending.count() > 0: + if now() > deadline: + await pending.cancel_all() + break + + if now() >= next_periodic_pause: + pending_pause_tasks.add(create_task(request_agent_loop_pause(ctx))) + next_periodic_pause += PERIODIC_ABORT_INTERVAL_S + + claimed = await pending.wait_and_claim(timeout_s=1) + for task in claimed: + await put_claimed_task(task) + + await cancel_and_drain(pending_pause_tasks) + return elapsed() +``` + +共卡路径直接把本次调用的局部 `set[Task]` 传给 helper,helper 内部自动包成 `_LocalPendingTasks`;非共卡路径直接传 `_PendingTasks`。这样 pause 协议复用,但 pending 的生命周期仍然独立: + +- 共卡:pending 生命周期等于一次 `produce_batch()`。 +- 非共卡:pending 生命周期跨多次后台 `produce_batch()`。 + +## 10. Progress 拆分 + +### 10.1 `ProduceProgress` + +构建入口只保留 `build(...)`: + +```python +ProduceProgress.build( + task_names=task_names, + target_samples=task_batch_sizes, + train_step=train_step, +) +``` + +字段: + +- `producer_future_step` +- `target_samples` +- 本次 raw reward / produced samples / produced tokens / produce time 统计字段 + +特点: + +- 不保存到 checkpoint。 +- 不维护 `next_consumer_step / consumed_samples / target_upto_future_step`。 +- 不新增 `model_step` 字段;`model_step` 仍是 manager 构建 `ProduceContext` 时传入的运行时参数。 +- 不把当前 `target_samples` 改名为 `task_batch_sizes`。共卡路径里 `target_samples` 表达本次 `produce_batch()` 的局部 target,不是非共卡的绝对累计 target。 +- 不维护非共卡后台 producer 推进语义;`producer_future_step` 只作为本次 staleness / future step 写入字段。 +- 不暴露 `state_dict()`。 + +### 10.2 `DisaggProduceProgress` + +字段: + +- `producer_future_step` +- `next_consumer_step` +- `target_samples` +- `consumed_samples` +- `target_upto_future_step` +- 后台 producer 统计字段 + +特点: + +- `target_samples / consumed_samples` 使用绝对累计口径。 +- `state_dict / load_state_dict` 是非共卡 checkpoint 的一部分。 +- producer 和 consumer 共享同一个对象引用。 + +## 11. ReplayBuffer 保持共享 + +Replay buffer 是真正共享的深 **Module**,不按共卡/非共卡拆。它提供: + +- `put(...)` +- `refresh_staleness(...)` +- `is_ready(...)` +- `take_batch(...)` +- `count_statuses(...)` + +共享理由: + +- 共卡和非共卡都需要落库和取 completed rollout groups。 +- Replay buffer 不理解 manager 状态机。 +- Replay buffer 的 **Interface** 已经足够表达 storage / replay ordering 行为。 + +## 12. Trainer 集成 + +共卡 trainer: + +```python +cfg.agent_loop_manager_cfg.mode = "colocate" +# task.produce_strategy_config = SyncProduceStrategyConfig(...) 或 AsyncProduceStrategyConfig(...) +self.agent_loop_manager = cfg.agent_loop_manager_cfg.build(...) +``` + +非共卡 trainer: + +```python +cfg.agent_loop_manager_cfg.mode = "disaggregated" +# task.produce_strategy_config = DisaggAsyncProduceStrategyConfig(...) +self.agent_loop_manager = cfg.agent_loop_manager_cfg.build(...) +``` + +评测 manager 建议始终用共卡 manager: + +```python +cfg.eval_agent_loop_manager_cfg.mode = "colocate" +self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build(...) +``` + +原因:evaluation 是一次性 `produce_batch()`,不是后台 **Background Producer**。 + +## 13. 迁移步骤 + +1. 用 `Literal["colocate", "disaggregated"]` 表达 `AgentLoopManagerConfig.mode`,`AgentLoopManagerConfig` 只保留 `build(...)`。 +2. 新增 `AgentLoopManager`,把当前 `produce_batch()` 的共卡逻辑迁移过去。 +3. 新增 `DisaggAgentLoopManager`,把 `produce_loop/get_batch/pause/continue/shutdown/save/resume` 迁移过去。 +4. 拆出 `ProduceProgress` 和 `DisaggProduceProgress`,`ProduceProgress` 只保留 `build(...)` 构造入口。 +5. 把当前 `AsyncProduceStrategy` 拆成 `AsyncProduceStrategy` 和 `DisaggAsyncProduceStrategy`,并新增 `DisaggAsyncProduceStrategyConfig`。 +6. 把 batch allocation、refresh、take batch、pause pending drain 抽成 module-level helper。 +7. trainer 通过设置 manager config `mode` 后调用同一个 `build(...)`;非共卡训练配置同步替换为 `DisaggAsyncProduceStrategyConfig`。 +8. 保留必要兼容导出,但不保留“同一个 strategy config 靠 mode 切换”的兼容语义。 + +## 14. 测试建议 + +共卡 manager: + +- `AgentLoopManagerConfig(mode="colocate").build(...)` 构建出的 manager 能通过 public `produce_batch(...)` 返回非空训练 batch。 +- 共卡 multi-task `produce_batch(...)` 按 task 权重稳定返回训练数据。 +- 共卡 multi-task async 生产中,先完成的 task 不会提前暂停 rollout worker;所有 active task 完成生产后才统一收尾 pending。 +- 共卡 async `produce_batch(...)` 返回后,再次调用仍能正常生产,不受上一次 pending 收尾影响。 + +非共卡 manager: + +- `AgentLoopManagerConfig(mode="disaggregated").build(...)` 构建出的 manager 能通过 public `produce_loop(...)` / `get_batch(...)` 协作返回训练 batch。 +- `produce_loop/get_batch/pause_produce` 的 single-task 和 multi-task public 行为一致,避免复制后台状态机。 +- `produce_loop/get_batch` 仍处理空/非空 **Expired Produce Batch**。 +- `pause_produce/continue_produce` 顺序不变。 +- checkpoint/resume 后,public `get_batch(...)` / `continue_produce(...)` 行为延续保存前的 producer progress 和 model step。 + +策略: + +- `AsyncProduceStrategy` 通过 public `produce_batch(ctx)` / `pause_produce(ctx)` 覆盖 oversample、partial rollout、tail batch 和本次 pending 收尾结果。 +- `DisaggAsyncProduceStrategy` 通过 public `produce_batch(ctx)` / `pause_produce(ctx)` / `pending_task_count()` 覆盖跨调用 pending、abort、expired 和 checkpoint 前 fail fast。 + +trainer: + +- 共卡 trainer 只依赖 `produce_batch()`。 +- 非共卡 trainer 只依赖 `produce_loop/get_batch/pause/continue/shutdown`。 +- 非共卡 trainer 的 eval manager 是 colocate manager,initial evaluate 后按非共卡训练需求恢复 producer。 + +## 15. 关键判断 + +`AsyncProduceStrategy` 的领域含义不是“非共卡策略”,而是“共卡异步 rollout 生产策略”。非共卡后台 producer 需要显式的 `DisaggAsyncProduceStrategyConfig` / `DisaggAsyncProduceStrategy`。 + +真正需要隔离的是执行环境: + +- 共卡执行环境:局部 pending,单次调用完成。 +- 非共卡执行环境:后台 pending,跨调用状态机。 + +所以最终代码形状应是: + +```python +AsyncProduceStrategyConfig + -> AsyncProduceStrategy + +DisaggAsyncProduceStrategyConfig + -> DisaggAsyncProduceStrategy +``` + +而不是让 `AsyncProduceStrategyConfig` 内部继续同时认识共卡和非共卡两套执行协议。 + +后者会让一个 Config 继续知道两套执行协议,复杂度只是换位置,不能提供足够的 **Locality**。 diff --git a/docs/design/sep_code_demo.py b/docs/design/sep_code_demo.py new file mode 100644 index 0000000000..4debd27c39 --- /dev/null +++ b/docs/design/sep_code_demo.py @@ -0,0 +1,1472 @@ +"""共卡 / 非共卡生产代码拆分伪代码。 + +说明: +- 这是设计伪代码,用来展示 Module、Interface 和 Adapter 关系,不是可直接运行实现。 +- 重点是把共卡同步生产和非共卡 Background Producer / Training Consumer 分开。 +- 共卡 AsyncProduceStrategyConfig 和非共卡 DisaggAsyncProduceStrategyConfig 是不同配置类型, + 不在 strategy config.build(...) 里用 mode 切换。 +""" + +from __future__ import annotations + +import asyncio +import math +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum, auto +from pathlib import Path +from typing import Any, Awaitable, Callable, Literal, Protocol, TypeAlias + + +class Status(Enum): + INIT = auto() + COMPLETED = auto() + ABORTED = auto() + EXPIRED = auto() + FAILED = auto() + FILTERED = auto() + + +class ProduceBatchStatus(Enum): + NORMAL = auto() + UPDATE_WEIGHT_AND_ABORT = auto() + EXPIRED_BATCH = auto() + + +ProducerMode: TypeAlias = Literal["colocate", "disaggregated"] + + +class DisaggManagerStatus(Enum): + NORMAL = auto() + UPDATE_WEIGHT_AND_ABORT = auto() + EXPIRED_BATCH = auto() + FINISH = auto() + + +def get_group_status(group: list[Any]) -> Status: + """聚合 rollout group 状态。 + + 这里只读状态,不修改样本。过滤和过期翻转必须发生在显式业务逻辑里。 + """ + + ... + + +def calculate_seq_staleness(model_step: int, train_step: int) -> int: + ... + + +AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S = 10.0 +PERIODIC_ABORT_INTERVAL_S = 5.0 +PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S = 60.0 + + +def calculate_stale_threshold(max_staleness: int, sync_weights_interval: int) -> int: + return (max_staleness + 1) * sync_weights_interval + + +@dataclass +class ProduceBatchResult: + rollout_states: list[list[Any]] + status: ProduceBatchStatus = ProduceBatchStatus.NORMAL + group_gen_count: int | None = None + group_gen_mean_s: float | None = None + group_gen_p50_s: float | None = None + group_gen_p99_s: float | None = None + group_gen_p99_p50_ratio: float | None = None + group_gen_pause_time_s: float | None = None + leftover_init: int = 0 + leftover_completed: int = 0 + leftover_aborted: int = 0 + leftover_expired: int = 0 + leftover_failed: int = 0 + leftover_filtered: int = 0 + raw_rewards_sum: float = 0.0 + raw_rewards_count: int = 0 + produced_samples: int = 0 + produced_tokens: int = 0 + produce_time_s: float = 0.0 + task_batch_sizes: dict[str, int] | None = None + task_results: dict[str, "ProduceBatchResult"] | None = None + + +class ReplayBuffer(Protocol): + async def put( + self, + group: list[Any], + task_name: str, + *, + model_step: int | None = None, + current_train_step: int | None = None, + stale_threshold: int | None = None, + ) -> None: ... + + async def count(self, task_name: str, group_status: Status) -> int: ... + + async def refresh_staleness( + self, + *, + task_stale_thresholds: dict[str, int], + current_train_step: int, + statuses: list[Status], + ) -> dict[str, int]: ... + + async def is_ready(self, task_batch_sizes: dict[str, int]) -> bool: ... + + async def take_batch( + self, + task_batch_sizes: dict[str, int], + ) -> tuple[dict[str, list[list[Any]]], dict[str, int]]: ... + + async def count_statuses( + self, + task_names: list[str], + statuses: list[Status], + ) -> dict[str, dict[Status, int]]: ... + + async def save(self, checkpoint_path: Path) -> None: ... + + async def resume(self, checkpoint_path: Path) -> None: ... + + +class Sampler(Protocol): + async def sample( + self, + *, + task_name: str, + group_status: Status | list[Status] | None = None, + ) -> list[Any]: ... + + def save(self, checkpoint_path: Path) -> None: ... + + def resume(self, checkpoint_path: Path) -> None: ... + + +class AgentLoop(Protocol): + async def generate_group( + self, + group: list[Any], + *, + enable_partial_rollout: bool = False, + ) -> list[Any]: ... + + async def pause(self) -> None: ... + + +class RolloutController(Protocol): + async def continue_generation(self) -> None: ... + + async def pause_generation(self) -> None: ... + + +class ShouldContinueFn(Protocol): + def __call__(self, completed_count: int, batch_size: int, **kwargs: Any) -> bool: ... + + +class IsValidSampleFn(Protocol): + def __call__(self, samples: list[Any]) -> bool: ... + + +def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs: Any) -> bool: + return completed_count < batch_size + + +def default_is_valid_sample_fn(samples: list[Any]) -> bool: + return True + + +@dataclass +class ProduceProgress: + """共卡单次 produce_batch 的局部进度。 + + 中文不变量: + - 只表达本次调用,不进入 checkpoint。 + - pending task 由具体 strategy 在本次调用内持有。 + - 裁剪非共卡需要的 producer_future_step / next_consumer_step / consumed_samples / target_upto_future_step / state_dict。 + - 不新增 model_step,model_step 仍由 manager 放进 ProduceContext。 + """ + + target_samples: dict[str, int] + raw_rewards_sum: dict[str, float] = field(default_factory=dict) + raw_rewards_count: dict[str, int] = field(default_factory=dict) + produced_samples: dict[str, int] = field(default_factory=dict) + produced_tokens: dict[str, int] = field(default_factory=dict) + produce_time_s: float = 0.0 + + @classmethod + def build( + cls, + *, + task_names: list[str], + target_samples: dict[str, int], + ) -> "ProduceProgress": + return cls( + target_samples=dict(target_samples), + raw_rewards_sum={name: 0.0 for name in task_names}, + raw_rewards_count={name: 0 for name in task_names}, + produced_samples={name: 0 for name in task_names}, + produced_tokens={name: 0 for name in task_names}, + ) + + def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None: + self.raw_rewards_sum[task_name] += rewards_sum + self.raw_rewards_count[task_name] += rewards_count + + def add_produced(self, task_name: str, samples: int, tokens: int) -> None: + self.produced_samples[task_name] += samples + self.produced_tokens[task_name] += tokens + + def add_produce_time(self, elapsed_s: float) -> None: + self.produce_time_s += elapsed_s + + +@dataclass +class DisaggProduceProgress: + """非共卡 Background Producer / Training Consumer 共享进度。 + + 中文不变量: + - target_samples / consumed_samples 使用绝对累计口径。 + - consumer 从 replay buffer 取走样本后只增加 consumed,不回退 target。 + - producer_future_step 只由后台 producer 正常完成生产后推进。 + - 该对象会进入 checkpoint/resume。 + """ + + task_names: list[str] + producer_future_step: int = 1 + next_consumer_step: int = 1 + target_upto_future_step: int = 0 + consumed_samples: dict[str, int] = field(default_factory=dict) + target_samples: dict[str, int] = field(default_factory=dict) + raw_rewards_sum: dict[str, float] = field(default_factory=dict) + raw_rewards_count: dict[str, int] = field(default_factory=dict) + produced_samples: dict[str, int] = field(default_factory=dict) + produced_tokens: dict[str, int] = field(default_factory=dict) + produce_time_s: float = 0.0 + + @classmethod + def build(cls, task_names: list[str]) -> "DisaggProduceProgress": + return cls( + task_names=task_names, + consumed_samples={name: 0 for name in task_names}, + target_samples={name: 0 for name in task_names}, + raw_rewards_sum={name: 0.0 for name in task_names}, + raw_rewards_count={name: 0 for name in task_names}, + produced_samples={name: 0 for name in task_names}, + produced_tokens={name: 0 for name in task_names}, + ) + + def ensure_target_upto( + self, + *, + batch_size: int, + future_step: int, + allocate_batch_sizes: Callable[[int, int], dict[str, int]], + ) -> dict[str, int]: + if future_step > self.target_upto_future_step: + for step in range(self.target_upto_future_step + 1, future_step + 1): + task_sizes = allocate_batch_sizes(batch_size, step) + for task_name, task_size in task_sizes.items(): + self.target_samples[task_name] += task_size + self.target_upto_future_step = future_step + return allocate_batch_sizes(batch_size, future_step) + + def begin_consume(self, train_step: int) -> None: + self.next_consumer_step = train_step + + def mark_consumed(self, consumed_counts: dict[str, int]) -> None: + for task_name, count in consumed_counts.items(): + self.consumed_samples[task_name] += count + + def finish_consume(self, train_step: int) -> None: + self.next_consumer_step = train_step + 1 + + def advance_future_step(self) -> None: + self.producer_future_step += 1 + + def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None: + self.raw_rewards_sum[task_name] += rewards_sum + self.raw_rewards_count[task_name] += rewards_count + + def add_produced(self, task_name: str, samples: int, tokens: int) -> None: + self.produced_samples[task_name] += samples + self.produced_tokens[task_name] += tokens + + def add_produce_time(self, elapsed_s: float) -> None: + self.produce_time_s += elapsed_s + + def state_dict(self) -> dict[str, Any]: + return { + "producer_future_step": self.producer_future_step, + "next_consumer_step": self.next_consumer_step, + "target_upto_future_step": self.target_upto_future_step, + "consumed_samples": dict(self.consumed_samples), + "target_samples": dict(self.target_samples), + "raw_rewards_sum": dict(self.raw_rewards_sum), + "raw_rewards_count": dict(self.raw_rewards_count), + "produced_samples": dict(self.produced_samples), + "produced_tokens": dict(self.produced_tokens), + "produce_time_s": self.produce_time_s, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + self.producer_future_step = state["producer_future_step"] + self.next_consumer_step = state["next_consumer_step"] + self.target_upto_future_step = state["target_upto_future_step"] + self.consumed_samples.clear() + self.consumed_samples.update(state["consumed_samples"]) + self.target_samples.clear() + self.target_samples.update(state["target_samples"]) + self.raw_rewards_sum.clear() + self.raw_rewards_sum.update(state.get("raw_rewards_sum", {})) + self.raw_rewards_count.clear() + self.raw_rewards_count.update(state.get("raw_rewards_count", {})) + self.produced_samples.clear() + self.produced_samples.update(state.get("produced_samples", {})) + self.produced_tokens.clear() + self.produced_tokens.update(state.get("produced_tokens", {})) + self.produce_time_s = state.get("produce_time_s", 0.0) + + +@dataclass +class BaseProduceContext: + """strategy 生产一个 task 时看到的公共上下文。 + + 共卡和非共卡共享生成、采样、入库能力;具体 target / abort 语义由子类表达。 + """ + + task_name: str + agent_loop: AgentLoop + sampler: Sampler + replay_buffer: ReplayBuffer + task_batch_size: int + train_step: int + model_step: int + progress: ProduceProgress | DisaggProduceProgress + is_valid_sample_fn: IsValidSampleFn + stale_threshold: int | None + + @property + def current_train_step_for_staleness(self) -> int: + return self.train_step + + async def sample_group(self, *, from_expired_pool: bool) -> list[Any]: + statuses = [Status.EXPIRED, Status.ABORTED] if from_expired_pool else [Status.ABORTED] + return await self.sampler.sample(task_name=self.task_name, group_status=statuses) + + async def expired_count(self) -> int: + return await self.replay_buffer.count(self.task_name, Status.EXPIRED) + + async def generate_group( + self, + group: list[Any], + *, + enable_partial_rollout: bool, + ) -> list[Any]: + start = time.perf_counter() + result = await self.agent_loop.generate_group( + group, + enable_partial_rollout=enable_partial_rollout, + ) + self.progress.add_produce_time(time.perf_counter() - start) + return result + + async def put_generated_group(self, group: list[Any]) -> bool: + """统一处理生成结果过滤、统计和入库。 + + 中文设计点: + - 只有 completed group 才执行业务过滤。 + - ReplayBuffer.put 负责写 model_step、刷新 staleness、按阈值转 expired。 + - put 之后重新判断 group 状态,因为 completed 可能在入库前被转成 expired。 + """ + + is_completed = get_group_status(group) == Status.COMPLETED + produced_tokens = sum(len(getattr(item, "response_ids", []) or []) for item in group) + if is_completed: + # 真实实现按当前字段结构写 raw_rewards_sum/raw_rewards_count,不把这些字段重构成 metrics 对象。 + self.progress.add_raw_rewards(self.task_name, rewards_sum=0.0, rewards_count=0) + if not self.is_valid_sample_fn(group): + for item in group: + item.status = Status.FILTERED + + await self.replay_buffer.put( + group, + self.task_name, + model_step=self.model_step, + current_train_step=self.current_train_step_for_staleness, + stale_threshold=self.stale_threshold, + ) + self.progress.add_produced(self.task_name, samples=len(group), tokens=produced_tokens) + return get_group_status(group) == Status.COMPLETED + + +@dataclass +class ProduceContext(BaseProduceContext): + """共卡生产 context。 + + 中文设计点: + - 去掉非共卡 update_event / absolute consumed / checkpoint progress 语义。 + - 保留当前 ProduceContext 内部字段结构:task_batch_size、progress、stale_threshold 等仍按原形状传递。 + """ + + def should_abort(self) -> bool: + return False + + +@dataclass +class DisaggProduceContext(BaseProduceContext): + update_event: asyncio.Event + progress: DisaggProduceProgress + + @property + def current_train_step_for_staleness(self) -> int: + return self.progress.next_consumer_step + + def should_abort(self) -> bool: + return self.update_event.is_set() + + async def available_count(self) -> int: + completed = await self.replay_buffer.count(self.task_name, Status.COMPLETED) + return self.progress.consumed_samples[self.task_name] + completed + + @property + def target_abs(self) -> int: + return self.progress.target_samples[self.task_name] + + +class ProduceStrategy(ABC): + @abstractmethod + async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: ... + + async def pause_produce(self, ctx: ProduceContext) -> float: + return 0.0 + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + return False + + +class DisaggProduceStrategy(ABC): + @abstractmethod + async def produce_batch(self, ctx: DisaggProduceContext) -> ProduceBatchStatus: ... + + async def pause_produce(self, ctx: DisaggProduceContext) -> float: + return 0.0 + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + return False + + def pending_task_count(self) -> int: + return 0 + + +ModeSpecificProduceStrategy = ProduceStrategy | DisaggProduceStrategy + + +class ProduceStrategyConfig(Protocol): + def build( + self, + *, + sync_weights_interval: int, + rollout_controller: RolloutController, + ) -> ProduceStrategy: ... + + +class DisaggProduceStrategyConfig(Protocol): + def build( + self, + *, + sync_weights_interval: int, + rollout_controller: RolloutController, + ) -> DisaggProduceStrategy: ... + + +@dataclass +class SyncProduceStrategyConfig: + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + def build( + self, + *, + sync_weights_interval: int, + rollout_controller: RolloutController, + ) -> ProduceStrategy: + return SyncProduceStrategy( + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +@dataclass +class AsyncProduceStrategyConfig: + over_sample_threshold: float = 0.0 + enable_partial_rollout: bool = False + max_staleness: int = 0 + tail_batch_trigger_size: int = 0 + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + def build( + self, + *, + sync_weights_interval: int, + rollout_controller: RolloutController, + ) -> ProduceStrategy: + return AsyncProduceStrategy( + over_sample_threshold=self.over_sample_threshold, + enable_partial_rollout=self.enable_partial_rollout, + max_staleness=self.max_staleness, + sync_weights_interval=sync_weights_interval, + tail_batch_trigger_size=self.tail_batch_trigger_size, + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +@dataclass +class DisaggAsyncProduceStrategyConfig: + over_sample_threshold: float = 0.0 + enable_partial_rollout: bool = False + max_staleness: int = 0 + tail_batch_trigger_size: int = 0 + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + def build( + self, + *, + sync_weights_interval: int, + rollout_controller: RolloutController, + ) -> DisaggProduceStrategy: + return DisaggAsyncProduceStrategy( + over_sample_threshold=self.over_sample_threshold, + enable_partial_rollout=self.enable_partial_rollout, + max_staleness=self.max_staleness, + sync_weights_interval=sync_weights_interval, + tail_batch_trigger_size=self.tail_batch_trigger_size, + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +class SyncProduceStrategy(ProduceStrategy): + def __init__( + self, + *, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ) -> None: + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + + async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: + pending: set[asyncio.Task] = set() + completed = await ctx.replay_buffer.count(ctx.task_name, Status.COMPLETED) + + for _ in range(ctx.task_batch_size): + group = await ctx.sampler.sample(task_name=ctx.task_name) + pending.add(asyncio.create_task(ctx.generate_group(group, enable_partial_rollout=False))) + + while self.should_continue_fn(completed, ctx.task_batch_size): + if not pending: + break + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group = task.result() + if await ctx.put_generated_group(group): + completed += 1 + + while len(pending) + completed < ctx.task_batch_size and self.should_continue_fn( + completed, + ctx.task_batch_size, + ): + group = await ctx.sampler.sample(task_name=ctx.task_name) + pending.add(asyncio.create_task(ctx.generate_group(group, enable_partial_rollout=False))) + + return ProduceBatchStatus.NORMAL + + +class AsyncProduceStrategy(ProduceStrategy): + def __init__( + self, + *, + over_sample_threshold: float, + enable_partial_rollout: bool, + max_staleness: int, + sync_weights_interval: int, + tail_batch_trigger_size: int, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ) -> None: + self.over_sample_threshold = over_sample_threshold + self.enable_partial_rollout = enable_partial_rollout + self.max_staleness = max_staleness + self.sync_weights_interval = sync_weights_interval + self.tail_batch_trigger_size = tail_batch_trigger_size + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval) + self._pending_tasks: set[asyncio.Task] = set() + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + return calculate_seq_staleness(model_step, train_step) >= self.stale_threshold + + async def produce_batch( + self, + ctx: ProduceContext, + ) -> ProduceBatchStatus: + """共卡 async 生产。 + + 中文不变量: + - pending 只属于本次 manager.produce_batch 调用,不跨 manager 调用保存。 + - 本函数只生产到 replay buffer,不在这里 pause/drain。 + - manager 必须等所有 task 的 produce_batch 都返回后,再调用 pause_produce 收尾 pending。 + - 不读取 update_event,不返回 UPDATE_WEIGHT_AND_ABORT。 + """ + + self._pending_tasks = set() + expired_count = await ctx.expired_count() + sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size + + # 保持当前实现语义:normal 模式只按本 task batch size 追加固定超发预算; + # tail-batch 模式只补必要缺口,并固定从 expired/aborted pool 取样。 + oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * ctx.task_batch_size) + scheduled_target = ctx.task_batch_size + oversample_budget + completed = await ctx.replay_buffer.count(ctx.task_name, Status.COMPLETED) + + async def schedule_one() -> None: + group = await ctx.sample_group(from_expired_pool=sample_from_expired) + self._pending_tasks.add( + asyncio.create_task( + ctx.generate_group( + group, + enable_partial_rollout=self.enable_partial_rollout, + ) + ) + ) + + while len(self._pending_tasks) + completed < scheduled_target: + await schedule_one() + + while self.should_continue_fn(completed, ctx.task_batch_size): + if not self._pending_tasks: + break + done, pending = await asyncio.wait(self._pending_tasks, return_when=asyncio.FIRST_COMPLETED) + self._pending_tasks = pending + for task in done: + if await ctx.put_generated_group(task.result()): + completed += 1 + + while len(self._pending_tasks) + completed < scheduled_target and self.should_continue_fn( + completed, ctx.task_batch_size + ): + await schedule_one() + + return ProduceBatchStatus.NORMAL + + async def pause_produce(self, ctx: ProduceContext) -> float: + pending_tasks = self._pending_tasks + self._pending_tasks = set() + return await pause_pending_tasks( + pending_tasks=pending_tasks, + ctx=ctx, + put_claimed_task=lambda task: ctx.put_generated_group(task.result()), + ) + + +class _LocalPendingTasks: + """把共卡本次调用的局部 set 包装成 pause helper 可使用的形状。""" + + def __init__(self, tasks: set[asyncio.Task]) -> None: + self._tasks = tasks + + def count(self) -> int: + return len(self._tasks) + + async def wait_and_claim(self, timeout_s: float) -> set[asyncio.Task]: + if not self._tasks: + return set() + done, _ = await asyncio.wait(self._tasks, timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED) + self._tasks.difference_update(done) + return done + + async def cancel_all(self) -> int: + tasks = set(self._tasks) + self._tasks.clear() + for task in tasks: + task.cancel() + return len(tasks) + + +class _PendingTasks: + """非共卡专用 pending 集合。 + + 共卡不使用它,因为共卡 pending 不跨 produce_batch 调用。 + """ + + def __init__(self) -> None: + self._tasks: set[asyncio.Task] = set() + self._lock = asyncio.Lock() + + def count(self) -> int: + return len(self._tasks) + + async def claim_ready(self) -> set[asyncio.Task]: + async with self._lock: + ready = {task for task in self._tasks if task.done()} + self._tasks.difference_update(ready) + return ready + + async def schedule_one( + self, + *, + max_pending: int, + should_abort: Callable[[], bool], + spawn_one: Callable[[], Awaitable[asyncio.Task]], + ) -> bool: + async with self._lock: + if should_abort() or len(self._tasks) >= max_pending: + return False + self._tasks.add(await spawn_one()) + return True + + async def wait_and_claim(self, timeout_s: float) -> set[asyncio.Task]: + async with self._lock: + snapshot = set(self._tasks) + if not snapshot: + return set() + done, _ = await asyncio.wait(snapshot, timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED) + async with self._lock: + claimed = done & self._tasks + self._tasks.difference_update(claimed) + return claimed + + async def cancel_all(self) -> int: + async with self._lock: + tasks = set(self._tasks) + self._tasks.clear() + for task in tasks: + task.cancel() + return len(tasks) + + +PendingTasksInput = set[asyncio.Task] | _PendingTasks + + +async def request_agent_loop_pause(ctx: BaseProduceContext, *, pending_count: int) -> None: + """发送一次 agent loop pause 请求。 + + 最新生产代码里 pause_produce 会周期性调用 agent_loop.pause(),这里把这段协议抽成全局工具函数, + 让共卡本地 pending 收尾和非共卡后台 pending drain 使用同一套超时/日志语义。 + """ + + try: + await asyncio.wait_for(ctx.agent_loop.pause(), timeout=AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S) + except asyncio.TimeoutError: + # 真实实现写 logger.warning,伪代码只保留关键上下文。 + print( + f"Agent loop pause timed out: task={ctx.task_name}, " + f"timeout_s={AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S}, pending={pending_count}" + ) + except Exception: + print(f"Agent loop pause failed: task={ctx.task_name}, pending={pending_count}") + + +async def pause_pending_tasks( + *, + pending_tasks: PendingTasksInput, + ctx: BaseProduceContext, + put_claimed_task: Callable[[asyncio.Task], Awaitable[Any]], +) -> float: + """复用当前 pause_produce 的 pending drain 协议。 + + 中文不变量: + - 先发 pause,再等待 pending 产出。 + - pending 没清空时周期性补发 pause,兼容后端 abort 信号丢失或延迟。 + - 超时后 cancel 剩余 pending,避免 checkpoint/save 前仍有任务写 buffer。 + - 已完成任务必须 claim 后再 put,避免 produce 和 pause 重复入库同一个 done task。 + """ + + pending = _LocalPendingTasks(pending_tasks) if isinstance(pending_tasks, set) else pending_tasks + pause_start = time.perf_counter() + if pending.count() == 0: + return 0.0 + + pending_pause_tasks = { + asyncio.create_task(request_agent_loop_pause(ctx, pending_count=pending.count())) + } + cleanup_start_time = time.perf_counter() + next_periodic_abort_time = cleanup_start_time + PERIODIC_ABORT_INTERVAL_S + + while True: + elapsed_time = time.perf_counter() - cleanup_start_time + if elapsed_time > PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S: + cancelled_count = await pending.cancel_all() + print( + f"Cleanup timeout reached. Forcefully cancelling {cancelled_count} " + f"remaining tasks for task={ctx.task_name}." + ) + break + + if pending.count() == 0: + break + + current_time = time.perf_counter() + pending_pause_tasks = {task for task in pending_pause_tasks if not task.done()} + if PERIODIC_ABORT_INTERVAL_S > 0 and current_time >= next_periodic_abort_time: + pending_pause_tasks.add( + asyncio.create_task(request_agent_loop_pause(ctx, pending_count=pending.count())) + ) + next_periodic_abort_time += PERIODIC_ABORT_INTERVAL_S + + claimed_done = await pending.wait_and_claim(timeout_s=1.0) + for task in claimed_done: + await put_claimed_task(task) + + for task in pending_pause_tasks: + task.cancel() + if pending_pause_tasks: + await asyncio.gather(*pending_pause_tasks, return_exceptions=True) + + return time.perf_counter() - pause_start + + +class DisaggAsyncProduceStrategy(DisaggProduceStrategy): + def __init__( + self, + *, + over_sample_threshold: float, + enable_partial_rollout: bool, + max_staleness: int, + sync_weights_interval: int, + tail_batch_trigger_size: int, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ) -> None: + self.over_sample_threshold = over_sample_threshold + self.enable_partial_rollout = enable_partial_rollout + self.max_staleness = max_staleness + self.sync_weights_interval = sync_weights_interval + self.tail_batch_trigger_size = tail_batch_trigger_size + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval) + self._pending_tasks = _PendingTasks() + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + return calculate_seq_staleness(model_step, train_step) >= self.stale_threshold + + def pending_task_count(self) -> int: + return self._pending_tasks.count() + + async def produce_batch( + self, + ctx: DisaggProduceContext, + ) -> ProduceBatchStatus: + """非共卡后台 async 生产。 + + 中文不变量: + - pending 可以跨多次 produce_batch 调用存在。 + - 每轮循环都观察 update_event 和 model expired。 + - 只负责生产到 replay buffer,不取训练 batch。 + """ + + if ctx.should_abort(): + return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + if self.is_model_expired(ctx.progress.producer_future_step, ctx.model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + await self._put_claimed(await self._pending_tasks.claim_ready(), ctx) + + expired_count = await ctx.expired_count() + sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size + + # 保持当前实现语义:normal 模式只按本 task batch size 追加固定超发预算; + # tail-batch 模式只补必要缺口,并固定从 expired/aborted pool 取样。 + target_abs = ctx.target_abs + oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * ctx.task_batch_size) + scheduled_target = target_abs + oversample_budget + + async def spawn_one() -> asyncio.Task: + group = await ctx.sample_group(from_expired_pool=sample_from_expired) + return asyncio.create_task( + ctx.generate_group( + group, + enable_partial_rollout=self.enable_partial_rollout, + ) + ) + + while True: + if ctx.should_abort(): + return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + if self.is_model_expired(ctx.progress.producer_future_step, ctx.model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + available = await ctx.available_count() + if not self.should_continue_fn(available, target_abs): + return ProduceBatchStatus.NORMAL + + desired_pending = max(0, scheduled_target - available) + while await self._pending_tasks.schedule_one( + max_pending=desired_pending, + should_abort=ctx.should_abort, + spawn_one=spawn_one, + ): + pass + + claimed = await self._pending_tasks.wait_and_claim(timeout_s=1.0) + await self._put_claimed(claimed, ctx) + + async def pause_produce( + self, + ctx: DisaggProduceContext, + ) -> float: + return await pause_pending_tasks( + pending_tasks=self._pending_tasks, + ctx=ctx, + put_claimed_task=lambda task: ctx.put_generated_group(task.result()), + ) + + async def _put_claimed( + self, + claimed: set[asyncio.Task], + ctx: BaseProduceContext, + ) -> None: + for task in claimed: + await ctx.put_generated_group(task.result()) + + +@dataclass(frozen=True) +class TaskRunner: + task_name: str + agent_loop: AgentLoop + sampler: Sampler + produce_strategy: ModeSpecificProduceStrategy + weight: float = 1.0 + order: int = 0 + + @property + def stale_threshold(self) -> int | None: + return getattr(self.produce_strategy, "stale_threshold", None) + + +class AgentLoopManagerConfig: + def __init__(self, tasks: list[Any], mode: ProducerMode = "colocate") -> None: + self.tasks = tasks + self.mode = mode + + def build( + self, + *, + rollout_controller: RolloutController, + tokenizer: Any, + replay_buffer: ReplayBuffer, + logger: Any, + sync_weights_interval: int, + ) -> "AgentLoopManager | DisaggAgentLoopManager": + mode = self.mode + runners = self._build_task_runners( + mode=mode, + rollout_controller=rollout_controller, + tokenizer=tokenizer, + replay_buffer=replay_buffer, + logger=logger, + sync_weights_interval=sync_weights_interval, + ) + if mode == "colocate": + return AgentLoopManager(runners, replay_buffer, rollout_controller, logger) + return DisaggAgentLoopManager(runners, replay_buffer, rollout_controller, logger) + + def _build_task_runners( + self, + *, + mode: ProducerMode, + rollout_controller: RolloutController, + tokenizer: Any, + replay_buffer: ReplayBuffer, + logger: Any, + sync_weights_interval: int, + ) -> list[TaskRunner]: + runners: list[TaskRunner] = [] + for task_cfg in self.tasks: + # manager mode 只选择 manager 类型;strategy 的执行环境由 config 类型表达。 + if mode == "colocate" and not isinstance( + task_cfg.produce_strategy_config, + (SyncProduceStrategyConfig, AsyncProduceStrategyConfig), + ): + raise ValueError("colocate mode expects ProduceStrategyConfig") + if mode == "disaggregated" and not isinstance( + task_cfg.produce_strategy_config, + DisaggAsyncProduceStrategyConfig, + ): + raise ValueError("disaggregated mode expects DisaggProduceStrategyConfig") + strategy = task_cfg.produce_strategy_config.build( + sync_weights_interval=sync_weights_interval, + rollout_controller=rollout_controller, + ) + runners.append( + TaskRunner( + task_name=task_cfg.task_name, + agent_loop=task_cfg.agent_loop_config.build(rollout_controller, logger), + sampler=task_cfg.sampler_config.build(tokenizer, replay_buffer), + produce_strategy=strategy, + weight=task_cfg.weight, + order=len(runners), + ) + ) + return runners + + +def allocate_task_batch_sizes( + task_runners: list[TaskRunner], + global_batch_size: int, + train_step: int, +) -> dict[str, int]: + # 真实实现沿用当前按 task weight 分配的逻辑;保持为全局 helper,避免两个 manager 继承公共父类。 + ... + + +def validate_task_batch_sizes( + task_runners: list[TaskRunner], + task_sizes: dict[str, int], + global_batch_size: int, +) -> None: + ... + + +async def refresh_for_all_tasks( + *, + task_runners: list[TaskRunner], + replay_buffer: ReplayBuffer, + train_step: int, +) -> None: + thresholds = { + task.task_name: task.stale_threshold or 1 + for task in task_runners + } + await replay_buffer.refresh_staleness( + task_stale_thresholds=thresholds, + current_train_step=train_step, + statuses=[Status.COMPLETED, Status.ABORTED], + ) + + +async def take_train_batch( + *, + task_runners: list[TaskRunner], + replay_buffer: ReplayBuffer, + task_sizes: dict[str, int], + progress: ProduceProgress | DisaggProduceProgress, +) -> ProduceBatchResult: + batch_by_task, consumed_counts = await replay_buffer.take_batch(task_sizes) + if isinstance(progress, DisaggProduceProgress): + progress.mark_consumed(consumed_counts) + + counts = await replay_buffer.count_statuses( + [task.task_name for task in task_runners], + [Status.INIT, Status.COMPLETED, Status.ABORTED, Status.EXPIRED, Status.FAILED, Status.FILTERED], + ) + return build_produce_batch_result( + task_runners=task_runners, + batch_by_task=batch_by_task, + leftover_counts=counts, + progress=progress, + ) + + +def build_produce_batch_result( + *, + task_runners: list[TaskRunner], + batch_by_task: dict[str, list[list[Any]]], + leftover_counts: dict[str, dict[Status, int]], + progress: ProduceProgress | DisaggProduceProgress, +) -> ProduceBatchResult: + # 真实实现负责 task result 聚合、timing 聚合、leftover 聚合。 + ... + + +class AgentLoopManager: + def __init__( + self, + task_runners: list[TaskRunner], + replay_buffer: ReplayBuffer, + rollout_controller: RolloutController, + logger: Any, + ) -> None: + self.task_runners = task_runners + self.replay_buffer = replay_buffer + self.rollout_controller = rollout_controller + self.logger = logger + self.task_names = [task.task_name for task in task_runners] + + def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: + return allocate_task_batch_sizes(self.task_runners, global_batch_size, train_step) + + async def produce_batch( + self, + batch_size: int, + train_step: int, + *, + model_step: int, + ) -> ProduceBatchResult: + """共卡训练唯一生产入口。 + + 中文不变量: + - 不触碰非共卡 status/update_event。 + - 所有 active task 生产结束后,再统一收尾 pending。 + - 同一 manager 实例不并发调用 produce_batch;strategy pending 是本次调用的局部状态。 + - 返回必须是非空训练 batch。 + """ + + task_sizes = ( + {self.task_runners[0].task_name: batch_size} + if len(self.task_runners) == 1 + else self.get_task_batch_sizes(batch_size, train_step) + ) + validate_task_batch_sizes(self.task_runners, task_sizes, batch_size) + progress = ProduceProgress.build( + task_names=self.task_names, + target_samples=task_sizes, + ) + active_contexts = [ + ( + task, + ProduceContext( + task_name=task.task_name, + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=task_sizes[task.task_name], + train_step=train_step, + model_step=model_step, + progress=progress, + is_valid_sample_fn=getattr(task.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn), + stale_threshold=task.stale_threshold, + ), + ) + for task in self.task_runners + if task_sizes[task.task_name] > 0 + ] + + await self.rollout_controller.continue_generation() + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + train_step=train_step, + ) + await asyncio.gather(*[task.produce_strategy.produce_batch(ctx) for task, ctx in active_contexts]) + # 共卡 multi-task 的关键顺序:所有 task 正常完成生产后,才统一 pause/drain pending。 + # 如果上面的生产抛异常,异常直接冒泡中断训练,不在 manager 内做 best-effort cleanup。 + for task, ctx in active_contexts: + await task.produce_strategy.pause_produce(ctx) + result = await take_train_batch( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + task_sizes=task_sizes, + progress=progress, + ) + await self.rollout_controller.pause_generation() + + assert result.rollout_states, "共卡 produce_batch 必须返回非空训练 batch。" + return result + + async def save(self, checkpoint_path: Path, model_step: int) -> None: + # 共卡 checkpoint 不保存 DisaggProduceProgress。 + for task in self.task_runners: + task.sampler.save(checkpoint_path / "tasks" / task.task_name) + await self.replay_buffer.save(checkpoint_path) + + async def resume(self, checkpoint_path: Path) -> int: + for task in self.task_runners: + task.sampler.resume(checkpoint_path / "tasks" / task.task_name) + await self.replay_buffer.resume(checkpoint_path) + return 0 + + +class DisaggAgentLoopManager: + def __init__( + self, + task_runners: list[TaskRunner], + replay_buffer: ReplayBuffer, + rollout_controller: RolloutController, + logger: Any, + ) -> None: + self.task_runners = task_runners + self.replay_buffer = replay_buffer + self.rollout_controller = rollout_controller + self.logger = logger + self.task_names = [task.task_name for task in task_runners] + self.status = DisaggManagerStatus.NORMAL + self.update_event = asyncio.Event() + self.finish_event = asyncio.Event() + self.model_step = 0 + self.pause_time_s = 0.0 + self.progress = DisaggProduceProgress.build(self.task_names) + + def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: + return allocate_task_batch_sizes(self.task_runners, global_batch_size, train_step) + + async def produce_loop(self, batch_size: int) -> None: + """非共卡 Background Producer。""" + + while not self.finish_event.is_set(): + if self.status == DisaggManagerStatus.FINISH: + break + if self.status in ( + DisaggManagerStatus.UPDATE_WEIGHT_AND_ABORT, + DisaggManagerStatus.EXPIRED_BATCH, + ): + await self._wait_for_status_exit(self.status) + continue + + task_sizes = self.progress.ensure_target_upto( + batch_size=batch_size, + future_step=self.progress.producer_future_step, + allocate_batch_sizes=self.get_task_batch_sizes, + ) + validate_task_batch_sizes(self.task_runners, task_sizes, batch_size) + statuses = await asyncio.gather( + *[ + task.produce_strategy.produce_batch( + DisaggProduceContext( + task_name=task.task_name, + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=task_sizes[task.task_name], + train_step=self.progress.producer_future_step, + model_step=self.model_step, + progress=self.progress, + is_valid_sample_fn=getattr( + task.produce_strategy, + "is_valid_sample_fn", + default_is_valid_sample_fn, + ), + stale_threshold=task.stale_threshold, + update_event=self.update_event, + ) + ) + for task in self.task_runners + if task_sizes[task.task_name] > 0 + ] + ) + + if ProduceBatchStatus.EXPIRED_BATCH in statuses: + self.status = DisaggManagerStatus.EXPIRED_BATCH + elif ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT in statuses: + self.status = DisaggManagerStatus.UPDATE_WEIGHT_AND_ABORT + else: + self.progress.advance_future_step() + + await asyncio.sleep(0) + + async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult: + """非共卡 Training Consumer。""" + + self.progress.begin_consume(train_step) + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + train_step=train_step, + ) + task_sizes = self.get_task_batch_sizes(batch_size, train_step) + validate_task_batch_sizes(self.task_runners, task_sizes, batch_size) + current_model_step = train_step - 1 + + while not self.finish_event.is_set(): + if self.status == DisaggManagerStatus.EXPIRED_BATCH: + if current_model_step > self.model_step: + return ProduceBatchResult([], status=ProduceBatchStatus.EXPIRED_BATCH) + if not await self.replay_buffer.is_ready(task_sizes): + raise RuntimeError("Expired Produce Batch 不能跳过,且当前训练 batch 未 ready。") + + if await self.replay_buffer.is_ready(task_sizes): + result = await take_train_batch( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + task_sizes=task_sizes, + progress=self.progress, + ) + if self.status == DisaggManagerStatus.EXPIRED_BATCH: + result.status = ProduceBatchStatus.EXPIRED_BATCH + if result.rollout_states: + self.progress.finish_consume(train_step) + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + train_step=train_step + 1, + ) + return result + + await asyncio.sleep(1.0) + + return ProduceBatchResult([]) + + async def pause_produce(self) -> float: + """非共卡权重同步前的显式暂停入口。""" + + self.update_event.set() + self.status = DisaggManagerStatus.UPDATE_WEIGHT_AND_ABORT + await self.rollout_controller.pause_generation() + + pause_time_s = 0.0 + for task in self.task_runners: + ctx = DisaggProduceContext( + task_name=task.task_name, + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=0, + train_step=self.progress.producer_future_step, + model_step=self.model_step, + progress=self.progress, + is_valid_sample_fn=getattr(task.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn), + stale_threshold=task.stale_threshold, + update_event=self.update_event, + ) + pause_time_s += await task.produce_strategy.pause_produce(ctx) + self.pause_time_s = pause_time_s + return pause_time_s + + async def continue_produce(self, model_step: int) -> None: + self.model_step = model_step + await self.rollout_controller.continue_generation() + self.status = DisaggManagerStatus.NORMAL + self.update_event.clear() + + def shutdown(self) -> None: + self.status = DisaggManagerStatus.FINISH + self.update_event.set() + self.finish_event.set() + + async def save(self, checkpoint_path: Path, model_step: int) -> None: + pending = { + task.task_name: task.produce_strategy.pending_task_count() + for task in self.task_runners + if task.produce_strategy.pending_task_count() > 0 + } + if pending: + raise RuntimeError(f"保存 checkpoint 前必须先 pause producer: {pending}") + + for task in self.task_runners: + task.sampler.save(checkpoint_path / "tasks" / task.task_name) + await self.replay_buffer.save(checkpoint_path) + self._save_manager_state(checkpoint_path, model_step) + + async def resume(self, checkpoint_path: Path) -> int: + for task in self.task_runners: + task.sampler.resume(checkpoint_path / "tasks" / task.task_name) + await self.replay_buffer.resume(checkpoint_path) + saved_model_step = self._load_manager_state(checkpoint_path) + + self.update_event = asyncio.Event() + self.finish_event = asyncio.Event() + self.update_event.set() + self.status = DisaggManagerStatus.UPDATE_WEIGHT_AND_ABORT + self.model_step = saved_model_step + return saved_model_step + + async def _wait_for_status_exit(self, blocked_status: DisaggManagerStatus) -> None: + while not self.finish_event.is_set() and self.status == blocked_status: + await asyncio.sleep(1.0) + + def _save_manager_state(self, checkpoint_path: Path, model_step: int) -> None: + ... + + def _load_manager_state(self, checkpoint_path: Path) -> int: + ... + + +class RLTrainer: + def __init__(self, cfg: Any) -> None: + cfg.agent_loop_manager_cfg.mode = "colocate" + self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( + rollout_controller=cfg.rollout_controller, + tokenizer=cfg.tokenizer, + replay_buffer=cfg.replay_buffer_config.build(), + logger=cfg.logger, + sync_weights_interval=cfg.sync_weights_interval, + ) + if cfg.eval_agent_loop_manager_cfg is not None: + cfg.eval_agent_loop_manager_cfg.mode = "colocate" + self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build(...) + + def fit(self) -> None: + for train_step in range(1, self.total_train_steps + 1): + produce_result = asyncio.run( + self.agent_loop_manager.produce_batch( + self.train_batch_size, + train_step=train_step, + model_step=self._current_rollout_model_step(train_step), + ) + ) + self._train_one_batch(produce_result.rollout_states, train_step) + + +class RLDisaggTrainer: + def __init__(self, cfg: Any) -> None: + train_replay_buffer = cfg.replay_buffer_config.build() + cfg.agent_loop_manager_cfg.mode = "disaggregated" + self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( + rollout_controller=cfg.rollout_controller, + tokenizer=cfg.tokenizer, + replay_buffer=train_replay_buffer, + logger=cfg.logger, + sync_weights_interval=cfg.sync_weights_interval, + ) + # eval 是一次性同步 produce_batch,不应构建成后台 manager。 + if cfg.eval_agent_loop_manager_cfg is not None: + cfg.eval_agent_loop_manager_cfg.mode = "colocate" + self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build( + rollout_controller=cfg.rollout_controller, + tokenizer=cfg.tokenizer, + replay_buffer=train_replay_buffer, + logger=cfg.logger, + sync_weights_interval=cfg.sync_weights_interval, + ) + + async def _fit(self) -> None: + producer_task = asyncio.create_task( + self.agent_loop_manager.produce_loop(batch_size=self.train_batch_size) + ) + train_step = self.cur_step + 1 + while train_step <= self.total_train_steps: + get_batch_task = asyncio.create_task( + self.agent_loop_manager.get_batch( + self.train_batch_size, + train_step=train_step, + ) + ) + # 非共卡 fail-fast:consumer 等 batch 时必须同时观察后台 producer。 + # producer 异常不是业务 status,直接通过 result() 暴露原始异常栈。 + done, _ = await asyncio.wait( + {producer_task, get_batch_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + if producer_task in done: + producer_task.result() + raise RuntimeError("非共卡后台 producer 在训练结束前退出。") + produce_result = get_batch_task.result() + + empty_expired = ( + produce_result.status == ProduceBatchStatus.EXPIRED_BATCH + and not produce_result.rollout_states + ) + if not empty_expired: + self._train_one_batch(produce_result.rollout_states, train_step) + sync_model_step = train_step + else: + sync_model_step = train_step - 1 + + if self._need_sync(sync_model_step, produce_result): + await self.agent_loop_manager.pause_produce() + await self._sync_weights_and_save(sync_model_step) + await self.agent_loop_manager.continue_produce(model_step=sync_model_step) + + if empty_expired: + continue + self.cur_step = train_step + train_step += 1 + + self.agent_loop_manager.shutdown() + await producer_task diff --git a/examples/v1/config/rl_disagg_multi.py b/examples/v1/config/rl_disagg_multi.py index fb1c30bd5a..394baf0b0f 100644 --- a/examples/v1/config/rl_disagg_multi.py +++ b/examples/v1/config/rl_disagg_multi.py @@ -35,15 +35,16 @@ from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig from xtuner.v1.rl.agent_loop_manager import ( AgentLoopManagerConfig, - AsyncProduceStrategyConfig, + DisaggAsyncProduceStrategyConfig, + DisaggAgentLoopManagerConfig, + DisaggTaskSpecConfig, SamplerConfig, - SyncProduceStrategyConfig, TaskSpecConfig, ) from xtuner.v1.rl.evaluator import EvaluatorConfig from xtuner.v1.rl.judger import DapoMathJudgerConfig from xtuner.v1.rl.loss import GRPOLossConfig -from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig from xtuner.v1.rl.rollout.worker import RolloutConfig from xtuner.v1.rl.trainer import WorkerConfig from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token @@ -221,19 +222,17 @@ ), ) -if over_sample_threshold > 0 or partial_rollout: - produce_strategy_config = AsyncProduceStrategyConfig( - over_sample_threshold=over_sample_threshold, - enable_partial_rollout=partial_rollout, - tail_batch_trigger_size=tail_batch_trigger_size, - max_staleness=max_staleness, - ) -else: - produce_strategy_config = SyncProduceStrategyConfig() - -agent_loop_manager_cfg = AgentLoopManagerConfig( +# 非共卡后台 producer 使用独立的 Disagg* config,不复用共卡 AsyncProduceStrategyConfig。 +produce_strategy_config = DisaggAsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=partial_rollout, + tail_batch_trigger_size=tail_batch_trigger_size, + max_staleness=max_staleness, +) + +agent_loop_manager_cfg = DisaggAgentLoopManagerConfig( tasks=[ - TaskSpecConfig( + DisaggTaskSpecConfig( task_name="train_task:dapo_math", weight=dapo_task_weight, agent_loop_config=dapo_train_agent_loop_config, @@ -241,7 +240,7 @@ produce_strategy_config=produce_strategy_config, sampler_config=dapo_train_sampler_config, ), - TaskSpecConfig( + DisaggTaskSpecConfig( task_name="train_task:gsm8k", weight=gsm8k_task_weight, agent_loop_config=gsm8k_train_agent_loop_config, @@ -335,7 +334,7 @@ def compute_metric(samples): train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, tokenizer_path=model_path, - replay_buffer_config=SyncReplayBufferConfig(), + replay_buffer_config=AsyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, evaluator_config=evaluator_config, diff --git a/examples/v1/config/rl_disagg_single.py b/examples/v1/config/rl_disagg_single.py index 23d7aedbcd..6ef2250e37 100644 --- a/examples/v1/config/rl_disagg_single.py +++ b/examples/v1/config/rl_disagg_single.py @@ -45,15 +45,16 @@ from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig from xtuner.v1.rl.agent_loop_manager import ( AgentLoopManagerConfig, - AsyncProduceStrategyConfig, + DisaggAsyncProduceStrategyConfig, + DisaggAgentLoopManagerConfig, + DisaggTaskSpecConfig, SamplerConfig, - SyncProduceStrategyConfig, TaskSpecConfig, ) from xtuner.v1.rl.evaluator import EvaluatorConfig from xtuner.v1.rl.judger import GSM8KJudgerConfig from xtuner.v1.rl.loss import GRPOLossConfig -from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig from xtuner.v1.rl.rollout.worker import RolloutConfig from xtuner.v1.rl.trainer import WorkerConfig from xtuner.v1.rl.utils import AcceleratorResourcesConfig @@ -193,17 +194,15 @@ hf_checkpoint=model_path, sample_params=training_sample_params, ) -if over_sample_threshold > 0 or partial_rollout: - produce_strategy_config = AsyncProduceStrategyConfig( - over_sample_threshold=over_sample_threshold, - enable_partial_rollout=partial_rollout, - tail_batch_trigger_size=tail_batch_trigger_size, - max_staleness=max_staleness, - ) -else: - produce_strategy_config = SyncProduceStrategyConfig() -agent_loop_manager_cfg = AgentLoopManagerConfig( - tasks=TaskSpecConfig( +# 非共卡后台 producer 使用独立的 Disagg* config,不复用共卡 AsyncProduceStrategyConfig。 +produce_strategy_config = DisaggAsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=partial_rollout, + tail_batch_trigger_size=tail_batch_trigger_size, + max_staleness=max_staleness, +) +agent_loop_manager_cfg = DisaggAgentLoopManagerConfig( + tasks=DisaggTaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, judger_config=judger_config, @@ -258,7 +257,7 @@ train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, tokenizer_path=model_path, - replay_buffer_config=SyncReplayBufferConfig(), + replay_buffer_config=AsyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, evaluator_config=evaluator_config, diff --git a/tests/rl/fast/pr_fast/test_multi_task_agent_loop_manager.py b/tests/rl/fast/pr_fast/test_multi_task_agent_loop_manager.py index cfc3ef9cb9..6a3191d218 100644 --- a/tests/rl/fast/pr_fast/test_multi_task_agent_loop_manager.py +++ b/tests/rl/fast/pr_fast/test_multi_task_agent_loop_manager.py @@ -2,14 +2,14 @@ 本文件从旧的 test_multi_task_agent_loop_manager.py 迁入共卡路径测试: - produce_batch 按 task 权重分配 batch,并按 task 名稳定返回训练数据。 -- 自定义 get_task_batch_sizes 可以禁用某些 task。 - produce_batch 会汇总 producer 收尾耗时和 group 生成耗时。 - 共卡 produce_batch 必须返回非空训练 batch。 -- 多 task 中任一 task 返回 UPDATE_WEIGHT_AND_ABORT 时,public 结果状态会体现该中断。 +- 共卡 produce_batch 不聚合 producer status,public 状态保持 NORMAL。 非共卡 get_batch / produce_loop / pause_continue 编排暂不迁入 PR-fast。 """ +import asyncio import unittest from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -19,9 +19,13 @@ AgentLoopManager, AgentLoopManagerConfig, TaskSpecConfig, +) +from xtuner.v1.rl.agent_loop_manager.disagg_agent_loop_manager import DisaggAgentLoopManager +from xtuner.v1.rl.agent_loop_manager.produce_utils import ( + GROUP_GENERATE_TIME_KEY, + ProduceBatchStatus, _TaskRunner, ) -from xtuner.v1.rl.agent_loop_manager.producer import GROUP_GENERATE_TIME_KEY, ProduceBatchStatus class _FakeSampler: @@ -35,31 +39,30 @@ def __len__(self) -> int: class _FakeProduceStrategy: def __init__( self, - status: ProduceBatchStatus = ProduceBatchStatus.NORMAL, cleanup_pause_time_s: float = 0.0, stale_threshold: int = 1, ): - self.status = status self.cleanup_pause_time_s = cleanup_pause_time_s self.stale_threshold = stale_threshold self.called_batch_sizes: list[int] = [] self.called_train_steps: list[int] = [] self.called_model_steps: list[int] = [] - self.called_update_events: list[object | None] = [] - self.called_update_event_states: list[bool | None] = [] self.called_progresses: list[object] = [] self.cleanup_model_steps: list[int] = [] self.cleanup_progresses: list[object | None] = [] self.cleanup_call_count = 0 - async def produce_batch(self, ctx) -> ProduceBatchStatus: + async def produce_batch(self, ctx) -> None: self.called_batch_sizes.append(ctx.task_batch_size) self.called_train_steps.append(ctx.train_step) self.called_model_steps.append(ctx.model_step) - self.called_update_events.append(ctx.update_event) - self.called_update_event_states.append(None if ctx.update_event is None else ctx.update_event.is_set()) + self.assert_colocate_context(ctx) self.called_progresses.append(ctx.progress) - return self.status + + def assert_colocate_context(self, ctx) -> None: + for disagg_only_name in ("update_event", "available_count", "total_target"): + if hasattr(ctx, disagg_only_name): + raise AssertionError(f"colocate ProduceContext should not expose {disagg_only_name}") async def pause_produce(self, ctx) -> float: self.cleanup_call_count += 1 @@ -67,30 +70,24 @@ async def pause_produce(self, ctx) -> float: self.cleanup_progresses.append(ctx.progress) return self.cleanup_pause_time_s - def is_model_expired(self, train_step: int, model_step: int) -> bool: - return False - -class _FakeStatusProduceStrategy: - def __init__(self, status: ProduceBatchStatus, pause_time_s: float): - self.status = status +class _FakeTimedProduceStrategy: + def __init__(self, pause_time_s: float): self.pause_time_s = pause_time_s self.cleanup_call_count = 0 self.called_train_steps: list[int] = [] self.called_model_steps: list[int] = [] - self.called_update_events: list[object | None] = [] - self.called_update_event_states: list[bool | None] = [] self.called_progresses: list[object] = [] self.cleanup_model_steps: list[int] = [] self.cleanup_progresses: list[object | None] = [] - async def produce_batch(self, ctx) -> ProduceBatchStatus: + async def produce_batch(self, ctx) -> None: self.called_train_steps.append(ctx.train_step) self.called_model_steps.append(ctx.model_step) - self.called_update_events.append(ctx.update_event) - self.called_update_event_states.append(None if ctx.update_event is None else ctx.update_event.is_set()) + for disagg_only_name in ("update_event", "available_count", "total_target"): + if hasattr(ctx, disagg_only_name): + raise AssertionError(f"colocate ProduceContext should not expose {disagg_only_name}") self.called_progresses.append(ctx.progress) - return self.status async def pause_produce(self, ctx) -> float: self.cleanup_call_count += 1 @@ -98,8 +95,13 @@ async def pause_produce(self, ctx) -> float: self.cleanup_progresses.append(ctx.progress) return self.pause_time_s - def is_model_expired(self, train_step: int, model_step: int) -> bool: - return False + +class _FailingProduceStrategy: + async def produce_batch(self, ctx) -> None: + raise RuntimeError("original produce failure") + + async def pause_produce(self, ctx) -> float: + raise RuntimeError("cleanup failure") class _FakeRolloutState: @@ -171,6 +173,14 @@ def _fake_agent_loop(): return agent_loop +def _fake_rollout_controller(): + rollout_controller = MagicMock() + rollout_controller.continue_generation.remote = AsyncMock() + rollout_controller.pause_generation.remote = AsyncMock() + rollout_controller.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + return rollout_controller + + class TestMultiTaskAgentLoopManager(unittest.IsolatedAsyncioTestCase): def test_manager_config_accepts_single_task_spec(self): # 单 task 配置可以直接传入,兼容最小 AgentLoopManager 配置。 @@ -231,6 +241,7 @@ async def test_produce_batch_allocates_by_weight_and_returns_task_sorted_results ), ], replay_buffer=replay_buffer, + rollout_controller=_fake_rollout_controller(), ) result = await multi_task_manager.produce_batch(batch_size=7, train_step=3, model_step=2) @@ -247,56 +258,49 @@ async def test_produce_batch_allocates_by_weight_and_returns_task_sorted_results self.assertIn("task_b", result.task_results) self.assertIn("task_c", result.task_results) - async def test_custom_get_task_batch_sizes_can_disable_tasks(self): - # 自定义 task batch size 可以禁用某个 task,训练 batch 只从启用 task 取数。 - strategy_a = _FakeProduceStrategy() - strategy_b = _FakeProduceStrategy() + async def test_disagg_get_batch_aggregates_multi_task_results_without_colocate_surface(self): + # 非共卡 get_batch 使用后台 progress 消费 replay buffer,不依赖共卡 produce_batch 继承面。 replay_buffer = _FakeReplayBuffer( rollout_states_by_task={ - "task_a": [["a-0"]], - "task_b": [["b-0"], ["b-1"]], + "task_a": [["a-0"], ["a-1"]], + "task_b": [["b-0"]], + }, + leftover_counts={ + ("task_a", Status.COMPLETED): 2, + ("task_b", Status.COMPLETED): 1, }, - leftover_counts={}, ) - - class _CustomBatchManager(AgentLoopManager): - def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: - self.observed_train_step = train_step - return {"task_a": 0, "task_b": global_batch_size} - - multi_task_manager = _CustomBatchManager( + manager = DisaggAgentLoopManager( task_runners=[ _TaskRunner( task_name="task_a", agent_loop=_fake_agent_loop(), - produce_strategy=strategy_a, + produce_strategy=_FakeProduceStrategy(), sampler=_FakeSampler(), - weight=1.0, + weight=2.0, order=0, ), _TaskRunner( task_name="task_b", agent_loop=_fake_agent_loop(), - produce_strategy=strategy_b, + produce_strategy=_FakeProduceStrategy(), sampler=_FakeSampler(), weight=1.0, order=1, ), ], replay_buffer=replay_buffer, + rollout_controller=_fake_rollout_controller(), ) - result = await multi_task_manager.produce_batch(batch_size=2, train_step=9, model_step=8) + result = await manager.get_batch(batch_size=3, train_step=2) - self.assertEqual(multi_task_manager.observed_train_step, 9) - self.assertEqual(result.task_batch_sizes, {"task_a": 0, "task_b": 2}) - self.assertEqual(strategy_a.called_batch_sizes, []) - self.assertEqual(strategy_b.called_batch_sizes, [2]) - self.assertEqual(result.rollout_states, [["b-0"], ["b-1"]]) + self.assertEqual(result.rollout_states, [["a-0"], ["a-1"], ["b-0"]]) + self.assertEqual(result.task_batch_sizes, {"task_a": 2, "task_b": 1}) - async def test_status_returning_strategy_uses_cleanup_and_reconstructs_group_timing_stats(self): + async def test_produce_batch_uses_cleanup_and_reconstructs_group_timing_stats(self): # 共卡 produce_batch 会把 producer 收尾耗时和 rollout group 生成耗时汇总到结果中。 - strategy = _FakeStatusProduceStrategy(status=ProduceBatchStatus.NORMAL, pause_time_s=1.25) + strategy = _FakeTimedProduceStrategy(pause_time_s=1.25) replay_buffer = _FakeReplayBuffer( rollout_states_by_task={ "task_a": [ @@ -319,6 +323,7 @@ async def test_status_returning_strategy_uses_cleanup_and_reconstructs_group_tim ), ], replay_buffer=replay_buffer, + rollout_controller=_fake_rollout_controller(), ) result = await manager.produce_batch(batch_size=2, train_step=7, model_step=6) @@ -336,26 +341,27 @@ async def test_produce_batch_requires_non_empty_rollout_states(self): _TaskRunner( task_name="task_a", agent_loop=_fake_agent_loop(), - produce_strategy=_FakeStatusProduceStrategy(status=ProduceBatchStatus.NORMAL, pause_time_s=0.0), + produce_strategy=_FakeTimedProduceStrategy(pause_time_s=0.0), sampler=_FakeSampler(), weight=1.0, order=0, ), ], replay_buffer=_FakeReplayBuffer(rollout_states_by_task={}, leftover_counts={}), + rollout_controller=_fake_rollout_controller(), ) with self.assertRaisesRegex(AssertionError, "must return non-empty rollout_states"): await manager.produce_batch(batch_size=1, train_step=3, model_step=2) - async def test_produce_batch_returns_update_abort_when_any_task_requests_abort(self): - # 多 task 共卡生产时,任一 task 返回 UPDATE_WEIGHT_AND_ABORT 会体现在 public 结果状态中。 + async def test_produce_batch_status_stays_normal_for_colocate_flow(self): + # 共卡生产不再消费 producer status;只要训练 batch 非空,public 状态保持 NORMAL。 manager = AgentLoopManager( task_runners=[ _TaskRunner( task_name="task_a", agent_loop=_fake_agent_loop(), - produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.NORMAL), + produce_strategy=_FakeProduceStrategy(), sampler=_FakeSampler(), weight=1.0, order=0, @@ -363,7 +369,7 @@ async def test_produce_batch_returns_update_abort_when_any_task_requests_abort(s _TaskRunner( task_name="task_b", agent_loop=_fake_agent_loop(), - produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.EXPIRED_BATCH), + produce_strategy=_FakeProduceStrategy(), sampler=_FakeSampler(), weight=1.0, order=1, @@ -371,7 +377,7 @@ async def test_produce_batch_returns_update_abort_when_any_task_requests_abort(s _TaskRunner( task_name="task_c", agent_loop=_fake_agent_loop(), - produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT), + produce_strategy=_FakeProduceStrategy(), sampler=_FakeSampler(), weight=1.0, order=2, @@ -385,9 +391,87 @@ async def test_produce_batch_returns_update_abort_when_any_task_requests_abort(s }, leftover_counts={}, ), + rollout_controller=_fake_rollout_controller(), ) result = await manager.produce_batch(batch_size=3, train_step=6, model_step=5) - self.assertEqual(result.status, ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT) + self.assertEqual(result.status, ProduceBatchStatus.NORMAL) self.assertEqual(result.rollout_states, [["a-0"], ["b-0"], ["c-0"]]) + + async def test_produce_batch_waits_all_tasks_before_any_pause(self): + # fast task 不能在 slow task 仍在生产时提前 pause rollout worker。 + events: list[str] = [] + fast_done = asyncio.Event() + test_case = self + + class _FastStrategy(_FakeProduceStrategy): + async def produce_batch(self, ctx) -> None: + events.append("fast_produce_done") + fast_done.set() + + async def pause_produce(self, ctx) -> float: + test_case.assertIn("slow_produce_done", events) + events.append("fast_pause") + return 0.0 + + class _SlowStrategy(_FakeProduceStrategy): + async def produce_batch(self, ctx) -> None: + await fast_done.wait() + events.append("slow_produce_done") + + async def pause_produce(self, ctx) -> float: + test_case.assertIn("slow_produce_done", events) + events.append("slow_pause") + return 0.0 + + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_fast", + agent_loop=_fake_agent_loop(), + produce_strategy=_FastStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_slow", + agent_loop=_fake_agent_loop(), + produce_strategy=_SlowStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=1, + ), + ], + replay_buffer=_FakeReplayBuffer( + rollout_states_by_task={"task_fast": [["fast-0"]], "task_slow": [["slow-0"]]}, + leftover_counts={}, + ), + rollout_controller=_fake_rollout_controller(), + ) + + result = await manager.produce_batch(batch_size=2, train_step=4, model_step=3) + + self.assertEqual(result.rollout_states, [["fast-0"], ["slow-0"]]) + self.assertEqual(events[:2], ["fast_produce_done", "slow_produce_done"]) + + async def test_produce_batch_preserves_original_terminal_exception(self): + # 终止性生产异常应直接暴露给 trainer,保留生产调用处的调试堆栈。 + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FailingProduceStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer(rollout_states_by_task={}, leftover_counts={}), + rollout_controller=_fake_rollout_controller(), + ) + + with self.assertRaisesRegex(RuntimeError, "original produce failure"): + await manager.produce_batch(batch_size=1, train_step=3, model_step=2) diff --git a/tests/rl/fast/pr_fast/test_pending_tasks.py b/tests/rl/fast/pr_fast/test_pending_tasks.py index 465ffeef76..60e634a2e7 100644 --- a/tests/rl/fast/pr_fast/test_pending_tasks.py +++ b/tests/rl/fast/pr_fast/test_pending_tasks.py @@ -18,11 +18,49 @@ import asyncio import unittest +from unittest.mock import AsyncMock, MagicMock -from xtuner.v1.rl.agent_loop_manager.producer import _PendingTasks +from xtuner.v1.rl.agent_loop_manager.produce_utils import _PendingTasks, pause_pending_tasks class TestPendingTasks(unittest.IsolatedAsyncioTestCase): + async def test_pause_pending_tasks_drains_local_set_and_pending_container(self): + # 验证共卡本地 set 和非共卡 _PendingTasks 都复用同一套 pause/drain 协议。 + for pending_tasks in (set(), _PendingTasks()): + ctx = MagicMock() + ctx.task_name = "test_pause_helper" + ctx.agent_loop.pause = AsyncMock(return_value=None) + claimed_results: list[str] = [] + + async def done(): + return "done" + + if isinstance(pending_tasks, set): + pending_tasks.add(asyncio.create_task(done())) + else: + async def spawn_one(): + return asyncio.create_task(done()) + + await pending_tasks.schedule_one( + max_pending=1, + should_abort=lambda: False, + spawn_one=spawn_one, + ) + + async def put_claimed_task(task: asyncio.Task) -> None: + claimed_results.append(task.result()) + + pause_time_s = await pause_pending_tasks( + pending_tasks=pending_tasks, + ctx=ctx, + put_claimed_task=put_claimed_task, + ) + + self.assertGreaterEqual(pause_time_s, 0.0) + self.assertEqual(claimed_results, ["done"]) + self.assertEqual(len(pending_tasks) if isinstance(pending_tasks, set) else pending_tasks.count(), 0) + self.assertGreaterEqual(ctx.agent_loop.pause.await_count, 1) + async def test_claim_ready_returns_each_done_task_once(self): # 验证 done task 被 claim 后会从 pending 集合移除,避免 producer/pause 重复处理同一结果。 pending_tasks = _PendingTasks() diff --git a/tests/rl/fast/pr_fast/test_produce_progress.py b/tests/rl/fast/pr_fast/test_produce_progress.py index c720debede..651f05b3ba 100644 --- a/tests/rl/fast/pr_fast/test_produce_progress.py +++ b/tests/rl/fast/pr_fast/test_produce_progress.py @@ -1,7 +1,7 @@ -"""ProduceProgress 的深模块契约测试。 +"""ProduceProgress / DisaggProduceProgress 的深模块契约测试。 Good Tests: -- 通过 ProduceProgress 的公开构造和方法验证“累计 target / consumed / metrics”的领域行为。 +- 通过公开构造和方法验证共卡本地窗口、非共卡累计 target / consumed / metrics 的领域行为。 - 测试描述 progress 对 producer/consumer 可见的行为结果,而不是 manager 或 strategy 如何调用它。 - 这些测试应在字段内部实现调整后仍保持稳定,只要公开方法语义不变。 @@ -10,21 +10,21 @@ - 不通过 mock manager/strategy 调用顺序来间接证明 progress 行为。 本文件主要覆盖的 public 行为: -- global progress 按 future step 维护绝对累计 target。 -- local progress 表达单次共卡 produce_batch,不污染 global progress。 +- DisaggProduceProgress 按 future step 维护绝对累计 target。 +- ProduceProgress 表达单次共卡 produce_batch,不污染 disagg progress。 - consumer 按实际取出的 Rollout Group 推进 consumed 和 next consumer step。 - state_dict/load_state_dict 原地恢复状态,metrics 被读取一次后清零。 """ import unittest -from xtuner.v1.rl.agent_loop_manager import ProduceProgress +from xtuner.v1.rl.agent_loop_manager import DisaggProduceProgress, ProduceProgress class TestProduceProgress(unittest.TestCase): def test_global_progress_accumulates_absolute_targets_by_future_step(self): - # 验证 global progress 按 future step 累计绝对 target,而不是只记录当前 batch 缺口。 - progress = ProduceProgress.build(["task_a", "task_b"]) + # 验证非共卡 progress 按 future step 累计绝对 target,而不是只记录当前 batch 缺口。 + progress = DisaggProduceProgress.build(["task_a", "task_b"]) def allocate(batch_size: int, step: int) -> dict[str, int]: self.assertEqual(batch_size, 4) @@ -37,12 +37,13 @@ def allocate(batch_size: int, step: int) -> dict[str, int]: ) self.assertEqual(current_sizes, {"task_a": 2, "task_b": 2}) + self.assertIsInstance(progress, DisaggProduceProgress) self.assertEqual(progress.target_samples, {"task_a": 3, "task_b": 5}) self.assertEqual(progress.target_upto_future_step, 2) def test_consumption_records_actual_taken_groups_and_next_consumer_step(self): # 验证 consumer 只按实际取出的 rollout group 数更新 consumed,并推进下一消费 step。 - progress = ProduceProgress.build(["task_a", "task_b"]) + progress = DisaggProduceProgress.build(["task_a", "task_b"]) progress.begin_consume(2) progress.mark_consumed({"task_a": 1, "task_b": 2}) @@ -53,7 +54,7 @@ def test_consumption_records_actual_taken_groups_and_next_consumer_step(self): def test_producer_future_step_advances_independently_from_consumer_step(self): # 验证 producer future step 是独立的生产进度,不会被 consumer step 更新隐式推进。 - progress = ProduceProgress.build(["task_a"]) + progress = DisaggProduceProgress.build(["task_a"]) progress.begin_consume(5) progress.finish_consume(5) @@ -64,23 +65,29 @@ def test_producer_future_step_advances_independently_from_consumer_step(self): def test_local_progress_keeps_global_window_untouched(self): # 验证共卡 local progress 只表达本次 produce_batch,不污染非共卡 global progress。 - global_progress = ProduceProgress.build(["task_a", "task_b"]) - global_progress.ensure_target_upto( + disagg_progress = DisaggProduceProgress.build(["task_a", "task_b"]) + disagg_progress.ensure_target_upto( batch_size=4, future_step=2, allocate_batch_sizes=lambda batch_size, step: {"task_a": step, "task_b": batch_size - step}, ) - local_progress = ProduceProgress.build_local(["task_a", "task_b"], {"task_a": 1, "task_b": 3}, 7) + local_progress = ProduceProgress.build( + task_names=["task_a", "task_b"], + target_samples={"task_a": 1, "task_b": 3}, + ) - self.assertEqual(local_progress.next_consumer_step, 7) - self.assertEqual(local_progress.producer_future_step, 7) self.assertEqual(local_progress.target_samples, {"task_a": 1, "task_b": 3}) - self.assertEqual(global_progress.target_samples, {"task_a": 3, "task_b": 5}) + self.assertEqual(disagg_progress.target_samples, {"task_a": 3, "task_b": 5}) + self.assertFalse(hasattr(local_progress, "producer_future_step")) + self.assertFalse(hasattr(local_progress, "next_consumer_step")) + self.assertFalse(hasattr(local_progress, "consumed_samples")) + self.assertFalse(hasattr(local_progress, "target_upto_future_step")) + self.assertFalse(hasattr(local_progress, "state_dict")) def test_load_state_dict_updates_existing_dicts_in_place(self): # 验证 resume/load 原地更新 dict,避免 strategy 或 context 持有的旧引用失效。 - progress = ProduceProgress.build(["task_a", "task_b"]) + progress = DisaggProduceProgress.build(["task_a", "task_b"]) consumed_ref = progress.consumed_samples target_ref = progress.target_samples raw_sum_ref = progress.raw_rewards_sum @@ -112,7 +119,7 @@ def test_load_state_dict_updates_existing_dicts_in_place(self): def test_metrics_are_consumed_once_and_reset(self): # 验证 producer 侧统计被 trainer 读取一次后清零,避免后续 step 重复上报。 - progress = ProduceProgress.build(["task_a"]) + progress = DisaggProduceProgress.build(["task_a"]) progress.add_raw_rewards("task_a", 1.25, 2) progress.add_produced("task_a", samples=3, tokens=30) progress.add_produce_time(0.5) diff --git a/tests/rl/fast/pr_fast/test_producer.py b/tests/rl/fast/pr_fast/test_producer.py index 5e5fde27df..af290328ae 100644 --- a/tests/rl/fast/pr_fast/test_producer.py +++ b/tests/rl/fast/pr_fast/test_producer.py @@ -13,7 +13,8 @@ 本文件主要覆盖的 public 行为: - sampler 优先复用可重试 Rollout Group,耗尽后回退 dataloader。 - ProduceContext 统一处理生成结果落库、过滤、raw reward 和模型版本记录。 -- SyncProduceStrategy / AsyncProduceStrategy 返回 NORMAL、UPDATE_WEIGHT_AND_ABORT、EXPIRED_BATCH 的行为。 +- SyncProduceStrategy / AsyncProduceStrategy 通过共卡入口完成生产,不返回状态控制信号。 +- DisaggAsyncProduceStrategy 返回 UPDATE_WEIGHT_AND_ABORT、EXPIRED_BATCH 的后台生产状态。 - AsyncProduceStrategy 的 oversample、tail-batch、partial rollout、pause drain 和 staleness 结果。 """ @@ -24,6 +25,9 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.rl.agent_loop_manager import ( AsyncProduceStrategyConfig, + DisaggAsyncProduceStrategyConfig, + DisaggProduceContext, + DisaggProduceProgress, ProduceBatchStatus, ProduceContext, ProduceProgress, @@ -81,20 +85,15 @@ def _build_progress( self, task_name: str, target: int, - train_step: int = 0, consumed: int = 0, - producer_future_step: int | None = None, target_upto_future_step: int | None = None, ) -> ProduceProgress: - progress = ProduceProgress.build([task_name]) - progress.next_consumer_step = train_step - progress.producer_future_step = producer_future_step if producer_future_step is not None else train_step - progress.consumed_samples[task_name] = consumed - progress.target_samples[task_name] = target - progress.target_upto_future_step = ( - target_upto_future_step if target_upto_future_step is not None else train_step + if consumed != 0 or target_upto_future_step is not None: + raise ValueError("Use _build_disagg_progress for absolute consumed/target progress.") + return ProduceProgress.build( + task_names=[task_name], + target_samples={task_name: target}, ) - return progress def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None): mock_agent_loop = MagicMock() @@ -130,14 +129,60 @@ def _build_context( train_step: int = 0, model_step: int = 0, progress: ProduceProgress | None = None, - update_event: asyncio.Event | None = None, ) -> ProduceContext: # 测试只走新的 ProduceContext 入口,不再覆盖旧散装参数兼容逻辑。 if progress is None: - progress = self._build_progress(task_name, target=batch_size, train_step=train_step) + progress = self._build_progress(task_name, target=batch_size) + return ProduceContext( + agent_loop=agent_loop, + sampler=sampler, + replay_buffer=self.replay_buffer, + task_batch_size=batch_size, + task_name=task_name, + train_step=train_step, + model_step=model_step, + progress=progress, + is_valid_sample_fn=strategy.is_valid_sample_fn, + stale_threshold=getattr(strategy, "stale_threshold", None), + ) + + def _build_disagg_progress( + self, + task_name: str, + target: int, + train_step: int = 0, + consumed: int = 0, + producer_future_step: int | None = None, + target_upto_future_step: int | None = None, + ) -> DisaggProduceProgress: + progress = DisaggProduceProgress.build([task_name]) + progress.next_consumer_step = train_step + progress.producer_future_step = producer_future_step if producer_future_step is not None else train_step + progress.consumed_samples[task_name] = consumed + progress.target_samples[task_name] = target + progress.target_upto_future_step = ( + target_upto_future_step if target_upto_future_step is not None else train_step + ) + return progress + + def _build_disagg_context( + self, + strategy, + task_name: str, + agent_loop, + sampler, + *, + batch_size: int, + train_step: int = 0, + model_step: int = 0, + progress: DisaggProduceProgress | None = None, + update_event: asyncio.Event | None = None, + ) -> DisaggProduceContext: + if progress is None: + progress = self._build_disagg_progress(task_name, target=batch_size, train_step=train_step) if update_event is None: update_event = asyncio.Event() - return ProduceContext( + return DisaggProduceContext( agent_loop=agent_loop, sampler=sampler, replay_buffer=self.replay_buffer, @@ -151,6 +196,49 @@ def _build_context( stale_threshold=getattr(strategy, "stale_threshold", None), ) + async def test_contexts_keep_colocate_and_disagg_control_surface_separate(self): + # 共卡 context 只表达一次本地生产窗口;update_event / abort / 绝对累计进度只属于非共卡 context。 + task_name = "test_context_surface" + sampler = self._build_sampler() + agent_loop = self._build_agent_loop() + + colocate_strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + colocate_ctx = self._build_context( + colocate_strategy, + task_name, + agent_loop, + sampler, + batch_size=1, + train_step=3, + model_step=2, + progress=ProduceProgress.build( + task_names=[task_name], + target_samples={task_name: 1}, + ), + ) + self.assertEqual(colocate_ctx.batch_target, 1) + for disagg_only_name in ("update_event", "should_abort", "available_count", "total_target"): + self.assertFalse(hasattr(colocate_ctx, disagg_only_name), disagg_only_name) + + disagg_strategy = DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + update_event = asyncio.Event() + disagg_ctx = self._build_disagg_context( + disagg_strategy, + task_name, + agent_loop, + sampler, + batch_size=1, + train_step=3, + model_step=2, + progress=self._build_disagg_progress(task_name, target=2, train_step=3, consumed=1), + update_event=update_event, + ) + self.assertEqual(disagg_ctx.total_target, 2) + self.assertEqual(await disagg_ctx.available_count(), 1) + self.assertFalse(disagg_ctx.should_abort()) + update_event.set() + self.assertTrue(disagg_ctx.should_abort()) + async def test_sampler_with_replay_buffer(self): # 验证 sampler 优先复用 replay buffer 中可重试的 rollout group,耗尽后回退 dataloader。 task_name = "test_task" @@ -250,10 +338,9 @@ async def test_sync_produce_strategy(self): batch_size=2, train_step=4, model_step=3, - progress=self._build_progress(task_name, target=2, train_step=4), + progress=self._build_progress(task_name, target=2), ) - status = await strategy.produce_batch(ctx) - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) @@ -261,6 +348,48 @@ async def test_sync_produce_strategy(self): self.assertEqual(final_data[0][0].message_uid, 0) self.assertEqual(final_data[1][0].message_uid, 1) + async def test_sync_produce_strategy_refills_after_filtered_and_aborted_groups(self): + # 验证 filtered / aborted group 不占用 completed quota,sync producer 会继续补齐训练 batch。 + task_name = "test_sync_refill" + + def is_valid_sample_fn(samples): + return samples[0].message_uid != 0 + + async def mock_gen(rs, **kwargs): + for r in rs: + if r.message_uid == 1: + r.status = Status.ABORTED + r.response = "" + r.response_ids = [] + else: + r.status = Status.COMPLETED + r.response = "ok" + r.response_ids = [1, 2] + r.reward = {"score": 1.0} + return rs + + mock_agent_loop = self._build_agent_loop() + mock_agent_loop.generate_group = mock_gen + strategy = SyncProduceStrategyConfig(is_valid_sample_fn=is_valid_sample_fn).build() + sampler = self._build_sampler() + ctx = self._build_context( + strategy, + task_name, + mock_agent_loop, + sampler, + batch_size=2, + train_step=4, + model_step=3, + progress=self._build_progress(task_name, target=2), + ) + + await strategy.produce_batch(ctx) + completed = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(len(completed), 2) + self.assertEqual(sorted(group[0].message_uid for group in completed), [2, 3]) + self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1) + self.assertEqual(await self.replay_buffer.count(task_name, Status.ABORTED), 1) + async def test_async_produce_strategy_oversamples_and_retries_aborted_groups(self): # 验证异步生产策略会按超发预算生产,并优先重试 replay buffer 中的 aborted group。 # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 @@ -300,8 +429,7 @@ async def mock_gen(rs, **kwargs): model_step=0, progress=self._build_progress(task_name, target=2), ) - status = await strategy.produce_batch(ctx) - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据。 final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) @@ -314,7 +442,7 @@ async def test_async_produce_strategy_accepts_context_entrypoint(self): mock_agent_loop = self._build_agent_loop() sampler = self._build_sampler() strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0).build() - progress = self._build_progress(task_name, target=1, train_step=1) + progress = self._build_progress(task_name, target=1) ctx = self._build_context( strategy, task_name, @@ -325,9 +453,7 @@ async def test_async_produce_strategy_accepts_context_entrypoint(self): progress=progress, ) - status = await strategy.produce_batch(ctx) - - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 1) async def test_async_produce_strategy_uses_live_consumed_progress(self): @@ -346,8 +472,8 @@ async def mock_gen(rs, **kwargs): mock_agent_loop.generate_group = mock_gen sampler = self._build_sampler() # 该用例验证版本记录顺序,放宽 stale 策略避免在生产入口提前返回。 - strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build() - progress = self._build_progress( + strategy = DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build() + progress = self._build_disagg_progress( task_name, target=2, train_step=1, @@ -356,7 +482,7 @@ async def mock_gen(rs, **kwargs): target_upto_future_step=2, ) - ctx = self._build_context( + ctx = self._build_disagg_context( strategy, task_name, mock_agent_loop, @@ -375,6 +501,8 @@ async def mock_gen(rs, **kwargs): async def test_async_produce_strategy_uses_fixed_batch_oversample_budget(self): # 验证超发预算按当前 task batch size 固定计算,而不是按剩余缺口缩小。 task_name = "test_fixed_oversample" + for sample_id in range(9): + await self.replay_buffer.put([make_rollout_state(sample_id, status=Status.COMPLETED)], task_name) sampler = MagicMock() sample_ids = iter(range(100, 200)) @@ -385,7 +513,7 @@ async def sample(task_name, group_status=None): sampler.sample = AsyncMock(side_effect=sample) mock_agent_loop = self._build_agent_loop() strategy = AsyncProduceStrategyConfig(over_sample_threshold=1.0).build() - progress = self._build_progress(task_name, target=10, consumed=9) + progress = self._build_progress(task_name, target=10) ctx = self._build_context( strategy, @@ -396,13 +524,11 @@ async def sample(task_name, group_status=None): model_step=0, progress=progress, ) - status = await strategy.produce_batch(ctx) - - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) # 当前只缺 1 个样本,但 over-sample 预算固定为 over * batch_size = 4, # 因此本轮最多调度到 target + 4,对应初始发射 5 个任务。 self.assertEqual(sampler.sample.await_count, 5) - self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 5) + self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 14) async def test_async_produce_strategy_tail_batch_is_static_and_no_oversample(self): # 验证 tail-batch 模式固定从 expired/aborted pool 补必要缺口,并禁用额外超发。 @@ -434,9 +560,7 @@ async def instrumented_sample(task_name, group_status=None): model_step=0, progress=self._build_progress(task_name, target=2), ) - status = await strategy.produce_batch(ctx) - - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) # tail-batch 模式在本轮优先走 EXPIRED pool,并且不使用 over-sample 额外发射。 self.assertEqual(sampled_statuses, [[Status.EXPIRED, Status.ABORTED], [Status.EXPIRED, Status.ABORTED]]) completed = await self.replay_buffer.get(10, task_name, Status.COMPLETED) @@ -450,14 +574,8 @@ async def test_async_produce_strategy_fails_fast_on_invalid_progress(self): sampler = MagicMock() sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) - missing_consumed = ProduceProgress( - next_consumer_step=1, - producer_future_step=1, - consumed_samples={}, - target_samples={task_name: 1}, - target_upto_future_step=1, - ) - with self.assertRaisesRegex(KeyError, "consumed_samples"): + missing_target = ProduceProgress(target_samples={}) + with self.assertRaisesRegex(KeyError, "target_samples"): ctx = self._build_context( strategy, task_name, @@ -466,29 +584,31 @@ async def test_async_produce_strategy_fails_fast_on_invalid_progress(self): batch_size=1, train_step=1, model_step=0, - progress=missing_consumed, + progress=missing_target, ) await strategy.produce_batch(ctx) - missing_target = ProduceProgress( - next_consumer_step=1, + disagg_strategy = DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + missing_consumed = DisaggProduceProgress( + task_names=[task_name], producer_future_step=1, - consumed_samples={task_name: 0}, - target_samples={}, + next_consumer_step=1, + consumed_samples={}, + target_samples={task_name: 1}, target_upto_future_step=1, ) - with self.assertRaisesRegex(KeyError, "target_samples"): - ctx = self._build_context( - strategy, + with self.assertRaisesRegex(KeyError, "consumed_samples"): + ctx = self._build_disagg_context( + disagg_strategy, task_name, mock_agent_loop, sampler, batch_size=1, train_step=1, model_step=0, - progress=missing_target, + progress=missing_consumed, ) - await strategy.produce_batch(ctx) + await disagg_strategy.produce_batch(ctx) async def test_async_produce_strategy_records_sample_version_before_staleness_refresh(self): # 验证新生成 token 会先记录 Rollout Model Step,再按 consumer step 刷新 staleness。 @@ -515,11 +635,9 @@ async def mock_gen(rs, **kwargs): batch_size=1, train_step=5, model_step=3, - progress=self._build_progress(task_name, target=1, train_step=5), + progress=self._build_progress(task_name, target=1), ) - status = await strategy.produce_batch(ctx) - - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) completed = await self.replay_buffer.get(1, task_name, Status.COMPLETED) self.assertEqual(completed[0][0].response_model_steps, [3, 3]) self.assertEqual(completed[0][0].seq_staleness, 1) @@ -557,17 +675,16 @@ async def mock_gen(rs, **kwargs): batch_size=1, train_step=5, model_step=3, - progress=self._build_progress(task_name, target=1, train_step=5), + progress=self._build_progress(task_name, target=1), ) - status = await strategy.produce_batch(ctx) + await strategy.produce_batch(ctx) - self.assertEqual(status, ProduceBatchStatus.NORMAL) completed = await self.replay_buffer.get(1, task_name, Status.COMPLETED) self.assertEqual(completed[0][0].response_model_steps, [1, 3, 3]) self.assertEqual(completed[0][0].seq_staleness, 3) - async def test_async_produce_strategy_reclaims_cross_call_pending_and_records_timing(self): - # 验证跨 produce_batch 调用遗留的 pending task 会被回收,并写入生成耗时指标。 + async def test_async_produce_strategy_does_not_reclaim_previous_call_pending(self): + # 共卡 async 的 pending 只属于一次 produce_batch;下一次调用不能回收上一次遗留结果。 task_name = "test_task" mock_agent_loop = self._build_agent_loop({0: 0.01, 1: 0.05, 2: 0.05}) produce_strategy_cfg = AsyncProduceStrategyConfig(over_sample_threshold=2.0, enable_partial_rollout=True) @@ -584,8 +701,7 @@ async def test_async_produce_strategy_reclaims_cross_call_pending_and_records_ti model_step=0, progress=progress, ) - status = await strategy.produce_batch(ctx) - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) self.assertGreater(strategy.pending_task_count(), 0) await asyncio.sleep(0.08) @@ -599,13 +715,12 @@ async def test_async_produce_strategy_reclaims_cross_call_pending_and_records_ti model_step=0, progress=progress, ) - status = await strategy.produce_batch(ctx) - self.assertEqual(status, ProduceBatchStatus.NORMAL) + await strategy.produce_batch(ctx) self.assertEqual(strategy.pending_task_count(), 0) final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) - self.assertEqual(len(final_data), 3) - self.assertEqual(sorted(group[0].message_uid for group in final_data), [0, 1, 2]) + self.assertEqual(len(final_data), 1) + self.assertEqual(final_data[0][0].message_uid, 0) for group in final_data: self.assertIn("group_generate_time_s", group[0].extra_fields) self.assertGreater(group[0].extra_fields["group_generate_time_s"], 0.0) @@ -674,14 +789,14 @@ async def test_async_produce_strategy_pause_produce_collects_without_cancelling( async def test_async_produce_strategy_returns_update_abort_without_sampling(self): # 验证 update_event 已设置时策略立即返回 UPDATE_WEIGHT_AND_ABORT,不再采样新 rollout。 task_name = "test_update_abort" - strategy = AsyncProduceStrategyConfig(over_sample_threshold=1.0).build() + strategy = DisaggAsyncProduceStrategyConfig(over_sample_threshold=1.0).build() mock_agent_loop = self._build_agent_loop() sampler = MagicMock() sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) update_event = asyncio.Event() update_event.set() - ctx = self._build_context( + ctx = self._build_disagg_context( strategy, task_name, mock_agent_loop, @@ -690,7 +805,7 @@ async def test_async_produce_strategy_returns_update_abort_without_sampling(self train_step=1, model_step=1, update_event=update_event, - progress=self._build_progress(task_name, target=1, train_step=1), + progress=self._build_disagg_progress(task_name, target=1, train_step=1), ) status = await strategy.produce_batch(ctx) @@ -700,11 +815,11 @@ async def test_async_produce_strategy_returns_update_abort_without_sampling(self async def test_async_produce_strategy_returns_update_abort_after_schedule_pause(self): # 验证调度临界区中途触发 pause 后,策略停止继续调度并返回 UPDATE_WEIGHT_AND_ABORT。 task_name = "test_update_abort_after_schedule" - strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + strategy = DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0).build() mock_agent_loop = self._build_agent_loop({0: 0.05}) sampler = MagicMock() update_event = asyncio.Event() - progress = self._build_progress(task_name, target=1) + progress = self._build_disagg_progress(task_name, target=1) async def sample(task_name, group_status=None): # 模拟 manager 在调度临界区中途触发 pause;当前样本会进入 pending,后续应停止继续调度。 @@ -713,7 +828,7 @@ async def sample(task_name, group_status=None): sampler.sample = AsyncMock(side_effect=sample) - ctx = self._build_context( + ctx = self._build_disagg_context( strategy, task_name, mock_agent_loop, @@ -731,16 +846,16 @@ async def sample(task_name, group_status=None): await strategy.pause_produce(ctx) self.assertEqual(strategy.pending_task_count(), 0) - async def test_async_produce_strategy_returns_expired_batch_before_processing_leftovers(self): - # 验证 Rollout Model Step 过期时策略先返回 EXPIRED_BATCH,不消费已有 completed leftovers。 + async def test_disagg_async_produce_strategy_returns_expired_batch_before_processing_leftovers(self): + # 验证非共卡 Rollout Model Step 过期时策略先返回 EXPIRED_BATCH,不消费已有 completed leftovers。 task_name = "test_expired_batch" - strategy = AsyncProduceStrategyConfig(max_staleness=0).build() + strategy = DisaggAsyncProduceStrategyConfig(max_staleness=0).build() mock_agent_loop = self._build_agent_loop() sampler = MagicMock() sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) await self.replay_buffer.put([make_rollout_state(999, status=Status.COMPLETED)], task_name) - ctx = self._build_context( + ctx = self._build_disagg_context( strategy, task_name, mock_agent_loop, @@ -748,7 +863,7 @@ async def test_async_produce_strategy_returns_expired_batch_before_processing_le batch_size=1, train_step=3, model_step=1, - progress=self._build_progress(task_name, target=1, train_step=3), + progress=self._build_disagg_progress(task_name, target=1, train_step=3), ) status = await strategy.produce_batch(ctx) diff --git a/tests/rl/fast/pr_fast/test_rl_colocate_trainer.py b/tests/rl/fast/pr_fast/test_rl_colocate_trainer.py index 6c2582d089..f8527e97a4 100644 --- a/tests/rl/fast/pr_fast/test_rl_colocate_trainer.py +++ b/tests/rl/fast/pr_fast/test_rl_colocate_trainer.py @@ -30,7 +30,8 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.rl.agent_loop_manager import AsyncProduceStrategyConfig, ProduceBatchResult -from xtuner.v1.rl.agent_loop_manager.agent_loop_manager import AgentLoopManager, _TaskRunner +from xtuner.v1.rl.agent_loop_manager.agent_loop_manager import AgentLoopManager +from xtuner.v1.rl.agent_loop_manager.produce_utils import _TaskRunner from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SerializedRayObjectRef from xtuner.v1.train.rl_trainer import RLColocateTrainer, RLThroughputBenchmark @@ -71,11 +72,16 @@ async def sample(self, task_name, group_status=None, **kwargs): return [item] -def _build_fake_agent_loop(): +def _build_fake_rollout_controller(): rollout_ctl = MagicMock() rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + return rollout_ctl + + +def _build_fake_agent_loop(): + rollout_ctl = _build_fake_rollout_controller() agent_loop = MagicMock() agent_loop.rollout_ctl = rollout_ctl @@ -178,6 +184,7 @@ def test_fit_accepts_async_strategy_manager_on_colocate_path(self): ) ], replay_buffer=replay_buffer, + rollout_controller=_build_fake_rollout_controller(), ) trainer = self._make_trainer(manager) diff --git a/tests/rl/fast/pr_fast/test_rl_disaggregated_trainer.py b/tests/rl/fast/pr_fast/test_rl_disaggregated_trainer.py index 8bda9a0c9b..c5f81122d8 100644 --- a/tests/rl/fast/pr_fast/test_rl_disaggregated_trainer.py +++ b/tests/rl/fast/pr_fast/test_rl_disaggregated_trainer.py @@ -54,8 +54,8 @@ async def get_batch(self, batch_size: int, train_step: int): self.calls.append(("get_batch", batch_size, train_step)) return self._results.pop(0) - async def pause_produce(self, *, use_global_progress: bool): - self.calls.append(("pause_produce", use_global_progress)) + async def pause_produce(self): + self.calls.append("pause_produce") return 0.25 async def continue_produce(self, model_step: int): @@ -83,6 +83,17 @@ async def produce_loop(self, batch_size: int): self.calls.append("produce_loop_exit") +class _FailingProducerManager(_FakeManager): + async def produce_loop(self, batch_size: int): + self.calls.append(("produce_loop_start", batch_size)) + raise RuntimeError("background producer failed") + + async def get_batch(self, batch_size: int, train_step: int): + self.calls.append(("get_batch", batch_size, train_step)) + await asyncio.sleep(0.05) + return self._results.pop(0) + + class TestRLDisaggregatedTrainer(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -267,6 +278,18 @@ def blocking_train_one_batch(*args, **kwargs): self.assertIn("produce_loop_tick_during_training", manager.calls) self.assertEqual(trainer._cur_step, 1) + def test_fit_observes_background_producer_failure_before_training_waited_batch(self): + # 后台 producer 异常是终止性失败;前台 get_batch 还在等待时必须立刻暴露,不能先训练随后才失败。 + train_sample = SimpleNamespace(message_uid=1, uid=1) + manager = _FailingProducerManager([ProduceBatchResult(rollout_states=[[train_sample]])]) + trainer = self._make_trainer(manager) + + with self.assertRaisesRegex(RuntimeError, "background producer failed"): + self._run_fit(trainer) + + trainer.train_controller.fit.assert_not_called() + self.assertIn(("get_batch", 2, 1), manager.calls) + def test_fit_runs_eval_before_reset_and_stops_producer(self): # 验证 eval 在 producer 恢复前执行,避免生产侧提前抢占 rollout 资源。 # 确定性排序依赖 RolloutState 的 message_uid 和 uid,测试用轻量对象模拟即可。 diff --git a/tests/rl/fast/pr_fast/test_staleness_policy.py b/tests/rl/fast/pr_fast/test_staleness_policy.py index 29d63f370b..df97a545fe 100644 --- a/tests/rl/fast/pr_fast/test_staleness_policy.py +++ b/tests/rl/fast/pr_fast/test_staleness_policy.py @@ -2,21 +2,25 @@ from pydantic import ValidationError -from xtuner.v1.rl.agent_loop_manager import AsyncProduceStrategyConfig, calculate_stale_threshold +from xtuner.v1.rl.agent_loop_manager import ( + AsyncProduceStrategyConfig, + DisaggAsyncProduceStrategyConfig, + calculate_stale_threshold, +) class TestStalenessPolicy(unittest.TestCase): def test_max_staleness_zero_uses_sync_interval_as_threshold(self): # max_staleness=0 表示只接受同步间隔内天然存在的最小滞后。 self.assertEqual(calculate_stale_threshold(max_staleness=0, sync_weights_interval=4), 4) - strategy = AsyncProduceStrategyConfig(max_staleness=0).build(sync_weights_interval=4) + strategy = DisaggAsyncProduceStrategyConfig(max_staleness=0).build(sync_weights_interval=4) self.assertFalse(strategy.is_model_expired(train_step=8, model_step=4)) self.assertTrue(strategy.is_model_expired(train_step=9, model_step=4)) def test_max_staleness_one_allows_one_extra_sync_interval(self): self.assertEqual(calculate_stale_threshold(max_staleness=1, sync_weights_interval=4), 8) - strategy = AsyncProduceStrategyConfig(max_staleness=1).build(sync_weights_interval=4) + strategy = DisaggAsyncProduceStrategyConfig(max_staleness=1).build(sync_weights_interval=4) self.assertFalse(strategy.is_model_expired(train_step=12, model_step=4)) self.assertTrue(strategy.is_model_expired(train_step=13, model_step=4)) @@ -24,6 +28,8 @@ def test_max_staleness_one_allows_one_extra_sync_interval(self): def test_negative_max_staleness_is_invalid(self): with self.assertRaises(ValidationError): AsyncProduceStrategyConfig(max_staleness=-1) + with self.assertRaises(ValidationError): + DisaggAsyncProduceStrategyConfig(max_staleness=-1) if __name__ == "__main__": diff --git a/tests/rl/test_agent_loop_manager_checkpoint.py b/tests/rl/test_agent_loop_manager_checkpoint.py index ba424ff41d..4ec363a3d2 100644 --- a/tests/rl/test_agent_loop_manager_checkpoint.py +++ b/tests/rl/test_agent_loop_manager_checkpoint.py @@ -13,11 +13,11 @@ 本文件主要覆盖的 public 行为: - 共卡 produce_batch 下,SyncProduceStrategy / AsyncProduceStrategy 都会在 resume 后 继续同一段 sampler 序列。 -- 非共卡 produce_loop/get_batch 下,AsyncProduceStrategy 也会在 resume 后 +- 非共卡 produce_loop/get_batch 下,DisaggAsyncProduceStrategy 也会在 resume 后 继续同一段 sampler 序列。 - save/resume 后,checkpoint 中尚未消费的 completed rollout group 仍可通过 get_batch 取出。 - save 时如果 AsyncProduceStrategy 仍有 pending rollout task,会 fail fast,避免保存不完整状态。 -- resume 后的 AsyncProduceStrategy 后台 producer 必须等 trainer 显式 continue_produce 后才恢复。 +- resume 后的 DisaggAsyncProduceStrategy 后台 producer 必须等 trainer 显式 continue_produce 后才恢复。 """ import asyncio @@ -36,6 +36,8 @@ from xtuner.v1.rl.agent_loop_manager import ( AgentLoopManagerConfig, AsyncProduceStrategyConfig, + DisaggAgentLoopManagerConfig, + DisaggAsyncProduceStrategyConfig, SamplerConfig, SyncProduceStrategyConfig, ) @@ -126,10 +128,19 @@ def _write_dataset(self, dataset_path: Path): ) dataset_path.write_text("\n".join(json.dumps(row) for row in rows) + "\n", encoding="utf-8") - def _build_manager(self, replay_buffer_config, *, rollout_controller=None, produce_strategy_config=None): + def _build_manager( + self, + replay_buffer_config, + *, + rollout_controller=None, + produce_strategy_config=None, + mode="colocate", + ): assert QWEN3_4B_PATH is not None rollout_controller = rollout_controller or _FakeRolloutController() - produce_strategy_config = produce_strategy_config or SyncProduceStrategyConfig() + produce_strategy_config = produce_strategy_config or ( + DisaggAsyncProduceStrategyConfig() if mode == "disaggregated" else SyncProduceStrategyConfig() + ) dataloader_cfg = DataloaderConfig( dataset_config_list=[ { @@ -149,7 +160,8 @@ def _build_manager(self, replay_buffer_config, *, rollout_controller=None, produ num_workers=0, round_up=False, ) - manager_cfg = AgentLoopManagerConfig( + manager_config_cls = DisaggAgentLoopManagerConfig if mode == "disaggregated" else AgentLoopManagerConfig + manager_cfg = manager_config_cls( tasks=[ { "task_name": "unit_task", @@ -160,7 +172,7 @@ def _build_manager(self, replay_buffer_config, *, rollout_controller=None, produ "produce_strategy_config": produce_strategy_config, "sampler_config": SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=2), } - ] + ], ) return manager_cfg.build( rollout_controller=rollout_controller, @@ -169,13 +181,22 @@ def _build_manager(self, replay_buffer_config, *, rollout_controller=None, produ ) def _build_async_manager(self, *, rollout_controller=None): - with patch("xtuner.v1.rl.agent_loop_manager.producer.ray.get", side_effect=lambda ref, *_, **__: ref): + with patch("ray.get", side_effect=lambda ref, *_, **__: ref): return self._build_manager( AsyncReplayBufferConfig(), rollout_controller=rollout_controller, produce_strategy_config=AsyncProduceStrategyConfig(over_sample_threshold=0.0), ) + def _build_disagg_async_manager(self, *, rollout_controller=None): + with patch("ray.get", side_effect=lambda ref, *_, **__: ref): + return self._build_manager( + AsyncReplayBufferConfig(), + rollout_controller=rollout_controller, + produce_strategy_config=DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0), + mode="disaggregated", + ) + def _build_sync_produce_batch_manager(self): return self._build_manager(SyncReplayBufferConfig()) @@ -238,14 +259,14 @@ async def test_produce_batch_resume_continues_same_sampler_suffix_for_sync_and_a @unittest.skipUnless(QWEN3_4B_PATH, "QWEN3_4B_PATH is required for AgentLoopManager checkpoint tests") async def test_async_produce_loop_resume_continues_same_sampler_suffix_after_checkpoint(self): - # 验证非共卡 AsyncProduceStrategy: sample1 后保存 checkpoint,正常继续生产 sample2/sample3。 + # 验证非共卡 DisaggAsyncProduceStrategy: sample1 后保存 checkpoint,正常继续生产 sample2/sample3。 # 从 checkpoint resume 后也必须继续生产同一段 sample2/sample3。 - manager = self._build_async_manager() + manager = self._build_disagg_async_manager() produce_task = asyncio.create_task(manager.produce_loop(batch_size=1)) try: sample1_index = await self._consume_async_index(manager, train_step=1) - await manager.pause_produce(use_global_progress=True) + await manager.pause_produce() with tempfile.TemporaryDirectory() as tmp_dir: checkpoint_path = Path(tmp_dir) / "ckpt" @@ -256,7 +277,7 @@ async def test_async_produce_loop_resume_continues_same_sampler_suffix_after_che await self._continue_and_consume_async_index(manager, train_step=3, model_step=2), ] - restored_manager = self._build_async_manager() + restored_manager = self._build_disagg_async_manager() restored_model_step = await restored_manager.resume(checkpoint_path) restored_produce_task = asyncio.create_task(restored_manager.produce_loop(batch_size=1)) try: @@ -282,7 +303,7 @@ async def test_async_produce_loop_resume_continues_same_sampler_suffix_after_che @unittest.skipUnless(QWEN3_4B_PATH, "QWEN3_4B_PATH is required for AgentLoopManager checkpoint tests") async def test_resume_keeps_unconsumed_completed_groups_available_to_get_batch(self): # 验证 save 时 replay buffer 中未消费的 completed group,resume 后仍可被 get_batch 消费。 - manager = self._build_manager(AsyncReplayBufferConfig()) + manager = self._build_disagg_async_manager() buffered_group = [ RolloutState( uid=9000 + idx, @@ -303,7 +324,7 @@ async def test_resume_keeps_unconsumed_completed_groups_available_to_get_batch(s checkpoint_path = Path(tmp_dir) / "ckpt" await manager.save(checkpoint_path, model_step=4) - restored_manager = self._build_manager(AsyncReplayBufferConfig()) + restored_manager = self._build_disagg_async_manager() restored_model_step = await restored_manager.resume(checkpoint_path) result = await restored_manager.get_batch(batch_size=1, train_step=5) @@ -316,11 +337,12 @@ async def test_save_rejects_while_async_rollout_task_is_pending(self): # 验证后台异步 rollout 还在进行时不能保存。 # 否则 checkpoint 会丢失未入库的生产结果。 rollout_controller = _BlockingRolloutController() - with patch("xtuner.v1.rl.agent_loop_manager.producer.ray.get", side_effect=lambda ref, *_, **__: ref): + with patch("ray.get", side_effect=lambda ref, *_, **__: ref): manager = self._build_manager( AsyncReplayBufferConfig(), rollout_controller=rollout_controller, - produce_strategy_config=AsyncProduceStrategyConfig(over_sample_threshold=0.0), + produce_strategy_config=DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0), + mode="disaggregated", ) produce_task = asyncio.create_task(manager.produce_loop(batch_size=1)) @@ -341,21 +363,23 @@ async def test_resume_requires_continue_before_async_producer_generates(self): with tempfile.TemporaryDirectory() as tmp_dir: checkpoint_path = Path(tmp_dir) / "ckpt" rollout_controller = _FakeRolloutController() - with patch("xtuner.v1.rl.agent_loop_manager.producer.ray.get", side_effect=lambda ref, *_, **__: ref): + with patch("ray.get", side_effect=lambda ref, *_, **__: ref): manager = self._build_manager( AsyncReplayBufferConfig(), rollout_controller=rollout_controller, - produce_strategy_config=AsyncProduceStrategyConfig(over_sample_threshold=0.0), + produce_strategy_config=DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0), + mode="disaggregated", ) await manager.save(checkpoint_path, model_step=1) restored_rollout_controller = _FakeRolloutController() - with patch("xtuner.v1.rl.agent_loop_manager.producer.ray.get", side_effect=lambda ref, *_, **__: ref): + with patch("ray.get", side_effect=lambda ref, *_, **__: ref): restored_manager = self._build_manager( AsyncReplayBufferConfig(), rollout_controller=restored_rollout_controller, - produce_strategy_config=AsyncProduceStrategyConfig(over_sample_threshold=0.0), - ) + produce_strategy_config=DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0), + mode="disaggregated", + ) restored_model_step = await restored_manager.resume(checkpoint_path) produce_task = asyncio.create_task(restored_manager.produce_loop(batch_size=1)) diff --git a/tests/rl/test_rl_trainer_checkpoint.py b/tests/rl/test_rl_trainer_checkpoint.py index b8178980e7..ec314ce590 100644 --- a/tests/rl/test_rl_trainer_checkpoint.py +++ b/tests/rl/test_rl_trainer_checkpoint.py @@ -31,7 +31,13 @@ from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense4BConfig from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig -from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + DisaggAgentLoopManagerConfig, + DisaggAsyncProduceStrategyConfig, + SamplerConfig, + SyncProduceStrategyConfig, +) from xtuner.v1.rl.loss import GRPOLossConfig from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig from xtuner.v1.rl.rollout.worker import RolloutConfig @@ -206,6 +212,7 @@ def build_rollout_controller(rollout_cfg, placement_group): patch("xtuner.v1.train.rl_trainer.AutoAcceleratorWorkers.build_placement_group", side_effect=build_pg), patch("xtuner.v1.train.rl_trainer.CPUResourceManager", _FakeCPUResourceManager), patch("xtuner.v1.train.rl_trainer.set_cpu_resource_manager", lambda manager: None), + patch("xtuner.v1.train.rl_trainer.get_rollout_engine_version", return_value={}), patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj), patch("xtuner.v1.train.rl_trainer.BaseRLTrainer._release_trace_store", return_value=None), patch.object(WorkerConfig, "build", autospec=True, side_effect=build_train_controller), @@ -231,7 +238,13 @@ def _build_train_worker_config(self, model_path: str) -> WorkerConfig: pack_max_length=256, ) - def _build_agent_loop_manager_config(self, model_path: str) -> AgentLoopManagerConfig: + def _build_agent_loop_manager_config( + self, + model_path: str, + *, + mode: str = "colocate", + produce_strategy_config=None, + ) -> AgentLoopManagerConfig | DisaggAgentLoopManagerConfig: dataloader_cfg = DataloaderConfig( dataset_config_list=[ { @@ -251,7 +264,11 @@ def _build_agent_loop_manager_config(self, model_path: str) -> AgentLoopManagerC num_workers=0, round_up=False, ) - return AgentLoopManagerConfig( + manager_config_cls = DisaggAgentLoopManagerConfig if mode == "disaggregated" else AgentLoopManagerConfig + produce_strategy_config = produce_strategy_config or ( + DisaggAsyncProduceStrategyConfig() if mode == "disaggregated" else SyncProduceStrategyConfig() + ) + return manager_config_cls( tasks=[ { "task_name": "unit_task", @@ -259,10 +276,10 @@ def _build_agent_loop_manager_config(self, model_path: str) -> AgentLoopManagerC hf_checkpoint=model_path, sample_params=SampleParams(max_tokens=2, temperature=0.0, top_k=1), ), - "produce_strategy_config": SyncProduceStrategyConfig(), + "produce_strategy_config": produce_strategy_config, "sampler_config": SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=2), } - ] + ], ) def _build_rollout_config(self, model_path: str) -> RolloutConfig: @@ -329,7 +346,11 @@ def _build_disaggregated_config( rollout_config=self._build_rollout_config(QWEN3_4B_PATH), tokenizer_path=QWEN3_4B_PATH, replay_buffer_config=AsyncReplayBufferConfig(), - agent_loop_manager_cfg=self._build_agent_loop_manager_config(QWEN3_4B_PATH), + agent_loop_manager_cfg=self._build_agent_loop_manager_config( + QWEN3_4B_PATH, + mode="disaggregated", + produce_strategy_config=DisaggAsyncProduceStrategyConfig(over_sample_threshold=0.0), + ), load_from=QWEN3_4B_PATH, total_train_steps=total_train_steps, train_batch_size=1, diff --git a/xtuner/v1/rl/agent_loop_manager/__init__.py b/xtuner/v1/rl/agent_loop_manager/__init__.py index c58c1f5cc5..ddd90fd6be 100644 --- a/xtuner/v1/rl/agent_loop_manager/__init__.py +++ b/xtuner/v1/rl/agent_loop_manager/__init__.py @@ -1,21 +1,32 @@ from .agent_loop_manager import ( AgentLoopManager, AgentLoopManagerConfig, - AgentLoopManagerStatus, - ProduceBatchResult, TaskSpecConfig, ) +from .disagg_agent_loop_manager import ( + AgentLoopManagerStatus, + DisaggAgentLoopManager, + DisaggAgentLoopManagerConfig, + DisaggTaskSpecConfig, +) +from .disagg_producer import ( + DisaggAsyncProduceStrategy, + DisaggAsyncProduceStrategyConfig, + DisaggProduceContext, + DisaggProduceProgress, + DisaggProduceStrategy, + DisaggProduceStrategyConfig, +) +from .produce_utils import ProduceBatchResult, ProduceBatchStatus, calculate_stale_threshold from .producer import ( AsyncProduceStrategy, AsyncProduceStrategyConfig, - ProduceBatchStatus, ProduceContext, ProduceProgress, ProduceStrategy, ProduceStrategyConfig, SyncProduceStrategy, SyncProduceStrategyConfig, - calculate_stale_threshold, ) from .sampler import Sampler, SamplerConfig @@ -24,18 +35,27 @@ __all__ = [ "AgentLoopManagerConfig", "AgentLoopManager", + "DisaggAgentLoopManager", + "DisaggAgentLoopManagerConfig", "AgentLoopManagerStatus", "TaskSpecConfig", + "DisaggTaskSpecConfig", "ProduceBatchResult", "ProduceStrategyConfig", + "DisaggProduceStrategyConfig", + "DisaggProduceProgress", + "DisaggProduceContext", + "DisaggProduceStrategy", "SyncProduceStrategyConfig", "AsyncProduceStrategyConfig", + "DisaggAsyncProduceStrategyConfig", "ProduceBatchStatus", "ProduceContext", "ProduceProgress", "ProduceStrategy", "SyncProduceStrategy", "AsyncProduceStrategy", + "DisaggAsyncProduceStrategy", "calculate_stale_threshold", "SamplerConfig", "Sampler", diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index c2257afbb0..2b45a6ca15 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -1,216 +1,43 @@ import asyncio import json -import math import time -from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path +from typing import Any, cast from pydantic import BaseModel, ConfigDict, Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1.data_proto.rl_data import RolloutState, Status -from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.agent_loop import AgentLoopConfig from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger from xtuner.v1.rl.replay_buffer import ReplayBuffer from xtuner.v1.rl.rollout import RolloutController from xtuner.v1.utils import get_logger +from .produce_utils import ( + _MANAGER_STATE_PATH, + _STATUS_POLL_INTERVAL_S, + _TASK_CHECKPOINT_DIR, + ProduceBatchResult, + _TaskRunner, + _TaskSamplerView, + allocate_task_batch_sizes, + get_pending_task_counts, + manager_state_path, + refresh_for_all_tasks, + take_train_batch, + task_checkpoint_path, +) from .producer import ( - GROUP_GENERATE_TIME_KEY, - ProduceBatchStatus, ProduceContext, ProduceProgress, ProduceStrategy, ProduceStrategyConfig, SyncProduceStrategyConfig, - default_is_valid_sample_fn, ) from .sampler import Sampler, SamplerConfig -@dataclass -class ProduceBatchResult: - """Result of a single ``produce_batch`` call. - - Args: - rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training. - group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran). - group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds. - group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds. - group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds. - group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew. - group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds. - leftover_init (int): Number of init groups remaining in the replay buffer after this batch. - leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch. - leftover_aborted (int): Number of aborted groups remaining in the replay buffer. - leftover_expired (int): Number of expired groups remaining in the replay buffer. - leftover_failed (int): Number of failed groups remaining in the replay buffer. - leftover_filtered (int): Number of filtered groups remaining in the replay buffer. - raw_rewards_sum (float): Sum of rewards produced before replay-buffer insertion for the current window. - raw_rewards_count (int): Number of reward-bearing samples included in ``raw_rewards_sum``. - produced_samples (int): Number of rollout samples produced in the current produce window. - produced_tokens (int): Number of response tokens produced in the current produce window. - produce_time_s (float): Wall-clock production time consumed by the current produce window. - """ - - rollout_states: list[list[RolloutState]] - status: ProduceBatchStatus = ProduceBatchStatus.NORMAL - # per-group generation timing stats (all None if no generations occurred) - group_gen_count: int | None = None - group_gen_mean_s: float | None = None - group_gen_p50_s: float | None = None - group_gen_p99_s: float | None = None - group_gen_p99_p50_ratio: float | None = None - group_gen_pause_time_s: float | None = None - # leftover samples remaining in replay buffer after batch retrieval - leftover_init: int = 0 - leftover_completed: int = 0 - leftover_aborted: int = 0 - leftover_expired: int = 0 - leftover_failed: int = 0 - leftover_filtered: int = 0 - # rewards produced during the current produce window, including completed and filtered groups. - raw_rewards_sum: float = 0.0 - raw_rewards_count: int = 0 - produced_samples: int = 0 - produced_tokens: int = 0 - produce_time_s: float = 0.0 - task_batch_sizes: dict[str, int] | None = None - task_results: dict[str, "ProduceBatchResult"] | None = None - - -@dataclass(frozen=True) -class _TaskRunner: - task_name: str - agent_loop: AgentLoopSpec - produce_strategy: ProduceStrategy - sampler: Sampler - weight: float = 1.0 - order: int = 0 - - -class _TaskSamplerView: - def __init__(self, samplers: list[Sampler]): - self._samplers = samplers - - def __len__(self) -> int: - return sum(len(sampler) for sampler in self._samplers) - - -class AgentLoopManagerStatus(Enum): - """AgentLoopManager 的全局状态. - - 按下面的路径流转: - - 初始状态是 NORMAL - - NORMAL -> UPDATE_WEIGHT_AND_ABORT - - trainer 开始做权重同步前触发 - - UPDATE_WEIGHT_AND_ABORT -> NORMAL - - 权重同步完成后调用 continue_product() - - NORMAL -> EXPIRED_BATCH - - 当前 rollout model 已经过旧 - - EXPIRED_BATCH -> UPDATE_WEIGHT_AND_ABORT - - trainer 检测到过期后,进入权重同步阶段 - - 任意状态 -> FINISH - - 训练结束 - - 这里有一个重要区分: - - AgentLoopManagerStatus 是“后台 producer 的全局运行状态” - - ProduceBatchStatus 是“单次调度调用的局部结果” - """ - - NORMAL = auto() - UPDATE_WEIGHT_AND_ABORT = auto() - EXPIRED_BATCH = auto() - FINISH = auto() - - -def _fill_produce_timing_stats( - result: ProduceBatchResult, generate_times_s: list[float], pause_time_s: float = 0.0 -) -> None: - if not generate_times_s: - if pause_time_s > 0: - result.group_gen_pause_time_s = pause_time_s - return - sorted_times = sorted(generate_times_s) - n = len(sorted_times) - mean_s = sum(sorted_times) / n - p50_s = sorted_times[n // 2] - p99_s = sorted_times[int(n * 0.99)] - ratio = p99_s / p50_s if p50_s > 0 else float("inf") - result.group_gen_count = n - result.group_gen_mean_s = mean_s - result.group_gen_p50_s = p50_s - result.group_gen_p99_s = p99_s - result.group_gen_p99_p50_ratio = ratio - result.group_gen_pause_time_s = pause_time_s - - -def _fill_group_timing_stats( - result: ProduceBatchResult, rollout_states: list[list[RolloutState]], pause_time_s: float = 0.0 -) -> None: - generate_times: list[float] = [] - for group in rollout_states: - if not group: - continue - group_time = getattr(group[0], "extra_fields", {}).get(GROUP_GENERATE_TIME_KEY) - if group_time is not None: - generate_times.append(group_time) - - _fill_produce_timing_stats(result, generate_times, pause_time_s=pause_time_s) - - -def _aggregate_status(statuses: list[ProduceBatchStatus]) -> ProduceBatchStatus: - if any(status == ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT for status in statuses): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - if any(status == ProduceBatchStatus.EXPIRED_BATCH for status in statuses): - return ProduceBatchStatus.EXPIRED_BATCH - return ProduceBatchStatus.NORMAL - - -_LEFTOVER_STATUSES = [ - Status.INIT, - Status.COMPLETED, - Status.ABORTED, - Status.EXPIRED, - Status.FAILED, - Status.FILTERED, -] - - -def _fill_leftover_counts(result: ProduceBatchResult, status_counts: dict[Status, int]) -> None: - result.leftover_init = status_counts.get(Status.INIT, 0) - result.leftover_completed = status_counts.get(Status.COMPLETED, 0) - result.leftover_aborted = status_counts.get(Status.ABORTED, 0) - result.leftover_expired = status_counts.get(Status.EXPIRED, 0) - result.leftover_failed = status_counts.get(Status.FAILED, 0) - result.leftover_filtered = status_counts.get(Status.FILTERED, 0) - - -def _build_produce_context( - task_runner: _TaskRunner, - replay_buffer: ReplayBuffer, - batch_size: int, - train_step: int, - model_step: int, - update_event: asyncio.Event, - progress: ProduceProgress, -) -> ProduceContext: - return ProduceContext( - agent_loop=task_runner.agent_loop, - sampler=task_runner.sampler, - replay_buffer=replay_buffer, - task_batch_size=batch_size, - task_name=task_runner.task_name, - train_step=train_step, - update_event=update_event, - model_step=model_step, - progress=progress, - is_valid_sample_fn=getattr(task_runner.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn), - stale_threshold=getattr(task_runner.produce_strategy, "stale_threshold", None), - ) - - class TaskSpecConfig(BaseModel): """Configuration for one task managed by ``AgentLoopManager``. @@ -335,19 +162,28 @@ def build( return AgentLoopManager( task_runners=task_runners, replay_buffer=replay_buffer, + rollout_controller=rollout_controller, logger=logger, ) class AgentLoopManager: - _TASK_CHECKPOINT_DIR = "tasks" - _MANAGER_STATE_PATH = "agent_loop_manager_state.json" - _STATUS_POLL_INTERVAL_S = 1.0 + _TASK_CHECKPOINT_DIR = _TASK_CHECKPOINT_DIR + _MANAGER_STATE_PATH = _MANAGER_STATE_PATH + _STATUS_POLL_INTERVAL_S = _STATUS_POLL_INTERVAL_S + task_runners: list[_TaskRunner] + replay_buffer: ReplayBuffer + _rollout_controller: RolloutController + data_sampler: Sampler | _TaskSamplerView + name: str + logger: Any + task_names: list[str] def __init__( self, task_runners: list[_TaskRunner], replay_buffer: ReplayBuffer, + rollout_controller: RolloutController, logger=None, ): if not task_runners: @@ -357,408 +193,15 @@ def __init__( self.task_runners = task_runners self.replay_buffer = replay_buffer + self._rollout_controller = rollout_controller self.data_sampler = ( task_runners[0].sampler if len(task_runners) == 1 else _TaskSamplerView([task.sampler for task in task_runners]) ) self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task" - if logger is None: - self.logger = get_logger() - else: - self.logger = logger - - self.task_names = [task.task_name for task in self.task_runners] - - # 非共卡并发控制信号:consumer 在同步权重前置位,producer / strategy 应直接观察 - # event 状态并尽快停止继续发新 rollout;不要用额外布尔快照替代这个 event。 - self._update_event = asyncio.Event() - - self._finish_event = asyncio.Event() - - # 非共卡 producer 读取的 model_step:rollout 侧当前使用的是哪个 train_step 同步后的模型。 - # 权重更新前必须先 pause 并清空 pending task,因此一个 pending 生命周期内只对应一个 model_step。 - self._model_step = 0 - - # 非共卡 producer / consumer 共享的控制状态。produce_loop / get_batch 应直接读取 - # self._status,不要跨 await 缓存局部快照,避免错过同步、过期或结束状态变化。 - self._status = AgentLoopManagerStatus.NORMAL - - # pause_produce 写入、下一次 get batch 读取并清零的耗时指标。 - # 只用于消费侧日志/metrics;读写不构成生产正确性依赖。 - self._pause_time_s = 0.0 - - # 非共卡 producer / consumer 共享的绝对累计进度。对象引用必须保持稳定; - # consumer 原地更新字段,producer / strategy 需要字段值时直接读取 progress.xxx, - # 不要把字段值复制成跨 await 使用的局部快照。 - self._produce_progress = ProduceProgress.build(self.task_names) - - def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: - """Return the per-task batch sizes for the current train step. - - Subclasses may override this method to implement custom dynamic batch allocation policies. Returning 0 for a - task effectively disables that task for the current produce_batch call. - """ - if global_batch_size < 0: - raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}") - - total_weight = sum(task.weight for task in self.task_runners) - if total_weight <= 0: - raise ValueError("Sum of task weights must be positive.") - if global_batch_size == 0: - return {task.task_name: 0 for task in self.task_runners} - - raw_allocations = [global_batch_size * task.weight / total_weight for task in self.task_runners] - floor_allocations = [math.floor(raw) for raw in raw_allocations] - remaining = global_batch_size - sum(floor_allocations) - - task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(self.task_runners)} - if remaining <= 0: - return task_batch_sizes - - ranked_tasks = sorted( - enumerate(self.task_runners), - key=lambda item: ( - -(raw_allocations[item[0]] - floor_allocations[item[0]]), - item[1].order, - ), - ) - for idx, task in ranked_tasks[:remaining]: - task_batch_sizes[task.task_name] += 1 - return task_batch_sizes - - def _validate_task_batch_sizes(self, task_batch_sizes: dict[str, int], global_batch_size: int) -> None: - expected_task_names = {task.task_name for task in self.task_runners} - actual_task_names = set(task_batch_sizes.keys()) - if actual_task_names != expected_task_names: - missing_task_names = expected_task_names - actual_task_names - extra_task_names = actual_task_names - expected_task_names - raise ValueError( - "Invalid task batch sizes returned by get_task_batch_sizes: " - f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}" - ) - - negative_batch_sizes = { - task_name: task_batch_size - for task_name, task_batch_size in task_batch_sizes.items() - if task_batch_size < 0 - } - if negative_batch_sizes: - raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}") - - total_batch_size = sum(task_batch_sizes.values()) - if total_batch_size != global_batch_size: - raise ValueError( - "Task batch sizes must sum to the requested global batch size, " - f"got total={total_batch_size}, expected={global_batch_size}" - ) - - async def _refresh_for_all_tasks(self, train_step: int, statuses: list[Status]) -> None: - task_stale_thresholds: dict[str, int] = {} - for task in self.task_runners: - # colocate / disagg 都统一刷新 staleness;同步策略没有 stale_threshold 时使用 1。 - stale_threshold = getattr(task.produce_strategy, "stale_threshold", 1) - task_stale_thresholds[task.task_name] = stale_threshold - - expired_counts = await self.replay_buffer.refresh_staleness( - task_stale_thresholds=task_stale_thresholds, - current_train_step=train_step, - statuses=statuses, - ) - for task_name, expired_count in expired_counts.items(): - self.logger.info( - f"[AgentLoopManager][{self.name}] Refresh staleness for task {task_name}: expired_count={expired_count}" - ) - - def _get_task_batch_sizes_for_step(self, batch_size: int, train_step: int) -> dict[str, int]: - if len(self.task_runners) == 1: - return {self.task_runners[0].task_name: batch_size} - - task_batch_sizes = self.get_task_batch_sizes(batch_size, train_step) - self._validate_task_batch_sizes(task_batch_sizes, batch_size) - return task_batch_sizes - - @staticmethod - def _aggregate_task_results( - ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult] - ) -> ProduceBatchResult: - rollout_states: list[list[RolloutState]] = [] - leftover_init = 0 - leftover_completed = 0 - leftover_aborted = 0 - leftover_expired = 0 - leftover_failed = 0 - leftover_filtered = 0 - total_group_count = 0 - weighted_group_mean_sum = 0.0 - weighted_group_p50_sum = 0.0 - weighted_group_p99_sum = 0.0 - weighted_group_ratio_sum = 0.0 - total_pause_time_s = 0.0 - raw_rewards_sum = 0.0 - raw_rewards_count = 0 - produced_samples = 0 - produced_tokens = 0 - produce_time_s = 0.0 - - for task in ordered_tasks: - result = task_results[task.task_name] - rollout_states.extend(result.rollout_states) - leftover_init += result.leftover_init - leftover_completed += result.leftover_completed - leftover_aborted += result.leftover_aborted - leftover_expired += result.leftover_expired - leftover_failed += result.leftover_failed - leftover_filtered += result.leftover_filtered - raw_rewards_sum += result.raw_rewards_sum - raw_rewards_count += result.raw_rewards_count - produced_samples += result.produced_samples - produced_tokens += result.produced_tokens - produce_time_s += result.produce_time_s - if result.group_gen_count is not None and result.group_gen_mean_s is not None: - total_group_count += result.group_gen_count - weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s - weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0) - weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0) - weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0) - total_pause_time_s += result.group_gen_pause_time_s or 0.0 - - aggregated = ProduceBatchResult( - rollout_states=rollout_states, - leftover_init=leftover_init, - leftover_completed=leftover_completed, - leftover_aborted=leftover_aborted, - leftover_expired=leftover_expired, - leftover_failed=leftover_failed, - leftover_filtered=leftover_filtered, - raw_rewards_sum=raw_rewards_sum, - raw_rewards_count=raw_rewards_count, - produced_samples=produced_samples, - produced_tokens=produced_tokens, - produce_time_s=produce_time_s, - task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, - ) - if total_group_count > 0: - aggregated.group_gen_count = total_group_count - aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count - aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count - aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count - aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count - aggregated.group_gen_pause_time_s = total_pause_time_s - return aggregated - - async def _produce_batch_to_buffer( - self, - task_batch_sizes: dict[str, int], - progress: ProduceProgress, - ) -> ProduceBatchStatus: - current_future_step = progress.producer_future_step - model_step = self._model_step - expired_tasks = [ - task.task_name - for task in self.task_runners - if task.produce_strategy.is_model_expired(current_future_step, model_step) - ] - if expired_tasks: - self.logger.info( - f"[AgentLoopManager][{self.name}] EXPIRED_BATCH: " - f"future_step={current_future_step}, tasks={expired_tasks}" - ) - return ProduceBatchStatus.EXPIRED_BATCH - - active_tasks = [task for task in self.task_runners if progress.target_samples[task.task_name] > 0] - assert active_tasks, "No active tasks found" - - produce_start = time.perf_counter() - try: - statuses = await asyncio.gather( - *[ - task.produce_strategy.produce_batch( - _build_produce_context( - task, - self.replay_buffer, - task_batch_sizes[task.task_name], - current_future_step, - model_step, - self._update_event, - progress, - ) - ) - for task in active_tasks - ] - ) - finally: - progress.add_produce_time(time.perf_counter() - produce_start) - return _aggregate_status(statuses) - - async def pause_produce( - self, - *, - use_global_progress: bool, - progress: ProduceProgress | None = None, - ) -> float: - # 这是 producer 的“显式刹车”接口。 - # - # 设计动机: - # - 旧 colocate 语义里,一次 produce_batch() 结束后就自然收尾; - # - 非共卡后,producer 可能在后台持续运行,何时停下来必须交给 trainer 明确控制。 - # - # 因此调用方必须显式说明是否使用全局 progress: - # - use_global_progress=True:非共卡后台生产循环在权重同步点前暂停; - # - use_global_progress=False:共卡同步 produce_batch 的本次调用收尾,使用本地 progress。 - # 返回值 `pause_time_s` 不是业务语义,而是日志/诊断信息, - # 供训练侧在下一次消费 batch 时上报。 - # use_global_progress=False 模式会在下一次 produce_batch 入口通过 continue_produce 恢复; - # use_global_progress=True 模式则由 trainer 在权重同步和评测完成后显式恢复。 - if use_global_progress: - if progress is not None: - raise ValueError("progress must not be provided when use_global_progress=True.") - pause_progress = self._produce_progress - else: - if progress is None: - raise ValueError("progress must be provided when use_global_progress=False.") - pause_progress = progress - - # 合法参数确认后,统一拉起 manager 级暂停信号,阻止仍在运行的 produce_batch 继续调度新 rollout。 - self._update_event.set() - self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT - - # 必须先让 producer / strategy 看到暂停状态,再暂停 rollout controller,避免暂停过程中继续调度新请求。 - rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) - await rollout_ctl.pause_generation.remote() # type: ignore[attr-defined] - - pause_time_s = 0.0 - for task in self.task_runners: - ctx = _build_produce_context( - task, - self.replay_buffer, - 0, - pause_progress.producer_future_step, - self._model_step, - self._update_event, - pause_progress, - ) - pause_time_s += await task.produce_strategy.pause_produce( - ctx, - ) - self._pause_time_s = pause_time_s - return pause_time_s - - def _log_buffer_counts( - self, - task_batch_sizes: dict[str, int], - batch_by_task: dict[str, list[list[RolloutState]]], - leftover_counts: dict[str, dict[Status, int]], - ) -> None: - for task in self.task_runners: - task_name = task.task_name - task_counts = leftover_counts.get(task_name, {}) - self.logger.info( - f"[AgentLoopManager][{self.name}] get_batch from buffer for task {task_name}: " - f"requested={task_batch_sizes[task_name]}, retrieved={len(batch_by_task.get(task_name, []))}, " - f"leftover_init={task_counts.get(Status.INIT, 0)}, " - f"leftover_completed={task_counts.get(Status.COMPLETED, 0)}, " - f"leftover_aborted={task_counts.get(Status.ABORTED, 0)}, " - f"leftover_expired={task_counts.get(Status.EXPIRED, 0)}, " - f"leftover_failed={task_counts.get(Status.FAILED, 0)}, " - f"leftover_filtered={task_counts.get(Status.FILTERED, 0)}" - ) - - def _build_result_from_batch( - self, - task_batch_sizes: dict[str, int], - batch_by_task: dict[str, list[list[RolloutState]]], - leftover_counts: dict[str, dict[Status, int]], - *, - progress: ProduceProgress, - pause_time_s: float, - ) -> ProduceBatchResult: - if len(self.task_runners) == 1: - task = self.task_runners[0] - raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) - produced_samples, produced_tokens = progress.consume_produced(task.task_name) - produce_time_s = progress.consume_produce_time() - result = ProduceBatchResult( - rollout_states=batch_by_task.get(task.task_name, []), - raw_rewards_sum=raw_rewards_sum, - raw_rewards_count=raw_rewards_count, - produced_samples=produced_samples, - produced_tokens=produced_tokens, - produce_time_s=produce_time_s, - ) - _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) - _fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s) - return result - - task_results: dict[str, ProduceBatchResult] = {} - produce_time_s = progress.consume_produce_time() - for task in self.task_runners: - raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) - produced_samples, produced_tokens = progress.consume_produced(task.task_name) - result = ProduceBatchResult( - rollout_states=batch_by_task.get(task.task_name, []), - raw_rewards_sum=raw_rewards_sum, - raw_rewards_count=raw_rewards_count, - produced_samples=produced_samples, - produced_tokens=produced_tokens, - ) - _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) - task_results[task.task_name] = result - - ordered_tasks = sorted(self.task_runners, key=lambda task: (task.task_name, task.order)) - aggregated = self._aggregate_task_results(ordered_tasks, task_results) - aggregated.produce_time_s = produce_time_s - aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks} - _fill_group_timing_stats(aggregated, aggregated.rollout_states, pause_time_s=pause_time_s) - return aggregated - - async def _get_batch_from_buffer( - self, - *, - batch_size: int, - task_batch_sizes: dict[str, int], - consume_progress: ProduceProgress, - ) -> ProduceBatchResult: - pause_time_s = self._pause_time_s - self._pause_time_s = 0.0 - - self._validate_task_batch_sizes(task_batch_sizes, batch_size) - batch_by_task, consumed_counts = await self.replay_buffer.take_batch(task_batch_sizes) - consume_progress.mark_consumed(consumed_counts) - leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES) - self._log_buffer_counts(task_batch_sizes, batch_by_task, leftover_counts) - return self._build_result_from_batch( - task_batch_sizes, - batch_by_task, - leftover_counts, - progress=consume_progress, - pause_time_s=pause_time_s, - ) - - async def continue_produce(self, model_step: int) -> None: - # - # 它和 pause_produce(use_global_progress=True) 是一对: - # - pause_produce(...) 负责让 producer 停下来; - # - continue_produce(...) 负责在同步/评测完成后解除暂停。 - # - # 这里同步更新 `_model_step`,表示 rollout 侧接下来生成样本时, - # 应把“当前正在使用的是哪一版权重”记录成这个版本号。 - self._model_step = model_step - rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) - await rollout_ctl.continue_generation.remote() # type: ignore[attr-defined] - # rollout controller 真正恢复后,再把 manager 暴露成 NORMAL,produce_loop 才能继续生产。 - self._status = AgentLoopManagerStatus.NORMAL - self._update_event.clear() - - def shutdown(self) -> None: - # 公开收口后台 producer 的退出信号,避免 trainer 直接写 manager 私有状态。 - self._status = AgentLoopManagerStatus.FINISH - self._update_event.set() - self._finish_event.set() - - async def _wait_for_status_exit(self, blocked_status: AgentLoopManagerStatus) -> None: - while not self._finish_event.is_set() and self._status == blocked_status: - await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + self.logger = get_logger() if logger is None else logger + self.task_names = [task.task_name for task in task_runners] async def produce_batch( self, @@ -767,52 +210,83 @@ async def produce_batch( *, model_step: int, ) -> ProduceBatchResult: - # `produce_batch()` 是保留给 colocate 路径的同步入口。 - # - # 它虽然名字没变,但内部已经改成三段式: - # 1. `_produce_batch_to_buffer()` 只负责生产,把结果写入 replay buffer - # 2. `pause_produce()` 显式收尾 pending rollout - # 3. `_get_batch_from_buffer()` 再把训练 batch 取出来 - # - # 这也是为什么这里要求返回非空 batch: - # - colocate 语义下,调用它就是为了拿一批可训练 completed groups - # - 如果需要合法返回空 batch + 特殊状态,那应该走 disagg 的 `get_batch()` + # 共卡同步入口:生产入 buffer -> pause/drain 本轮 pending -> 取非空训练 batch。 if batch_size <= 0: raise ValueError(f"produce_batch expects batch_size > 0, got {batch_size}") start = time.perf_counter() self.logger.info( f"[AgentLoopManager][{self.name}] Start produce_batch: train_step={train_step} model_step={model_step} batch_size={batch_size}" ) - current_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) + current_sizes = allocate_task_batch_sizes(self.task_runners, batch_size, train_step) active_tasks = [task for task in self.task_runners if current_sizes[task.task_name] > 0] assert active_tasks, "No active tasks found" - # 共卡路径不复用非共卡的 paused producer 状态机。 - # 即使 manager 是从 resume() 恢复出来、当前仍处在 UPDATE_WEIGHT_AND_ABORT, - # produce_batch() 也应视作一次独立的同步生产过程,从干净状态开始。 - # - # 共卡路径下,produce_batch() 对应 rollout worker 当前持有的权重版本。 - await self.continue_produce(model_step=model_step) - local_progress = ProduceProgress.build_local(self.task_names, current_sizes, train_step) - status = ProduceBatchStatus.NORMAL - try: - # 共卡 produce_batch 也是消费入口;生产前先刷新 buffer 中已有 completed / aborted。 - await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) - status = await self._produce_batch_to_buffer( - task_batch_sizes=current_sizes, - progress=local_progress, + await self._rollout_controller.continue_generation.remote() # type: ignore[attr-defined] + local_progress = ProduceProgress.build( + task_names=self.task_names, + target_samples=current_sizes, + ) + # 生产前刷新已有 completed / aborted 的 staleness。 + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + logger=self.logger, + manager_name=self.name, + train_step=train_step, + statuses=[Status.COMPLETED, Status.ABORTED], + ) + produce_start = time.perf_counter() + produce_futures = [] + for task in active_tasks: + produce_strategy = cast(ProduceStrategy, task.produce_strategy) + produce_futures.append( + produce_strategy.produce_batch( + ProduceContext( + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=current_sizes[task.task_name], + task_name=task.task_name, + train_step=train_step, + model_step=model_step, + progress=local_progress, + is_valid_sample_fn=task.is_valid_sample_fn, + stale_threshold=task.stale_threshold, + ) + ) ) - finally: - await self.pause_produce( - use_global_progress=False, - progress=local_progress, + await asyncio.gather(*produce_futures) + local_progress.add_produce_time(time.perf_counter() - produce_start) + + # pause 只收尾本轮本地 pending。 + await self._rollout_controller.pause_generation.remote() # type: ignore[attr-defined] + + pause_time_s = 0.0 + for task in active_tasks: + produce_strategy = cast(ProduceStrategy, task.produce_strategy) + pause_time_s += await produce_strategy.pause_produce( + ProduceContext( + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=0, + task_name=task.task_name, + train_step=train_step, + model_step=model_step, + progress=local_progress, + is_valid_sample_fn=task.is_valid_sample_fn, + stale_threshold=task.stale_threshold, + ) ) - result = await self._get_batch_from_buffer( - batch_size=batch_size, + result = await take_train_batch( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + logger=self.logger, + manager_name=self.name, task_batch_sizes=current_sizes, - consume_progress=local_progress, + progress=local_progress, + pause_time_s=pause_time_s, ) - result.status = status assert result.rollout_states, ( "AgentLoopManager.produce_batch() must return non-empty rollout_states for colocated training. " "Use get_batch() for disaggregated empty/expired reads." @@ -824,142 +298,6 @@ async def produce_batch( ) return result - async def produce_loop(self, batch_size: int) -> None: - # `produce_loop()` 是非共卡新增的后台生产循环。 - # batch_size 表示每个 future train_step 的目标生产规模;producer 需要它来推进累计目标, - # 所以这个参数保留在后台生产入口,而不是从 get_batch() 的消费请求里推断。 - # - # 和 colocate 最大的区别是: - # - 它不直接把 batch 返回给 trainer - # - 它只是持续把样本“喂”进 replay buffer - # - trainer 前台通过 `get_batch()` 异步消费 - # - # 因此这里的核心职责不是“凑出一批训练数据”,而是根据 manager 的全局状态机 - # 决定什么时候继续生产、什么时候暂停等待、什么时候彻底退出。 - while not self._finish_event.is_set(): - if self._status == AgentLoopManagerStatus.FINISH: - break - if self._status in (AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT, AgentLoopManagerStatus.EXPIRED_BATCH): - # 同步前主动暂停和模型过期都只能由 trainer 调用 continue_produce() 恢复。 - await self._wait_for_status_exit(self._status) - continue - - task_batch_sizes = self._produce_progress.ensure_target_upto( - batch_size=batch_size, - future_step=self._produce_progress.producer_future_step, - allocate_batch_sizes=self._get_task_batch_sizes_for_step, - ) - produce_status = await self._produce_batch_to_buffer( - task_batch_sizes=task_batch_sizes, - progress=self._produce_progress, - ) - - if produce_status == ProduceBatchStatus.EXPIRED_BATCH: - # 注意: - # - EXPIRED_BATCH 是 producer 在生产过程中自己检测出来的“立即停下”信号 - # - UPDATE_WEIGHT_AND_ABORT 则是 trainer 在同步前通过 pause_produce() 主动设置的 - self._status = AgentLoopManagerStatus.EXPIRED_BATCH - elif produce_status == ProduceBatchStatus.NORMAL: - # 只有正常完成一轮生产时,producer 自己维护的 train_step 才前进一步。 - self._produce_progress.advance_future_step() - - # 主动让出事件循环,避免 fake strategy / 极快路径在测试里造成忙等空转。 - await asyncio.sleep(0) - - async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult: - # `get_batch()` 是非共卡路径给 trainer 的消费接口。 - # - # 设计上它和 `produce_batch()` 明确分工: - # - `produce_batch()`:colocate,一次调用内完成“生产+收尾+取数” - # - `get_batch()`:disagg,等待 replay buffer 准备好当前训练步所需 batch 后再取数 - # - # 因而这里允许返回空 batch 的唯一合法场景是: - # - manager 已进入 EXPIRED_BATCH - # - 当前训练侧已有比 rollout 侧更新的 Model Step,可以直接同步过去 - # 如果没有更新的模型版本,则要么消费当前已准备好的 batch,要么 fail fast 暴露不变量破坏。 - progress = self._produce_progress - progress.begin_consume(train_step) - await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) - task_batch_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) - current_model_step = train_step - 1 - - while not self._finish_event.is_set(): - if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: - # 只有训练侧已经有更新的 Model Step,空 expired 才能跳过训练并直接同步。 - if current_model_step > self._model_step: - pause_time_s = self._pause_time_s - self._pause_time_s = 0.0 - result = ProduceBatchResult( - rollout_states=[], - status=ProduceBatchStatus.EXPIRED_BATCH, - ) - if pause_time_s > 0: - result.group_gen_pause_time_s = pause_time_s - return result - # 没有更新模型且当前 batch 不 ready 时,producer 已停且无法靠同步恢复,必须立即暴露不变量。 - if not await self.replay_buffer.is_ready(task_batch_sizes): - leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES) - raise RuntimeError( - "AgentLoopManager reached EXPIRED_BATCH without a newer model or a ready batch: " - f"train_step={train_step}, current_model_step={current_model_step}, " - f"rollout_model_step={self._model_step}, manager_status={self._status.name}, " - f"producer_future_step={progress.producer_future_step}, " - f"next_consumer_step={progress.next_consumer_step}, " - f"target_upto_future_step={progress.target_upto_future_step}, " - f"target_samples={progress.target_samples}, " - f"consumed_samples={progress.consumed_samples}, " - f"task_batch_sizes={task_batch_sizes}, " - f"leftover_status_counts={leftover_counts}" - ) - if await self.replay_buffer.is_ready(task_batch_sizes): - result = await self._get_batch_from_buffer( - batch_size=batch_size, - task_batch_sizes=task_batch_sizes, - consume_progress=progress, - ) - if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: - # expired 但带数据表示 trainer 仍需完成本 step,再用新 Model Step 恢复 producer。 - result.status = ProduceBatchStatus.EXPIRED_BATCH - if result.rollout_states: - progress.finish_consume(train_step) - await self._refresh_for_all_tasks(train_step + 1, [Status.COMPLETED, Status.ABORTED]) - return result - await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) - - return ProduceBatchResult(rollout_states=[]) - - def _task_checkpoint_path(self, checkpoint_path: Path | str, task_name: str) -> Path: - checkpoint_path = Path(checkpoint_path) - return checkpoint_path / self._TASK_CHECKPOINT_DIR / task_name - - def _manager_state_path(self, checkpoint_path: Path | str) -> Path: - checkpoint_path = Path(checkpoint_path) - return checkpoint_path / self._MANAGER_STATE_PATH - - def _progress_state_without_replay_buffer(self, progress_state: dict) -> dict: - progress_state = dict(progress_state) - consumed_samples = dict(progress_state["consumed_samples"]) - task_names = list(consumed_samples) - next_consumer_step = int(progress_state["next_consumer_step"]) - - progress_state["producer_future_step"] = next_consumer_step - progress_state["target_samples"] = dict(consumed_samples) - progress_state["target_upto_future_step"] = max(0, next_consumer_step - 1) - progress_state["raw_rewards_sum"] = {task_name: 0.0 for task_name in task_names} - progress_state["raw_rewards_count"] = {task_name: 0 for task_name in task_names} - progress_state["produced_samples"] = {task_name: 0 for task_name in task_names} - progress_state["produced_tokens"] = {task_name: 0 for task_name in task_names} - progress_state["produce_time_s"] = 0.0 - return progress_state - - def _get_pending_task_counts(self) -> dict[str, int]: - pending_task_counts: dict[str, int] = {} - for task in self.task_runners: - pending_count = task.produce_strategy.pending_task_count() - if pending_count > 0: - pending_task_counts[task.task_name] = pending_count - return pending_task_counts - async def save( self, checkpoint_path: Path | str, @@ -970,34 +308,27 @@ async def save( """Save all task sampler states and the shared replay buffer.""" checkpoint_path = Path(checkpoint_path) checkpoint_path.mkdir(parents=True, exist_ok=True) - pending_task_counts = self._get_pending_task_counts() + pending_task_counts = get_pending_task_counts(self.task_runners) if pending_task_counts: raise RuntimeError( "Cannot save AgentLoopManager while pending rollout tasks still exist: " - f"{pending_task_counts}. Call pause_produce() first." + f"{pending_task_counts}. Finish the current produce_batch before saving." ) - # 保存前显式记录当前 checkpoint 对应的模型步数,resume 时直接恢复这一份状态。 - self._model_step = model_step for task in self.task_runners: - task_checkpoint_path = self._task_checkpoint_path(checkpoint_path, task.task_name) - task_checkpoint_path.mkdir(parents=True, exist_ok=True) - task.sampler.save(task_checkpoint_path) + checkpoint_dir = task_checkpoint_path(checkpoint_path, task.task_name) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + task.sampler.save(checkpoint_dir) # manager 层保持 async 语义;同步入口只允许在 trainer 边界用 asyncio_run 包起来。 if no_save_replay_buffer: self.logger.info(f"Skip saving replay buffer to {checkpoint_path}") else: await self.replay_buffer.save(checkpoint_path) - manager_state_path = self._manager_state_path(checkpoint_path) - progress_state = self._produce_progress.state_dict() - if no_save_replay_buffer: - progress_state = self._progress_state_without_replay_buffer(progress_state) - with manager_state_path.open("w") as f: + state_path = manager_state_path(checkpoint_path) + with state_path.open("w") as f: json.dump( { - "status": self._status.name, - "model_step": self._model_step, + "model_step": model_step, "replay_buffer_saved": not no_save_replay_buffer, - **progress_state, }, f, ) @@ -1006,10 +337,10 @@ async def resume(self, checkpoint_path: Path | str) -> int: """Resume all task sampler states and the shared replay buffer.""" checkpoint_path = Path(checkpoint_path) for task in self.task_runners: - task.sampler.resume(self._task_checkpoint_path(checkpoint_path, task.task_name)) + task.sampler.resume(task_checkpoint_path(checkpoint_path, task.task_name)) - manager_state_path = self._manager_state_path(checkpoint_path) - with manager_state_path.open("r") as f: + state_path = manager_state_path(checkpoint_path) + with state_path.open("r") as f: manager_state = json.load(f) if manager_state.get("replay_buffer_saved", True): # replay buffer 恢复是 async I/O,不能在已有 event loop 中再次嵌套 asyncio_run。 @@ -1018,13 +349,4 @@ async def resume(self, checkpoint_path: Path | str) -> int: raise RuntimeError("Cannot resume without replay buffer checkpoint into a non-empty buffer") else: self.logger.info(f"Skip replay buffer resume for checkpoint without replay buffer: {checkpoint_path}") - saved_model_step = manager_state["model_step"] - self._produce_progress.load_state_dict(manager_state) - - self._update_event = asyncio.Event() - self._finish_event = asyncio.Event() - self._update_event.set() - self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT - self._pause_time_s = 0.0 - self._model_step = saved_model_step - return saved_model_step + return manager_state["model_step"] diff --git a/xtuner/v1/rl/agent_loop_manager/disagg_agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/disagg_agent_loop_manager.py new file mode 100644 index 0000000000..e73dc63da3 --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/disagg_agent_loop_manager.py @@ -0,0 +1,475 @@ +import asyncio +import json +import time +from enum import Enum, auto +from pathlib import Path +from typing import Any, cast + +from pydantic import BaseModel, ConfigDict, Field + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.agent_loop import AgentLoopConfig +from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.utils import get_logger + +from .disagg_producer import ( + DisaggAsyncProduceStrategyConfig, + DisaggProduceContext, + DisaggProduceProgress, + DisaggProduceStrategy, + DisaggProduceStrategyConfig, +) +from .produce_utils import ( + _LEFTOVER_STATUSES, + _MANAGER_STATE_PATH, + _STATUS_POLL_INTERVAL_S, + _TASK_CHECKPOINT_DIR, + ProduceBatchResult, + ProduceBatchStatus, + _TaskRunner, + _TaskSamplerView, + allocate_task_batch_sizes, + get_pending_task_counts, + manager_state_path, + refresh_for_all_tasks, + take_train_batch, + task_checkpoint_path, +) +from .sampler import Sampler, SamplerConfig + + +class DisaggTaskSpecConfig(BaseModel): + """单个非共卡 RL 数据源配置。""" + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + task_name: str + weight: float = Field(default=1.0, ge=0.0) + agent_loop_config: AgentLoopConfig + judger_config: JudgerConfig | ComposedJudgerConfig | None = None + produce_strategy_config: DisaggProduceStrategyConfig = DisaggAsyncProduceStrategyConfig() + sampler_config: SamplerConfig + + +class DisaggAgentLoopManagerConfig(BaseModel): + """非共卡 rollout 后台生产侧配置。""" + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + tasks: list[DisaggTaskSpecConfig] | DisaggTaskSpecConfig + + def build( + self, + rollout_controller: RolloutController, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + replay_buffer: ReplayBuffer, + logger=None, + sync_weights_interval: int = 1, + ) -> "DisaggAgentLoopManager": + tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks] + if not tasks: + raise ValueError("DisaggAgentLoopManagerConfig requires at least one task config.") + + seen_task_names: set[str] = set() + task_runners: list[_TaskRunner] = [] + for order, task_cfg in enumerate(tasks): + if task_cfg.task_name in seen_task_names: + raise ValueError(f"Duplicate task_name found in DisaggAgentLoopManagerConfig: {task_cfg.task_name}") + seen_task_names.add(task_cfg.task_name) + + agent_loop = task_cfg.agent_loop_config.build( + rollout_controller=rollout_controller, + judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, + logger=logger, + ) + produce_strategy = task_cfg.produce_strategy_config.build( + sync_weights_interval=sync_weights_interval, + rollout_controller=rollout_controller, + ) + sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) + task_runners.append( + _TaskRunner( + task_name=task_cfg.task_name, + agent_loop=agent_loop, + produce_strategy=produce_strategy, + sampler=sampler, + weight=task_cfg.weight, + order=order, + ) + ) + + return DisaggAgentLoopManager( + task_runners=task_runners, + replay_buffer=replay_buffer, + rollout_controller=rollout_controller, + logger=logger, + ) + + +class AgentLoopManagerStatus(Enum): + """AgentLoopManager 的全局状态. + + 按下面的路径流转: + - 初始状态是 NORMAL + - NORMAL -> UPDATE_WEIGHT_AND_ABORT + - trainer 开始做权重同步前触发 + - UPDATE_WEIGHT_AND_ABORT -> NORMAL + - 权重同步完成后调用 continue_product() + - NORMAL -> EXPIRED_BATCH + - 当前 rollout model 已经过旧 + - EXPIRED_BATCH -> UPDATE_WEIGHT_AND_ABORT + - trainer 检测到过期后,进入权重同步阶段 + - 任意状态 -> FINISH + - 训练结束 + + 这里有一个重要区分: + - AgentLoopManagerStatus 是“后台 producer 的全局运行状态” + - ProduceBatchStatus 是“单次调度调用的局部结果” + """ + + NORMAL = auto() + UPDATE_WEIGHT_AND_ABORT = auto() + EXPIRED_BATCH = auto() + FINISH = auto() + + +def _aggregate_status(statuses: list[ProduceBatchStatus]) -> ProduceBatchStatus: + if any(status == ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT for status in statuses): + return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + if any(status == ProduceBatchStatus.EXPIRED_BATCH for status in statuses): + return ProduceBatchStatus.EXPIRED_BATCH + return ProduceBatchStatus.NORMAL + + +class DisaggAgentLoopManager: + """非共卡后台 producer / 前台 consumer 状态机。""" + + _TASK_CHECKPOINT_DIR = _TASK_CHECKPOINT_DIR + _MANAGER_STATE_PATH = _MANAGER_STATE_PATH + _STATUS_POLL_INTERVAL_S = _STATUS_POLL_INTERVAL_S + task_runners: list[_TaskRunner] + replay_buffer: ReplayBuffer + _rollout_controller: RolloutController + data_sampler: Sampler | _TaskSamplerView + name: str + logger: Any + task_names: list[str] + + def __init__( + self, + task_runners: list[_TaskRunner], + replay_buffer: ReplayBuffer, + rollout_controller: RolloutController, + logger=None, + ): + if not task_runners: + raise ValueError("DisaggAgentLoopManager requires at least one task runner.") + if sum(task.weight for task in task_runners) <= 0: + raise ValueError("At least one task weight must be positive for DisaggAgentLoopManager.") + + self.task_runners = task_runners + self.replay_buffer = replay_buffer + self._rollout_controller = rollout_controller + self.data_sampler = ( + task_runners[0].sampler + if len(task_runners) == 1 + else _TaskSamplerView([task.sampler for task in task_runners]) + ) + self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task" + self.logger = get_logger() if logger is None else logger + self.task_names = [task.task_name for task in task_runners] + + # consumer 同步权重前置位;producer / strategy 直接观察 event。 + self._update_event = asyncio.Event() + self._finish_event = asyncio.Event() + + # rollout 侧当前模型版本;pause 清空 pending 后才能更新。 + self._model_step = 0 + + # 跨 await 直接读 self._status,避免错过状态变化。 + self._status = AgentLoopManagerStatus.NORMAL + + # pause_produce 写入,下一次 get_batch 消费并清零。 + self._pause_time_s = 0.0 + + # producer / consumer 共享绝对进度;对象引用保持稳定。 + self._produce_progress = DisaggProduceProgress.build(self.task_names) + + def _consume_pause_time(self) -> float: + pause_time_s = self._pause_time_s + self._pause_time_s = 0.0 + return pause_time_s + + async def _produce_batch_to_buffer( + self, + task_batch_sizes: dict[str, int], + progress: DisaggProduceProgress, + ) -> ProduceBatchStatus: + producer_train_step = progress.producer_future_step + expired_tasks = [] + for task in self.task_runners: + produce_strategy = cast(DisaggProduceStrategy, task.produce_strategy) + if produce_strategy.is_model_expired(producer_train_step, self._model_step): + expired_tasks.append(task.task_name) + if expired_tasks: + self.logger.info( + f"[DisaggAgentLoopManager][{self.name}] EXPIRED_BATCH: " + f"future_step={producer_train_step}, tasks={expired_tasks}" + ) + return ProduceBatchStatus.EXPIRED_BATCH + + active_tasks = [task for task in self.task_runners if progress.target_samples[task.task_name] > 0] + assert active_tasks, "No active tasks found" + + produce_start = time.perf_counter() + produce_futures = [] + for task in active_tasks: + produce_strategy = cast(DisaggProduceStrategy, task.produce_strategy) + produce_futures.append( + produce_strategy.produce_batch( + DisaggProduceContext( + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=task_batch_sizes[task.task_name], + task_name=task.task_name, + train_step=producer_train_step, + model_step=self._model_step, + progress=progress, + update_event=self._update_event, + is_valid_sample_fn=task.is_valid_sample_fn, + stale_threshold=task.stale_threshold, + ) + ) + ) + produce_status = _aggregate_status(await asyncio.gather(*produce_futures)) + progress.add_produce_time(time.perf_counter() - produce_start) + return produce_status + + async def pause_produce(self) -> float: + # 非共卡显式刹车;共卡没有 public pause。 + self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT + self._update_event.set() + await self._rollout_controller.pause_generation.remote() # type: ignore[attr-defined] + + pause_time_s = 0.0 + for task in self.task_runners: + produce_strategy = cast(DisaggProduceStrategy, task.produce_strategy) + ctx = DisaggProduceContext( + agent_loop=task.agent_loop, + sampler=task.sampler, + replay_buffer=self.replay_buffer, + task_batch_size=0, + task_name=task.task_name, + train_step=self._produce_progress.producer_future_step, + model_step=self._model_step, + progress=self._produce_progress, + update_event=self._update_event, + is_valid_sample_fn=task.is_valid_sample_fn, + stale_threshold=task.stale_threshold, + ) + pause_time_s += await produce_strategy.pause_produce(ctx) + self._pause_time_s = pause_time_s + return pause_time_s + + async def continue_produce(self, model_step: int) -> None: + # 与 pause_produce 成对:同步/评测完成后,用新 model_step 恢复后台 producer。 + self._model_step = model_step + await self._rollout_controller.continue_generation.remote() # type: ignore[attr-defined] + self._status = AgentLoopManagerStatus.NORMAL + self._update_event.clear() + + def shutdown(self) -> None: + self._status = AgentLoopManagerStatus.FINISH + self._update_event.set() + self._finish_event.set() + + async def _wait_for_status_exit(self, blocked_status: AgentLoopManagerStatus) -> None: + while not self._finish_event.is_set() and self._status == blocked_status: + await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + + async def produce_loop(self, batch_size: int) -> None: + # 后台持续生产;前台通过 get_batch 消费。 + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.FINISH: + break + if self._status in (AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT, AgentLoopManagerStatus.EXPIRED_BATCH): + # 暂停/过期只能由 trainer 调用 continue_produce 恢复。 + await self._wait_for_status_exit(self._status) + continue + + task_batch_sizes = self._produce_progress.ensure_target_upto( + batch_size=batch_size, + future_step=self._produce_progress.producer_future_step, + allocate_batch_sizes=lambda current_batch_size, future_step: allocate_task_batch_sizes( + self.task_runners, + current_batch_size, + future_step, + ), + ) + produce_status = await self._produce_batch_to_buffer(task_batch_sizes, self._produce_progress) + + if produce_status == ProduceBatchStatus.EXPIRED_BATCH: + # EXPIRED_BATCH 是 producer 自己检测出来的“立即停下”信号。 + self._status = AgentLoopManagerStatus.EXPIRED_BATCH + elif produce_status == ProduceBatchStatus.NORMAL: + # 只有正常完成一轮生产时,producer 自己维护的 train_step 才前进一步。 + self._produce_progress.advance_future_step() + + # 极快路径下主动让出事件循环。 + await asyncio.sleep(0) + + async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult: + # 非共卡消费入口;空 batch 只表示已过期且已有更新模型可同步。 + progress = self._produce_progress + progress.begin_consume(train_step) + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + logger=self.logger, + manager_name=self.name, + train_step=train_step, + statuses=[Status.COMPLETED, Status.ABORTED], + ) + task_batch_sizes = allocate_task_batch_sizes(self.task_runners, batch_size, train_step) + current_model_step = train_step - 1 + + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + if current_model_step > self._model_step: + pause_time_s = self._consume_pause_time() + result = ProduceBatchResult( + rollout_states=[], + status=ProduceBatchStatus.EXPIRED_BATCH, + ) + if pause_time_s > 0: + result.group_gen_pause_time_s = pause_time_s + return result + # producer 已停且没有新模型可同步,立即暴露坏状态。 + if not await self.replay_buffer.is_ready(task_batch_sizes): + leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES) + raise RuntimeError( + "AgentLoopManager reached EXPIRED_BATCH without a newer model or a ready batch: " + f"train_step={train_step}, current_model_step={current_model_step}, " + f"rollout_model_step={self._model_step}, manager_status={self._status.name}, " + f"producer_future_step={progress.producer_future_step}, " + f"next_consumer_step={progress.next_consumer_step}, " + f"target_upto_future_step={progress.target_upto_future_step}, " + f"target_samples={progress.target_samples}, " + f"consumed_samples={progress.consumed_samples}, " + f"task_batch_sizes={task_batch_sizes}, " + f"leftover_status_counts={leftover_counts}" + ) + if await self.replay_buffer.is_ready(task_batch_sizes): + result = await take_train_batch( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + logger=self.logger, + manager_name=self.name, + task_batch_sizes=task_batch_sizes, + progress=progress, + pause_time_s=self._consume_pause_time(), + ) + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + # 有数据的 expired batch 仍需训练本 step。 + result.status = ProduceBatchStatus.EXPIRED_BATCH + if result.rollout_states: + progress.finish_consume(train_step) + await refresh_for_all_tasks( + task_runners=self.task_runners, + replay_buffer=self.replay_buffer, + logger=self.logger, + manager_name=self.name, + train_step=train_step + 1, + statuses=[Status.COMPLETED, Status.ABORTED], + ) + return result + await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + + return ProduceBatchResult(rollout_states=[]) + + def _progress_state_without_replay_buffer(self, progress_state: dict) -> dict: + progress_state = dict(progress_state) + consumed_samples = dict(progress_state["consumed_samples"]) + task_names = list(consumed_samples) + next_consumer_step = int(progress_state["next_consumer_step"]) + + progress_state["producer_future_step"] = next_consumer_step + progress_state["target_samples"] = dict(consumed_samples) + progress_state["target_upto_future_step"] = max(0, next_consumer_step - 1) + progress_state["raw_rewards_sum"] = {task_name: 0.0 for task_name in task_names} + progress_state["raw_rewards_count"] = {task_name: 0 for task_name in task_names} + progress_state["produced_samples"] = {task_name: 0 for task_name in task_names} + progress_state["produced_tokens"] = {task_name: 0 for task_name in task_names} + progress_state["produce_time_s"] = 0.0 + return progress_state + + async def save( + self, + checkpoint_path: Path | str, + model_step: int, + *, + no_save_replay_buffer: bool = False, + ) -> None: + """保存非共卡 sampler、replay buffer 和后台生产进度。""" + checkpoint_path = Path(checkpoint_path) + checkpoint_path.mkdir(parents=True, exist_ok=True) + pending_task_counts = get_pending_task_counts(self.task_runners) + if pending_task_counts: + raise RuntimeError( + "Cannot save AgentLoopManager while pending rollout tasks still exist: " + f"{pending_task_counts}. Call pause_produce() first." + ) + self._model_step = model_step + for task in self.task_runners: + checkpoint_dir = task_checkpoint_path(checkpoint_path, task.task_name) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + task.sampler.save(checkpoint_dir) + if no_save_replay_buffer: + self.logger.info(f"Skip saving replay buffer to {checkpoint_path}") + else: + await self.replay_buffer.save(checkpoint_path) + state_path = manager_state_path(checkpoint_path) + progress_state = self._produce_progress.state_dict() + if no_save_replay_buffer: + progress_state = self._progress_state_without_replay_buffer(progress_state) + with state_path.open("w") as f: + json.dump( + { + "status": self._status.name, + "model_step": self._model_step, + "replay_buffer_saved": not no_save_replay_buffer, + **progress_state, + }, + f, + ) + + async def resume(self, checkpoint_path: Path | str) -> int: + """恢复非共卡 sampler、replay buffer 和后台生产进度。""" + checkpoint_path = Path(checkpoint_path) + for task in self.task_runners: + task.sampler.resume(task_checkpoint_path(checkpoint_path, task.task_name)) + + state_path = manager_state_path(checkpoint_path) + with state_path.open("r") as f: + manager_state = json.load(f) + if manager_state.get("replay_buffer_saved", True): + # replay buffer 恢复是 async I/O,不能在已有 event loop 中再次嵌套 asyncio_run。 + await self.replay_buffer.resume(checkpoint_path) + elif len(self.replay_buffer) > 0: + raise RuntimeError("Cannot resume without replay buffer checkpoint into a non-empty buffer") + else: + self.logger.info(f"Skip replay buffer resume for checkpoint without replay buffer: {checkpoint_path}") + saved_model_step = manager_state["model_step"] + self._produce_progress.load_state_dict(manager_state) + + self._update_event = asyncio.Event() + self._finish_event = asyncio.Event() + self._update_event.set() + self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT + self._pause_time_s = 0.0 + self._model_step = saved_model_step + return saved_model_step diff --git a/xtuner/v1/rl/agent_loop_manager/disagg_producer.py b/xtuner/v1/rl/agent_loop_manager/disagg_producer.py new file mode 100644 index 0000000000..90422eae84 --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/disagg_producer.py @@ -0,0 +1,409 @@ +import asyncio +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.utils import calculate_seq_staleness, create_task +from xtuner.v1.utils import get_logger + +from .produce_utils import ( + PERIODIC_ABORT_INTERVAL_S, + BaseProduceContext, + IsValidSampleFn, + ProduceBatchStatus, + ShouldContinueFn, + _PendingTasks, + _ProgressDisplayer, + _put_claimed_tasks, + calculate_stale_threshold, + default_is_valid_sample_fn, + default_should_continue_fn, + pause_pending_tasks, +) + + +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.controller import RolloutControllerProxy + + +logger = get_logger() + + +@dataclass +class DisaggProduceProgress: + """非共卡 producer / consumer 共享的绝对进度。""" + + task_names: list[str] = field(default_factory=list) + producer_future_step: int = 1 + next_consumer_step: int = 1 + target_upto_future_step: int = 0 + consumed_samples: dict[str, int] = field(default_factory=dict) + target_samples: dict[str, int] = field(default_factory=dict) + raw_rewards_sum: dict[str, float] = field(default_factory=dict) + raw_rewards_count: dict[str, int] = field(default_factory=dict) + produced_samples: dict[str, int] = field(default_factory=dict) + produced_tokens: dict[str, int] = field(default_factory=dict) + produce_time_s: float = 0.0 + + @classmethod + def build(cls, task_names: list[str]) -> "DisaggProduceProgress": + return cls( + task_names=list(task_names), + consumed_samples={task_name: 0 for task_name in task_names}, + target_samples={task_name: 0 for task_name in task_names}, + raw_rewards_sum={task_name: 0.0 for task_name in task_names}, + raw_rewards_count={task_name: 0 for task_name in task_names}, + produced_samples={task_name: 0 for task_name in task_names}, + produced_tokens={task_name: 0 for task_name in task_names}, + ) + + def ensure_target_upto( + self, + *, + batch_size: int, + future_step: int, + allocate_batch_sizes: Callable[[int, int], dict[str, int]], + ) -> dict[str, int]: + """把累计 target 推进到指定 future step,并返回该 step 的 task batch size。""" + + current_task_batch_sizes: dict[str, int] | None = None + if future_step > self.target_upto_future_step: + for step in range(self.target_upto_future_step + 1, future_step + 1): + current_task_batch_sizes = allocate_batch_sizes(batch_size, step) + for task_name, task_batch_size in current_task_batch_sizes.items(): + self.target_samples[task_name] += task_batch_size + self.target_upto_future_step = future_step + + if current_task_batch_sizes is None: + current_task_batch_sizes = allocate_batch_sizes(batch_size, future_step) + return current_task_batch_sizes + + def begin_consume(self, train_step: int) -> None: + self.next_consumer_step = train_step + + def mark_consumed(self, consumed_counts: dict[str, int]) -> None: + # target 不回退;producer 用 consumed + completed 判断真实缺口。 + for task_name, count in consumed_counts.items(): + self.consumed_samples[task_name] += count + + def finish_consume(self, train_step: int) -> None: + self.next_consumer_step = train_step + 1 + + def advance_future_step(self) -> None: + self.producer_future_step += 1 + + def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None: + self.raw_rewards_sum[task_name] += rewards_sum + self.raw_rewards_count[task_name] += rewards_count + + def add_produced(self, task_name: str, samples: int, tokens: int) -> None: + self.produced_samples[task_name] += samples + self.produced_tokens[task_name] += tokens + + def add_produce_time(self, elapsed_s: float) -> None: + self.produce_time_s += elapsed_s + + def consume_produced(self, task_name: str) -> tuple[int, int]: + samples = self.produced_samples[task_name] + tokens = self.produced_tokens[task_name] + self.produced_samples[task_name] = 0 + self.produced_tokens[task_name] = 0 + return samples, tokens + + def consume_produce_time(self) -> float: + produce_time_s = self.produce_time_s + self.produce_time_s = 0.0 + return produce_time_s + + def consume_raw_rewards(self, task_name: str) -> tuple[float, int]: + rewards_sum = self.raw_rewards_sum[task_name] + rewards_count = self.raw_rewards_count[task_name] + self.raw_rewards_sum[task_name] = 0.0 + self.raw_rewards_count[task_name] = 0 + return rewards_sum, rewards_count + + def state_dict(self) -> dict[str, Any]: + return { + "producer_future_step": self.producer_future_step, + "next_consumer_step": self.next_consumer_step, + "target_upto_future_step": self.target_upto_future_step, + "consumed_samples": dict(self.consumed_samples), + "target_samples": dict(self.target_samples), + "raw_rewards_sum": dict(self.raw_rewards_sum), + "raw_rewards_count": dict(self.raw_rewards_count), + "produced_samples": dict(self.produced_samples), + "produced_tokens": dict(self.produced_tokens), + "produce_time_s": self.produce_time_s, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + # 原地更新,避免 strategy / context 持有旧引用。 + self.producer_future_step = state["producer_future_step"] + self.next_consumer_step = state["next_consumer_step"] + self.target_upto_future_step = state["target_upto_future_step"] + self.consumed_samples.clear() + self.consumed_samples.update(state["consumed_samples"]) + self.target_samples.clear() + self.target_samples.update(state["target_samples"]) + task_names = set(self.consumed_samples) | set(self.target_samples) + self.raw_rewards_sum.clear() + self.raw_rewards_sum.update( + {task_name: float(state.get("raw_rewards_sum", {}).get(task_name, 0.0)) for task_name in task_names} + ) + self.raw_rewards_count.clear() + self.raw_rewards_count.update( + {task_name: int(state.get("raw_rewards_count", {}).get(task_name, 0)) for task_name in task_names} + ) + produced_samples_state = state.get("produced_samples", {}) + produced_tokens_state = state.get("produced_tokens", {}) + self.produced_samples.clear() + self.produced_samples.update( + {task_name: int(produced_samples_state.get(task_name, 0)) for task_name in task_names} + ) + self.produced_tokens.clear() + self.produced_tokens.update( + {task_name: int(produced_tokens_state.get(task_name, 0)) for task_name in task_names} + ) + self.produce_time_s = float(state.get("produce_time_s", 0.0)) + + +@dataclass(kw_only=True) +class DisaggProduceContext(BaseProduceContext): + """非共卡后台生产上下文。""" + + progress: DisaggProduceProgress + update_event: asyncio.Event = field(default_factory=asyncio.Event) + + @property + def consumer_step(self) -> int: + return self.progress.next_consumer_step + + @property + def total_target(self) -> int: + return self.progress.target_samples[self.task_name] + + def should_abort(self) -> bool: + return self.update_event.is_set() + + async def available_count(self) -> int: + completed_count = await self.replay_buffer.count(task_name=self.task_name, group_status=Status.COMPLETED) + return self.progress.consumed_samples[self.task_name] + completed_count + + +class DisaggProduceStrategyConfig(ABC, BaseModel): + """非共卡后台 producer strategy 配置。""" + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + @abstractmethod + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "DisaggProduceStrategy": ... + + +class DisaggAsyncProduceStrategyConfig(DisaggProduceStrategyConfig): + """非共卡异步生产配置。""" + + over_sample_threshold: float = 0.0 + enable_partial_rollout: bool = False + max_staleness: int = Field(default=0, ge=0) + tail_batch_trigger_size: int = 0 + + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "DisaggAsyncProduceStrategy": + if rollout_controller is not None: + import ray + + ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout)) + return DisaggAsyncProduceStrategy( + over_sample_threshold=self.over_sample_threshold, + enable_partial_rollout=self.enable_partial_rollout, + max_staleness=self.max_staleness, + sync_weights_interval=sync_weights_interval, + tail_batch_trigger_size=self.tail_batch_trigger_size, + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +class DisaggProduceStrategy(ABC): + def __init__( + self, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + + @abstractmethod + async def produce_batch(self, ctx: DisaggProduceContext) -> ProduceBatchStatus: ... + + async def pause_produce(self, ctx: DisaggProduceContext) -> float: + return 0.0 + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + return False + + def pending_task_count(self) -> int: + return 0 + + +class DisaggAsyncProduceStrategy(DisaggProduceStrategy): + """非共卡 async strategy;pending 跨后台生产轮次存在。""" + + PERIODIC_ABORT_INTERVAL_S = PERIODIC_ABORT_INTERVAL_S + + def __init__( + self, + over_sample_threshold: float, + enable_partial_rollout: bool, + tail_batch_trigger_size: int, + max_staleness: int, + sync_weights_interval: int, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + super().__init__(is_valid_sample_fn, should_continue_fn) + + if not enable_partial_rollout and max_staleness > 0: + logger.warning( + "max_staleness > 0, enable_partial_rollout is False, this will affect rollout efficiency because not support tail_batch_max_tries logic now" + ) + + self.over_sample_threshold = over_sample_threshold + self.enable_partial_rollout = enable_partial_rollout + self.max_staleness = max_staleness + self.sync_weights_interval = sync_weights_interval + self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval) + self.tail_batch_trigger_size = tail_batch_trigger_size + self._pending_tasks = _PendingTasks() + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + staleness = calculate_seq_staleness(model_step, train_step) + return staleness >= self.stale_threshold + + def pending_task_count(self) -> int: + return self._pending_tasks.count() + + async def pause_produce(self, ctx: DisaggProduceContext) -> float: + return await pause_pending_tasks( + pending_tasks=self._pending_tasks, + ctx=ctx, + put_claimed_task=lambda task: ctx.put_generated_group(task.result()), + ) + + async def produce_batch(self, ctx: DisaggProduceContext) -> ProduceBatchStatus: + if ctx.task_name not in ctx.progress.consumed_samples: + raise KeyError(f"DisaggProduceProgress.consumed_samples missing task_name={ctx.task_name!r}") + if ctx.task_name not in ctx.progress.target_samples: + raise KeyError(f"DisaggProduceProgress.target_samples missing task_name={ctx.task_name!r}") + + # TODO: place this check just before while loop + if ctx.should_abort(): + return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + if self.is_model_expired(ctx.train_step, ctx.model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + # 进入下一轮前先回收已完成的旧 pending。 + claimed_done = await self._pending_tasks.claim_ready() + await _put_claimed_tasks(claimed_done, ctx) + + # TODO: remove this check + if ctx.should_abort(): + return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + if self.is_model_expired(ctx.train_step, ctx.model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + if ctx.total_target <= 0: + return ProduceBatchStatus.NORMAL + + expired_count = await ctx.expired_count() + sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size + if sample_from_expired: + logger.info( + f"Tail batch trigger condition met: {expired_count} expired samples " + f"(threshold: {self.tail_batch_trigger_size}). Enabling tail batch mode." + ) + + # normal 使用固定超发预算;tail-batch 只补必要缺口。 + total_target = ctx.total_target + oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * ctx.task_batch_size) + scheduled_target = total_target + oversample_budget + logger.info( + f"Starting produce_batch for task {ctx.task_name} with total_target={total_target}, " + f"oversample_budget={oversample_budget}, scheduled_target={scheduled_target}." + ) + + async def spawn_one() -> asyncio.Task: + rollout_state = await ctx.sample_group(from_expired_pool=sample_from_expired) + return create_task( + ctx.generate_group( + rollout_state, + enable_partial_rollout=self.enable_partial_rollout, + ) + ) + + initial_available = await ctx.available_count() + progress_displayer = _ProgressDisplayer.create( + strategy_name=self.__class__.__name__, + task_name=ctx.task_name, + total=ctx.total_target, + initial=initial_available, + ) + produce_status = ProduceBatchStatus.NORMAL + while True: + if ctx.should_abort(): + produce_status = ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + break + if self.is_model_expired(ctx.train_step, ctx.model_step): + produce_status = ProduceBatchStatus.EXPIRED_BATCH + break + + available = await ctx.available_count() + progress_displayer.update(available) + if not self.should_continue_fn(available, total_target): + break + + pending_count = self._pending_tasks.count() + desired_pending = max(0, scheduled_target - available) + if available + pending_count < scheduled_target: + while await self._pending_tasks.schedule_one( + max_pending=desired_pending, + should_abort=ctx.should_abort, + spawn_one=spawn_one, + ): + pass + if ctx.should_abort(): + produce_status = ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + break + + # TODO: remove this check, because will check it when exit if statement, it's redundant + if ctx.should_abort(): + produce_status = ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT + break + if self._pending_tasks.count() == 0: + logger.warning("All tasks are done but not enough samples collected.") + break + + claimed_done = await self._pending_tasks.wait_and_claim(timeout_s=1) + await _put_claimed_tasks( + claimed_done, + ctx, + available_base=available, + progress_displayer=progress_displayer, + ) + progress_displayer.close() + return produce_status diff --git a/xtuner/v1/rl/agent_loop_manager/produce_utils.py b/xtuner/v1/rl/agent_loop_manager/produce_utils.py new file mode 100644 index 0000000000..c9e990934f --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/produce_utils.py @@ -0,0 +1,777 @@ +import asyncio +import math +import time +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import Any, Awaitable, Callable, Protocol, runtime_checkable + +import ray +import tqdm +from mmengine.dist import get_rank + +from xtuner.v1.data_proto.rl_data import RolloutState, Status, get_group_status, reset_rollout_response +from xtuner.v1.rl.agent_loop import AgentLoopSpec +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.utils import ( + AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S, + PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S, + cancel_and_drain, + create_task, +) +from xtuner.v1.utils import get_logger + +from .sampler import Sampler + + +logger = get_logger() +GROUP_GENERATE_TIME_KEY = "group_generate_time_s" +PERIODIC_ABORT_INTERVAL_S = 5.0 + + +class _ProgressDisplayer: + def __init__(self, progress_bar: Any | None) -> None: + self._tqdm = progress_bar + + @classmethod + def create(cls, *, strategy_name: str, task_name: str, total: int, initial: int) -> "_ProgressDisplayer": + total = max(0, total) + initial = min(total, max(0, initial)) + if total <= 0 or get_rank() != 0: + return cls(None) + return cls( + tqdm.tqdm( + total=total, + initial=initial, + desc=f"{strategy_name} {task_name}", + unit="sample", + dynamic_ncols=True, + mininterval=30, + leave=False, + ) + ) + + def update(self, value: int) -> None: + if self._tqdm is None: + return + total = max(0, int(self._tqdm.total or 0)) + value = min(total, max(0, value)) + delta = value - self._tqdm.n + if delta > 0: + self._tqdm.update(delta) + self._tqdm.n = value + + def close(self) -> None: + if self._tqdm is not None: + self._tqdm.close() + self._tqdm = None + + +class ProduceBatchStatus(Enum): + NORMAL = auto() + UPDATE_WEIGHT_AND_ABORT = auto() + EXPIRED_BATCH = auto() + + +def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool: + return True + + +def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool: + return completed_count < batch_size + + +def calculate_stale_threshold(max_staleness: int, sync_weights_interval: int) -> int: + if max_staleness < 0: + raise ValueError(f"max_staleness must be non-negative, got {max_staleness}.") + if sync_weights_interval <= 0: + raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.") + + # max_staleness 按同步周期计数;+1 表示训练天然必须接受的当前同步周期滞后。 + return (max_staleness + 1) * sync_weights_interval + + +@runtime_checkable +class IsValidSampleFn(Protocol): + def __call__(self, samples: list[RolloutState]) -> bool: ... + + +@runtime_checkable +class ShouldContinueFn(Protocol): + def __call__(self, completed_count: int, batch_size: int, **kwargs) -> bool: ... + + +@dataclass(kw_only=True) +class BaseProduceContext: + """共卡/非共卡共享的 sample、generate、put 能力。""" + + agent_loop: AgentLoopSpec + sampler: Sampler + replay_buffer: ReplayBuffer + task_batch_size: int + task_name: str + train_step: int + model_step: int + progress: Any + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + stale_threshold: int | None = None + + @property + def consumer_step(self) -> int: + return self.train_step + + async def expired_count(self) -> int: + return await self.replay_buffer.count(task_name=self.task_name, group_status=Status.EXPIRED) + + async def sample_group(self, *, from_expired_pool: bool) -> list[RolloutState]: + group_status = [Status.EXPIRED, Status.ABORTED] if from_expired_pool else [Status.ABORTED] + return await self.sampler.sample(task_name=self.task_name, group_status=group_status) + + async def generate_group( + self, + rollout_state: list[RolloutState], + *, + enable_partial_rollout: bool = False, + ) -> list[RolloutState]: + # strategy 不关心 agent_loop 是 ray actor 还是本地对象。 + start = time.perf_counter() + if isinstance(self.agent_loop, ray.actor.ActorHandle): + result = await self.agent_loop.generate_group.remote( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + else: + result = await self.agent_loop.generate_group( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + elapsed = time.perf_counter() - start + for item in result: + extra_fields = getattr(item, "extra_fields", None) + if extra_fields is None: + extra_fields = {} + setattr(item, "extra_fields", extra_fields) + extra_fields[GROUP_GENERATE_TIME_KEY] = elapsed + return result + + async def put_generated_group(self, group: list[RolloutState]) -> bool: + # 只有 COMPLETED group 需要业务过滤;ABORTED / EXPIRED 保留原状态。 + is_completed = get_group_status(group) == Status.COMPLETED + produced_tokens = sum(len(item.response_ids) for item in group if item.response_ids is not None) + if is_completed: + rewards_sum = 0.0 + rewards_count = 0 + for item in group: + if item.reward is None or "score" not in item.reward: + logger.warning( + f"Missing reward score in item (uid: {item.uid}) of completed group for task {self.task_name}. This item will be skipped in reward statistics." + ) + continue + rewards_sum += float(item.reward["score"]) # type: ignore[index] + rewards_count += 1 + self.progress.add_raw_rewards(self.task_name, rewards_sum, rewards_count) + is_valid = self.is_valid_sample_fn(group) + if not is_valid: + for item in group: + item.status = Status.FILTERED + reset_rollout_response(item) + await self.replay_buffer.put( + group, + self.task_name, + model_step=self.model_step, + current_train_step=self.consumer_step, + stale_threshold=self.stale_threshold, + ) + self.progress.add_produced(self.task_name, samples=len(group), tokens=produced_tokens) + # replay_buffer.put 可能因 staleness 把 group 转为 EXPIRED。 + is_completed = get_group_status(group) == Status.COMPLETED + return is_completed + + +@dataclass +class ProduceBatchResult: + """Result of a single ``produce_batch`` call. + + Args: + rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training. + group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran). + group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds. + group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds. + group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds. + group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew. + group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds. + leftover_init (int): Number of init groups remaining in the replay buffer after this batch. + leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch. + leftover_aborted (int): Number of aborted groups remaining in the replay buffer. + leftover_expired (int): Number of expired groups remaining in the replay buffer. + leftover_failed (int): Number of failed groups remaining in the replay buffer. + leftover_filtered (int): Number of filtered groups remaining in the replay buffer. + raw_rewards_sum (float): Sum of rewards produced before replay-buffer insertion for the current window. + raw_rewards_count (int): Number of reward-bearing samples included in ``raw_rewards_sum``. + produced_samples (int): Number of rollout samples produced in the current produce window. + produced_tokens (int): Number of response tokens produced in the current produce window. + produce_time_s (float): Wall-clock production time consumed by the current produce window. + """ + + rollout_states: list[list[RolloutState]] + status: ProduceBatchStatus = ProduceBatchStatus.NORMAL + # per-group generation timing stats (all None if no generations occurred) + group_gen_count: int | None = None + group_gen_mean_s: float | None = None + group_gen_p50_s: float | None = None + group_gen_p99_s: float | None = None + group_gen_p99_p50_ratio: float | None = None + group_gen_pause_time_s: float | None = None + # leftover samples remaining in replay buffer after batch retrieval + leftover_init: int = 0 + leftover_completed: int = 0 + leftover_aborted: int = 0 + leftover_expired: int = 0 + leftover_failed: int = 0 + leftover_filtered: int = 0 + # rewards produced during the current produce window, including completed and filtered groups. + raw_rewards_sum: float = 0.0 + raw_rewards_count: int = 0 + produced_samples: int = 0 + produced_tokens: int = 0 + produce_time_s: float = 0.0 + task_batch_sizes: dict[str, int] | None = None + task_results: dict[str, "ProduceBatchResult"] | None = None + + +@dataclass(frozen=True) +class _TaskRunner: + task_name: str + agent_loop: AgentLoopSpec + produce_strategy: Any + sampler: Sampler + weight: float = 1.0 + order: int = 0 + + @property + def is_valid_sample_fn(self) -> IsValidSampleFn: + return getattr(self.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn) + + @property + def stale_threshold(self) -> int | None: + return getattr(self.produce_strategy, "stale_threshold", None) + + +class _TaskSamplerView: + def __init__(self, samplers: list[Sampler]): + self._samplers = samplers + + def __len__(self) -> int: + return sum(len(sampler) for sampler in self._samplers) + + +def _fill_produce_timing_stats( + result: ProduceBatchResult, generate_times_s: list[float], pause_time_s: float = 0.0 +) -> None: + if not generate_times_s: + if pause_time_s > 0: + result.group_gen_pause_time_s = pause_time_s + return + sorted_times = sorted(generate_times_s) + n = len(sorted_times) + mean_s = sum(sorted_times) / n + p50_s = sorted_times[n // 2] + p99_s = sorted_times[int(n * 0.99)] + ratio = p99_s / p50_s if p50_s > 0 else float("inf") + result.group_gen_count = n + result.group_gen_mean_s = mean_s + result.group_gen_p50_s = p50_s + result.group_gen_p99_s = p99_s + result.group_gen_p99_p50_ratio = ratio + result.group_gen_pause_time_s = pause_time_s + + +def _fill_group_timing_stats( + result: ProduceBatchResult, rollout_states: list[list[RolloutState]], pause_time_s: float = 0.0 +) -> None: + generate_times: list[float] = [] + for group in rollout_states: + if not group: + continue + group_time = getattr(group[0], "extra_fields", {}).get(GROUP_GENERATE_TIME_KEY) + if group_time is not None: + generate_times.append(group_time) + + _fill_produce_timing_stats(result, generate_times, pause_time_s=pause_time_s) + + +_LEFTOVER_STATUSES = [ + Status.INIT, + Status.COMPLETED, + Status.ABORTED, + Status.EXPIRED, + Status.FAILED, + Status.FILTERED, +] +_TASK_CHECKPOINT_DIR = "tasks" +_MANAGER_STATE_PATH = "agent_loop_manager_state.json" +_STATUS_POLL_INTERVAL_S = 1.0 + + +def _fill_leftover_counts(result: ProduceBatchResult, status_counts: dict[Status, int]) -> None: + result.leftover_init = status_counts.get(Status.INIT, 0) + result.leftover_completed = status_counts.get(Status.COMPLETED, 0) + result.leftover_aborted = status_counts.get(Status.ABORTED, 0) + result.leftover_expired = status_counts.get(Status.EXPIRED, 0) + result.leftover_failed = status_counts.get(Status.FAILED, 0) + result.leftover_filtered = status_counts.get(Status.FILTERED, 0) + + +def allocate_task_batch_sizes( + task_runners: list[_TaskRunner], + global_batch_size: int, + train_step: int, +) -> dict[str, int]: + # train_step 只为后台 progress 回调保留同一形状;当前按静态 weight 分配。 + if global_batch_size < 0: + raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}") + + total_weight = sum(task.weight for task in task_runners) + if total_weight <= 0: + raise ValueError("Sum of task weights must be positive.") + if global_batch_size == 0: + task_batch_sizes = {task.task_name: 0 for task in task_runners} + else: + raw_allocations = [global_batch_size * task.weight / total_weight for task in task_runners] + floor_allocations = [math.floor(raw) for raw in raw_allocations] + remaining = global_batch_size - sum(floor_allocations) + + task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(task_runners)} + ranked_tasks = sorted( + enumerate(task_runners), + key=lambda item: ( + -(raw_allocations[item[0]] - floor_allocations[item[0]]), + item[1].order, + ), + ) + for idx, task in ranked_tasks[:remaining]: + task_batch_sizes[task.task_name] += 1 + + expected_task_names = {task.task_name for task in task_runners} + actual_task_names = set(task_batch_sizes.keys()) + if actual_task_names != expected_task_names: + missing_task_names = expected_task_names - actual_task_names + extra_task_names = actual_task_names - expected_task_names + raise ValueError( + "Invalid task batch sizes allocated: " + f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}" + ) + + negative_batch_sizes = { + task_name: task_batch_size for task_name, task_batch_size in task_batch_sizes.items() if task_batch_size < 0 + } + if negative_batch_sizes: + raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}") + + total_batch_size = sum(task_batch_sizes.values()) + if total_batch_size != global_batch_size: + raise ValueError( + "Task batch sizes must sum to the requested global batch size, " + f"got total={total_batch_size}, expected={global_batch_size}" + ) + return task_batch_sizes + + +async def refresh_for_all_tasks( + *, + task_runners: list[_TaskRunner], + replay_buffer: ReplayBuffer, + logger, + manager_name: str, + train_step: int, + statuses: list[Status], +) -> None: + task_stale_thresholds: dict[str, int] = {} + for task in task_runners: + # 没有 stale_threshold 的同步策略按 1 处理。 + task_stale_thresholds[task.task_name] = task.stale_threshold if task.stale_threshold is not None else 1 + + expired_counts = await replay_buffer.refresh_staleness( + task_stale_thresholds=task_stale_thresholds, + current_train_step=train_step, + statuses=statuses, + ) + for task_name, expired_count in expired_counts.items(): + logger.info( + f"[AgentLoopManager][{manager_name}] Refresh staleness for task {task_name}: expired_count={expired_count}" + ) + + +def aggregate_task_results( + ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult] +) -> ProduceBatchResult: + rollout_states: list[list[RolloutState]] = [] + leftover_init = 0 + leftover_completed = 0 + leftover_aborted = 0 + leftover_expired = 0 + leftover_failed = 0 + leftover_filtered = 0 + total_group_count = 0 + weighted_group_mean_sum = 0.0 + weighted_group_p50_sum = 0.0 + weighted_group_p99_sum = 0.0 + weighted_group_ratio_sum = 0.0 + total_pause_time_s = 0.0 + raw_rewards_sum = 0.0 + raw_rewards_count = 0 + produced_samples = 0 + produced_tokens = 0 + produce_time_s = 0.0 + + for task in ordered_tasks: + result = task_results[task.task_name] + rollout_states.extend(result.rollout_states) + leftover_init += result.leftover_init + leftover_completed += result.leftover_completed + leftover_aborted += result.leftover_aborted + leftover_expired += result.leftover_expired + leftover_failed += result.leftover_failed + leftover_filtered += result.leftover_filtered + raw_rewards_sum += result.raw_rewards_sum + raw_rewards_count += result.raw_rewards_count + produced_samples += result.produced_samples + produced_tokens += result.produced_tokens + produce_time_s += result.produce_time_s + if result.group_gen_count is not None and result.group_gen_mean_s is not None: + total_group_count += result.group_gen_count + weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s + weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0) + weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0) + weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0) + total_pause_time_s += result.group_gen_pause_time_s or 0.0 + + aggregated = ProduceBatchResult( + rollout_states=rollout_states, + leftover_init=leftover_init, + leftover_completed=leftover_completed, + leftover_aborted=leftover_aborted, + leftover_expired=leftover_expired, + leftover_failed=leftover_failed, + leftover_filtered=leftover_filtered, + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, + produced_samples=produced_samples, + produced_tokens=produced_tokens, + produce_time_s=produce_time_s, + task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, + ) + if total_group_count > 0: + aggregated.group_gen_count = total_group_count + aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count + aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count + aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count + aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count + aggregated.group_gen_pause_time_s = total_pause_time_s + return aggregated + + +def log_buffer_counts( + logger, + *, + manager_name: str, + task_runners: list[_TaskRunner], + task_batch_sizes: dict[str, int], + batch_by_task: dict[str, list[list[RolloutState]]], + leftover_counts: dict[str, dict[Status, int]], +) -> None: + for task in task_runners: + task_name = task.task_name + task_counts = leftover_counts.get(task_name, {}) + logger.info( + f"[AgentLoopManager][{manager_name}] get_batch from buffer for task {task_name}: " + f"requested={task_batch_sizes[task_name]}, retrieved={len(batch_by_task.get(task_name, []))}, " + f"leftover_init={task_counts.get(Status.INIT, 0)}, " + f"leftover_completed={task_counts.get(Status.COMPLETED, 0)}, " + f"leftover_aborted={task_counts.get(Status.ABORTED, 0)}, " + f"leftover_expired={task_counts.get(Status.EXPIRED, 0)}, " + f"leftover_failed={task_counts.get(Status.FAILED, 0)}, " + f"leftover_filtered={task_counts.get(Status.FILTERED, 0)}" + ) + + +def build_produce_batch_result( + *, + task_runners: list[_TaskRunner], + task_batch_sizes: dict[str, int], + batch_by_task: dict[str, list[list[RolloutState]]], + leftover_counts: dict[str, dict[Status, int]], + progress: Any, + pause_time_s: float, +) -> ProduceBatchResult: + if len(task_runners) == 1: + task = task_runners[0] + raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) + produced_samples, produced_tokens = progress.consume_produced(task.task_name) + produce_time_s = progress.consume_produce_time() + result = ProduceBatchResult( + rollout_states=batch_by_task.get(task.task_name, []), + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, + produced_samples=produced_samples, + produced_tokens=produced_tokens, + produce_time_s=produce_time_s, + ) + _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) + _fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s) + return result + + task_results: dict[str, ProduceBatchResult] = {} + produce_time_s = progress.consume_produce_time() + for task in task_runners: + raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) + produced_samples, produced_tokens = progress.consume_produced(task.task_name) + result = ProduceBatchResult( + rollout_states=batch_by_task.get(task.task_name, []), + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, + produced_samples=produced_samples, + produced_tokens=produced_tokens, + ) + _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) + task_results[task.task_name] = result + + ordered_tasks = sorted(task_runners, key=lambda task: (task.task_name, task.order)) + aggregated = aggregate_task_results(ordered_tasks, task_results) + aggregated.produce_time_s = produce_time_s + aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks} + _fill_group_timing_stats(aggregated, aggregated.rollout_states, pause_time_s=pause_time_s) + return aggregated + + +async def take_train_batch( + *, + task_runners: list[_TaskRunner], + replay_buffer: ReplayBuffer, + logger, + manager_name: str, + task_batch_sizes: dict[str, int], + progress: Any, + pause_time_s: float = 0.0, +) -> ProduceBatchResult: + batch_by_task, consumed_counts = await replay_buffer.take_batch(task_batch_sizes) + if hasattr(progress, "mark_consumed"): + progress.mark_consumed(consumed_counts) + task_names = [task.task_name for task in task_runners] + leftover_counts = await replay_buffer.count_statuses(task_names, _LEFTOVER_STATUSES) + log_buffer_counts( + logger, + manager_name=manager_name, + task_runners=task_runners, + task_batch_sizes=task_batch_sizes, + batch_by_task=batch_by_task, + leftover_counts=leftover_counts, + ) + return build_produce_batch_result( + task_runners=task_runners, + task_batch_sizes=task_batch_sizes, + batch_by_task=batch_by_task, + leftover_counts=leftover_counts, + progress=progress, + pause_time_s=pause_time_s, + ) + + +def task_checkpoint_path(checkpoint_path: Path | str, task_name: str) -> Path: + return Path(checkpoint_path) / _TASK_CHECKPOINT_DIR / task_name + + +def manager_state_path(checkpoint_path: Path | str) -> Path: + return Path(checkpoint_path) / _MANAGER_STATE_PATH + + +def get_pending_task_counts(task_runners: list[_TaskRunner]) -> dict[str, int]: + pending_task_counts: dict[str, int] = {} + for task in task_runners: + pending_count = task.produce_strategy.pending_task_count() + if pending_count > 0: + pending_task_counts[task.task_name] = pending_count + return pending_task_counts + + +class _PendingTasks: + """AsyncProduceStrategy 的并发 pending task 集合。 + + 这里只封装 pending set 的并发协议,不理解 sampler / rollout / replay buffer: + - wait 使用快照,随后必须二次 claim,避免 produce 和 pause 重复处理同一个 done task。 + - cancel 前先原子 claim 并清空集合,避免 cancel 后又被其他路径 claim。 + - schedule one 在锁内同时检查 abort 和 pending 数,避免 pause 已触发后继续新增 task。 + """ + + def __init__(self) -> None: + self._tasks: set[asyncio.Task] = set() + self._lock = asyncio.Lock() + + def count(self) -> int: + return len(self._tasks) + + async def claim_ready(self) -> set[asyncio.Task]: + async with self._lock: + ready = {task for task in self._tasks if task.done()} + self._tasks.difference_update(ready) + return ready + + async def wait_and_claim(self, *, timeout_s: float) -> set[asyncio.Task]: + async with self._lock: + snapshot = set(self._tasks) + if not snapshot: + return set() + + done, _ = await asyncio.wait(snapshot, timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED) + async with self._lock: + claimed = done & self._tasks + self._tasks.difference_update(claimed) + return claimed + + async def schedule_one( + self, + *, + max_pending: int, + should_abort: Callable[[], bool], + spawn_one: Callable[[], Awaitable[asyncio.Task]], + ) -> bool: + async with self._lock: + if should_abort() or len(self._tasks) >= max_pending: + return False + self._tasks.add(await spawn_one()) + return True + + async def _claim_all(self) -> set[asyncio.Task]: + async with self._lock: + claimed = set(self._tasks) + self._tasks.clear() + return claimed + + async def cancel_all(self) -> int: + tasks = await self._claim_all() + if not tasks: + return 0 + logger.warning(f"Cancelling {len(tasks)} pending rollout tasks.") + await cancel_and_drain(list(tasks)) + return len(tasks) + + +class _LocalPendingTasks: + """把共卡本次调用的局部 pending set 适配成统一 drain 协议。 + + 共卡 pending 不跨 produce_batch 调用;这里原地修改传入的 set,让 pending_task_count() 在 pause 过程中仍能反映剩余本地任务数量。 + """ + + def __init__(self, tasks: set[asyncio.Task]) -> None: + self._tasks = tasks + + def count(self) -> int: + return len(self._tasks) + + async def wait_and_claim(self, *, timeout_s: float) -> set[asyncio.Task]: + if not self._tasks: + return set() + done, _ = await asyncio.wait(set(self._tasks), timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED) + self._tasks.difference_update(done) + return done + + async def cancel_all(self) -> int: + tasks = set(self._tasks) + self._tasks.clear() + if not tasks: + return 0 + logger.warning(f"Cancelling {len(tasks)} pending rollout tasks.") + await cancel_and_drain(list(tasks)) + return len(tasks) + + +async def request_agent_loop_pause(ctx: BaseProduceContext, *, pending_count: int) -> None: + """发送一次 agent loop pause 请求。""" + + pause_request_start = time.perf_counter() + if isinstance(ctx.agent_loop, ray.actor.ActorHandle): + pause_future = ctx.agent_loop.pause.remote() + else: + pause_future = ctx.agent_loop.pause() + try: + await asyncio.wait_for(pause_future, timeout=AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S) + except asyncio.TimeoutError: + logger.warning( + f"Agent loop pause timed out: task={ctx.task_name}, timeout_s={AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S}, " + f"elapsed={time.perf_counter() - pause_request_start:.2f}s, pending={pending_count}" + ) + except Exception: + logger.exception( + f"Agent loop pause failed: task={ctx.task_name}, " + f"elapsed={time.perf_counter() - pause_request_start:.2f}s, pending={pending_count}" + ) + + +async def pause_pending_tasks( + *, + pending_tasks: set[asyncio.Task] | _PendingTasks, + ctx: BaseProduceContext, + put_claimed_task: Callable[[asyncio.Task], Awaitable[Any]], +) -> float: + """Pause/drain pending;超时后 cancel 剩余任务。""" + + pending = _LocalPendingTasks(pending_tasks) if isinstance(pending_tasks, set) else pending_tasks + pause_start = time.perf_counter() + if pending.count() == 0: + return 0.0 + + initial_pending_count = pending.count() + logger.info( + f"Pause signal loop started for task {ctx.task_name}. " + f"Waiting for {initial_pending_count} pending tasks to complete. " + f"periodic_abort_interval_s={PERIODIC_ABORT_INTERVAL_S}, " + f"producer_pause_pending_task_timeout_s={PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}" + ) + + pending_pause_tasks = {create_task(request_agent_loop_pause(ctx, pending_count=initial_pending_count))} + cleanup_start_time = time.perf_counter() + next_periodic_abort_time = cleanup_start_time + PERIODIC_ABORT_INTERVAL_S + while True: + elapsed_time = time.perf_counter() - cleanup_start_time + if elapsed_time > PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S: + cancelled_count = await pending.cancel_all() + logger.warning( + f"Cleanup timeout of {PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}s reached. " + f"Forcefully cancelling {cancelled_count} remaining tasks. task={ctx.task_name}" + ) + break + + if pending.count() == 0: + break + current_time = time.perf_counter() + pending_pause_tasks = {task for task in pending_pause_tasks if not task.done()} + + # 定时发送 pause 信号,避免后端漏掉第一次 pause 后 pending 长时间不结束。 + if PERIODIC_ABORT_INTERVAL_S > 0 and current_time >= next_periodic_abort_time: + pending_pause_tasks.add(create_task(request_agent_loop_pause(ctx, pending_count=pending.count()))) + next_periodic_abort_time += PERIODIC_ABORT_INTERVAL_S + + claimed_done = await pending.wait_and_claim(timeout_s=1) + for task in claimed_done: + await put_claimed_task(task) + + await cancel_and_drain(list(pending_pause_tasks)) + pause_time = time.perf_counter() - pause_start + logger.info(f"pause_produce completed for task {ctx.task_name} within {pause_time}s.") + return pause_time + + +async def _put_claimed_tasks( + claimed_tasks: set[asyncio.Task], + ctx: BaseProduceContext, + *, + available_base: int | None = None, + progress_displayer: _ProgressDisplayer | None = None, +) -> None: + completed_count = 0 + for task in claimed_tasks: + is_completed = await ctx.put_generated_group(task.result()) + if is_completed: + completed_count += 1 + if is_completed and available_base is not None and progress_displayer is not None: + progress_displayer.update(available_base + completed_count) diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index bf6dafe54d..59af4b2fed 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -1,118 +1,41 @@ import asyncio import math -import time from abc import ABC, abstractmethod from dataclasses import dataclass, field -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from xtuner.v1.rl.rollout.controller import RolloutControllerProxy - -import ray -import tqdm -from mmengine.dist import get_rank from pydantic import BaseModel, ConfigDict, Field -from xtuner.v1.data_proto.rl_data import ( - RolloutState, - Status, - get_group_status, - reset_rollout_response, -) -from xtuner.v1.rl.agent_loop import AgentLoopSpec -from xtuner.v1.rl.replay_buffer import ReplayBuffer -from xtuner.v1.rl.utils import ( - AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S, - PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S, - calculate_seq_staleness, - cancel_and_drain, - create_task, -) +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.utils import create_task from xtuner.v1.utils import get_logger -from .sampler import Sampler - - -logger = get_logger() -GROUP_GENERATE_TIME_KEY = "group_generate_time_s" - +from .produce_utils import ( + PERIODIC_ABORT_INTERVAL_S, + BaseProduceContext, + IsValidSampleFn, + ShouldContinueFn, + _ProgressDisplayer, + _put_claimed_tasks, + calculate_stale_threshold, + default_is_valid_sample_fn, + default_should_continue_fn, + pause_pending_tasks, +) -class _ProgressDisplayer: - def __init__(self, progress_bar: Any | None) -> None: - self._tqdm = progress_bar - @classmethod - def create(cls, *, strategy_name: str, task_name: str, total: int, initial: int) -> "_ProgressDisplayer": - total = max(0, total) - initial = min(total, max(0, initial)) - if total <= 0 or get_rank() != 0: - return cls(None) - return cls( - tqdm.tqdm( - total=total, - initial=initial, - desc=f"{strategy_name} {task_name}", - unit="sample", - dynamic_ncols=True, - mininterval=30, - leave=False, - ) - ) +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.controller import RolloutControllerProxy - def update(self, value: int) -> None: - if self._tqdm is None: - return - total = max(0, int(self._tqdm.total or 0)) - value = min(total, max(0, value)) - delta = value - self._tqdm.n - if delta > 0: - self._tqdm.update(delta) - self._tqdm.n = value - def close(self) -> None: - if self._tqdm is not None: - self._tqdm.close() - self._tqdm = None +logger = get_logger() @dataclass class ProduceProgress: - """生产者和消费者共享的 live 进度对象。 - - 设计目标: - - Manager / 调用方负责初始化并原地更新这个对象,strategy 只接收引用并读取最新进度。 - - target / consumed 使用全局绝对累计口径,避免 consumer 取走 buffer 中的 completed 后, - producer 把已消费样本误判成缺口并重复补发。 - - 同一套语义同时服务非共卡全局 progress 和共卡 produce_batch 的局部 progress。 - - 使用注意: - - 不要在 strategy 中补 key 或用 dict.get(..., 0) 兜底;缺少 task key 应 fail fast。 - - 除非语义明确要求冻结本轮 produce_batch 的 target / scheduled_target, - 否则不要把字段值复制成局部快照后跨 await 使用;需要字段值时直接读 progress.xxx, - 让并发更新后的 next_consumer_step / consumed_samples 能尽早生效。 - - 运行中不要整体替换 ProduceProgress 对象;resume 时也应原地更新字段,避免旧引用失效。 - - 字段含义: - - next_consumer_step:producer 写入新样本时应面向的训练 step。get_batch(i) 入口设为 i, - 成功取出非空 batch 后设为 i + 1。 - - producer_future_step:producer 当前准备生产的 future step。 - - consumed_samples:各 task 已被 consumer 从 replay buffer 取走的 group 绝对累计数。 - - target_samples:各 task 截至 target_upto_future_step 应生产出的 group 绝对累计目标。 - - target_upto_future_step:target_samples 已覆盖到的最大 future step。 - - raw_rewards_sum / raw_rewards_count:各 task 自上次 consumer 取 batch 后,producer 实际生成出的 - completed group reward 统计。filtered group 在过滤前仍按 completed 生成结果计入。 - - produced_samples / produced_tokens:各 task 自上次 consumer 取 batch 后,producer 实际返回的样本数和 - response token 数,包含 filtered / aborted / 未被训练消费的 completed 样本。 - - produce_time_s:自上次 consumer 取 batch 后,producer 实际执行 produce_batch 的累计 wall time。 - """ + """共卡单次 produce_batch 的局部指标,不进入 checkpoint。""" - next_consumer_step: int = 1 - producer_future_step: int = 1 - consumed_samples: dict[str, int] = field(default_factory=dict) target_samples: dict[str, int] = field(default_factory=dict) - target_upto_future_step: int = 0 raw_rewards_sum: dict[str, float] = field(default_factory=dict) raw_rewards_count: dict[str, int] = field(default_factory=dict) produced_samples: dict[str, int] = field(default_factory=dict) @@ -120,62 +43,20 @@ class ProduceProgress: produce_time_s: float = 0.0 @classmethod - def build(cls, task_names: list[str]) -> "ProduceProgress": - return cls( - consumed_samples={task_name: 0 for task_name in task_names}, - target_samples={task_name: 0 for task_name in task_names}, - raw_rewards_sum={task_name: 0.0 for task_name in task_names}, - raw_rewards_count={task_name: 0 for task_name in task_names}, - produced_samples={task_name: 0 for task_name in task_names}, - produced_tokens={task_name: 0 for task_name in task_names}, - ) - - @classmethod - def build_local( + def build( cls, + *, task_names: list[str], - task_batch_sizes: dict[str, int], - train_step: int, + target_samples: dict[str, int], ) -> "ProduceProgress": - # 共卡路径使用局部 progress,只表达本次 produce_batch 的目标,不污染非共卡累计窗口。 return cls( - next_consumer_step=train_step, - producer_future_step=train_step, - consumed_samples={task_name: 0 for task_name in task_names}, - target_samples=dict(task_batch_sizes), - target_upto_future_step=train_step, + target_samples=dict(target_samples), raw_rewards_sum={task_name: 0.0 for task_name in task_names}, raw_rewards_count={task_name: 0 for task_name in task_names}, produced_samples={task_name: 0 for task_name in task_names}, produced_tokens={task_name: 0 for task_name in task_names}, ) - def ensure_target_upto( - self, - *, - batch_size: int, - future_step: int, - allocate_batch_sizes: Callable[[int, int], dict[str, int]], - ) -> dict[str, int]: - """把累计 target 推进到指定 future step,并返回该 step 的 task batch size。""" - - if future_step > self.target_upto_future_step: - for step in range(self.target_upto_future_step + 1, future_step + 1): - task_batch_sizes = allocate_batch_sizes(batch_size, step) - for task_name, task_batch_size in task_batch_sizes.items(): - self.target_samples[task_name] += task_batch_size - self.target_upto_future_step = future_step - - return allocate_batch_sizes(batch_size, future_step) - - def begin_consume(self, train_step: int) -> None: - self.next_consumer_step = train_step - - def mark_consumed(self, consumed_counts: dict[str, int]) -> None: - # consumer 真实取出多少就累计多少,target 不回退,避免 producer 把已消费样本当成缺口。 - for task_name, count in consumed_counts.items(): - self.consumed_samples[task_name] += count - def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None: self.raw_rewards_sum[task_name] += rewards_sum self.raw_rewards_count[task_name] += rewards_count @@ -206,196 +87,17 @@ def consume_raw_rewards(self, task_name: str) -> tuple[float, int]: self.raw_rewards_count[task_name] = 0 return rewards_sum, rewards_count - def finish_consume(self, train_step: int) -> None: - self.next_consumer_step = train_step + 1 - - def advance_future_step(self) -> None: - self.producer_future_step += 1 - - def state_dict(self) -> dict[str, Any]: - return { - "next_consumer_step": self.next_consumer_step, - "producer_future_step": self.producer_future_step, - "consumed_samples": dict(self.consumed_samples), - "target_samples": dict(self.target_samples), - "target_upto_future_step": self.target_upto_future_step, - "raw_rewards_sum": dict(self.raw_rewards_sum), - "raw_rewards_count": dict(self.raw_rewards_count), - "produced_samples": dict(self.produced_samples), - "produced_tokens": dict(self.produced_tokens), - "produce_time_s": self.produce_time_s, - } - - def load_state_dict(self, state: dict[str, Any]) -> None: - # 原地更新 dict,避免 strategy / context 持有旧引用。 - self.next_consumer_step = state["next_consumer_step"] - self.producer_future_step = state["producer_future_step"] - self.target_upto_future_step = state["target_upto_future_step"] - self.consumed_samples.clear() - self.consumed_samples.update(state["consumed_samples"]) - self.target_samples.clear() - self.target_samples.update(state["target_samples"]) - task_names = set(self.consumed_samples) | set(self.target_samples) - self.raw_rewards_sum.clear() - self.raw_rewards_sum.update( - {task_name: float(state.get("raw_rewards_sum", {}).get(task_name, 0.0)) for task_name in task_names} - ) - self.raw_rewards_count.clear() - self.raw_rewards_count.update( - {task_name: int(state.get("raw_rewards_count", {}).get(task_name, 0)) for task_name in task_names} - ) - produced_samples_state = state.get("produced_samples", {}) - produced_tokens_state = state.get("produced_tokens", {}) - self.produced_samples.clear() - self.produced_samples.update( - {task_name: int(produced_samples_state.get(task_name, 0)) for task_name in task_names} - ) - self.produced_tokens.clear() - self.produced_tokens.update( - {task_name: int(produced_tokens_state.get(task_name, 0)) for task_name in task_names} - ) - self.produce_time_s = float(state.get("produce_time_s", 0.0)) - -class ProduceBatchStatus(Enum): - NORMAL = auto() - UPDATE_WEIGHT_AND_ABORT = auto() - EXPIRED_BATCH = auto() - - -def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool: - return True - - -def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool: - return completed_count < batch_size - - -def calculate_stale_threshold(max_staleness: int, sync_weights_interval: int) -> int: - if max_staleness < 0: - raise ValueError(f"max_staleness must be non-negative, got {max_staleness}.") - if sync_weights_interval <= 0: - raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.") - - # max_staleness 按同步周期计数;+1 表示训练天然必须接受的当前同步周期滞后。 - return (max_staleness + 1) * sync_weights_interval - - -@runtime_checkable -class IsValidSampleFn(Protocol): - def __call__(self, samples: list[RolloutState]) -> bool: ... - - -@runtime_checkable -class ShouldContinueFn(Protocol): - def __call__(self, completed_count: int, batch_size: int, **kwargs) -> bool: ... - - -@dataclass -class ProduceContext: - """单 task 生产上下文。 - - 这里集中维护 AsyncProduceStrategy 最容易传错的运行时契约: - - strategy 只接受 ProduceContext,不再兼容散装参数入口; - - target / consumed 都按绝对累计口径读取; - - 暂停只读 manager 传入的 update_event; - - rollout generate 的 ray/local 差异和 timing 字段写入; - - 生成结果先按业务有效性过滤,再统一交给 ReplayBuffer 写版本、刷新 staleness、执行过期。 - """ - - agent_loop: AgentLoopSpec - sampler: Sampler - replay_buffer: ReplayBuffer - task_batch_size: int - task_name: str - train_step: int - update_event: asyncio.Event - model_step: int - progress: ProduceProgress - is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn - stale_threshold: int | None = None +@dataclass(kw_only=True) +class ProduceContext(BaseProduceContext): + """共卡本地生产窗口;不暴露非共卡状态机字段。""" @property - def consumer_step(self) -> int: - return self.progress.next_consumer_step - - @property - def target_abs(self) -> int: + def batch_target(self) -> int: return self.progress.target_samples[self.task_name] - def should_abort(self) -> bool: - return self.update_event.is_set() - - async def expired_count(self) -> int: - return await self.replay_buffer.count(task_name=self.task_name, group_status=Status.EXPIRED) - - async def available_count(self) -> int: - completed_count = await self.replay_buffer.count(task_name=self.task_name, group_status=Status.COMPLETED) - return self.progress.consumed_samples[self.task_name] + completed_count - - async def sample_group(self, *, from_expired_pool: bool) -> list[RolloutState]: - group_status = [Status.EXPIRED, Status.ABORTED] if from_expired_pool else [Status.ABORTED] - return await self.sampler.sample(task_name=self.task_name, group_status=group_status) - - async def generate_group( - self, - rollout_state: list[RolloutState], - *, - enable_partial_rollout: bool = False, - ) -> list[RolloutState]: - # strategy 只表达“要生成”,不关心 agent_loop 是 ray actor 还是本地对象。 - start = time.perf_counter() - if isinstance(self.agent_loop, ray.actor.ActorHandle): - result = await self.agent_loop.generate_group.remote( - rollout_state, - enable_partial_rollout=enable_partial_rollout, - ) - else: - result = await self.agent_loop.generate_group( - rollout_state, - enable_partial_rollout=enable_partial_rollout, - ) - elapsed = time.perf_counter() - start - for item in result: - extra_fields = getattr(item, "extra_fields", None) - if extra_fields is None: - extra_fields = {} - setattr(item, "extra_fields", extra_fields) - extra_fields[GROUP_GENERATE_TIME_KEY] = elapsed - return result - - async def put_generated_group(self, group: list[RolloutState]) -> bool: - # 只有完整生成的 group 才需要业务有效性过滤;ABORTED / EXPIRED 保留原状态供重试或统计。 - is_completed = get_group_status(group) == Status.COMPLETED - produced_tokens = sum(len(item.response_ids) for item in group if item.response_ids is not None) - if is_completed: - rewards_sum = 0.0 - rewards_count = 0 - for item in group: - if item.reward is None or "score" not in item.reward: - logger.warning( - f"Missing reward score in item (uid: {item.uid}) of completed group for task {self.task_name}. This item will be skipped in reward statistics." - ) - continue - rewards_sum += float(item.reward["score"]) # type: ignore[index] - rewards_count += 1 - self.progress.add_raw_rewards(self.task_name, rewards_sum, rewards_count) - is_valid = self.is_valid_sample_fn(group) - if not is_valid: - for item in group: - item.status = Status.FILTERED - reset_rollout_response(item) - await self.replay_buffer.put( - group, - self.task_name, - model_step=self.model_step, - current_train_step=self.consumer_step, - stale_threshold=self.stale_threshold, - ) - self.progress.add_produced(self.task_name, samples=len(group), tokens=produced_tokens) - # replay_buffer.put 可能把 stale group 转为 EXPIRED,返回前重新判断是否仍可训练。 - is_completed = get_group_status(group) == Status.COMPLETED - return is_completed + async def completed_count(self) -> int: + return await self.replay_buffer.count(task_name=self.task_name, group_status=Status.COMPLETED) class ProduceStrategyConfig(ABC, BaseModel): @@ -460,12 +162,12 @@ def build( class AsyncProduceStrategyConfig(ProduceStrategyConfig): - """Configuration for asynchronous rollout production. + """Configuration for colocated asynchronous rollout production. - The asynchronous strategy keeps producing rollout samples in the background - and stores them in the replay buffer. It can oversample, allow partial - rollout continuation, and discard samples that are too stale relative to the - current training step. + The colocated asynchronous strategy produces rollout samples concurrently + within one ``AgentLoopManager.produce_batch`` call and stores them in the + replay buffer. It can oversample, allow partial rollout continuation, and + discard samples that are too stale relative to the current training step. Args: is_valid_sample_fn (IsValidSampleFn): Function used to decide whether a @@ -530,89 +232,19 @@ def __init__( self.should_continue_fn = should_continue_fn @abstractmethod - async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: ... + async def produce_batch(self, ctx: ProduceContext) -> None: ... async def pause_produce(self, ctx: ProduceContext) -> float: return 0.0 - def is_model_expired(self, train_step: int, model_step: int) -> bool: - # 默认同步策略没有跨权重版本的后台样本,只有异步策略需要判定模型过期。 - return False - def pending_task_count(self) -> int: return 0 -class _PendingTasks: - """AsyncProduceStrategy 的并发 pending task 集合。 - - 这里只封装 pending set 的并发协议,不理解 sampler / rollout / replay buffer: - - wait 使用快照,随后必须二次 claim,避免 produce 和 pause 重复处理同一个 done task。 - - cancel 前先原子 claim 并清空集合,避免 cancel 后又被其他路径 claim。 - - schedule one 在锁内同时检查 abort 和 pending 数,避免 pause 已触发后继续新增 task。 - """ - - def __init__(self) -> None: - self._tasks: set[asyncio.Task] = set() - self._lock = asyncio.Lock() - - def count(self) -> int: - # 只暴露已经纳入 pending 集合的 task 数量。 - return len(self._tasks) - - async def claim_ready(self) -> set[asyncio.Task]: - async with self._lock: - ready = {task for task in self._tasks if task.done()} - self._tasks.difference_update(ready) - return ready - - async def wait_and_claim(self, *, timeout_s: float) -> set[asyncio.Task]: - async with self._lock: - snapshot = set(self._tasks) - if not snapshot: - return set() - - done, _ = await asyncio.wait(snapshot, timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED) - async with self._lock: - claimed = done & self._tasks - self._tasks.difference_update(claimed) - return claimed - - async def schedule_one( - self, - *, - max_pending: int, - should_abort: Callable[[], bool], - spawn_one: Callable[[], Awaitable[asyncio.Task]], - ) -> bool: - async with self._lock: - if should_abort() or len(self._tasks) >= max_pending: - return False - # 保持“检查 abort / pending 数 / 新增 task”这一组操作原子化。 - self._tasks.add(await spawn_one()) - return True - - async def _claim_all(self) -> set[asyncio.Task]: - async with self._lock: - claimed = set(self._tasks) - self._tasks.clear() - return claimed - - async def cancel_all(self) -> int: - tasks = await self._claim_all() - if not tasks: - return 0 - logger.warning(f"Cancelling {len(tasks)} pending rollout tasks.") - await cancel_and_drain(list(tasks)) - return len(tasks) - - class SyncProduceStrategy(ProduceStrategy): - async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: + async def produce_batch(self, ctx: ProduceContext) -> None: pending_tasks = set() completed_sample_count = await ctx.replay_buffer.count(task_name=ctx.task_name, group_status=Status.COMPLETED) - # TODO: 是否支持 SyncProduceStrategy 在非共卡时使用?如果支持,下面这行注释掉? - # assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start." for _ in range(ctx.task_batch_size): rollout_state = await ctx.sampler.sample(task_name=ctx.task_name) @@ -624,44 +256,38 @@ async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: progress_displayer = _ProgressDisplayer.create( strategy_name=self.__class__.__name__, task_name=ctx.task_name, - total=ctx.target_abs, + total=ctx.batch_target, initial=completed_sample_count, ) - try: - while self.should_continue_fn(completed_sample_count, ctx.task_batch_size): - if not pending_tasks: - logger.warning("[SyncProduceStrategy] All tasks are done but not enough samples collected.") - break - done_tasks, pending_tasks = await asyncio.wait( - pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED - ) - # 如果要过滤,在这个地方处理,然后加入到 replay buffer - # 如果被过滤的数据就放到 put_to_filtered pool 中 - for task in done_tasks: - items = task.result() - - is_completed = await ctx.put_generated_group(items) - if not is_completed: - continue + while self.should_continue_fn(completed_sample_count, ctx.task_batch_size): + if not pending_tasks: + logger.warning("[SyncProduceStrategy] All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + # put_generated_group 负责过滤和入库。 + for task in done_tasks: + items = task.result() - completed_sample_count += 1 - progress_displayer.update(completed_sample_count) + is_completed = await ctx.put_generated_group(items) + if not is_completed: + continue - while len(pending_tasks) + completed_sample_count < ctx.task_batch_size and self.should_continue_fn( - completed_sample_count, ctx.task_batch_size - ): - rollout_state = await ctx.sampler.sample(task_name=ctx.task_name) - task = create_task(ctx.generate_group(rollout_state)) - pending_tasks.add(task) - finally: - progress_displayer.close() + completed_sample_count += 1 + progress_displayer.update(completed_sample_count) - return ProduceBatchStatus.NORMAL + while len(pending_tasks) + completed_sample_count < ctx.task_batch_size and self.should_continue_fn( + completed_sample_count, ctx.task_batch_size + ): + rollout_state = await ctx.sampler.sample(task_name=ctx.task_name) + task = create_task(ctx.generate_group(rollout_state)) + pending_tasks.add(task) + progress_displayer.close() class AsyncProduceStrategy(ProduceStrategy): - # Local retry interval for re-sending pause/abort while pending tasks drain. - PERIODIC_ABORT_INTERVAL_S = 5.0 + PERIODIC_ABORT_INTERVAL_S = PERIODIC_ABORT_INTERVAL_S def __init__( self, @@ -693,123 +319,27 @@ def __init__( self.sync_weights_interval = sync_weights_interval self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval) self.tail_batch_trigger_size = tail_batch_trigger_size - self._pending_tasks = _PendingTasks() - - def is_model_expired(self, train_step: int, model_step: int) -> bool: - staleness = calculate_seq_staleness(model_step, train_step) - return staleness >= self.stale_threshold + self._local_pending_tasks: set[asyncio.Task] = set() def pending_task_count(self) -> int: - return self._pending_tasks.count() - - async def _put_claimed( - self, - claimed_tasks: set[asyncio.Task], - ctx: ProduceContext, - available_base: int | None = None, - progress_displayer: _ProgressDisplayer | None = None, - ) -> None: - completed_count = 0 - for task in claimed_tasks: - items = task.result() - is_completed = await ctx.put_generated_group(items) - if is_completed: - completed_count += 1 - if is_completed and available_base is not None and progress_displayer is not None: - progress_displayer.update(available_base + completed_count) - - async def _pause_agent_loop(self, ctx: ProduceContext) -> None: - pause_request_start = time.perf_counter() - if isinstance(ctx.agent_loop, ray.actor.ActorHandle): - pause_future = ctx.agent_loop.pause.remote() - else: - pause_future = ctx.agent_loop.pause() - try: - await asyncio.wait_for(pause_future, timeout=AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S) - except asyncio.TimeoutError: - logger.warning( - f"Agent loop pause timed out: task={ctx.task_name}, timeout_s={AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S}, " - f"elapsed={time.perf_counter() - pause_request_start:.2f}s, " - f"pending={self._pending_tasks.count()}" - ) - except Exception: - logger.exception( - f"Agent loop pause failed: task={ctx.task_name}, " - f"elapsed={time.perf_counter() - pause_request_start:.2f}s, " - f"pending={self._pending_tasks.count()}" - ) + return len(self._local_pending_tasks) async def pause_produce(self, ctx: ProduceContext) -> float: - pause_start = time.perf_counter() - if self._pending_tasks.count() == 0: - return 0.0 - - pending_pause_tasks = {create_task(self._pause_agent_loop(ctx))} - initial_pending_count = self._pending_tasks.count() - - logger.info( - f"Pause signal loop started for task {ctx.task_name}. " - f"Waiting for {initial_pending_count} pending tasks to complete. " - f"periodic_abort_interval_s={self.PERIODIC_ABORT_INTERVAL_S}, " - f"producer_pause_pending_task_timeout_s={PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}" + return await pause_pending_tasks( + pending_tasks=self._local_pending_tasks, + ctx=ctx, + put_claimed_task=lambda task: ctx.put_generated_group(task.result()), ) - cleanup_start_time = time.perf_counter() - next_periodic_abort_time = cleanup_start_time + self.PERIODIC_ABORT_INTERVAL_S - while True: - elapsed_time = time.perf_counter() - cleanup_start_time - if elapsed_time > PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S: - # 超时强制取消所有pending的任务 - cancelled_count = await self._pending_tasks.cancel_all() - logger.warning( - f"Cleanup timeout of {PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}s reached. " - f"Forcefully cancelling {cancelled_count} remaining tasks. " - f"task={ctx.task_name}" - ) - break - if self._pending_tasks.count() == 0: - break - current_time = time.perf_counter() - pending_pause_tasks = {task for task in pending_pause_tasks if not task.done()} - - # 定时发送 pause 信号 - if self.PERIODIC_ABORT_INTERVAL_S > 0 and current_time >= next_periodic_abort_time: - pending_pause_tasks.add(create_task(self._pause_agent_loop(ctx))) - next_periodic_abort_time += self.PERIODIC_ABORT_INTERVAL_S - - claimed_done = await self._pending_tasks.wait_and_claim(timeout_s=1) - for task in claimed_done: - paused_items = task.result() - await ctx.put_generated_group(paused_items) - await cancel_and_drain(list(pending_pause_tasks)) - pause_time = time.perf_counter() - pause_start - logger.info(f"pause_produce completed for task {ctx.task_name} within {pause_time}s.") - return pause_time - - async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: - if ctx.task_name not in ctx.progress.consumed_samples: - raise KeyError(f"ProduceProgress.consumed_samples missing task_name={ctx.task_name!r}") + async def produce_batch(self, ctx: ProduceContext) -> None: if ctx.task_name not in ctx.progress.target_samples: raise KeyError(f"ProduceProgress.target_samples missing task_name={ctx.task_name!r}") - if ctx.target_abs <= 0: - return ProduceBatchStatus.NORMAL - - # TODO: place this check just before while loop - if ctx.should_abort(): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - if self.is_model_expired(ctx.train_step, ctx.model_step): - return ProduceBatchStatus.EXPIRED_BATCH + # 共卡 async 的 pending 只属于本次 produce_batch。 + self._local_pending_tasks = set() - # 先回收跨 produce_batch 调用遗留的已完成任务,避免 done task 长期留在 pending 集合里。 - claimed_done = await self._pending_tasks.claim_ready() - await self._put_claimed(claimed_done, ctx) - - # TODO: remove this check - if ctx.should_abort(): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - if self.is_model_expired(ctx.train_step, ctx.model_step): - return ProduceBatchStatus.EXPIRED_BATCH + if ctx.batch_target <= 0: + return expired_count = await ctx.expired_count() sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size @@ -819,13 +349,12 @@ async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: f"(threshold: {self.tail_batch_trigger_size}). Enabling tail batch mode." ) - # 本轮 produce_batch 的必要累计目标固定;normal 模式只按当前 task batch 追加固定超发预算。 - # tail-batch 模式只补必要缺口,新增任务固定从 EXPIRED pool 取,不再扩大超发窗口。 - target_abs = ctx.target_abs + # normal 使用固定超发预算;tail-batch 只补必要缺口。 + batch_target = ctx.batch_target oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * ctx.task_batch_size) - scheduled_target = target_abs + oversample_budget + scheduled_target = batch_target + oversample_budget logger.info( - f"Starting produce_batch for task {ctx.task_name} with target_abs={target_abs}, " + f"Starting produce_batch for task {ctx.task_name} with batch_target={batch_target}, " f"oversample_budget={oversample_budget}, scheduled_target={scheduled_target}." ) @@ -838,50 +367,37 @@ async def spawn_one() -> asyncio.Task: ) ) - initial_available = await ctx.available_count() + initial_available = await ctx.completed_count() progress_displayer = _ProgressDisplayer.create( strategy_name=self.__class__.__name__, task_name=ctx.task_name, - total=ctx.target_abs, + total=ctx.batch_target, initial=initial_available, ) - try: - while True: - if ctx.should_abort(): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - if self.is_model_expired(ctx.train_step, ctx.model_step): - return ProduceBatchStatus.EXPIRED_BATCH - - available = await ctx.available_count() - progress_displayer.update(available) - if not self.should_continue_fn(available, target_abs): - return ProduceBatchStatus.NORMAL - - pending_count = self._pending_tasks.count() - desired_pending = max(0, scheduled_target - available) - if available + pending_count < scheduled_target: - while await self._pending_tasks.schedule_one( - max_pending=desired_pending, - should_abort=ctx.should_abort, - spawn_one=spawn_one, - ): - pass - # TODO: remove this check, because will check it when exit if statement, it's redundant - if ctx.should_abort(): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - - if ctx.should_abort(): - return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT - if self._pending_tasks.count() == 0: - logger.warning("All tasks are done but not enough samples collected.") - return ProduceBatchStatus.NORMAL - - claimed_done = await self._pending_tasks.wait_and_claim(timeout_s=1) - await self._put_claimed( - claimed_done, - ctx, - available_base=available, - progress_displayer=progress_displayer, - ) - finally: - progress_displayer.close() + while True: + available = await ctx.completed_count() + progress_displayer.update(available) + if not self.should_continue_fn(available, batch_target): + break + + pending_count = len(self._local_pending_tasks) + desired_pending = max(0, scheduled_target - available) + if available + pending_count < scheduled_target: + while len(self._local_pending_tasks) < desired_pending: + self._local_pending_tasks.add(await spawn_one()) + + if not self._local_pending_tasks: + logger.warning("All tasks are done but not enough samples collected.") + break + + done_tasks, _ = await asyncio.wait( + set(self._local_pending_tasks), timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + self._local_pending_tasks.difference_update(done_tasks) + await _put_claimed_tasks( + done_tasks, + ctx, + available_base=available, + progress_displayer=progress_displayer, + ) + progress_displayer.close() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 42be1c976f..797e4950da 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -25,11 +25,14 @@ from xtuner.v1.patch import patch_default_save_plan from xtuner.v1.rl.advantage import BaseAdvantageConfig, GRPOAdvantageConfig from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManager, AgentLoopManagerConfig, + DisaggAgentLoopManager, + DisaggAgentLoopManagerConfig, ProduceBatchResult, ProduceBatchStatus, ) -from xtuner.v1.rl.agent_loop_manager.producer import default_should_continue_fn +from xtuner.v1.rl.agent_loop_manager.produce_utils import default_should_continue_fn from xtuner.v1.rl.evaluator import EvaluatorConfig from xtuner.v1.rl.replay_buffer import ( AsyncReplayBufferConfig, @@ -194,7 +197,7 @@ def get_train_seq_ctx( seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") position_ids = _to_cpu_tensor(position_ids, dtype=torch.long) if position_ids is not None and len(position_ids.shape) == 3: - # qwen3vl 需要特殊处理,其余的不需要额外处理 + # VLM 位置编码需要补 response 段。 max_value = position_ids.max(dim=-1).values # (3,1) response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange( 1, len_response_ids + 1, device=max_value.device @@ -315,7 +318,7 @@ class BaseRLTrainerConfig(BaseModel): rollout_config: RolloutConfig tokenizer_path: str | Path replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig = SyncReplayBufferConfig() - agent_loop_manager_cfg: AgentLoopManagerConfig + agent_loop_manager_cfg: AgentLoopManagerConfig | DisaggAgentLoopManagerConfig eval_agent_loop_manager_cfg: AgentLoopManagerConfig | None = None evaluator_config: EvaluatorConfig | None = None load_from: str | Path @@ -450,6 +453,7 @@ class RLColocateTrainerConfig(BaseRLTrainerConfig): ) """ + agent_loop_manager_cfg: AgentLoopManagerConfig resources: AcceleratorResourcesConfig def build(self) -> "RLColocateTrainer": @@ -540,6 +544,7 @@ class RLDisaggregatedTrainerConfig(BaseRLTrainerConfig): ) """ + agent_loop_manager_cfg: DisaggAgentLoopManagerConfig train_resources: AcceleratorResourcesConfig rollout_resources: AcceleratorResourcesConfig @@ -555,6 +560,8 @@ class BaseRLTrainer: train_controller: TrainingController rollout_controller: RolloutControllerProxy + agent_loop_manager: AgentLoopManager | DisaggAgentLoopManager + eval_agent_loop_manager: AgentLoopManager _debug_train_files: dict[int, Path] def _init_common(self, cfg: BaseRLTrainerConfig, *, meta_path: str, logger_tag: str) -> None: @@ -707,23 +714,27 @@ def _init_runtime_flags(self, cfg: BaseRLTrainerConfig) -> None: def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) -> None: self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True) - self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( + agent_loop_manager = cfg.agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, sync_weights_interval=cfg.sync_weights_interval, ) + self.agent_loop_manager = cast(AgentLoopManager | DisaggAgentLoopManager, agent_loop_manager) if self._enable_evaluate: assert cfg.eval_agent_loop_manager_cfg is not None - self._eval_replay_buffer = SyncReplayBufferConfig().build() - self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build( - rollout_controller=self.rollout_controller, - tokenizer=self.tokenizer, - replay_buffer=self._eval_replay_buffer, - logger=self.logger, - sync_weights_interval=cfg.sync_weights_interval, + # 评测始终走一次同步 rollout。 + self.eval_agent_loop_manager = cast( + AgentLoopManager, + cfg.eval_agent_loop_manager_cfg.build( + rollout_controller=self.rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + logger=self.logger, + sync_weights_interval=cfg.sync_weights_interval, + ), ) total_eval_samples = len(self.eval_agent_loop_manager.data_sampler) @@ -759,7 +770,7 @@ def _resolve_load_checkpoint_cfg( return load_checkpoint_cfg def _resume_train_controller_and_state(self, checkpoint_path: Path | str) -> Path: - # 子类只复用训练 worker 和 train_state 恢复,权重同步流程各自维护。 + # 权重同步恢复由共卡/非共卡子类分别处理。 self.logger.info(f"Resume train controller and state from {checkpoint_path}") checkpoint_path = Path(checkpoint_path) self.train_controller.resume(self._load_checkpoint_cfg) @@ -773,11 +784,7 @@ def _resume_train_controller_and_state(self, checkpoint_path: Path | str) -> Pat async def _resume_agent_loop_manager(self, checkpoint_path: Path | str) -> int: self.logger.info(f"Resume agent_loop_manager from {checkpoint_path}") checkpoint_path = Path(checkpoint_path) - # asyncio_run 只能出现在 trainer 的同步边界: - # - colocate 的 __init__/fit/_sync_weights_and_save 仍是同步入口,可以显式包一层; - # - disaggregated 的 _fit 已经在 asyncio_run 启动的事件循环里,内部必须全程 await。 - # 因此 agent_loop_manager / replay_buffer 的 save/resume 必须保持 async;如果它们内部再调用 - # asyncio_run,save/resume 会在 disaggregated 训练循环里触发 nested asyncio_run 失败。 + # manager/replay_buffer 保持 async;asyncio_run 只放在 trainer 同步边界。 saved_model_step = await self.agent_loop_manager.resume(checkpoint_path) return saved_model_step @@ -919,7 +926,7 @@ def _train_one_batch( self._save_trajectories(train_batch, train_trajectory_path) self.logger.info(f"Train step {train_step} train trajectories saved to {train_trajectory_path}") - # 共卡需要先确认 rollout worker 可恢复,再释放 rollout,最后把训练 worker onload;非共卡不走这些动作。 + # 共卡训练前切换资源:检查 rollout -> offload rollout -> onload train。 if offload_rollout_before_train: ray.get( self.rollout_controller.ensure_workers_healthy_before_training.remote(), @@ -1133,7 +1140,6 @@ def _prepare_train_data( response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() # 返回的 routed_experts 不包括 eos 的值,实际上也不需要,需要减一 - # TODO: verl tool agent loop 是否需要? input_ids = prompt_ids + response_ids[:-1] prompt_len_list.append(len(prompt_ids)) @@ -1533,8 +1539,9 @@ def _check_chat_completions_with_retry(base_url: str, max_attempts: int = 5, int class RLColocateTrainer(BaseRLTrainer): _META_PATH = ".xtuner_rl_colocate_trainer" + agent_loop_manager: AgentLoopManager - # 共卡 trainer 保留自己的资源编排、resume、主循环和权重同步;通用保存、日志仍在 BaseRLTrainer。 + # 共卡保留资源切换和权重同步流程;通用保存、日志在 BaseRLTrainer。 def __init__(self, cfg: RLColocateTrainerConfig): self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLTrainer") self._num_workers = float(cfg.resources.num_workers) @@ -1579,9 +1586,7 @@ def __init__(self, cfg: RLColocateTrainerConfig): ) return - # Free trainer-side GPU memory before bringing up colocated rollout workers. - # Backends like sglang may size KV cache against their own target utilization - # instead of the trainer's transient footprint, which can cause init-time OOM. + # 先释放训练显存,再启动共卡 rollout worker。 self.train_controller.offload(target="all") self.rollout_controller = self._rollout_config.build(self._pg) @@ -1634,7 +1639,7 @@ def fit(self): self.logger.info(f"Train step {train_step}/{self._total_train_steps} start") step_timer_dict = {} with timer("step", step_timer_dict): - # 共卡路径一次调用内完成 rollout 生产和 replay buffer 消费。 + # 共卡一次调用内完成生产和消费。 self.logger.info( f"[Step {train_step}] start to generate rollout experience for train step {train_step} with model step {model_step}" ) @@ -1712,8 +1717,7 @@ def _get_colocate_rollout_model_step(self, train_step: int) -> int: return previous_step - (previous_step % self._sync_weights_interval) def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool: - """Save state and switch colocated resources back to rollout - workers.""" + """保存后切回共卡 rollout 资源。""" should_sync_weights = train_step % self._sync_weights_interval == 0 will_evaluate = self._enable_evaluate and train_step % self._evaluate_step == 0 needs_rollout_ready = train_step < self._total_train_steps or will_evaluate @@ -1745,6 +1749,7 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool class RLDisaggregatedTrainer(BaseRLTrainer): _META_PATH = ".xtuner_rl_disaggregated_trainer" + agent_loop_manager: DisaggAgentLoopManager def __init__(self, cfg: RLDisaggregatedTrainerConfig): self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLDisaggTrainer") @@ -1765,9 +1770,7 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) - # 在非共卡使用模式时,生产者和消费者并发执行 - # 为了让生产者和消费者配合,不能引入生产中的早停机制,否则生产不够,消费者会被阻塞 - # 所以 should_continue_fn 必须为 default_should_continue_fn + # 非共卡 producer 不允许早停,否则 consumer 可能永久等不到 batch。 for task_runner in self.agent_loop_manager.task_runners: if task_runner.produce_strategy.should_continue_fn is not default_should_continue_fn: raise ValueError( @@ -1816,9 +1819,37 @@ def _resume_from_checkpoint(self, checkpoint_path: Path | str) -> None: asyncio_run(self.agent_loop_manager.continue_produce(model_step=saved_model_step)) def fit(self): - # 对外保留同步 fit 接口,内部用 async loop 组织 producer/consumer。 + # 对外同步 fit;内部用 async loop 组织 producer/consumer。 return asyncio_run(self._fit()) + async def _get_batch_or_raise_producer_failure( + self, + producer_task: asyncio.Task, + *, + batch_size: int, + train_step: int, + ) -> ProduceBatchResult: + # consumer 等 batch 时同步观察后台 producer;producer 异常立即终止训练。 + get_batch_task = create_task( + self.agent_loop_manager.get_batch(batch_size, train_step=train_step), + done_callbacks=[], + ) + done, _ = await asyncio.wait( + {get_batch_task, producer_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + if producer_task in done: + if not get_batch_task.done(): + get_batch_task.cancel() + await asyncio.gather(get_batch_task, return_exceptions=True) + if producer_task.cancelled(): + raise asyncio.CancelledError + if producer_task.exception() is not None: + producer_task.result() + raise RuntimeError("Disaggregated background producer exited before training finished.") + + return get_batch_task.result() + async def _fit(self): self.logger.info("Start RL disaggregated training") if self._cur_step >= self._total_train_steps: @@ -1827,16 +1858,19 @@ async def _fit(self): if self._enable_initial_evaluate: await self._run_initial_evaluate() + # 初始 eval 会暂停 rollout generation;启动后台 producer 前先恢复。 + await self.agent_loop_manager.continue_produce(model_step=self._cur_step) self._benchmark_start_time_s = time.perf_counter() self._benchmark_training_samples = 0 self._benchmark_training_tokens = 0 - # 后台 producer 只负责持续往 replay buffer 写数据,前台 trainer 通过 get_batch 消费。 + # 后台 producer 写 buffer,前台 trainer 取 batch。 producer_task = create_task( self.agent_loop_manager.produce_loop( batch_size=self.train_batch_size, - ) + ), + done_callbacks=[], ) try: # train_step 表示“下一步待完成训练”;空 expired 不算完成,所以必须用 while 支持重试同一步。 @@ -1848,8 +1882,10 @@ async def _fit(self): eval_log_info = {} with timer("step", step_timer_dict): with timer("get_batch", step_timer_dict): - produce_result = await self.agent_loop_manager.get_batch( - self.train_batch_size, train_step=train_step + produce_result = await self._get_batch_or_raise_producer_failure( + producer_task, + batch_size=self.train_batch_size, + train_step=train_step, ) if XTUNER_DETERMINISTIC: produce_result.rollout_states = sort_rollout_state_for_deterministic( @@ -1857,23 +1893,20 @@ async def _fit(self): ) train_batch = produce_result.rollout_states - # EXPIRED_BATCH 分两类:空 batch 是控制面同步;非空 batch 仍然是可训练数据。 + # 空 expired 只触发同步;非空 expired 仍需训练。 empty_expired_batch = produce_result.status == ProduceBatchStatus.EXPIRED_BATCH and not train_batch if empty_expired_batch: - # 没有完成训练,能同步的只能是上一轮已经完成的 Model Step。 sync_model_step = train_step - 1 self.logger.info( "Skip train step because rollout model is expired and a newer model already exists; " f"sync completed model_step={sync_model_step} first." ) else: - # 非空 expired 必须训练出当前 step 的新模型版本,否则 producer 没有更新权重可恢复。 assert train_batch, ( "RLDisaggregatedTrainer expects get_batch() to return non-empty rollout_states " "unless status is empty EXPIRED_BATCH." ) - # 非共卡训练要求后台 producer 在训练当前 batch 时继续推进; - # 同步训练路径放到线程里执行,避免 ray.get / 文件写入阻塞事件循环。 + # 训练路径放到线程里执行,避免阻塞事件循环。 train_log_info = await asyncio.to_thread( self._train_one_batch, train_batch, @@ -1884,7 +1917,7 @@ async def _fit(self): ) sync_model_step = train_step - # 后续保存、同步、评测、恢复 producer 都以“已完成的 Model Step”为唯一口径。 + # 保存、同步、评测、恢复 producer 都以已完成 model_step 为口径。 need_sync = ( empty_expired_batch or produce_result.status == ProduceBatchStatus.EXPIRED_BATCH @@ -1893,9 +1926,9 @@ async def _fit(self): ) if need_sync: - # 同步前先暂停后台 producer,避免 save/sync 时还有 pending rollout 继续写 buffer。 + # 同步前暂停 producer,避免 pending rollout 继续写 buffer。 with timer("pause_produce", step_timer_dict): - await self.agent_loop_manager.pause_produce(use_global_progress=True) + await self.agent_loop_manager.pause_produce() await self._sync_weights_and_save(sync_model_step, step_timer_dict) @@ -1904,14 +1937,14 @@ async def _fit(self): and sync_model_step > 0 and sync_model_step % self._evaluate_step == 0 ): - # eval 放在恢复 producer 前,避免后台生产抢占 rollout 资源。 + # eval 在恢复 producer 前执行,避免资源抢占。 with timer("evaluation", step_timer_dict): eval_log_info.update(await self._run_evaluation(sync_model_step)) await self.agent_loop_manager.continue_produce(model_step=sync_model_step) if empty_expired_batch: - # 空 expired 没有完成训练,不能 log 成已完成 step,也不能推进 _cur_step。 + # 空 expired 不推进训练步。 continue self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) self._cur_step = train_step @@ -1921,7 +1954,7 @@ async def _fit(self): await producer_task async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): - # 非共卡已经在 _fit 里暂停 producer;这里保持静止态下的 save -> bind -> update 顺序。 + # producer 已暂停;保持 save -> bind -> update 顺序。 with timer("save_ckpt", step_timer_dict): await self._maybe_save_checkpoint(model_step) self._maybe_save_hf(model_step) @@ -1932,6 +1965,6 @@ async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): self.update_weights() def update_weights(self): - # producer 的 pause/continue 由 AgentLoopManager 控制,避免这里提前恢复 rollout 影响 eval 顺序。 + # rollout 恢复由 AgentLoopManager 控制。 self.train_controller.update_weights() self.logger.info("Rollout workers update weights successfully in disaggregated mode")