diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index d64612e02af..dd7f9de52d3 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from .annotate_avg_pool1d import AnnotateAvgPool1D +from .annotate_concat_requant import AnnotateConcatRequant from .annotate_quant_attrs import AnnotateQuantAttrs from .annotate_stack import AnnotateStack from .annotate_unbind import AnnotateUnbind @@ -60,6 +61,7 @@ __all__ = [ AnnotateAvgPool1D, + AnnotateConcatRequant, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, diff --git a/backends/qualcomm/_passes/annotate_concat_requant.py b/backends/qualcomm/_passes/annotate_concat_requant.py new file mode 100644 index 00000000000..2bfdcd16b5b --- /dev/null +++ b/backends/qualcomm/_passes/annotate_concat_requant.py @@ -0,0 +1,97 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DTYPE, + QCOM_ENCODING, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, + QCOM_REQUANTIZE, + QCOM_SCALE, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import get_quant_attrs + + +EDGE_CAT_OPS = { + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.concat.default, +} + + +class AnnotateConcatRequant(ExportPass): + """ + Record explicit requantization needs for concat inputs whose concrete + post-calibration qparams do not match concat's output domain. + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + skip_advanced_requant: bool = False, + ): + super(AnnotateConcatRequant, self).__init__() + self.edge_program = edge_program + self.skip_advanced_requant = skip_advanced_requant + + def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]): + if self.skip_advanced_requant: + return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE] + + return any( + src_attrs[attr] != dst_attrs[attr] + for attr in [ + QCOM_SCALE, + QCOM_ZERO_POINT, + QCOM_QUANT_MIN, + QCOM_QUANT_MAX, + QCOM_DTYPE, + ] + ) + + def _annotate_concat_input_requant(self, quant_node: torch.fx.Node) -> None: + cat_node = quant_node.args[0] + if cat_node.target not in EDGE_CAT_OPS: + return + + output_q_attrs = get_quant_attrs(self.edge_program, quant_node) + for input_node in cat_node.args[0]: + if input_node.target not in dq_ops: + continue + + source_q_node = input_node.args[0] + if source_q_node.target not in q_ops: + continue + + source_q_attrs = get_quant_attrs(self.edge_program, source_q_node) + if not self._is_requant_needed(source_q_attrs, output_q_attrs): + continue + + source_node = source_q_node.args[0] + if not isinstance(source_node, torch.fx.Node): + continue + + requant_attrs = output_q_attrs.copy() + requant_attrs[QCOM_ENCODING] = source_q_attrs[QCOM_ENCODING] + source_node.meta.setdefault(QCOM_REQUANTIZE, {}) + source_node.meta[QCOM_REQUANTIZE][cat_node.name] = requant_attrs + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if ( + node.target in q_ops + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].target in EDGE_CAT_OPS + ): + self._annotate_concat_input_requant(node) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 6077d51b099..16f36eebb84 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -79,6 +79,21 @@ def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node: return last_dq_nodes + def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]): + if self.skip_advanced_requant: + return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE] + + return any( + src_attrs[attr] != dst_attrs[attr] + for attr in [ + QCOM_SCALE, + QCOM_ZERO_POINT, + QCOM_QUANT_MIN, + QCOM_QUANT_MAX, + QCOM_DTYPE, + ] + ) + def _annotate_requant(self, n): # Record requant attributes: # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... @@ -96,28 +111,7 @@ def _annotate_requant(self, n): # that has multiple outputs that requires quant attributes. # Determine if requantization is needed based on configuration and attribute mismatch. - is_requant_needed = False - if self.skip_advanced_requant: - # In skip_advanced_requant mode, only consider requant if dtypes differ. - if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]: - is_requant_needed = True - else: - # In full requant mode, consider requant if any key attribute differs. - # This aims to improve accuracy by adjusting scale, zero_point, etc. - # Users can disable this if it causes regressions. - if any( - q_attrs[attr] != dq_attrs[attr] - for attr in [ - QCOM_SCALE, - QCOM_ZERO_POINT, - QCOM_QUANT_MIN, - QCOM_QUANT_MAX, - QCOM_DTYPE, - ] - ): - is_requant_needed = True - - if is_requant_needed: + if self._is_requant_needed(q_attrs, dq_attrs): dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] user_node = list(dq_node.users)[0] n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 71b36cd746a..80d504b48d6 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -10,6 +10,7 @@ from executorch.backends.qualcomm._passes import ( AnnotateAvgPool1D, + AnnotateConcatRequant, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, @@ -99,6 +100,7 @@ def get_capture_program_passes(): default_passes_and_setting = [ (AnnotateAvgPool1D, True), (AnnotateQuantAttrs, True), + (AnnotateConcatRequant, True), (AnnotateStack, True), (AnnotateUnbind, True), (ConvertBmmToMatmul, False), diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 548d91219ef..f40e1c22ee6 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -61,6 +61,7 @@ def get_passes_dependency_for_capture_program(): """ from executorch.backends.qualcomm._passes import ( AnnotateAvgPool1D, + AnnotateConcatRequant, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, @@ -89,6 +90,7 @@ def get_passes_dependency_for_capture_program(): return { AnnotateAvgPool1D: [RemoveRedundancy], + AnnotateConcatRequant: [AnnotateQuantAttrs], AnnotateQuantAttrs: [ ConvertBmmToMatmul, RecomposePixelUnshuffle, @@ -108,9 +110,15 @@ def get_passes_dependency_for_capture_program(): DecomposeTrunc: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], FixedLinearKeepDim: [FoldQDQ], - FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], + FoldQDQ: [ + AnnotateConcatRequant, + AnnotateQuantAttrs, + AnnotateStack, + AnnotateUnbind, + ], I64toI32: [RemoveRedundancy], LayoutTransform: [ + AnnotateConcatRequant, AnnotateQuantAttrs, ExpandBroadcastTensorShape, FixedLinearKeepDim, diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index cd65d02c752..580b6177973 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -12,7 +12,6 @@ import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants import torch - from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( ConcatObserver, ) @@ -235,15 +234,16 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map, input_nodes = {}, node.args[0] - for input in input_nodes: - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) + for input_node in input_nodes: + assert isinstance(input_node, Node) + input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None) qspec = getattr(input_qspec, "output_qspec", None) - # keep shared qspec here for propagation the data range - # without introducing extra requantizations + # Preserve shared upstream qspecs, but derive concat's output domain + # from the merged output range to avoid clipping wider branches. if isinstance(qspec, SharedQuantizationSpec): - input_qspec_map[input] = SharedQuantizationSpec(input) + input_qspec_map[input_node] = SharedQuantizationSpec(input_node) else: - input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[input_node] = quantization_config.input_activation output_qspec = QuantizationSpec( dtype=quantization_config.output_activation.dtype, @@ -251,15 +251,11 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: quant_max=quantization_config.output_activation.quant_max, quant_min=quantization_config.output_activation.quant_min, observer_or_fake_quant_ctr=ConcatObserver.with_args( - # we need to know the concat node in order to hack all the input observers' data range - # since deep copy of fake tensor (node.meta["val"]) is inhibited - # we could only ship grap & node name and perform postprocess inside observer currently - **{ - "node_name": node.name, - "graph": node.graph, - } + node_name=node.name, + graph=node.graph, ), ) + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=output_qspec, @@ -295,6 +291,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default, ], @@ -1203,14 +1200,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name ) class PixelShuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( [torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name ) class PixelUnshuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index 60cebfcc5c0..167bdc7b140 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -11,7 +11,6 @@ import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants import torch - from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( ConcatObserver, ) @@ -181,15 +180,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map, input_nodes = {}, node.args[0] - for input in input_nodes: - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) + for input_node in input_nodes: + assert isinstance(input_node, Node) + input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None) qspec = getattr(input_qspec, "output_qspec", None) - # keep shared qspec here for propagation the data range - # without introducing extra requantizations if isinstance(qspec, SharedQuantizationSpec): - input_qspec_map[input] = SharedQuantizationSpec(input) + input_qspec_map[input_node] = SharedQuantizationSpec(input_node) else: - input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[input_node] = quantization_config.input_activation output_qspec = QuantizationSpec( dtype=quantization_config.output_activation.dtype, @@ -197,15 +195,11 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: quant_max=quantization_config.output_activation.quant_max, quant_min=quantization_config.output_activation.quant_min, observer_or_fake_quant_ctr=ConcatObserver.with_args( - # we need to know the concat node in order to hack all the input observers' data range - # since deep copy of fake tensor (node.meta["val"]) is inhibited - # we could only ship grap & node name and perform postprocess inside observer currently - **{ - "node_name": node.name, - "graph": node.graph, - } + node_name=node.name, + graph=node.graph, ), ) + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=output_qspec, @@ -223,6 +217,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default, ], @@ -705,14 +700,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name ) class PixelShuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( [torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name ) class PixelUnshuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 1f007628e61..e6fdaa0ddcf 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -2,17 +2,25 @@ import torch from executorch.backends.qualcomm._passes import ( + AnnotateConcatRequant, AnnotateQuantAttrs, ConvertBmmToMatmul, ConvertMhaToSha, FoldQDQ, InsertIOQDQ, + InsertRequantize, InsertReshapeForReduceOps, RemoveRedundancy, ) +from executorch.backends.qualcomm.builders.node_visitor import q_ops +from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( + ConcatObserver, +) from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.tests.models import TopKandIndex +from executorch.backends.qualcomm.tests.models import Cat2, TopKandIndex +from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.backends.qualcomm.utils.utils import ( generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -22,19 +30,25 @@ from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY from executorch.exir.dialects._ops import ops as exir_ops from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec class TestPasses(unittest.TestCase): - def _build_quantized_graph(self): - """Build a quantized graph through AnnotateQuantAttrs + FoldQDQ.""" + def _build_qdq_graph(self, module=None, sample_input=None): + """Build a quantized edge graph before Qualcomm capture passes.""" - class AddModule(torch.nn.Module): - def forward(self, x): - return x + 1 + if module is None: + + class AddModule(torch.nn.Module): + def forward(self, x): + return x + 1 - module = AddModule().eval() - sample_input = (torch.randn(1, 4),) + module = AddModule() + if sample_input is None: + sample_input = (torch.randn(1, 4),) + + module = module.eval() exported = torch.export.export(module, sample_input, strict=True).module() quantizer = QnnQuantizer() quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w) @@ -48,10 +62,35 @@ def forward(self, x): ep = edge_program.exported_program() gm = ep.graph_module + return gm, ep + + def _build_quantized_graph(self, module=None, sample_input=None): + """Build a quantized graph through AnnotateQuantAttrs + FoldQDQ.""" + + gm, ep = self._build_qdq_graph(module, sample_input) gm = AnnotateQuantAttrs(ep)(gm).graph_module gm = FoldQDQ(ep)(gm).graph_module return gm, ep + def _build_cat_qdq_graph(self): + sample_input = ( + torch.tensor([[[[-10.0, 2.0], [3.0, 4.0]]]]), + torch.tensor([[[[1.0, 3.0], [8.0, 10.0]]]]), + ) + return self._build_qdq_graph(Cat2(), sample_input) + + def _get_cat_output_q_and_input_qs(self, gm): + cat_output_q = next( + node + for node in gm.graph.nodes + if node.target in q_ops + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].target == exir_ops.edge.aten.cat.default + ) + cat_node = cat_output_q.args[0] + input_qs = [input_node.args[0] for input_node in cat_node.args[0]] + return cat_output_q, cat_node, input_qs + def test_insert_io_qdq_handles_dequant_encoding(self): """InsertIOQDQ should not KeyError when a node with a dequantize encoding feeds the output node (e.g. pre-quantized LLM parameters).""" @@ -102,6 +141,128 @@ def test_insert_io_qdq_no_revisit(self): # one quantize (input) and one dequantize (output) = +2 nodes. self.assertEqual(node_count_after, node_count_before + 2) + def test_skip_requantize_for_matched_cat_inputs(self): + gm, ep = self._build_cat_qdq_graph() + _, cat_node, input_qs = self._get_cat_output_q_and_input_qs(gm) + + gm = AnnotateQuantAttrs(ep)(gm).graph_module + gm = AnnotateConcatRequant(ep)(gm).graph_module + + for input_q in input_qs: + self.assertNotIn(QCOM_REQUANTIZE, input_q.args[0].meta) + + gm = FoldQDQ(ep)(gm).graph_module + gm = InsertRequantize()(gm).graph_module + + cat_node = next( + n for n in gm.graph.nodes if n.target == exir_ops.edge.aten.cat.default + ) + cat_inputs = cat_node.args[0] + to_copy_target = exir_ops.edge.aten._to_copy.default + + self.assertNotEqual(cat_inputs[0].target, to_copy_target) + self.assertNotEqual(cat_inputs[1].target, to_copy_target) + + def test_insert_requantize_for_mismatched_first_cat_input(self): + gm, ep = self._build_cat_qdq_graph() + _, cat_node, input_qs = self._get_cat_output_q_and_input_qs(gm) + + first_input_q = input_qs[0] + first_input_q.args = ( + first_input_q.args[0], + 0.05, + 0, + 0, + 255, + torch.uint8, + ) + + gm = AnnotateQuantAttrs(ep)(gm).graph_module + gm = AnnotateConcatRequant(ep)(gm).graph_module + + first_input_source = first_input_q.args[0] + second_input_source = input_qs[1].args[0] + self.assertIn(QCOM_REQUANTIZE, first_input_source.meta) + self.assertIn(cat_node.name, first_input_source.meta[QCOM_REQUANTIZE]) + self.assertNotIn(QCOM_REQUANTIZE, second_input_source.meta) + + gm = FoldQDQ(ep)(gm).graph_module + gm = InsertRequantize()(gm).graph_module + + cat_node = next( + n for n in gm.graph.nodes if n.target == exir_ops.edge.aten.cat.default + ) + cat_inputs = cat_node.args[0] + to_copy_target = exir_ops.edge.aten._to_copy.default + + self.assertEqual(cat_inputs[0].target, to_copy_target) + self.assertNotEqual(cat_inputs[1].target, to_copy_target) + + def test_insert_requantize_for_mismatched_second_cat_input(self): + gm, ep = self._build_cat_qdq_graph() + _, cat_node, input_qs = self._get_cat_output_q_and_input_qs(gm) + + second_input_q = input_qs[1] + second_input_q.args = ( + second_input_q.args[0], + 0.05, + 0, + 0, + 255, + torch.uint8, + ) + + gm = AnnotateQuantAttrs(ep)(gm).graph_module + gm = AnnotateConcatRequant(ep)(gm).graph_module + + first_input_source = input_qs[0].args[0] + second_input_source = second_input_q.args[0] + self.assertNotIn(QCOM_REQUANTIZE, first_input_source.meta) + self.assertIn(QCOM_REQUANTIZE, second_input_source.meta) + self.assertIn(cat_node.name, second_input_source.meta[QCOM_REQUANTIZE]) + + gm = FoldQDQ(ep)(gm).graph_module + gm = InsertRequantize()(gm).graph_module + + cat_node = next( + n for n in gm.graph.nodes if n.target == exir_ops.edge.aten.cat.default + ) + cat_inputs = cat_node.args[0] + to_copy_target = exir_ops.edge.aten._to_copy.default + + self.assertNotEqual(cat_inputs[0].target, to_copy_target) + self.assertEqual(cat_inputs[1].target, to_copy_target) + + def test_cat_annotation_uses_concat_observer_output_qspec(self): + sample_input = ( + torch.randn(1, 1, 4, 4), + torch.randn(1, 1, 4, 4), + ) + exported = torch.export.export( + Cat2().eval(), sample_input, strict=True + ).module() + quantizer = QnnQuantizer() + quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w) + prepared = prepare_pt2e(exported, quantizer) + + cat_node = next( + n for n in prepared.graph.nodes if n.target == torch.ops.aten.cat.default + ) + second_input_node = cat_node.args[0][1] + if second_input_node not in cat_node.meta[Q_ANNOTATION_KEY].input_qspec_map: + second_input_node = second_input_node.args[0] + + output_qspec = cat_node.meta[Q_ANNOTATION_KEY].output_qspec + self.assertNotIsInstance( + output_qspec, + SharedQuantizationSpec, + ) + self.assertIs(output_qspec.observer_or_fake_quant_ctr.p.func, ConcatObserver) + self.assertNotIsInstance( + cat_node.meta[Q_ANNOTATION_KEY].input_qspec_map[second_input_node], + SharedQuantizationSpec, + ) + def test_insert_reshape_for_argmax(self): class ArgmaxModule(torch.nn.Module): def forward(self, x): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8fc4fd4e6a1..07d04d88c95 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -8,6 +8,7 @@ import itertools import json import logging +import operator import subprocess import sys import tempfile @@ -33,11 +34,17 @@ make_quantizer, setup_common_args_and_variables, ) +from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( + ConcatObserver, +) +from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, QnnExecuTorchHtpPerformanceMode, QnnExecuTorchLpaiTargetEnv, ) + +from executorch.backends.qualcomm.tests.models import Cat2 from executorch.backends.qualcomm.tests.utils import ( convert_pt2e, generate_context_binary, @@ -72,7 +79,6 @@ to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) - from executorch.backends.qualcomm.tests.models import * # noqa: F403 import os @@ -97,6 +103,7 @@ from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec class TestQNNFloatingPointOperator(TestQNN): @@ -414,13 +421,6 @@ def test_qnn_backend_conv1d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) - def test_qnn_conv1d_batch_norm(self): - modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405 - sample_input = (torch.randn([1, 2048, 858]),) - for i, module in enumerate(modules): - with self.subTest(i=i): - self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -1688,12 +1688,16 @@ def test_qnn_backend_permute(self): def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 4, 3, 3]),) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pixel_unshuffle(self): module = PixelUnshuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 2, 6, 6]),) + sample_input = ( + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), + ) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): @@ -2799,6 +2803,20 @@ def test_qnn_backend_cat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cat_wide_range_inputs(self): + module = self.get_qdq_module( + Cat2(), + ( + torch.tensor([[[[-10.0, 2.0], [3.0, 4.0]]]]), + torch.tensor([[[[1.0, 3.0], [8.0, 10.0]]]]), + ), + ) + sample_input = ( + torch.tensor([[[[-10.0, 2.0], [3.0, 4.0]]]]), + torch.tensor([[[[1.0, 3.0], [8.0, 10.0]]]]), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): module = CDist() # noqa: F405 sample_input = ( @@ -2836,14 +2854,6 @@ def test_qnn_backend_conv1d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_conv1d_batch_norm(self): - modules = [Conv1dBn(), Conv1dBn(bias=False)] # noqa: F405 - sample_input = (torch.randn([1, 2048, 858]),) - for i, module in enumerate(modules): - with self.subTest(i=i): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -4229,16 +4239,213 @@ def test_qnn_backend_permute(self): def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 4, 3, 3]),) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pixel_unshuffle(self): module = PixelUnshuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 2, 6, 6]),) + sample_input = ( + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), + ) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def _prepare_module_for_qparam_assertions(self, module, sample_input): + backend = get_backend_type(self.backend) + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + per_channel_embedding=False, + backend=backend, + soc_model=self.soc_model, + ) + return prepare_pt2e( + torch.export.export(module, sample_input, strict=True).module(), + quantizer, + ) + + def _assert_prepared_nodes_share_qparams( + self, module, sample_input, target_tokens + ) -> list[torch.fx.Node]: + prepared = self._prepare_module_for_qparam_assertions(module, sample_input) + matching_nodes = [ + node + for node in prepared.graph.nodes + if node.op == "call_function" + and any(target_token in str(node.target) for target_token in target_tokens) + ] + + self.assertGreater( + len(matching_nodes), + 0, + f"Failed to find node matching any of {target_tokens}", + ) + for node in matching_nodes: + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + + return matching_nodes + + def test_qnn_backend_pixel_shuffle_unshuffle_share_qparams(self): + test_cases = [ + ( + "pixel_shuffle", + PixelShuffle(2), # noqa: F405 + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), + torch.ops.aten.pixel_shuffle.default, + ), + ( + "pixel_unshuffle", + PixelUnshuffle(2), # noqa: F405 + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), + torch.ops.aten.pixel_unshuffle.default, + ), + ] + + for name, module, sample_input, target in test_cases: + with self.subTest(name=name): + prepared = self._prepare_module_for_qparam_assertions( + module, sample_input + ) + for node in prepared.graph.nodes: + if node.op == "call_function" and node.target == target: + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + break + else: + self.fail(f"Failed to find {target} in prepared graph") + + def test_qnn_backend_value_preserving_ops_share_qparams(self): + test_cases = [ + ( + "channel_shuffle", + ChannelShuffle(2), # noqa: F405 + (torch.randn(1, 4, 3, 3),), + ("aten.channel_shuffle",), + ), + ( + "permute", + Permute([0, 2, 3, 1]), # noqa: F405 + (torch.randn(2, 3, 4, 5),), + ("aten.permute",), + ), + ( + "pixel_shuffle", + PixelShuffle(2), # noqa: F405 + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), + ("aten.pixel_shuffle",), + ), + ( + "pixel_unshuffle", + PixelUnshuffle(2), # noqa: F405 + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), + ("aten.pixel_unshuffle",), + ), + ( + "repeat", + Repeat(), # noqa: F405 + (torch.randn(2, 2, 2, 2),), + ("aten.repeat",), + ), + ( + "expand_as", + ExpandAs(), # noqa: F405 + (torch.randn(3, 4),), + ("aten.expand",), + ), + ( + "reshape", + Reshape(), # noqa: F405 + (torch.randn(3, 4),), + ("aten.reshape", "aten.view"), + ), + ] + + for name, module, sample_input, target_tokens in test_cases: + with self.subTest(name=name): + self._assert_prepared_nodes_share_qparams( + module, sample_input, target_tokens + ) + + def test_qnn_backend_split_with_sizes_copy_share_qparams(self): + class SplitWithSizesCopy(torch.nn.Module): + def forward(self, x): + out = torch.ops.aten.split_with_sizes_copy.default(x, [2, 2], 1) + return out[0] + out[1] + + backend = get_backend_type(self.backend) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + per_channel_embedding=False, + backend=backend, + soc_model=self.soc_model, + ) + prepared = prepare_pt2e( + torch.export.export( + SplitWithSizesCopy(), sample_input, strict=True + ).module(), + quantizer, + ) + + getitem_count = 0 + for node in prepared.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].target == torch.ops.aten.split_with_sizes_copy.default + ): + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + getitem_count += 1 + + self.assertGreater(getitem_count, 0) + + def test_qnn_backend_cat_uses_concat_observer_output_qspec(self): + sample_input = ( + torch.arange(2 * 3 * 4 * 5, dtype=torch.float32).reshape(2, 3, 4, 5), + torch.arange(2 * 3 * 4 * 5, dtype=torch.float32).reshape(2, 3, 4, 5), + ) + prepared = self._prepare_module_for_qparam_assertions(Cat2(), sample_input) + + for node in prepared.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.cat.default: + output_qspec = node.meta[Q_ANNOTATION_KEY].output_qspec + self.assertNotIsInstance( + output_qspec, + SharedQuantizationSpec, + ) + self.assertIs( + output_qspec.observer_or_fake_quant_ctr.p.func, + ConcatObserver, + ) + second_input_node = node.args[0][1] + if second_input_node not in node.meta[Q_ANNOTATION_KEY].input_qspec_map: + second_input_node = second_input_node.args[0] + self.assertNotIsInstance( + node.meta[Q_ANNOTATION_KEY].input_qspec_map[second_input_node], + SharedQuantizationSpec, + ) + break + else: + self.fail("Failed to find aten.cat.default in prepared graph") + def test_qnn_backend_pow_tensor_scalar(self): test_comb = [ { @@ -7371,11 +7578,6 @@ class MLLMSpecs: tok_embedding_pte_size: float decoder_pte_size: float - @dataclass(frozen=True) - class ALMSpecs(MLLMSpecs): - audio_path: str - golden_audio_feature: str - @dataclass(frozen=True) class VLMSpecs(MLLMSpecs): image_path: str @@ -7383,18 +7585,6 @@ class VLMSpecs(MLLMSpecs): # TODO: refactor to support different backends def setUp(self): - self.alm_specs = { - "granite_speech_3_3-2b": TestExampleMultimodalityScript.ALMSpecs( - max_seq_len=512, - sm8650_token_rate=5, - sm8750_token_rate=8, - encoder_pte_size=900_000_000, # 900MB - tok_embedding_pte_size=240_000_000, # 240MB - decoder_pte_size=3_000_000_000, # 3GB - audio_path="https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true", # Audio content: after his nap,... - golden_audio_feature="after his nap,", - ), - } self.vlm_specs = { "smolvlm_500m_instruct": TestExampleMultimodalityScript.VLMSpecs( max_seq_len=128, @@ -7418,96 +7608,6 @@ def setUp(self): ), } - def test_static_asr(self): - if not self.required_envs([self.model_name]): - self.skipTest("missing required envs") - - if self.enable_x86_64: - # Running on host is extremely slow for large models, so we skip this check to avoid timeouts. - # Please verify the output on the actual device instead. - self.skipTest( - "Skipping the check for the static ASR model on x86 due to long execution time." - ) - - alm_specs: TestExampleMultimodalityScript.ALMSpecs = self.alm_specs[ - self.model_name - ] - prompt = "can you transcribe the speech into a written format?" - audio_path = alm_specs.audio_path - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--soc_model", - self.soc_model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - prompt, - "--audio_path", - audio_path, - "--temperature", - "0", - "--decoder_model", - f"{self.model_name}", - "--model_mode", - "kv", - "--max_seq_len", - f"{alm_specs.max_seq_len}", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - if not self.compile_only: - model_out = msg["result"][0] - self.assertTrue( - alm_specs.golden_audio_feature in model_out.lower(), - f"Expected Output contains feature: '{alm_specs.golden_audio_feature}' Actual Output: '{model_out}'", - ) - print(f"Audio Path: {audio_path}") - print(f"Query: {prompt}") - print(f"Answer: {model_out}") - - encoder_pte_size = msg["audio_encoder_pte_size"] - tok_embedding_pte_size = msg["tok_embedding_pte_size"] - decoder_pte_size = msg["pte_size"] - self.assertLessEqual(encoder_pte_size, alm_specs.encoder_pte_size) - self.assertLessEqual( - tok_embedding_pte_size, alm_specs.tok_embedding_pte_size - ) - self.assertLessEqual(decoder_pte_size, alm_specs.decoder_pte_size) - print(f"Encoder PTE Size: {encoder_pte_size} bytes") - print(f"Token Embedding PTE Size: {tok_embedding_pte_size} bytes") - print(f"Text Decoder PTE Size: {decoder_pte_size} bytes") - - attr_name = f"{self.soc_model.lower()}_token_rate" - if not self.compile_only and hasattr(alm_specs, attr_name): - device_inference_speed = msg["inference_speed"] - expected_inference_speed = getattr(alm_specs, attr_name) - print(f"Prompt Evaluation: {device_inference_speed} tokens/second") - self.assertGreaterEqual( - device_inference_speed, expected_inference_speed - ) - def test_static_vlm(self): if not self.required_envs([self.model_name]): self.skipTest("missing required envs") @@ -7572,7 +7672,7 @@ def test_static_vlm(self): print(f"Query: {prompt}") print(f"Answer: {model_out}") if not self.enable_x86_64: - encoder_pte_size = msg["vision_encoder_pte_size"] + encoder_pte_size = msg["encoder_pte_size"] tok_embedding_pte_size = msg["tok_embedding_pte_size"] decoder_pte_size = msg["pte_size"] self.assertLessEqual(encoder_pte_size, vlm_specs.encoder_pte_size) diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 5b531fb27c7..7bf2c6aac71 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -59,6 +59,11 @@ DEFINE_string( "etdump.etdp", "If etdump generation is enabled an etdump will be written out to this path"); +DEFINE_bool( + enable_etdump, + true, + "Enable ETDump event tracing. Disable for cleaner latency benchmarking."); + DEFINE_bool( dump_intermediate_outputs, false, @@ -385,8 +390,11 @@ int main(int argc, char** argv) { // be used by a single thread at at time, but it can be reused. // ETDumpGen etdump_gen; + auto* event_tracer = (FLAGS_enable_etdump || FLAGS_dump_intermediate_outputs) + ? &etdump_gen + : nullptr; Result method = - program->load_method(method_name, &memory_manager, &etdump_gen); + program->load_method(method_name, &memory_manager, event_tracer); ET_CHECK_MSG( method.ok(), "Loading of method %s failed with status 0x%" PRIx32, @@ -694,7 +702,7 @@ int main(int argc, char** argv) { // Dump the etdump data containing profiling/debugging data to the specified // file. ETDumpResult result = etdump_gen.get_etdump_data(); - if (result.buf != nullptr && result.size > 0) { + if (FLAGS_enable_etdump && result.buf != nullptr && result.size > 0) { ET_LOG( Info, "Write etdump to %s, Size = %zu",