Skip to content

Gemma 4 builder scaffolding (Part 1 of N, #2062)#2088

Draft
jaburges wants to merge 1 commit intomicrosoft:mainfrom
jaburges:feat/gemma4-builder-scaffolding
Draft

Gemma 4 builder scaffolding (Part 1 of N, #2062)#2088
jaburges wants to merge 1 commit intomicrosoft:mainfrom
jaburges:feat/gemma4-builder-scaffolding

Conversation

@jaburges
Copy link
Copy Markdown

Description

Scaffolds Gemma 4 support in the Python model builder, addressing part of #2062. Intentionally opened as a Draft because this is Part 1 of N — the builder-side work only. See "Scope" below for exactly what is and isn't covered, and why a functional end-to-end path requires paired C++ runtime changes.

Credit to @elbruno for the original architectural write-up in #2062; this PR follows that decomposition.

Motivation

Google released the Gemma 4 family (E2B / E4B / 26B-A4B / 31B) in April 2026. All variants use a new architecture with three features that don't fit GenAI's current assumptions:

  1. Per-Layer Embeddings (PLE)embed_tokens emits two tensors; each decoder layer consumes its own slice.
  2. Variable attention head dimension — sliding layers use head_dim=256, full-attention layers use global_head_dim=512.
  3. KV cache sharing — E2B has 35 decoder layers but only 15 own unique KV caches (num_kv_shared_layers=20).

Existing workarounds (patching builder.py to route through Gemma 3, or using the onnx-community/gemma-4-E2B-it-ONNX export) fail with shape errors or I/O-contract mismatches. See #2062 for the full analysis.

Scope

✅ Covered by this PR

  • builders/gemma4.py (new): Gemma4TextModel(Gemma3Model) overriding:
    • is_local(layer_id) driven by config.layer_types (Gemma 3 used a hard-coded every-6th-layer rule; Gemma 4 exposes the full list).
    • make_key_value_cache_shape(layer_id, shape) using the existing per-layer extension point to emit head_dim=256 on sliding layers and global_head_dim=512 on full-attention layers.
    • make_rotary_embedding_multi_cache() building the two RoPE caches Gemma 4 needs (sliding: theta=10_000, partial_rotary_factor=1.0; full: theta=1_000_000, partial_rotary_factor=0.25, rope_type=proportional) from config.rope_parameters.
  • builder.py: dispatch for Gemma4ForConditionalGeneration and Gemma4ForCausalLM, with text_config flattening and exclude_embeds=true defaulting, matching the existing Gemma 3 pattern.
  • builders/__init__.py: exports Gemma4TextModel.
  • test/python/test_gemma4_builder.py (new): 10 unit tests for the Gemma 4-specific overrides. Runs in a clean Python env — no torch / transformers / onnxruntime needed — by stubbing the Gemma3Model parent.

❌ Not covered — requires C++ runtime changes (Parts 2/3/4)

These are explicitly out of scope. The builder raises NotImplementedError listing exactly which ones were tripped, so users get an actionable error rather than a broken graph:

  1. Per-Layer Embeddings routing. Needs a new per_layer_inputs side-channel input both in the ONNX graph (new make_inputs_and_outputs branch) and in the C++ model loader. Touches src/models/model.cc and friends.
  2. KV cache sharing. src/models/kv_cache.cpp currently allocates one past/present pair per layer. Gemma 4 needs layer aliasing (20 of 35 layers reuse an earlier layer's KV).
  3. Per-layer head_size in the runtime. Even with this builder emitting correct per-layer past_kv shapes, kv_cache.cpp:17 reads a single model_.config_->model.decoder.head_size and applies it uniformly. The genai_config.json schema and the KV cache manager need per-layer awareness, or a head_size + global_head_size split driven by the same layer_types pattern the builder already understands.

Until those land, this builder supports --extra_options config_only=true only. That's still useful for validating the dispatcher and for downstream tooling that only needs genai_config.json / processing files. The gate makes that explicit rather than silently emitting a broken graph.

Verification

cd test/python && python test_gemma4_builder.py -v
test_gate_quiet_in_config_only_mode ... ok
test_gate_raises_on_kv_sharing ... ok
test_gate_raises_on_ple ... ok
test_gate_raises_on_variable_head_dim ... ok
test_is_local_falls_back_when_layer_types_missing ... ok
test_is_local_uses_layer_types ... ok
test_kv_cache_shape_preserves_non_head_dim_axes ... ok
test_kv_cache_shape_uses_full_head_dim_on_full_layers ... ok
test_kv_cache_shape_uses_sliding_head_dim_on_sliding_layers ... ok
test_rope_caches_are_emitted_for_both_attention_types ... ok

Ran 10 tests in 0.009s
OK

The tests pin down:

  • is_local matches the expected E2B pattern — full layers at exactly 4, 9, 14, 19, 24, 29, 34.
  • make_key_value_cache_shape emits 512 on those seven layers and 256 everywhere else.
  • The unsupported-feature gate raises individually for each of the three runtime blockers.
  • config_only=true mode bypasses the gate cleanly.
  • Both RoPE caches are emitted with the correct theta and partial-rotary factor.

Request for maintainer review

Three specific things I'd welcome feedback on before taking Parts 2/3 further:

  1. PR decomposition. Happy to split Parts 2/3/4 into separate PRs (runtime PLE, runtime KV sharing, runtime per-layer head_size) or keep them together — whichever matches your preference.
  2. Config schema. For per-layer head_size, do you prefer (a) a head_size list in genai_config.json, (b) a head_size + global_head_size pair plus a layer_types-like pattern, or (c) keeping both values in the model metadata only and teaching kv_cache.cpp to read them from the ONNX graph? I have mild preference for (b) for symmetry with how is_local already works.
  3. RoPE caches at different head dims. Currently the builder emits both RoPE caches sized to the sliding head_dim. Once per-layer head_size support lands in the runtime, the builder should emit the full-attention RoPE cache at global_head_dim — open to suggestions on the cleanest way to thread that through make_rotary_embedding_caches without duplicating logic.

References

Scaffolds Gemma 4 support in the Python model builder. This is intentionally
Part 1 of N: it wires the dispatcher, introduces `Gemma4TextModel`, and
implements the Gemma 4-specific builder hooks that can be handled entirely
on the builder side. Features that require C++ runtime changes (per-layer
head_size, Per-Layer Embeddings, KV cache sharing) raise a clear
`NotImplementedError` unless `config_only=true`.

What this change covers
-----------------------

* `builders/gemma4.py` (new): `Gemma4TextModel(Gemma3Model)` with:
  - `is_local(layer_id)` driven by `config.layer_types` instead of
    Gemma 3's hard-coded every-6th rule. Gemma 4 exposes the full list
    (28x sliding + 7x full on E2B).
  - `make_key_value_cache_shape(layer_id, shape)` overridden to emit
    `head_dim=256` for sliding layers and `global_head_dim=512` for
    full-attention layers via the existing per-layer extension point.
  - `make_rotary_embedding_multi_cache()` that builds both RoPE caches
    (sliding: theta=10_000, partial_rotary=1.0; full: theta=1_000_000,
    partial_rotary=0.25, rope_type=proportional) from
    `config.rope_parameters`.
  - Explicit unsupported-feature gate that raises `NotImplementedError`
    with an actionable message when PLE, KV cache sharing, or variable
    head dim are actually required (i.e. not in `config_only` mode).

* `builder.py` / `builders/__init__.py`: dispatch for
  `Gemma4ForConditionalGeneration` and `Gemma4ForCausalLM`. Flattens
  `text_config` onto the top-level config the way the Gemma 3 branch
  does, sets `exclude_embeds=true`, stamps `model_type="gemma4_text"`.

* `test/python/test_gemma4_builder.py` (new): 10 unit tests for the
  Gemma 4-specific overrides. Uses a stubbed `Gemma3Model` parent so the
  tests run in a clean Python env without torch/transformers/onnxruntime.

What this change does NOT cover
-------------------------------

These are explicitly out of scope for Part 1 because each requires a
paired C++ runtime change:

1. Per-Layer Embeddings (PLE) routing. Gemma 4's `embed_tokens` emits
   `per_layer_inputs` [B, S, num_hidden_layers, hidden_size_per_layer_input]
   alongside `inputs_embeds`. Each decoder layer consumes its own slice.
   Needs a new side-channel input both in the ONNX graph and in the C++
   model loader.

2. KV cache sharing. Gemma 4 E2B has 35 layers but only 15 own unique
   KV caches (`num_kv_shared_layers=20`). `kv_cache.cpp` currently
   allocates one pair per layer.

3. Per-layer `head_size` in the C++ runtime. Even with this builder
   emitting correct per-layer past_kv shapes, `kv_cache.cpp` allocates
   using a single `model.decoder.head_size` from `genai_config.json`
   for all layers. The config schema and the runtime's KV cache manager
   need to grow per-layer awareness (or a `head_size` + `global_head_size`
   split driven by the same `layer_types` pattern).

The builder raises `NotImplementedError` listing exactly which of the
above gaps were tripped, so downstream consumers get an actionable
error rather than a broken graph.

Verification
------------

`python test/python/test_gemma4_builder.py -v` passes 10/10:
  - is_local correctly reflects `layer_types` for every layer 0..34
  - make_key_value_cache_shape emits 512 for layers 4/9/14/19/24/29/34
    and 256 for every other layer
  - The unsupported-feature gate raises for PLE, KV sharing, and variable
    head dim in isolation
  - config_only mode bypasses the gate cleanly
  - Both RoPE caches are emitted with the correct theta and partial
    rotary factor

References
----------

* Issue: microsoft#2062
* Gemma 4 E2B config:
  https://huggingface.co/google/gemma-4-E2B-it/blob/main/config.json
* Community ONNX export (architectural reference only; uses an
  incompatible I/O contract): https://huggingface.co/onnx-community/gemma-4-E2B-it-ONNX

Credit to @elbruno for the original architectural analysis in microsoft#2062.

Made-with: Cursor
@jaburges
Copy link
Copy Markdown
Author

@microsoft-github-policy-service agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant