Add HunYuan Dense V1 (hunyuan_v1_dense) model support#2045
Add HunYuan Dense V1 (hunyuan_v1_dense) model support#2045amdrajeevp1 wants to merge 6 commits intomicrosoft:mainfrom
Conversation
Adds builder and runtime support for tencent/HY-MT series models (HunYuanDenseV1ForCausalLM). Key implementation details: - New HunyuanDenseV1Model builder (src/python/py/models/builders/hunyuan.py) that overrides make_attention_qk_subgraph to apply QK norms AFTER RoPE, matching the model's architecture (standard base class applies QK norm before). - Dynamic NTK-alpha RoPE scaling baked into static rope_theta at export time: effective_theta = rope_theta * alpha^(head_dim/(head_dim-2)) - Forces disable_qkv_fusion and use_rope_in_attn=False to enable separate Q/K paths required for post-RoPE QK norm insertion. - Registers "hunyuan_v1_dense" in the LLM array in model_type.h (size 21->22). - Includes example inference script using HF tokenizer for correct special-token handling. Supports 1.8B and 7B variants of tencent/HY-MT1.5 (same architecture class). Requires transformers>=4.57 for native HunYuanDenseV1 support. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| if "position_ids" not in self.input_names: | ||
| self.input_names.append("position_ids") | ||
|
|
||
| self.model_type = "hunyuan_v1_dense" |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds ONNX export-time builder support (Python) and runtime model-type registration (C++) for Tencent HY-MT / HunYuan Dense V1 models (hunyuan_v1_dense), including post-RoPE QK norm ordering and NTK-alpha RoPE theta baking.
Changes:
- Introduces
HunyuanDenseV1Modelbuilder with post-RoPE Q/K LayerNorm placement and export-time RoPE theta adjustment. - Wires the new builder into the Python model factory and builders package exports.
- Registers
hunyuan_v1_denseas an LLM type in the C++ runtime and adds a Python example script.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/python/py/models/builders/hunyuan.py |
New builder implementing HunYuan Dense V1 architecture quirks (post-RoPE QK norm, baked RoPE theta, fusion toggles). |
src/python/py/models/builders/__init__.py |
Exposes the new builder via package imports/__all__. |
src/python/py/models/builder.py |
Routes HF architecture HunYuanDenseV1ForCausalLM to the new builder. |
src/models/model_type.h |
Registers hunyuan_v1_dense as an LLM model type for runtime classification. |
examples/python/test_hy_mt.py |
Adds a minimal inference example using an HF tokenizer for prompt encoding. |
| attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" | ||
| self.make_attention_op( | ||
| attn_name, | ||
| q_path=self.attention_attrs["q_path"], | ||
| k_path=self.attention_attrs["k_path"], | ||
| v_path=self.attention_attrs["v_path"], | ||
| past_k=past_k, | ||
| past_v=past_v, | ||
| present_k=present_k, | ||
| present_v=present_v, | ||
| cos_cache=cos_cache_name, | ||
| sin_cache=sin_cache_name, | ||
| sinks=sinks_name, | ||
| **kwargs, |
There was a problem hiding this comment.
make_attention_op is called without root_input, which is required when the selected attention op type is the packed Attention kernel. Also, this override assumes q_path/k_path/v_path exist, which is not true when use_matmul_in_attn is enabled for packed Attention. Either explicitly force this model off the packed Attention path (e.g., ensure op_type is MultiHeadAttention/GroupQueryAttention) or handle the packed Attention case separately and pass the required kwargs.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| # Disable QKV fusion so separate q_path/k_path/v_path are created in | ||
| # make_attention_input_proj — required so our override can apply QK norms | ||
| # on individual Q and K paths after RoPE. | ||
| extra_options = {**extra_options, "disable_qkv_fusion": True} |
There was a problem hiding this comment.
Can we find a way to auto-infer this information rather than setting any extra options? For example, we can turn on q_norm and k_norm since there is a QK norm subgraph instead of explicitly setting any extra options.
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 516 to 526 in 537b676
| # GQA fuses RoPE inside the attention op (use_rope_in_attn=True) which makes | ||
| # it impossible to insert QK norms between RoPE output and the attention op. | ||
| # Force explicit RotaryEmbedding nodes so our override can place QK norms after them. | ||
| if self.attention_attrs.get("use_rope_in_attn", False): |
There was a problem hiding this comment.
Can we just set use_rope_in_attn to false before we run make_attention_init instead?
| if "position_ids" not in self.input_names: | ||
| self.input_names["position_ids"] = "position_ids" | ||
|
|
||
| self.model_type = "hunyuan_v1_dense" |
There was a problem hiding this comment.
Can we use the auto-generated model type and remove this? From the config, it seems it will be hunyuandensev1.
|
|
||
| # Step 4: Sinks (attention sink tokens, rarely used) | ||
| sinks_name = "" | ||
| if self.attention_attrs["sinks"]: |
There was a problem hiding this comment.
This was for GPT-OSS and is not needed here.
| The query_layernorm / key_layernorm weight attributes are aliased to | ||
| q_norm / k_norm so the existing make_qk_norm() infrastructure can be reused. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Instead of rewriting the entire method, can we instead rewrite the following
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 2965 to 2989 in 537b676
into something such as
def make_attention_qk_norm(self, layer_id, attention):
# Make Q/K SimplifiedLayerNorm nodes
if self.attention_attrs["q_norm"] and self.attention_attrs["k_norm"]:
self.make_qk_norm(layer_id, attention)
def make_attention_qk_rope(self, layer_id, **kwargs):
# Make RotaryEmbedding nodes
cos_cache_name, sin_cache_name = "", ""
if self.attention_attrs["rope"]:
if self.attention_attrs["use_rope_in_attn"]:
cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches()
else:
q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding"
self.make_rotary_embedding(
q_rotary_name,
root_input=self.attention_attrs["q_path"],
position_ids=kwargs.get("position_ids", self.input_names["position_ids"]),
)
self.attention_attrs["q_path"] = f"{q_rotary_name}/output_0"
k_rotary_name = f"/model/layers.{layer_id}/attn/k_rotary/RotaryEmbedding"
self.make_rotary_embedding(
k_rotary_name,
root_input=self.attention_attrs["k_path"],
position_ids=kwargs.get("position_ids", self.input_names["position_ids"]),
)
self.attention_attrs["k_path"] = f"{k_rotary_name}/output_0"
def make_attention_qk_rope_and_norm(self, layer_id, attention, **kwargs):
self.make_attention_qk_norm(layer_id, attention)
self.make_attention_qk_rope(layer_id, **kwargs)
def make_attention_qk_subgraph(self, layer_id, attention, root_input, **kwargs):
self.make_attention_qk_rope_and_norm(layer_id, attention, **kwargs)
...and just override the make_attention_qk_rope_and_norm method instead?
Summary
Adds model builder and runtime support for Tencent's HY-MT series
(architecture class:
HunYuanDenseV1ForCausalLM), covering both the 1.8Band 7B parameter variants.
Changes
src/python/py/models/builders/hunyuan.py— NewHunyuanDenseV1Modelbuilder subclassing
Model(base). Key overrides:make_attention_qk_subgraphto apply QK norms (query/keyLayerNorm) after RoPE — the correct order for this architecture
(base class applies them before).
rope_thetaat export time:effective_theta = rope_theta × α^(head_dim/(head_dim−2))then clears
rope_scalingso the standard RoPE codepath is used.disable_qkv_fusion=Trueanduse_rope_in_attn=Falsetocreate the separate Q/K paths required for post-RoPE QK norm insertion.
src/models/model_type.h— Registers"hunyuan_v1_dense"in theLLM model type array (21 → 22 entries).
src/python/py/models/builder.pyandbuilders/__init__.py—Wire
HunyuanDenseV1Modelunder thehunyuan_v1_densemodel type key.examples/python/test_hy_mt.py— Example inference script using HFtokenizer for correct special-token handling.
Architecture Notes
HunYuan Dense V1 differs from Llama-style models in two ways:
position embedding, not before.
at export time to avoid runtime overhead.
All weight names are standard (no custom mapping needed).
Requirement:
transformers >= 4.57forHunYuanDenseV1ForCausalLM.