Skip to content
47 changes: 47 additions & 0 deletions examples/models/gemma4/diffusion_gemma.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# 2 * 60GiB
# This is just a demo for DiffusionGemma training.
# Notes:
# 1. Currently only --per_device_train_batch_size 1 is supported,
# and the response length of a single sample must be less than config.canvas_length.
# 2. --gradient_checkpointing false must be set. DiffusionGemma's encoder passes
# KV to the decoder via DynamicCache, and gradient checkpointing causes errors
# when recomputing the forward pass during backward.
# 3. For customizing the specific training loss, refer to:
# https://github.com/Jintao-Huang/llmscope/blob/7549e98709cfe8a9d6866ccfa560975b2c7bd375/swift/template/templates/gemma.py#L386-L428
CUDA_VISIBLE_DEVICES=0,1 \
NPROC_PER_NODE=2 \
swift sft \
--model google/diffusiongemma-26B-A4B-it \
--dataset 'sapientinc/sudoku-extreme-1k' \
--load_from_cache_file true \
--split_dataset_ratio 0.01 \
--tuner_type lora \
--torch_dtype bfloat16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--loss_scale ignore_empty_think \
--gradient_checkpointing false \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
--freeze_aligner true \
--gradient_accumulation_steps 4 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 4096 \
--output_dir output \
--warmup_ratio 0.05 \
--dataset_num_proc 4 \
--deepspeed zero2 \
--dataloader_num_workers 4

CUDA_VISIBLE_DEVICES=0 \
swift infer \
--adapters output/vx-xxx/checkpoint-xxx \
--load_data_args true \
--enable_thinking false
2 changes: 1 addition & 1 deletion examples/train/grpo/internal/qlora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ swift rlhf \
--vllm_max_model_len 10240 \
--vllm_enable_lora true \
--sleep_level 1 \
--train_type lora \
--tuner_type lora \
--quant_method bnb \
--quant_bits 4 \
--bnb_4bit_quant_type nf4 \
Expand Down
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sortedcontainers>=1.5.9
tensorboard
tiktoken
tqdm
transformers>=4.33,<5.11.0
transformers>=4.33,<5.13.0
transformers_stream_generator
trl>=0.15,<1.0
uvicorn
Expand Down
26 changes: 26 additions & 0 deletions swift/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,29 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
hf_dataset_id='open-r1/DAPO-Math-17k-Processed',
subsets=['all'],
tags=['math', 'rlvr']))


class SudokuPreprocessor(ResponsePreprocessor):
prompt = ('Solve the following 9x9 Sudoku puzzle. '
"Empty cells are marked with '0'. "
'Provide the completed grid as your answer.\n\n'
'Puzzle:\n{puzzle}')

@staticmethod
def _format_grid(s: str) -> str:
return '\n'.join(s[i:i + 9] for i in range(0, len(s), 9))

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
puzzle = row['query'].replace('.', '0')
response = row['response']
puzzle = self._format_grid(puzzle)
response = self._format_grid(response)
return super().preprocess({'query': self.prompt.format(puzzle=puzzle), 'response': response})


