Skip to content

[DataProcessor]merge processor#7747

Open
luukunn wants to merge 11 commits intoPaddlePaddle:developfrom
luukunn:merge_2
Open

[DataProcessor]merge processor#7747
luukunn wants to merge 11 commits intoPaddlePaddle:developfrom
luukunn:merge_2

Conversation

@luukunn
Copy link
Copy Markdown
Collaborator

@luukunn luukunn commented May 7, 2026

Motivation

将分散在各处的 VL 模型多模态处理器(QwenVL、Qwen3VL、ERNIE 4.5 VL、PaddleOCR-VL)整合到统一的 fastdeploy/input/multimodal/ 目录,统一图像/视频预处理接口,减少代码重复,便于后续扩展和维护。同时修复 _add_request 中对空列表的 IndexError 风险,简化 parse_chat_messages 的内容解析逻辑以支持直接传递 messages 格式。

Modifications

  • 新增 fastdeploy/input/multimodal/ 包:
    • common.py:共享图像缩放工具函数(smart_resize_qwensmart_resize_paddleocris_scaled_image 等)
    • mm_processor.py:MMProcessor 基类
    • ernie4_5_vl.py:ERNIE 4.5 VL 多模态处理器(含 3D position IDs、视频帧时间戳渲染)
    • qwen_vl.pyqwen3_vl.pypaddleocr_vl.py:其他 VL 模型处理器
    • image_processors/:各模型专用图像预处理器(ernie、qwen、qwen3、paddleocr)
  • 重构 fastdeploy/entrypoints/chat_utils.py:简化 parse_chat_messages,字符串/None 内容直接透传
  • 修复 fastdeploy/entrypoints/llm.py
    • _add_request 新增 list[dict](messages 格式)分支,修复空列表 IndexError(添加 len > 0 守卫)
    • chat() 方法直接传递 messages 而非包装为 {"messages": ...} dict
  • 新增单元测试:tests/input/multimodal/ 下覆盖各 VL 处理器

Usage or Command

N/A

Accuracy Tests

N/A

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings May 7, 2026 15:21
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在将输入预处理/数据处理链路进行合并重构:引入统一的 Processor(文本 + 多模态可选),并新增一套基于 MMProcessor 抽象类的多模态处理实现,同时同步调整 LLM chat 入参路径与相关单测。

