diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index ac6c112c08c..be426fa984d 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -554,6 +554,30 @@ def export_and_lower(model, config, args): _export_cuda(model, config, args) +def _strip_sampler_from_forward(model): + """Bind ``model.forward`` to a minimal ``(tokens, input_pos) -> logits`` + variant for non-CUDA export. + + The default ``Qwen35MoE.forward`` carries an optional temperature input and + a sampling branch used only by the on-device CUDA sampler; non-CUDA + backends sample on the host so that branch is dead code at trace time. + Even when statically eliminated, the extra parameter and branch perturb + the program ``torch.export`` produces enough to shift kernel selection in + the lowered MLX/Metal graph and slow execution by 10-30%. Eager callers + and the CUDA export path are unaffected. + """ + import types + + def _clean_forward(self, tokens, input_pos): + x = self.embed_tokens(tokens) + for layer in self.layers: + x = layer(x, input_pos) + x = self.norm(x) + return self.lm_head(x) + + model.forward = types.MethodType(_clean_forward, model) + + def _export_mlx(model, config, args): """Export model to .pte via torch.export + MLX backend.""" import gc @@ -568,6 +592,8 @@ def _export_mlx(model, config, args): from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export + _strip_sampler_from_forward(model) + example_tokens = torch.tensor([[0, 1]], dtype=torch.long) example_input_pos = torch.tensor([0, 1], dtype=torch.long) seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) @@ -650,6 +676,7 @@ def _export_metal(model, config, args): inductor_config.coordinate_descent_tuning = False inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + _strip_sampler_from_forward(model) # --- Decode method (T=1, static shape) --- print("Exporting decode method...") diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index c5024890645..e642cc585fb 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -25,6 +25,8 @@ #ifdef EXECUTORCH_BUILD_CUDA #include +#else +#include #endif DEFINE_string(model_path, "", "Model .pte file path."); @@ -37,7 +39,10 @@ DEFINE_string( "Path to file containing prompt text (overrides --prompt)."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); -DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method."); +DEFINE_bool( + cuda_graph, + false, + "Enable CUDA graph for decode method. CUDA only."); namespace llm = ::executorch::extension::llm; using ::executorch::extension::from_blob; @@ -48,10 +53,18 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; -// Read a sampled token from the model output tensor [B, 1]. -// The model performs Gumbel-max sampling on-device and returns a single -// float token ID. This function copies it from GPU and casts to uint64. +// Convert a model output tensor to the next sampled token id. +// +// On the CUDA build, the model fuses the sampler in (see sampler.py / +// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1] +// float tensor; we just copy that scalar back from device. +// +// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits +// of shape [B, T, V] in the model dtype (typically bf16). We sample on +// CPU via the shared `llm::logits_to_token` helper, which accepts a +// temperature (0 = greedy / argmax). static uint64_t read_token(const executorch::aten::Tensor& output) { +#ifdef EXECUTORCH_BUILD_CUDA const void* ptr = output.const_data_ptr(); cudaPointerAttributes attrs; @@ -73,6 +86,13 @@ static uint64_t read_token(const executorch::aten::Tensor& output) { memcpy(&val, ptr, sizeof(float)); } return static_cast(val); +#else + // logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 / + // UInt16 dtypes. Negative temperatures are clamped to 0 (greedy). + const float temp = + FLAGS_temperature <= 0.0 ? 0.0f : static_cast(FLAGS_temperature); + return static_cast(llm::logits_to_token(output, temp)); +#endif } int main(int argc, char** argv) { @@ -133,6 +153,7 @@ int main(int argc, char** argv) { } auto metadata = metadata_result.get(); +#ifdef EXECUTORCH_BUILD_CUDA // Set CUDA graph option if requested (must be before load_method) if (FLAGS_cuda_graph) { executorch::runtime::BackendOptions<2> cuda_opts; @@ -140,9 +161,15 @@ int main(int argc, char** argv) { executorch::runtime::set_option("CudaBackend", cuda_opts.view()); printf("CUDA graph enabled for decode method\n"); } +#else + if (FLAGS_cuda_graph) { + ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + } +#endif printf("Loading methods...\n"); +#ifdef EXECUTORCH_BUILD_CUDA // Enable cross-method per-FQN weight sharing in the CUDA backend so that // prefill and decode (which share KV cache and other mutable buffers / // weights) avoid duplicate GPU allocations. This is critical for fitting @@ -170,6 +197,7 @@ int main(int argc, char** argv) { return 1; } } +#endif auto err = module->load_method("prefill"); if (err != Error::Ok) { @@ -224,12 +252,16 @@ int main(int argc, char** argv) { // --------------------------------------------------------------- auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - // Use a very small temperature for greedy to avoid division by zero - // while keeping the Gumbel noise negligible relative to logit differences. +#ifdef EXECUTORCH_BUILD_CUDA + // CUDA build: model fuses the sampler in. Pass a temperature tensor as + // a third input. Use a very small temperature for greedy to avoid + // division by zero while keeping the Gumbel noise negligible relative + // to logit differences. float temp_val = FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); auto temp_tensor = from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#endif // --------------------------------------------------------------- // Prefill @@ -260,7 +292,9 @@ int main(int argc, char** argv) { std::vector prefill_inputs; prefill_inputs.push_back(tokens_tensor); prefill_inputs.push_back(pos_tensor); +#ifdef EXECUTORCH_BUILD_CUDA prefill_inputs.push_back(temp_tensor); +#endif auto prefill_result = module->execute(run_method, prefill_inputs); if (prefill_result.error() != Error::Ok) { @@ -308,7 +342,9 @@ int main(int argc, char** argv) { std::vector decode_inputs; decode_inputs.push_back(EValue(decode_tokens)); decode_inputs.push_back(EValue(decode_pos)); +#ifdef EXECUTORCH_BUILD_CUDA decode_inputs.push_back(EValue(temp_tensor)); +#endif auto decode_result = module->execute("decode", decode_inputs); if (decode_result.error() != Error::Ok) { diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 81c093f5652..a8d72febcf2 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -21,7 +21,6 @@ import torch import torch.nn as nn - from executorch.examples.models.qwen3_5_moe.sampler import sample from torch.nn import functional as F @@ -186,7 +185,6 @@ def _apply_rotary(x, cos, sin): class KVCache(nn.Module): - def __init__(self, n_kv_heads, head_dim, max_seq_len): super().__init__() self.register_buffer( @@ -207,7 +205,6 @@ def update(self, input_pos, k_val, v_val): class FullAttention(nn.Module): - def __init__(self, config): super().__init__() self.n_heads = config.num_attention_heads @@ -318,7 +315,6 @@ def forward(self, x, input_pos): class GatedDeltaNet(nn.Module): - def __init__(self, config): super().__init__() self.num_k_heads = config.linear_num_key_heads @@ -540,7 +536,6 @@ def forward(self, x): class SparseMoE(nn.Module): - def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -574,7 +569,6 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, config, layer_idx): super().__init__() self.layer_type = config.layer_types[layer_idx] @@ -599,7 +593,6 @@ def forward(self, x, input_pos): class Qwen35MoE(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -620,12 +613,8 @@ def forward( for layer in self.layers: x = layer(x, input_pos) x = self.norm(x) - # When no sampling is requested, return the full ``[B, T, V]`` - # logits so callers (eval, custom samplers) can inspect every - # position. Otherwise apply the prefill optimization and only - # materialize ``[B, V]`` for the last token. if temperature is None: - return self.lm_head(x).float() # [B, T, V] float32 + return self.lm_head(x) # [B, T, V] in model dtype logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32 # GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is # equivalent to drawing from softmax(logits/T) but stays entirely