register_dataset(
DatasetMeta(
ms_dataset_id='sapientinc/sudoku-extreme-1k',
hf_dataset_id='sapientinc/sudoku-extreme-1k',
preprocess_func=SudokuPreprocessor(),
))
4 changes: 2 additions & 2 deletions swift/infer_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ def prepare_generation_config(model_generation_config: Optional[GenerationConfig
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']:
new_value = getattr(request_config, key)
if new_value is None:
kwargs[key] = getattr(model_generation_config, key)
kwargs[key] = getattr(model_generation_config, key, None)
else:
kwargs[key] = new_value

if kwargs.get('top_k') is not None and kwargs['top_k'] <= 0:
kwargs['top_k'] = None

if not model_generation_config.do_sample and request_config.temperature in {0, None}:
if not getattr(model_generation_config, 'do_sample', False) and request_config.temperature in {0, None}:
kwargs['temperature'] = 0
if kwargs['temperature'] == 0:
kwargs['do_sample'] = False
Expand Down
1 change: 1 addition & 0 deletions swift/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class MLLMModelType:
gemma3n = 'gemma3n'
gemma4 = 'gemma4'
gemma4_unified = 'gemma4_unified'
diffusion_gemma = 'diffusion_gemma'
mistral3 = 'mistral3'
mistral3_2506 = 'mistral3_2506'
paddle_ocr = 'paddle_ocr'
Expand Down
9 changes: 9 additions & 0 deletions swift/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class MLLMModelArch:
valley = 'valley'
gemma3n = 'gemma3n'
gemma4_unified = 'gemma4_unified'
diffusion_gemma = 'diffusion_gemma'
keye_vl = 'keye_vl'

midashenglm = 'midashenglm'
Expand Down Expand Up @@ -759,6 +760,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
aligner=['model.embed_vision', 'model.embed_audio'],
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.diffusion_gemma,
language_model=['model.encoder.language_model', 'model.decoder', 'lm_head'],
vision_tower=['model.encoder.vision_tower'],
aligner=['model.encoder.embed_vision'],
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.keye_vl,
Expand Down
27 changes: 27 additions & 0 deletions swift/model/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,30 @@ def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrain
model_arch=ModelArch.gemma4_unified,
requires=['transformers>=5.10.1'],
))


class DiffusionGemmaLoader(ModelLoader):

def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
from transformers import DiffusionGemmaForBlockDiffusion
self.auto_model_cls = self.auto_model_cls or DiffusionGemmaForBlockDiffusion
model = super().get_model(model_dir, config, processor, model_kwargs)
model.prepare_inputs_for_generation = None
model.config.use_cache = True
return model


register_model(
ModelMeta(
MLLMModelType.diffusion_gemma,
[
ModelGroup([
Model('google/diffusiongemma-26B-A4B-it', 'google/diffusiongemma-26B-A4B-it'),
],
template=TemplateType.diffusion_gemma),
],
DiffusionGemmaLoader,
architectures=['DiffusionGemmaForBlockDiffusion'],
model_arch=ModelArch.diffusion_gemma,
requires=['transformers>=5.11'],
))
2 changes: 1 addition & 1 deletion swift/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _init_generation_config(self, model, model_dir):
model.generation_config = GenerationConfig.from_pretrained(model_dir) if os.path.isfile(
generation_config_path) else None
# fix llama2 warning
if getattr(model, 'generation_config', None):
if getattr(model, 'generation_config', None) and hasattr(model.generation_config, 'do_sample'):
fix_do_sample_warning(model.generation_config)

def _get_model_processor(self, model_dir, config):
Expand Down
13 changes: 11 additions & 2 deletions swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def __init__(
agent_template = agent_template or template_meta.agent_template
self._agent_template = agent_template
self.norm_bbox = norm_bbox or self.norm_bbox
if self.is_encoder_decoder:
self.skip_prompt = False
self.mode: Literal['transformers', 'vllm', 'lmdeploy', 'sglang', 'train', 'rlhf', 'kto'] = 'transformers'
Comment thread
Jintao-Huang marked this conversation as resolved.
self.task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'prm', 'reranker',
'generative_reranker'] = 'causal_lm'
Expand Down Expand Up @@ -767,6 +765,17 @@ def generate(self, model, *args, **kwargs):
kwargs['use_model_defaults'] = False
return model.generate(*args, **kwargs)

def compute_sft_loss(self, model, inputs: Dict[str, Any], num_items_in_batch: Optional[int] = None, trainer=None):
# Default SFT Loss Calculation Method
outputs = model(**inputs)
if 'labels' in inputs:
labels = inputs['labels']
outputs.loss = outputs.loss.to(labels.device)
# fix https://github.com/huggingface/transformers/issues/34263
if num_items_in_batch is not None:
outputs.loss = outputs.loss * ((labels[:, 1:] != -100).sum() / num_items_in_batch)
return outputs

