[Refactor] Unify agent loop output protocol on AgentMessage#1904
Conversation
…]; judger / _fill_rollout_state read finish_info / content from it directly.
- Move tool_turns counting into agent_in_localhost_loop (single source); trainer / evaluator just read extra_fields[agent_tool_turns] and emit tool_turns/{mean,min,max} to tensorboard.
- Trajectory dump: drop redundant fields (final_assistant_finish_reason, raw_prompt, lengths.input_tokens) and helper chain; messages/tools/tool_turns moved under agent.* sub-key. Switch to one-line JSONL records.
- Drop _to_json_safe / _dump_jsonl_record wrappers; explicit model_dump only where needed (StageRecord). Remove dead isinstance(Tensor) branches.
- Rollout: add EADDRINUSE retry for session_server and worker launch; add sample_timeout_s; disable replay buffer checkpoint by default.
- Add asyncio_diagnostics module (USR1 task dump + asyncio_run watchdog) for prod deadlock troubleshooting.
- Drop local _TokenBucket (use lagent.utils.rate_limiter.FairAsyncTokenBucket via API-level rate limiting in tools).
d6fcac7 to
c21ce32
Compare
|
@claude review |
| try: | ||
| # reward_model stays as-is (dataset-provided ground_truth/style etc.). | ||
| # Per-rollout artifacts (response message, agent trace) flow through extra_fields. | ||
| reward_model = dict(item.reward_model or {}) | ||
| segment = item.artifacts["messages"][-1] | ||
| response_message = item.artifacts.get("response_message") or {} | ||
| content = response_message.get("content") | ||
| response = content if isinstance(content, str) else (str(content) if content is not None else "") | ||
|
|
||
| messages = item.artifacts["messages"][-1]["messages"] | ||
| tool_turns = sum( | ||
| 1 for message in messages if isinstance(message.get("tool_calls"), list) and message["tool_calls"] | ||
| ) | ||
| reward_model.setdefault("agent_trace", messages) | ||
| reward_model.setdefault("num_turns", tool_turns) | ||
|
|
||
| response = str(item.artifacts.get("response") or "") | ||
| rollout_state = RolloutState( | ||
| message=[{"role": "user", "content": item.instruction}], | ||
| response=response, | ||
| reward_model=reward_model, | ||
| extra_fields={ | ||
| "agent_messages": segment["messages"], | ||
| "response_message": response_message, | ||
| }, | ||
| status=Status.COMPLETED, |
There was a problem hiding this comment.
Claude: Warning — Breaking change to judger contract.
Previously, reward_model.setdefault("agent_trace", messages) and reward_model.setdefault("num_turns", tool_turns) made conversation history available to custom judger implementations via rollout_state.reward_model. This PR moves that data to extra_fields["agent_messages"] and extra_fields["response_message"].
Any custom/external judger that reads rollout_state.reward_model["agent_trace"] or rollout_state.reward_model["num_turns"] will silently receive None (dict lookup miss) instead of the expected data, causing incorrect reward computation without raising an error.
Consider preserving the setdefault calls alongside the new extra_fields plumbing, or documenting this as a breaking protocol change in the PR description.
| ) | ||
| return self._fail_rollout_state( | ||
| rollout_state, |
There was a problem hiding this comment.
Claude: Warning — finish_reason semantic expansion affects ProduceBatchResult.
Previously, finish_reason was either "stop" or "error". This PR introduces "timeout" as a third value (from the new asyncio.TimeoutError handler) and also propagates arbitrary values from response_message.get("finish_reason") — effectively making finish_reason open-ended.
Downstream consumers that switch on finish_reason == "error" to classify failures will no longer catch timeouts. If ProduceBatchResult accounting or retry logic depends on enumerating finish reasons, this is a behavioral change.
ProduceBatchResult impact: finish_reason can now be "timeout" or any string from the LLM response, changing how failed/completed samples are classified by downstream consumers.
| ) | ||
|
|
||
|
|
||
| def _extract_reward_payload(item: AgentRolloutItem) -> dict[str, Any] | None: | ||
| for record in item.judgers.values(): | ||
| reward = record.metadata.get("reward") | ||
| if isinstance(reward, dict): | ||
| payload = dict(reward) | ||
| if item.reward is not None: | ||
| payload.setdefault("score", item.reward) | ||
| return payload | ||
| if item.reward is not None: | ||
| return {"score": item.reward} |
There was a problem hiding this comment.
Claude: Warning — _extract_reward_payload changes the shape of rollout_state.reward.
Previously: {"score": item.reward} (always single key).
Now: may include additional keys from record.metadata["reward"] (the full judger payload), plus "score" set via setdefault.
Downstream code that does data.reward["score"] will still work. But if any consumer iterates over data.reward keys or serializes the full dict (e.g. trajectory files, metrics), it will now see extra judger-specific keys. The trajectory dump already writes "reward_payload": data.reward, which captures this expanded dict — verify this is intentional and that downstream JSONL parsers tolerate variable-width reward dicts.
| response_message = output.model_dump(mode="json") if hasattr(output, "model_dump") else None | ||
| if response_message is None: | ||
| raise TypeError("Agent forward must return an AgentMessage-like object.") |
There was a problem hiding this comment.
Claude: Warning — Hard failure on non-AgentMessage output.
Previously, the stage was permissive — output.content if hasattr(output, "content") else output handled plain strings. Now anything without model_dump raises TypeError. If any configured agent returns a raw string or dict (e.g. a simpler agent or a mock in tests), this will crash the entire sample.
If backward-compatibility with non-AgentMessage outputs is needed, consider a softer fallback (e.g. constructing a minimal response_message dict from the raw output).
| raise TypeError("Agent messages trace segment.messages must be a list.") | ||
| item.artifacts["messages"] = messages | ||
| result = StageResult(stdout=item.artifacts["response"], return_code=0) | ||
| content = response_message.get("content") |
There was a problem hiding this comment.
Claude: Nit — The content extraction pattern content if isinstance(content, str) else (str(content) if content is not None else "") appears 4 times across this PR (here, agent_in_localhost_loop.py:249 and :278, and judger.py:50). Consider a small helper to keep the semantics in one place.
|
Claude: ## Summary Solid refactoring that unifies the agent loop output protocol around the full IssuesWarning
Nit
VerdictCOMMENT — The protocol changes are intentional per the PR description but warrant explicit documentation of the breaking judger contract change. No correctness bugs found. |
| ground_truth = None | ||
| if data.reward_model is not None: | ||
| ground_truth = data.reward_model.get("ground_truth") | ||
| response_text_len = len(self.tokenizer.encode(response or "", add_special_tokens=False)) |
There was a problem hiding this comment.
response_text_len 和 response_train_len 区别是啥,都是 id 长度吧?
| message=[{"role": "user", "content": item.instruction}], | ||
| response=response, | ||
| reward_model=reward_model, | ||
| extra_fields={ |
| "total_len": len(rewards), | ||
| } | ||
| json.dump(summary, f, ensure_ascii=False, indent=2) | ||
| json.dump(summary, f, ensure_ascii=False, separators=(",", ":")) |
| "status": data.status.value if hasattr(data.status, "value") else str(data.status), | ||
| "finish_reason": data.finish_reason, | ||
| "error_msg": data.error_msg, | ||
| "prompt": data.message, |
There was a problem hiding this comment.
raw_prompt 要保留,用于确认对话模板是否正确
| "total_len": len(rewards), | ||
| } | ||
| json.dump(summary, f, ensure_ascii=False, indent=2) | ||
| json.dump(summary, f, ensure_ascii=False, separators=(",", ":")) |
| "response_len": response_len, | ||
| "lengths": { | ||
| "num_tokens": data.num_tokens, | ||
| "response_train_tokens": response_train_len, |
There was a problem hiding this comment.
既然都放到 lenths 里面了,那外面的response_len就去掉,而且 response_train_tokens 和 response_text_tokens区别是啥?我跑了下发现是一样的
- Drop redundant `lengths` nested struct: response_train_tokens duplicates
response_len, response_text_tokens is misleading in agent loops (single_turn
matches response_train but localhost only covers the last turn). Keep one
definition: response_len = len(response_ids).
- Promote prompt length to top-level as `prompt_len` (was `lengths.num_tokens`,
which actually held len(prompt_ids) per local_run.py).
- Unify second trajectory dump path with the first (response_len = len(response_ids)
instead of encode(response); behaviour matches in eval where response_ids is
None and _get_trajectory_response_ids falls back to encode(data.response)).
- Drop dead extra_fields["raw_prompt"] write in agent_in_localhost_loop and
the matching dead .get("raw_prompt") read in the trajectory item — no
reader remained after the earlier trajectory cleanup.
…#1904) * - Stage surfaces full AgentMessage dump as artifacts[response_message]; judger / _fill_rollout_state read finish_info / content from it directly. - Move tool_turns counting into agent_in_localhost_loop (single source); trainer / evaluator just read extra_fields[agent_tool_turns] and emit tool_turns/{mean,min,max} to tensorboard. - Trajectory dump: drop redundant fields (final_assistant_finish_reason, raw_prompt, lengths.input_tokens) and helper chain; messages/tools/tool_turns moved under agent.* sub-key. Switch to one-line JSONL records. - Drop _to_json_safe / _dump_jsonl_record wrappers; explicit model_dump only where needed (StageRecord). Remove dead isinstance(Tensor) branches. - Rollout: add EADDRINUSE retry for session_server and worker launch; add sample_timeout_s; disable replay buffer checkpoint by default. - Add asyncio_diagnostics module (USR1 task dump + asyncio_run watchdog) for prod deadlock troubleshooting. - Drop local _TokenBucket (use lagent.utils.rate_limiter.FairAsyncTokenBucket via API-level rate limiting in tools). * [Refactor] Simplify trajectory dump length fields - Drop redundant `lengths` nested struct: response_train_tokens duplicates response_len, response_text_tokens is misleading in agent loops (single_turn matches response_train but localhost only covers the last turn). Keep one definition: response_len = len(response_ids). - Promote prompt length to top-level as `prompt_len` (was `lengths.num_tokens`, which actually held len(prompt_ids) per local_run.py). - Unify second trajectory dump path with the first (response_len = len(response_ids) instead of encode(response); behaviour matches in eval where response_ids is None and _get_trajectory_response_ids falls back to encode(data.response)). - Drop dead extra_fields["raw_prompt"] write in agent_in_localhost_loop and the matching dead .get("raw_prompt") read in the trajectory item — no reader remained after the earlier trajectory cleanup.
Uh oh!
There was an error while loading. Please reload this page.