Skip to content

feat: add multimodal (image) support for Kimi-K2.5/K2.6 training#9479

Draft
zhihanliu-collab wants to merge 9 commits into
modelscope:mainfrom
zhihanliu-collab:feat/kimi-k26-multimodal
Draft

feat: add multimodal (image) support for Kimi-K2.5/K2.6 training#9479
zhihanliu-collab wants to merge 9 commits into
modelscope:mainfrom
zhihanliu-collab:feat/kimi-k26-multimodal

Conversation

@zhihanliu-collab

Copy link
Copy Markdown

Summary

Fixes #9469 — adds image training support for moonshotai/Kimi-K2.5 and moonshotai/Kimi-K2.6 in KimiK25Template.

Changes in swift/template/templates/moonshot.py:

  • replace_tag: Returns the standard Kimi media token sequence <|media_start|>image<|media_content|><|media_pad|><|media_end|> for image media type; raises ValueError for unsupported types (video/audio) with a helpful message.

  • _encode: Wraps PIL images as [{'type': 'image', 'image': img}] dicts (the format expected by KimiK25ImageProcessor.preprocess), obtains pixel_values and grid_thws, reads merge_kernel_size from image_processor.media_proc_cfg, then expands each <|media_pad|> placeholder to the correct number of tokens: grid_thws[i].prod() // (kH * kW).

  • _data_collator_mm_data: Concatenates per-sample grid_thws tensors along dim 0 for batched training.

  • _post_encode: Extracts vision features via model._extract_image_features(pixel_values, grid_thws), concatenates the per-image feature tensors, and fills the pre-expanded <|media_pad|> positions in inputs_embeds using a boolean mask on model.config.media_placeholder_token_id. Includes a DeepSpeed dummy forward pass to keep all parameters in the computation graph.

Implementation notes

The KimiK25ImageProcessor returns grid_thws of shape (N, 3) (T, H, W in patch units), unlike Kimi-VL which returns image_grid_hws of shape (N, 2) (H, W). Token count per image is T * H * W / (kH * kW) — for static images T=1, so this reduces to H * W / 4 with the default merge_kernel_size=(2,2).

The merge is done by pre-expanding tokens in _encode (same pattern as KimiVLTemplate) rather than calling the model's _merge_input_ids_with_image_features (which assumes single-token placeholders and changes sequence length), keeping the implementation consistent with the rest of ms-swift's template framework.

Implements replace_tag, _encode, _data_collator_mm_data, and
_post_encode in KimiK25Template to enable image training for
moonshotai/Kimi-K2.5 and moonshotai/Kimi-K2.6.

Key implementation details:
- replace_tag returns the standard Kimi media token sequence for images
- _encode calls KimiK25ImageProcessor.preprocess() with {'type':'image',...}
  dicts, expands <|media_pad|> placeholders based on grid_thws and
  merge_kernel_size from the processor config
- _data_collator_mm_data concatenates grid_thws tensors across batch items
- _post_encode extracts vision features via model._extract_image_features
  and fills pre-expanded placeholder positions with the concatenated
  feature tensors; includes DeepSpeed dummy-forward for parameter sync

Closes modelscope#9469

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds image support to the KimiK25Template in swift/template/templates/moonshot.py by implementing image tag replacement, encoding, data collation, and post-encoding embedding replacement. The review feedback suggests improving the robustness of the implementation by using safer attribute retrieval for the image processor configuration, avoiding direct access to nested internal properties of the vision tower, and explicitly casting tensor devices and data types to prevent potential runtime errors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread swift/template/templates/moonshot.py Outdated
Comment on lines +128 to +133
image_processor = self.processor.image_processor
medias = [{'type': 'image', 'image': img} for img in inputs.images]
image_inputs = image_processor.preprocess(medias, return_tensors='pt')
grid_thws = image_inputs['grid_thws']
merge_kernel_size = image_processor.media_proc_cfg['merge_kernel_size']
merge_length = merge_kernel_size[0] * merge_kernel_size[1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential AttributeError or KeyError if media_proc_cfg or merge_kernel_size is missing or structured differently in future processor versions, it is safer to use getattr and .get with a fallback default value of (2, 2).

Suggested change
image_processor = self.processor.image_processor
medias = [{'type': 'image', 'image': img} for img in inputs.images]
image_inputs = image_processor.preprocess(medias, return_tensors='pt')
grid_thws = image_inputs['grid_thws']
merge_kernel_size = image_processor.media_proc_cfg['merge_kernel_size']
merge_length = merge_kernel_size[0] * merge_kernel_size[1]
image_processor = self.processor.image_processor
medias = [{'type': 'image', 'image': img} for img in inputs.images]
image_inputs = image_processor.preprocess(medias, return_tensors='pt')
grid_thws = image_inputs['grid_thws']
media_proc_cfg = getattr(image_processor, 'media_proc_cfg', {})
merge_kernel_size = media_proc_cfg.get('merge_kernel_size', (2, 2))
merge_length = merge_kernel_size[0] * merge_kernel_size[1]

Comment thread swift/template/templates/moonshot.py Outdated
Comment on lines +159 to +166
if pixel_values is not None and pixel_values.size(0) > 0:
vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype
pixel_values = pixel_values.to(vision_dtype)
image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws'])
all_features = torch.cat(image_features, dim=0).to(inputs_embeds.dtype)
media_token_id = model.config.media_placeholder_token_id
inputs_embeds = inputs_embeds.clone()
inputs_embeds[input_ids == media_token_id] = all_features

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Several improvements can be made here to increase robustness:

  1. Use next(model.vision_tower.parameters()).dtype instead of accessing the nested patch_embed.proj.weight.dtype directly, which avoids relying on specific internal layer names of the vision tower.
  2. Explicitly specify both device and dtype when casting all_features to match inputs_embeds to prevent device mismatch errors.
  3. Safely retrieve media_token_id from model.config with a fallback to the tokenizer's <|media_pad|> token ID to prevent AttributeError if the config property is missing.
Suggested change
if pixel_values is not None and pixel_values.size(0) > 0:
vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype
pixel_values = pixel_values.to(vision_dtype)
image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws'])
all_features = torch.cat(image_features, dim=0).to(inputs_embeds.dtype)
media_token_id = model.config.media_placeholder_token_id
inputs_embeds = inputs_embeds.clone()
inputs_embeds[input_ids == media_token_id] = all_features
if pixel_values is not None and pixel_values.size(0) > 0:
vision_dtype = next(model.vision_tower.parameters()).dtype
pixel_values = pixel_values.to(vision_dtype)
image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws'])
all_features = torch.cat(image_features, dim=0).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
media_token_id = getattr(model.config, 'media_placeholder_token_id', None) or self.tokenizer.convert_tokens_to_ids('<|media_pad|>')
inputs_embeds = inputs_embeds.clone()
inputs_embeds[input_ids == media_token_id] = all_features

Comment thread swift/template/templates/moonshot.py Outdated
Comment on lines +167 to +175
elif is_deepspeed_enabled():
image_processor = self.processor.image_processor
dummy_image = Image.new('RGB', (32, 32), (0, 0, 0))
dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt')
vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype
dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device)
dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device)
image_features = model._extract_image_features(dummy_pixels, dummy_grid)
inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean() * 0.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Apply similar robustness improvements for the DeepSpeed dummy forward pass:

  1. Use next(model.vision_tower.parameters()).dtype instead of model.vision_tower.patch_embed.proj.weight.dtype.
  2. Cast the dummy image features to inputs_embeds.dtype before adding them to avoid potential dtype mismatch errors.