def skip_stop_tokens(self, generate_ids: List[int], is_finished: bool = True) -> List[int]:
# Do not print template_meta.suffix_stop and eos_token.
# However, other stop_words will be printed.
Expand Down
1 change: 1 addition & 0 deletions swift/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class MLLMTemplateType:
gemma3n = 'gemma3n'
gemma4 = 'gemma4'
gemma4_nothinking = 'gemma4_nothinking'
diffusion_gemma = 'diffusion_gemma'
mistral_2503 = 'mistral_2503'
mistral_2506 = 'mistral_2506'
mistral_2512 = 'mistral_2512'
Expand Down
84 changes: 81 additions & 3 deletions swift/template/templates/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional

from swift.utils import upper_bound
from swift.utils import get_logger, upper_bound
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, Word, findall
from ..vision_utils import load_audio

logger = get_logger()


@dataclass
class GemmaTemplateMeta(TemplateMeta):
Expand Down Expand Up @@ -266,7 +268,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int

def _get_system(self, inputs: StdTemplateInputs) -> Optional[str]:
system = super()._get_system(inputs)
if self._get_enable_thinking(inputs):
if not self.is_training and self._get_enable_thinking(inputs):
system = '<|think|>\n' + (system or '')
return system

Expand Down Expand Up @@ -298,7 +300,10 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:

idx_list = []
for key in ['image', 'video', 'audio']:
idx_list += findall(input_ids, getattr(self.config, f'{key}_token_id'))
token_id = getattr(self.config, f'{key}_token_id', None)
if token_id is None:
continue
idx_list += findall(input_ids, token_id)
sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i])
idx_list = [idx_list[i] for i in sorted_order]
splited_tokens = [splited_tokens[i] for i in sorted_order]
Comment on lines 302 to 309

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential mismatch/alignment bug when any of the modality token IDs (e.g., video_token_id or audio_token_id) is None (unsupported by the model) but the corresponding media inputs are still present in inputs.\n\nBecause media_inputs is processed with all media types unconditionally, splited_tokens will contain elements for the unsupported modality. However, since token_id is None, it is skipped when building idx_list. This causes len(idx_list) to be less than len(splited_tokens). When sorting and slicing, the tokens of different modalities will get mismatched (e.g., audio tokens might be used in place of image tokens), leading to silent correctness bugs or runtime errors.\n\nWe can fix this by only keeping the split tokens for modalities that have a valid token_id in the config.

        active_splited_tokens = []\n        token_offset = 0\n        for key in ['image', 'video', 'audio']:\n            token_id = getattr(self.config, f'{key}_token_id', None)\n            num_items = len(getattr(inputs, f'{key}s') or [])\n            if token_id is not None:\n                idx_list += findall(input_ids, token_id)\n                active_splited_tokens += splited_tokens[token_offset:token_offset + num_items]\n            token_offset += num_items\n        sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i])\n        idx_list = [idx_list[i] for i in sorted_order]\n        splited_tokens = [active_splited_tokens[i] for i in sorted_order]

Expand Down Expand Up @@ -357,3 +362,76 @@ class Gemma4TemplateMeta(TemplateMeta):
agent_template='gemma4',
is_thinking=True,
non_thinking_prefix='<|channel>thought\n<channel|>'))


class DiffusionGemmaTemplate(Gemma4Template):
is_encoder_decoder = True
skip_prompt = True

@property
def loss_scale(self):
loss_scale = super().loss_scale
if self.is_training and loss_scale.base_strategy != 'last_round':
logger.warning_once('DiffusionGemmaTemplate only supports the `last_round` base strategy for loss scaling. '
'Setting loss_scale.base_strategy to `last_round`.')
loss_scale.base_strategy = 'last_round'
return loss_scale

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
inputs = super()._data_collator(batch, padding_to=padding_to)
if self.is_training:
inputs = self._update_inputs(inputs)
return inputs

