feat: add multimodal (image) support for Kimi-K2.5/K2.6 training#9479
feat: add multimodal (image) support for Kimi-K2.5/K2.6 training#9479zhihanliu-collab wants to merge 9 commits into
Conversation
Implements replace_tag, _encode, _data_collator_mm_data, and
_post_encode in KimiK25Template to enable image training for
moonshotai/Kimi-K2.5 and moonshotai/Kimi-K2.6.
Key implementation details:
- replace_tag returns the standard Kimi media token sequence for images
- _encode calls KimiK25ImageProcessor.preprocess() with {'type':'image',...}
dicts, expands <|media_pad|> placeholders based on grid_thws and
merge_kernel_size from the processor config
- _data_collator_mm_data concatenates grid_thws tensors across batch items
- _post_encode extracts vision features via model._extract_image_features
and fills pre-expanded placeholder positions with the concatenated
feature tensors; includes DeepSpeed dummy-forward for parameter sync
Closes modelscope#9469
There was a problem hiding this comment.
Code Review
This pull request adds image support to the KimiK25Template in swift/template/templates/moonshot.py by implementing image tag replacement, encoding, data collation, and post-encoding embedding replacement. The review feedback suggests improving the robustness of the implementation by using safer attribute retrieval for the image processor configuration, avoiding direct access to nested internal properties of the vision tower, and explicitly casting tensor devices and data types to prevent potential runtime errors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| image_processor = self.processor.image_processor | ||
| medias = [{'type': 'image', 'image': img} for img in inputs.images] | ||
| image_inputs = image_processor.preprocess(medias, return_tensors='pt') | ||
| grid_thws = image_inputs['grid_thws'] | ||
| merge_kernel_size = image_processor.media_proc_cfg['merge_kernel_size'] | ||
| merge_length = merge_kernel_size[0] * merge_kernel_size[1] |
There was a problem hiding this comment.
To prevent potential AttributeError or KeyError if media_proc_cfg or merge_kernel_size is missing or structured differently in future processor versions, it is safer to use getattr and .get with a fallback default value of (2, 2).
| image_processor = self.processor.image_processor | |
| medias = [{'type': 'image', 'image': img} for img in inputs.images] | |
| image_inputs = image_processor.preprocess(medias, return_tensors='pt') | |
| grid_thws = image_inputs['grid_thws'] | |
| merge_kernel_size = image_processor.media_proc_cfg['merge_kernel_size'] | |
| merge_length = merge_kernel_size[0] * merge_kernel_size[1] | |
| image_processor = self.processor.image_processor | |
| medias = [{'type': 'image', 'image': img} for img in inputs.images] | |
| image_inputs = image_processor.preprocess(medias, return_tensors='pt') | |
| grid_thws = image_inputs['grid_thws'] | |
| media_proc_cfg = getattr(image_processor, 'media_proc_cfg', {}) | |
| merge_kernel_size = media_proc_cfg.get('merge_kernel_size', (2, 2)) | |
| merge_length = merge_kernel_size[0] * merge_kernel_size[1] |
| if pixel_values is not None and pixel_values.size(0) > 0: | ||
| vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype | ||
| pixel_values = pixel_values.to(vision_dtype) | ||
| image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws']) | ||
| all_features = torch.cat(image_features, dim=0).to(inputs_embeds.dtype) | ||
| media_token_id = model.config.media_placeholder_token_id | ||
| inputs_embeds = inputs_embeds.clone() | ||
| inputs_embeds[input_ids == media_token_id] = all_features |
There was a problem hiding this comment.
Several improvements can be made here to increase robustness:
- Use
next(model.vision_tower.parameters()).dtypeinstead of accessing the nestedpatch_embed.proj.weight.dtypedirectly, which avoids relying on specific internal layer names of the vision tower. - Explicitly specify both
deviceanddtypewhen castingall_featuresto matchinputs_embedsto prevent device mismatch errors. - Safely retrieve
media_token_idfrommodel.configwith a fallback to the tokenizer's<|media_pad|>token ID to preventAttributeErrorif the config property is missing.
| if pixel_values is not None and pixel_values.size(0) > 0: | |
| vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype | |
| pixel_values = pixel_values.to(vision_dtype) | |
| image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws']) | |
| all_features = torch.cat(image_features, dim=0).to(inputs_embeds.dtype) | |
| media_token_id = model.config.media_placeholder_token_id | |
| inputs_embeds = inputs_embeds.clone() | |
| inputs_embeds[input_ids == media_token_id] = all_features | |
| if pixel_values is not None and pixel_values.size(0) > 0: | |
| vision_dtype = next(model.vision_tower.parameters()).dtype | |
| pixel_values = pixel_values.to(vision_dtype) | |
| image_features: list = model._extract_image_features(pixel_values, inputs['grid_thws']) | |
| all_features = torch.cat(image_features, dim=0).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) | |
| media_token_id = getattr(model.config, 'media_placeholder_token_id', None) or self.tokenizer.convert_tokens_to_ids('<|media_pad|>') | |
| inputs_embeds = inputs_embeds.clone() | |
| inputs_embeds[input_ids == media_token_id] = all_features |
| elif is_deepspeed_enabled(): | ||
| image_processor = self.processor.image_processor | ||
| dummy_image = Image.new('RGB', (32, 32), (0, 0, 0)) | ||
| dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt') | ||
| vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype | ||
| dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device) | ||
| dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device) | ||
| image_features = model._extract_image_features(dummy_pixels, dummy_grid) | ||
| inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean() * 0. |
There was a problem hiding this comment.
Apply similar robustness improvements for the DeepSpeed dummy forward pass:
- Use
next(model.vision_tower.parameters()).dtypeinstead ofmodel.vision_tower.patch_embed.proj.weight.dtype. - Cast the dummy image features to
inputs_embeds.dtypebefore adding them to avoid potential dtype mismatch errors.
| elif is_deepspeed_enabled(): | |
| image_processor = self.processor.image_processor | |
| dummy_image = Image.new('RGB', (32, 32), (0, 0, 0)) | |
| dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt') | |
| vision_dtype = model.vision_tower.patch_embed.proj.weight.dtype | |
| dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device) | |
| dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device) | |
| image_features = model._extract_image_features(dummy_pixels, dummy_grid) | |
| inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean() * 0. | |
| elif is_deepspeed_enabled(): | |
| image_processor = self.processor.image_processor | |
| dummy_image = Image.new('RGB', (32, 32), (0, 0, 0)) | |
| dummy_inputs = image_processor.preprocess([{'type': 'image', 'image': dummy_image}], return_tensors='pt') | |
| vision_dtype = next(model.vision_tower.parameters()).dtype | |
| dummy_pixels = dummy_inputs['pixel_values'].to(vision_dtype).to(inputs_embeds.device) | |
| dummy_grid = dummy_inputs['grid_thws'].to(inputs_embeds.device) | |
| image_features = model._extract_image_features(dummy_pixels, dummy_grid) | |
| inputs_embeds = inputs_embeds + torch.cat(image_features, dim=0).mean().to(dtype=inputs_embeds.dtype) * 0. |
…ost_encode - grid_thws was not moved to the GPU before passing to _extract_image_features, causing a device mismatch when pixel_values is on GPU - boolean mask on 2-D input_ids produced wrong shape when indexing 3-D inputs_embeds; use reshape(-1) to flatten both tensors before indexing
Smoke test on live Kimi-K2.6 weights revealed two issues: - media_proc_cfg['merge_kernel_size'] is int 2, not tuple (2,2); add isinstance guard so merge_length = k*k for int or k[0]*k[1] for tuple - model.vision_tower.patch_embed.proj.weight is fragile under LoRA wrapping; switch to next(vision_tower.parameters()).dtype + get_base_model() unwrap so dtype/device access works correctly through PeftModel
- _encode: use getattr+.get with fallback (2,2) for media_proc_cfg access; keep int/tuple isinstance guard from smoke test finding - _post_encode: explicitly cast all_features to both device and dtype of inputs_embeds; add getattr fallback for media_placeholder_token_id with tokenizer.convert_tokens_to_ids as secondary fallback - DeepSpeed path: cast mean() to inputs_embeds.dtype to avoid bf16/fp32 mismatch
…i_k25 - get_batch_on_this_pp_rank: null out pixel_values/grid_thws and other MM tensors on non-first PP stages; vision encoder on stage-0 consumes them and subsequent stages receive image features via activation channel - moonshot.py: add mcore_model_type='kimi_k25' to model meta so that Megatron-SWIFT can locate the mcore_bridge implementation when it becomes available Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove mcore_model_type='kimi_k25' until mcore_bridge impl lands (would crash all kimi_k25 Megatron users with get_model_meta assertion) - Add explicit KeyError for pixel_values-present-but-grid_thws-missing - Guard torch.cat(image_features) in DeepSpeed dummy path against empty list Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
_extract_image_features returns un-projected vision_tower output; the HF forward applies mm_projector (PatchMergerMLP) to map mm_hidden_size -> language hidden_size before merging. The previous code fed raw vision features into inputs_embeds, causing a dimension mismatch at scatter. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.nan_to_num on the zero-scaled dummy term prevents a non-finite value from the all-zero dummy forward (e.g. zero-variance normalization) leaking into text-only batches, since NaN * 0 == NaN in IEEE-754. Addresses review feedback on the companion mcore-bridge PR. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
Fixes #9469 — adds image training support for
moonshotai/Kimi-K2.5andmoonshotai/Kimi-K2.6inKimiK25Template.Changes in
swift/template/templates/moonshot.py:replace_tag: Returns the standard Kimi media token sequence<|media_start|>image<|media_content|><|media_pad|><|media_end|>forimagemedia type; raisesValueErrorfor unsupported types (video/audio) with a helpful message._encode: Wraps PIL images as[{'type': 'image', 'image': img}]dicts (the format expected byKimiK25ImageProcessor.preprocess), obtainspixel_valuesandgrid_thws, readsmerge_kernel_sizefromimage_processor.media_proc_cfg, then expands each<|media_pad|>placeholder to the correct number of tokens:grid_thws[i].prod() // (kH * kW)._data_collator_mm_data: Concatenates per-samplegrid_thwstensors along dim 0 for batched training._post_encode: Extracts vision features viamodel._extract_image_features(pixel_values, grid_thws), concatenates the per-image feature tensors, and fills the pre-expanded<|media_pad|>positions ininputs_embedsusing a boolean mask onmodel.config.media_placeholder_token_id. Includes a DeepSpeed dummy forward pass to keep all parameters in the computation graph.Implementation notes
The
KimiK25ImageProcessorreturnsgrid_thwsof shape(N, 3)(T, H, W in patch units), unlike Kimi-VL which returnsimage_grid_hwsof shape(N, 2)(H, W). Token count per image isT * H * W / (kH * kW)— for static images T=1, so this reduces toH * W / 4with the defaultmerge_kernel_size=(2,2).The merge is done by pre-expanding tokens in
_encode(same pattern asKimiVLTemplate) rather than calling the model's_merge_input_ids_with_image_features(which assumes single-token placeholders and changes sequence length), keeping the implementation consistent with the rest of ms-swift's template framework.