Suggested change
elif is_deepspeed_enabled():
image_processor = self.processor.image_processor
dummy_image = Image.new('RGB', (32, 32), (0, 0, 0))
dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt')
vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype
dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device)
dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device)
image_features = model._extract_image_features(dummy_pixels, dummy_grid)
inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean() * 0.
elif is_deepspeed_enabled():
image_processor = self.processor.image_processor
dummy_image = Image.new('RGB', (32, 32), (0, 0, 0))
dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt')
vision_dtype = next(model.vision_tower.parameters()).dtype
dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device)
dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device)
image_features = model._extract_image_features(dummy_pixels, dummy_grid)
inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean().to(dtype=inputs_embeds.dtype) * 0.

…ost_encode

- grid_thws was not moved to the GPU before passing to _extract_image_features,
  causing a device mismatch when pixel_values is on GPU
- boolean mask on 2-D input_ids produced wrong shape when indexing 3-D
  inputs_embeds; use reshape(-1) to flatten both tensors before indexing
Smoke test on live Kimi-K2.6 weights revealed two issues:
- media_proc_cfg['merge_kernel_size'] is int 2, not tuple (2,2); add
  isinstance guard so merge_length = k*k for int or k[0]*k[1] for tuple
- model.vision_tower.patch_embed.proj.weight is fragile under LoRA wrapping;
  switch to next(vision_tower.parameters()).dtype + get_base_model() unwrap
  so dtype/device access works correctly through PeftModel
- _encode: use getattr+.get with fallback (2,2) for media_proc_cfg
  access; keep int/tuple isinstance guard from smoke test finding
- _post_encode: explicitly cast all_features to both device and dtype
  of inputs_embeds; add getattr fallback for media_placeholder_token_id
  with tokenizer.convert_tokens_to_ids as secondary fallback
- DeepSpeed path: cast mean() to inputs_embeds.dtype to avoid bf16/fp32
  mismatch
@zhihanliu-collab zhihanliu-collab marked this pull request as draft June 3, 2026 00:05
zhihanliu-collab and others added 5 commits June 2, 2026 17:41
…i_k25

- get_batch_on_this_pp_rank: null out pixel_values/grid_thws and other
  MM tensors on non-first PP stages; vision encoder on stage-0 consumes
  them and subsequent stages receive image features via activation channel
- moonshot.py: add mcore_model_type='kimi_k25' to model meta so that
  Megatron-SWIFT can locate the mcore_bridge implementation when it
  becomes available

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove mcore_model_type='kimi_k25' until mcore_bridge impl lands
  (would crash all kimi_k25 Megatron users with get_model_meta assertion)
- Add explicit KeyError for pixel_values-present-but-grid_thws-missing
- Guard torch.cat(image_features) in DeepSpeed dummy path against empty list

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
_extract_image_features returns un-projected vision_tower output; the HF
forward applies mm_projector (PatchMergerMLP) to map mm_hidden_size ->
language hidden_size before merging. The previous code fed raw vision
features into inputs_embeds, causing a dimension mismatch at scatter.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.nan_to_num on the zero-scaled dummy term prevents a non-finite value
from the all-zero dummy forward (e.g. zero-variance normalization) leaking
into text-only batches, since NaN * 0 == NaN in IEEE-754. Addresses review
feedback on the companion mcore-bridge PR.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Kimi-K2.6训练的多模态(image)支持

1 participant