# Code reference: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/DiffusionGemma_(26B-A4B)-Sudoku.ipynb # noqa
def _update_inputs(self, inputs):
canvas_length = self.config.canvas_length
if inputs['labels'].shape[0] > 1:
raise ValueError('per_device_train_batch_size must be 1 for diffusion gemma')
first_idx = (inputs['labels'] != -100).int().argmax().item()
prompt_ids = inputs['input_ids'][:, :first_idx]
# reserve one slot at the end of the canvas for the explicit eos token expected by
# the diffusion sampler as the termination signal.
response_length = inputs['input_ids'].shape[1] - first_idx
if response_length > canvas_length - 1:
raise ValueError(f'response length ({response_length}) exceeds canvas_length-1 ({canvas_length - 1}); '
'please use a shorter response or increase canvas_length.')
canvas_content = inputs['input_ids'][:, first_idx:first_idx + canvas_length - 1]
# x0: clean canvas padded to canvas_length; loss is only computed on response + eos.
device = prompt_ids.device
eos_token_id = self.tokenizer.eos_token_id
pad_token_id = self.tokenizer.pad_token_id
x0 = torch.full((prompt_ids.shape[0], canvas_length), pad_token_id, dtype=torch.long, device=device)
n = canvas_content.shape[1]
x0[:, :n] = canvas_content
# explicitly append eos as the canvas-end signal expected by the diffusion sampler.
# without it, sampler keeps denoising the trailing positions during inference and emits garbage.
x0[:, n] = eos_token_id
labels = x0.clone()
labels[:, n + 1:] = -100

# forward diffusion: per-sample noise level t ∈ [min, max], replace tokens with random vocab ids
t = torch.empty((), device=device).uniform_(0.1, 1.)
noise_mask = torch.rand(canvas_length, device=device) < t
random_tokens = torch.randint(0, self.config.text_config.vocab_size, (canvas_length, ), device=device)
Comment thread
Jintao-Huang marked this conversation as resolved.
decoder_input_ids = torch.where(noise_mask, random_tokens, x0)
return {'input_ids': prompt_ids, 'decoder_input_ids': decoder_input_ids, 'labels': labels}

def compute_sft_loss(self, model, inputs: Dict[str, Any], num_items_in_batch: Optional[int] = None, trainer=None):
if trainer.args.gradient_checkpointing:
raise ValueError('Gradient checkpointing is not supported for diffusion gemma')
outputs = model(**inputs)
logits = outputs.logits.view(-1, outputs.logits.shape[-1])
labels = inputs['labels'].view(-1)
outputs.loss = F.cross_entropy(logits, labels, reduction='sum')
outputs.loss = outputs.loss / num_items_in_batch
Comment thread
Jintao-Huang marked this conversation as resolved.
return outputs


register_template(
Gemma4TemplateMeta(
MLLMTemplateType.diffusion_gemma,
template_cls=DiffusionGemmaTemplate,
agent_template='gemma4',
is_thinking=True,
non_thinking_prefix='<|channel>thought\n<channel|>'))
1 change: 1 addition & 0 deletions swift/template/templates/microsoft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class FlorenceTemplate(Template):
# If it's an encoder-decoder architecture, the default settings are
# loss_scale: 'last_round' and skip_prompt: False.
is_encoder_decoder = True
skip_prompt = False

@staticmethod
def _add_default_tags(inputs: StdTemplateInputs) -> None:
Expand Down
7 changes: 1 addition & 6 deletions swift/trainers/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not '
'take effect, potentially leading to increased GPU memory consumption.')
labels = inputs.pop('labels')
outputs = model(**inputs)
outputs = self.template.compute_sft_loss(model, inputs, num_items_in_batch=num_items_in_batch, trainer=self)
mode = 'train' if self.model.training else 'eval'
if getattr(outputs, 'aux_loss', None) is not None:
self.custom_metrics[mode]['aux_loss'].update(outputs.aux_loss)
Expand All @@ -147,11 +147,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

if labels is None:
labels = inputs['labels']
outputs.loss = outputs.loss.to(labels.device)
# fix https://github.com/huggingface/transformers/issues/34263
if num_items_in_batch is not None:
outputs.loss = outputs.loss * ((labels[:, 1:] != -100).sum() / num_items_in_batch)

if isinstance(outputs, dict) and 'loss' not in outputs:
raise ValueError(
'The model did not return a loss from the inputs, only the following keys: '
Expand Down
Loading