[model] support diffusion_gemma4#9548
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the diffusion_gemma model, adding its configurations, architecture registration, model loader, and template. It also includes robustness improvements when handling missing generation configuration attributes and modality token IDs. However, a critical bug was identified in swift/template/templates/gemma.py where skipping None token IDs without filtering the corresponding splited_tokens can lead to length mismatches and incorrect token alignment during sorting. A fix has been suggested to align active split tokens only with valid token IDs.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for key in ['image', 'video', 'audio']: | ||
| idx_list += findall(input_ids, getattr(self.config, f'{key}_token_id')) | ||
| token_id = getattr(self.config, f'{key}_token_id', None) | ||
| if token_id is None: | ||
| continue | ||
| idx_list += findall(input_ids, token_id) | ||
| sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i]) | ||
| idx_list = [idx_list[i] for i in sorted_order] | ||
| splited_tokens = [splited_tokens[i] for i in sorted_order] |
There was a problem hiding this comment.
There is a potential mismatch/alignment bug when any of the modality token IDs (e.g., video_token_id or audio_token_id) is None (unsupported by the model) but the corresponding media inputs are still present in inputs.\n\nBecause media_inputs is processed with all media types unconditionally, splited_tokens will contain elements for the unsupported modality. However, since token_id is None, it is skipped when building idx_list. This causes len(idx_list) to be less than len(splited_tokens). When sorting and slicing, the tokens of different modalities will get mismatched (e.g., audio tokens might be used in place of image tokens), leading to silent correctness bugs or runtime errors.\n\nWe can fix this by only keeping the split tokens for modalities that have a valid token_id in the config.
active_splited_tokens = []\n token_offset = 0\n for key in ['image', 'video', 'audio']:\n token_id = getattr(self.config, f'{key}_token_id', None)\n num_items = len(getattr(inputs, f'{key}s') or [])\n if token_id is not None:\n idx_list += findall(input_ids, token_id)\n active_splited_tokens += splited_tokens[token_offset:token_offset + num_items]\n token_offset += num_items\n sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i])\n idx_list = [idx_list[i] for i in sorted_order]\n splited_tokens = [active_splited_tokens[i] for i in sorted_order]|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the DiffusionGemma model, including its architecture registration, custom model loader, dataset preprocessing for Sudoku, and a specialized template (DiffusionGemmaTemplate) with custom input preparation and loss computation. It also refactors SFT loss calculation by delegating it to the template class and updates dependencies. The review feedback highlights several robustness and safety improvements: handling potential None or zero values for num_items_in_batch and self._loss_scale to prevent runtime errors, using x0.shape instead of canvas_length to avoid broadcasting issues, and safely preserving the default skip_prompt behavior for other encoder-decoder models.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self._loss_scale_cache[self._loss_scale] = get_loss_scale(self._loss_scale) | ||
| return self._loss_scale_cache[self._loss_scale] | ||
| loss_scale = self._loss_scale_cache[self._loss_scale] | ||
| if self.is_training and self.template_meta.is_thinking and self.template_meta.non_thinking_prefix and 'ignore_empty_think' not in self._loss_scale: |
There was a problem hiding this comment.
If self._loss_scale is None, the membership check 'ignore_empty_think' not in self._loss_scale will raise a TypeError. Adding a check to ensure self._loss_scale is not None or empty makes this safer.
| if self.is_training and self.template_meta.is_thinking and self.template_meta.non_thinking_prefix and 'ignore_empty_think' not in self._loss_scale: | |
| if self.is_training and self.template_meta.is_thinking and self.template_meta.non_thinking_prefix and (not self._loss_scale or 'ignore_empty_think' not in self._loss_scale): |
No description provided.