Changes:

  • 新增 fastdeploy/input/processor.py,作为统一的请求预处理与响应解码入口,并在 InputPreprocessor 中替换原 Text/MM processor 的创建逻辑。
  • 新增多模态子模块 fastdeploy/input/multimodal/*(含 MMProcessor 抽象类、Qwen/Ernie/PaddleOCR 处理器与 image processors)。
  • 调整 LLM.chat() 的 prompts/messages 传递方式与部分测试用例。

PR 标题/描述检查(需补充)

  • 标题建议按规范加空格与更清晰的动词,例如:[DataProcessor] Merge processors / [DataProcessor] Merge input processor
  • PR 描述仍是模板,缺少 Motivation / Modifications / Usage / Accuracy Tests 等信息;本 PR 改动面较大,建议补齐并说明兼容性与迁移影响,并视情况补充/更新相关文档。

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/input/test_preprocess.py 适配 Processor 替换原 TextProcessor 的创建逻辑测试
tests/entrypoints/test_generation.py 调整 chat 一致性测试(目前存在覆盖面退化问题)
tests/entrypoints/test_chat.py 改为 hook process_messages 以捕获 prompt 拼接结果
fastdeploy/input/processor.py 新增统一 Processor(文本/多模态入口)
fastdeploy/input/preprocess.py InputPreprocessor 改为创建统一 Processor 并挂载 mm_processor
fastdeploy/entrypoints/llm.py chat 入参传递与 _add_request 支持 messages/batch 形态调整
fastdeploy/entrypoints/chat_utils.py 调整 parse_chat_messages 的 content 归一化行为(当前存在兼容性问题)
fastdeploy/input/multimodal/mm_processor.py 新增 MMProcessor 抽象基类与多模态处理模板流程(当前与引擎侧字段契约不一致)
fastdeploy/input/multimodal/qwen_vl.py 新增 Qwen2.5-VL 多模态 processor
fastdeploy/input/multimodal/qwen3_vl.py 新增 Qwen3-VL 多模态 processor
fastdeploy/input/multimodal/paddleocr_vl.py 新增 PaddleOCR-VL 多模态 processor
fastdeploy/input/multimodal/ernie_vl.py 新增 ERNIE4.5-VL 多模态 processor
fastdeploy/input/multimodal/common.py 新增多模态通用 resize 工具
fastdeploy/input/multimodal/init.py 导出多模态 processors
fastdeploy/input/multimodal/image_processors/* 新增/整理 Qwen/Qwen3/PaddleOCR image processors 与导出
fastdeploy/input/multimodal/image_processors/init.py image processor 导出聚合
Comments suppressed due to low confidence (1)

fastdeploy/entrypoints/chat_utils.py:209

  • parse_chat_messages 现在在 content is Nonecontentstr 时直接返回 None/str,会破坏下游对 content 为 list[dict] 的假设(例如 MultiModalProcessor._extract_mm_items 会对 item.get(...) 直接调用,遇到 str 会抛 AttributeError)。建议保持返回格式稳定:None -> []str -> [{"type":"text","text":...}],仅在原始为 list 时才解析为多模态 part 列表。
        role = message["role"]
        content = message["content"]

        if content is None:
            parsed_content = content
        elif isinstance(content, str):
            parsed_content = content
        else:
            parsed_content = [parse_content_part(mm_parser, part) for part in content]

        conversation.append({"role": role, "content": parsed_content})

Comment on lines 129 to 133
for prompt_token_ids in self.TOKEN_IDS:
with self.subTest(prompt_token_ids=prompt_token_ids):
output1 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
output2 = self.llm.chat(
[{"prompt": "", "prompt_token_ids": prompt_token_ids}], sampling_params=sampling_params
)
output2 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
self.assert_outputs_equal(output1, output2)
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
Comment on lines +390 to +409
hashes_to_cache, items_to_cache = [], []
for idx, item in enumerate(all_items):
# Items fetched from cache (data is tuple) should not be re-cached
if isinstance(item.data, tuple):
continue
# Build pixel_values and meta for this item
if outputs["images"] is None or idx >= len(outputs["images"]):
continue
pixel_values = outputs["images"][idx]
# Compute hash: prefer uuid, fallback to content hash
cache_key = item.uuid if item.uuid else MultimodalHasher.hash_features(pixel_values)
meta = {}
if idx < len(outputs.get("grid_thw", []) or []):
grid_thw = np.asarray(outputs["grid_thw"][idx]) if outputs["grid_thw"] is not None else None
if grid_thw is not None:
if grid_thw.ndim > 1:
t_val, h, w = grid_thw[0]
else:
t_val, h, w = grid_thw
meta["thw"] = (int(t_val), int(h), int(w))
Comment on lines +460 to +470
def process_messages(self, request):
"""将 messages 格式转换为 prompt + multimodal_data(通用,与模型无关)。

职责:
1. 从 messages 中提取多模态内容(图片/视频)
→ 写入 request["multimodal_data"] = {"image": [...], "video": [...], "mm_order": [...]}
2. 调用 tokenizer.apply_chat_template(messages) 拼接 prompt
→ 写入 request["prompt"]

调用时机:request 含 "messages" 且尚未有 "prompt"/"prompt_token_ids" 时。
"""
Comment on lines +616 to +618
for seq in stop_sequences:
if seq != self.tokenizer.eos_token_id:
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
Comment on lines +754 to +764
if prompt_token_ids[0] > self.tokenizer.vocab_size:
if not add_prefix_space:
log_request(
level=1,
message="bad_words: '{prompt}' token id {token_id} > vocab_size, skipping",
prompt=prompt,
token_id=prompt_token_ids[0],
)
continue
if prompt_token_ids not in token_ids:
token_ids.extend(prompt_token_ids)
Comment on lines +769 to +770
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
return self.tokenizer.eos_token
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 7, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-09 19:59:35

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

有 1 个 required 任务失败(Approval),2 个 required 任务运行中,请等待 CI 完成并处理审批。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
38(0) 38 28 3 2 2 2

2 任务状态汇总

2.1 Required任务 : 5/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 10s 环境问题:PR 未获必要审批人批准(退出码 6) 请具有 approve 权限的 maintainer 审批 PR Job -
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage - 运行中 - Job -
Extracted partial CE model tasks to run in CI. / run_ce_cases - 运行中 - Job -
其余 5 个必选任务通过 - - - - -

2.2 可选任务 — 23/28 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Check PR Template 12s Job -
Trigger Jenkins for PR 16m4s Job -
⏸️ CI_HPU - - -
⏸️ Run iluvatar Tests / run_iluvatar_cases - - -
其余 23 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 基础设施(置信度: 中)

Approval

  • 状态: ❌ 失败
  • 错误类型: 基础设施
  • 置信度: 中
  • 根因摘要: PR 未获必要审批人批准,退出码 6
  • 分析器: 通用分析(fallback)

根因详情:
Approval 工作流是一个快速检测 PR 审批状态的 CI Gate(执行时长仅 10 秒)。退出码 6 表示 PR 尚未获得必要审批人的 Approve,CI 系统拒绝继续推进。这与 PR 代码本身无关,属于流程审核门控。

关键日志:

[FAILURE]: Process completed with exit code 6.

修复建议:

  1. 请具有 approve 权限的 Maintainer/Reviewer 在 PR 页面点击 "Approve" 完成审批
  2. 审批完成后,CI 会自动重新触发该 Workflow

修复建议摘要: 请具有 approve 权限的 maintainer 审批该 PR

链接: 查看日志

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 7, 2026

Codecov Report

❌ Patch coverage is 77.92350% with 404 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@d70f33d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/input/processor.py 41.35% 245 Missing and 40 partials ⚠️
fastdeploy/input/multimodal/qwen_vl.py 85.40% 25 Missing and 9 partials ⚠️
fastdeploy/input/multimodal/ernie4_5_vl.py 92.01% 18 Missing and 3 partials ⚠️
fastdeploy/input/multimodal/mm_processor.py 93.33% 9 Missing and 8 partials ⚠️
fastdeploy/input/multimodal/paddleocr_vl.py 86.40% 13 Missing and 4 partials ⚠️
...tdeploy/input/multimodal/image_processors/ernie.py 91.97% 8 Missing and 5 partials ⚠️
fastdeploy/input/preprocess.py 54.54% 10 Missing ⚠️
...stdeploy/input/multimodal/image_processors/qwen.py 93.02% 3 Missing and 3 partials ⚠️
fastdeploy/entrypoints/chat_utils.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7747   +/-   ##
==========================================
  Coverage           ?   71.75%           
==========================================
  Files              ?      409           
  Lines              ?    57510           
  Branches           ?     9009           
==========================================
  Hits               ?    41268           
  Misses             ?    13419           
  Partials           ?     2823           
Flag Coverage Δ
GPU 71.75% <77.92%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copilot AI review requested due to automatic review settings May 9, 2026 07:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 24 changed files in this pull request and generated 12 comments.

Comment thread fastdeploy/input/processor.py
Comment thread fastdeploy/input/processor.py
Comment thread fastdeploy/input/processor.py
Comment on lines +460 to +470
def process_messages(self, request):
"""将 messages 格式转换为 prompt + multimodal_data(通用,与模型无关)。

职责:
1. 从 messages 中提取多模态内容(图片/视频)
→ 写入 request["multimodal_data"] = {"image": [...], "video": [...], "mm_order": [...]}
2. 调用 tokenizer.apply_chat_template(messages) 拼接 prompt
→ 写入 request["prompt"]

调用时机:request 含 "messages" 且尚未有 "prompt"/"prompt_token_ids" 时。
"""
Comment thread fastdeploy/input/multimodal/mm_processor.py
Comment thread fastdeploy/input/multimodal/qwen_vl.py
Comment thread fastdeploy/input/multimodal/qwen_vl.py
Comment thread fastdeploy/input/multimodal/qwen_vl.py
Comment on lines +92 to +105
def get(self, hashes: list) -> list:
"""Retrieve cached multimodal data by hash list."""
req = pickle.dumps(hashes)
self.socket.send_multipart([b"", req])
_, resp = self.socket.recv_multipart()
items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {hashes}")
return items

def put(self, hashes: list, items: list) -> None:
"""Write processed multimodal items to cache."""
req = pickle.dumps((hashes, items))
self.socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {hashes}")
Comment on lines 125 to 133
def test_consistency_single_prompt_tokens_chat(self):
"""Test consistency between different prompt input formats"""
sampling_params = SamplingParams(temperature=1.0, top_p=0.0)

for prompt_token_ids in self.TOKEN_IDS:
with self.subTest(prompt_token_ids=prompt_token_ids):
output1 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
output2 = self.llm.chat(
[{"prompt": "", "prompt_token_ids": prompt_token_ids}], sampling_params=sampling_params
)
output2 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
self.assert_outputs_equal(output1, output2)
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 9, 2026 11:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 4 comments.

Comment on lines +229 to +233
# Interleaved type order: directly from mm_data, or default images-then-videos.
mm_order = mm_data.get("mm_order")
if not mm_order:
mm_order = ["image"] * len(images) + ["video"] * len(videos)

Comment on lines +460 to +470
def process_messages(self, request):
"""将 messages 格式转换为 prompt + multimodal_data(通用,与模型无关)。

职责:
1. 从 messages 中提取多模态内容(图片/视频)
→ 写入 request["multimodal_data"] = {"image": [...], "video": [...], "mm_order": [...]}
2. 调用 tokenizer.apply_chat_template(messages) 拼接 prompt
→ 写入 request["prompt"]

调用时机:request 含 "messages" 且尚未有 "prompt"/"prompt_token_ids" 时。
"""
Comment on lines 125 to 133
def test_consistency_single_prompt_tokens_chat(self):
"""Test consistency between different prompt input formats"""
"""Test deterministic output for prompt_token_ids via chat interface"""
sampling_params = SamplingParams(temperature=1.0, top_p=0.0)

for prompt_token_ids in self.TOKEN_IDS:
with self.subTest(prompt_token_ids=prompt_token_ids):
output1 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
output2 = self.llm.chat(
[{"prompt": "", "prompt_token_ids": prompt_token_ids}], sampling_params=sampling_params
)
output2 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
self.assert_outputs_equal(output1, output2)

def pack_position_ids(self, outputs):
"""Qwen: concatenate 3xN arrays, then transpose to Nx3."""
outputs["position_ids"] = np.concatenate(outputs["position_ids"], axis=1, dtype=np.int64)
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-09 19:46:12

📋 Review 摘要

PR 概述:将各 VL 模型多模态处理器(QwenVL、Qwen3VL、ERNIE 4.5 VL、PaddleOCR-VL)整合至统一的 fastdeploy/input/multimodal/ 目录,同时修复 _add_request 空列表 IndexError 并简化 parse_chat_messages 逻辑。
变更范围fastdeploy/input/multimodal/fastdeploy/entrypoints/
影响面 Tag[DataProcessor] [APIServer]


📝 PR 规范检查

PR 标题缺少空格且描述过于简短;PR body 的 Motivation、Modifications、Usage or Command、Accuracy Tests 均为空,Checklist 全部未勾选,不符合模板要求。

标题建议(可直接复制):

  • [DataProcessor] Merge multimodal processors into fastdeploy/input/multimodal/

PR 描述建议(可直接复制):

## Motivation
将分散在各处的 VL 模型多模态处理器(QwenVL、Qwen3VL、ERNIE 4.5 VL、PaddleOCR-VL)整合到统一的 `fastdeploy/input/multimodal/` 目录,统一图像/视频预处理接口,减少代码重复,便于后续扩展和维护。同时修复 `_add_request` 中对空列表的 IndexError 风险,简化 `parse_chat_messages` 的内容解析逻辑以支持直接传递 messages 格式。

## Modifications
- 新增 `fastdeploy/input/multimodal/` 包:
  - `common.py`:共享图像缩放工具函数(`smart_resize_qwen``smart_resize_paddleocr``is_scaled_image` 等)
  - `mm_processor.py`:MMProcessor 基类
  - `ernie4_5_vl.py`:ERNIE 4.5 VL 多模态处理器(含 3D position IDs、视频帧时间戳渲染)
  - `qwen_vl.py``qwen3_vl.py``paddleocr_vl.py`:其他 VL 模型处理器
  - `image_processors/`:各模型专用图像预处理器(ernie、qwen、qwen3、paddleocr)
- 重构 `fastdeploy/entrypoints/chat_utils.py`:简化 `parse_chat_messages`,字符串/None 内容直接透传
- 修复 `fastdeploy/entrypoints/llm.py`- `_add_request` 新增 `list[dict]`(messages 格式)分支,修复空列表 IndexError(添加 `len > 0` 守卫)
  - `chat()` 方法直接传递 messages 而非包装为 `{"messages": ...}` dict
- 新增单元测试:`tests/input/multimodal/` 下覆盖各 VL 处理器

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🟡 建议 fastdeploy/input/multimodal/image_processors/ernie.py:148 set_pixels 使用 assert 对用户输入做运行时校验,-O 模式下失效
🟡 建议 fastdeploy/input/multimodal/image_processors/ernie.py:209 _preprocess 使用 assert 校验参数匹配,同上
❓ 疑问 fastdeploy/entrypoints/chat_utils.py:205 content is Noneparsed_content[] 变为 None,需确认下游兼容性
🟡 建议 fastdeploy/input/multimodal/common.py:130 smart_resize_paddleocr 缺少 smart_resize_qwen 同款后置有效性校验

总体评价

PR 整体结构清晰,处理器整合方向合理,单元测试覆盖充分。主要关注点为两处 assert 用于运行时校验(建议改为 raise ValueError)以及 chat_utils.pyNone 内容的接口语义变更需确认下游兼容性。


def set_pixels(self, min_pixels=None, max_pixels=None, msg=""):
if min_pixels is not None:
assert isinstance(min_pixels, int) and min_pixels >= 0, "min_pixels must be positive int"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 set_pixels 中使用 assert 对用户输入进行运行时校验。

Python 以 -O 标志运行时 assert 会被跳过,导致非法参数静默通过,建议改为显式 raise ValueError

if not (isinstance(min_pixels, int) and min_pixels >= 0):
    raise ValueError("min_pixels must be a non-negative int")


if predetermined_grid_thw is not None:
assert len(predetermined_grid_thw) == len(
images
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _preprocess 中同样使用 assert 校验 predetermined_grid_thwimages 长度匹配。

建议改为:

if len(predetermined_grid_thw) != len(images):
    raise ValueError(
        f"len(predetermined_grid_thw) {len(predetermined_grid_thw)} != len(images) {len(images)}"
    )

parsed_content = content
elif isinstance(content, str):
parsed_content = [{"type": "text", "text": content}]
parsed_content = content
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 此前 content is Noneparsed_content = [](空列表),现在改为 parsed_content = content(即 None)。

如果下游有任何地方对 parsed_content 做迭代或 len() 操作,将抛出 TypeError。请确认所有消费路径均已兼容 None 值,或在此处改为 parsed_content = [] 保持向后兼容。

elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 smart_resize_paddleocr 在最终 return 前未对结果做有效性验证。

smart_resize_qwen 在返回前有:

if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
    raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")

PaddleOCR 版本中缺少此检查,当 h_bar * w_bar 在 floor/ceil 调整后仍越界时,会静默返回错误尺寸,建议补充相同的后置校验。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants