-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[model] support diffusion_gemma4 #9548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6cc3146
f23902e
d7f32b0
d181b98
a77e4c6
713c43b
491f32f
c84bda8
eda195f
efdf36e
722c7b7
f0a9bef
187c57d
7549e98
45131f3
03df244
4f4fa9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # 2 * 60GiB | ||
| # This is just a demo for DiffusionGemma training. | ||
| # Notes: | ||
| # 1. Currently only --per_device_train_batch_size 1 is supported, | ||
| # and the response length of a single sample must be less than config.canvas_length. | ||
| # 2. --gradient_checkpointing false must be set. DiffusionGemma's encoder passes | ||
| # KV to the decoder via DynamicCache, and gradient checkpointing causes errors | ||
| # when recomputing the forward pass during backward. | ||
| # 3. For customizing the specific training loss, refer to: | ||
| # https://github.com/Jintao-Huang/llmscope/blob/7549e98709cfe8a9d6866ccfa560975b2c7bd375/swift/template/templates/gemma.py#L386-L428 | ||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||
| NPROC_PER_NODE=2 \ | ||
| swift sft \ | ||
| --model google/diffusiongemma-26B-A4B-it \ | ||
| --dataset 'sapientinc/sudoku-extreme-1k' \ | ||
| --load_from_cache_file true \ | ||
| --split_dataset_ratio 0.01 \ | ||
| --tuner_type lora \ | ||
| --torch_dtype bfloat16 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --per_device_eval_batch_size 1 \ | ||
| --learning_rate 1e-4 \ | ||
| --num_train_epochs 3 \ | ||
| --loss_scale ignore_empty_think \ | ||
| --gradient_checkpointing false \ | ||
| --lora_rank 8 \ | ||
| --lora_alpha 32 \ | ||
| --target_modules all-linear \ | ||
| --freeze_vit true \ | ||
| --freeze_aligner true \ | ||
| --gradient_accumulation_steps 4 \ | ||
| --eval_steps 100 \ | ||
| --save_steps 100 \ | ||
| --save_total_limit 2 \ | ||
| --logging_steps 5 \ | ||
| --max_length 4096 \ | ||
| --output_dir output \ | ||
| --warmup_ratio 0.05 \ | ||
| --dataset_num_proc 4 \ | ||
| --deepspeed zero2 \ | ||
| --dataloader_num_workers 4 | ||
|
|
||
| CUDA_VISIBLE_DEVICES=0 \ | ||
| swift infer \ | ||
| --adapters output/vx-xxx/checkpoint-xxx \ | ||
| --load_data_args true \ | ||
| --enable_thinking false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,14 +5,16 @@ | |
| from dataclasses import dataclass, field | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| from swift.utils import upper_bound | ||
| from swift.utils import get_logger, upper_bound | ||
| from ..base import Template | ||
| from ..constant import LLMTemplateType, MLLMTemplateType | ||
| from ..register import TemplateMeta, register_template | ||
| from ..template_inputs import StdTemplateInputs | ||
| from ..utils import Context, Prompt, Word, findall | ||
| from ..vision_utils import load_audio | ||
|
|
||
| logger = get_logger() | ||
|
|
||
|
|
||
| @dataclass | ||
| class GemmaTemplateMeta(TemplateMeta): | ||
|
|
@@ -266,7 +268,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int | |
|
|
||
| def _get_system(self, inputs: StdTemplateInputs) -> Optional[str]: | ||
| system = super()._get_system(inputs) | ||
| if self._get_enable_thinking(inputs): | ||
| if not self.is_training and self._get_enable_thinking(inputs): | ||
| system = '<|think|>\n' + (system or '') | ||
| return system | ||
|
|
||
|
|
@@ -298,7 +300,10 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | |
|
|
||
| idx_list = [] | ||
| for key in ['image', 'video', 'audio']: | ||
| idx_list += findall(input_ids, getattr(self.config, f'{key}_token_id')) | ||
| token_id = getattr(self.config, f'{key}_token_id', None) | ||
| if token_id is None: | ||
| continue | ||
| idx_list += findall(input_ids, token_id) | ||
| sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i]) | ||
| idx_list = [idx_list[i] for i in sorted_order] | ||
| splited_tokens = [splited_tokens[i] for i in sorted_order] | ||
|
Comment on lines
302
to
309
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential mismatch/alignment bug when any of the modality token IDs (e.g., active_splited_tokens = []\n token_offset = 0\n for key in ['image', 'video', 'audio']:\n token_id = getattr(self.config, f'{key}_token_id', None)\n num_items = len(getattr(inputs, f'{key}s') or [])\n if token_id is not None:\n idx_list += findall(input_ids, token_id)\n active_splited_tokens += splited_tokens[token_offset:token_offset + num_items]\n token_offset += num_items\n sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i])\n idx_list = [idx_list[i] for i in sorted_order]\n splited_tokens = [active_splited_tokens[i] for i in sorted_order] |
||
|
|
@@ -357,3 +362,76 @@ class Gemma4TemplateMeta(TemplateMeta): | |
| agent_template='gemma4', | ||
| is_thinking=True, | ||
| non_thinking_prefix='<|channel>thought\n<channel|>')) | ||
|
|
||
|
|
||
| class DiffusionGemmaTemplate(Gemma4Template): | ||
| is_encoder_decoder = True | ||
| skip_prompt = True | ||
|
|
||
| @property | ||
| def loss_scale(self): | ||
| loss_scale = super().loss_scale | ||
| if self.is_training and loss_scale.base_strategy != 'last_round': | ||
| logger.warning_once('DiffusionGemmaTemplate only supports the `last_round` base strategy for loss scaling. ' | ||
| 'Setting loss_scale.base_strategy to `last_round`.') | ||
| loss_scale.base_strategy = 'last_round' | ||
| return loss_scale | ||
|
|
||
| def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: | ||
| inputs = super()._data_collator(batch, padding_to=padding_to) | ||
| if self.is_training: | ||
| inputs = self._update_inputs(inputs) | ||
| return inputs | ||
|
|
||
| # Code reference: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/DiffusionGemma_(26B-A4B)-Sudoku.ipynb # noqa | ||
| def _update_inputs(self, inputs): | ||
| canvas_length = self.config.canvas_length | ||
| if inputs['labels'].shape[0] > 1: | ||
| raise ValueError('per_device_train_batch_size must be 1 for diffusion gemma') | ||
| first_idx = (inputs['labels'] != -100).int().argmax().item() | ||
| prompt_ids = inputs['input_ids'][:, :first_idx] | ||
| # reserve one slot at the end of the canvas for the explicit eos token expected by | ||
| # the diffusion sampler as the termination signal. | ||
| response_length = inputs['input_ids'].shape[1] - first_idx | ||
| if response_length > canvas_length - 1: | ||
| raise ValueError(f'response length ({response_length}) exceeds canvas_length-1 ({canvas_length - 1}); ' | ||
| 'please use a shorter response or increase canvas_length.') | ||
| canvas_content = inputs['input_ids'][:, first_idx:first_idx + canvas_length - 1] | ||
| # x0: clean canvas padded to canvas_length; loss is only computed on response + eos. | ||
| device = prompt_ids.device | ||
| eos_token_id = self.tokenizer.eos_token_id | ||
| pad_token_id = self.tokenizer.pad_token_id | ||
| x0 = torch.full((prompt_ids.shape[0], canvas_length), pad_token_id, dtype=torch.long, device=device) | ||
| n = canvas_content.shape[1] | ||
| x0[:, :n] = canvas_content | ||
| # explicitly append eos as the canvas-end signal expected by the diffusion sampler. | ||
| # without it, sampler keeps denoising the trailing positions during inference and emits garbage. | ||
| x0[:, n] = eos_token_id | ||
| labels = x0.clone() | ||
| labels[:, n + 1:] = -100 | ||
|
|
||
| # forward diffusion: per-sample noise level t ∈ [min, max], replace tokens with random vocab ids | ||
| t = torch.empty((), device=device).uniform_(0.1, 1.) | ||
| noise_mask = torch.rand(canvas_length, device=device) < t | ||
| random_tokens = torch.randint(0, self.config.text_config.vocab_size, (canvas_length, ), device=device) | ||
|
Jintao-Huang marked this conversation as resolved.
|
||
| decoder_input_ids = torch.where(noise_mask, random_tokens, x0) | ||
| return {'input_ids': prompt_ids, 'decoder_input_ids': decoder_input_ids, 'labels': labels} | ||
|
|
||
| def compute_sft_loss(self, model, inputs: Dict[str, Any], num_items_in_batch: Optional[int] = None, trainer=None): | ||
| if trainer.args.gradient_checkpointing: | ||
| raise ValueError('Gradient checkpointing is not supported for diffusion gemma') | ||
| outputs = model(**inputs) | ||
| logits = outputs.logits.view(-1, outputs.logits.shape[-1]) | ||
| labels = inputs['labels'].view(-1) | ||
| outputs.loss = F.cross_entropy(logits, labels, reduction='sum') | ||
| outputs.loss = outputs.loss / num_items_in_batch | ||
|
Jintao-Huang marked this conversation as resolved.
|
||
| return outputs | ||
|
|
||
|
|
||
| register_template( | ||
| Gemma4TemplateMeta( | ||
| MLLMTemplateType.diffusion_gemma, | ||
| template_cls=DiffusionGemmaTemplate, | ||
| agent_template='gemma4', | ||
| is_thinking=True, | ||
| non_thinking_prefix='<|channel>thought\n<channel|>')) | ||
Uh oh!
There was an error while loading. Please reload this page.