Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from .annotate_avg_pool1d import AnnotateAvgPool1D
from .annotate_concat_requant import AnnotateConcatRequant
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
Expand Down Expand Up @@ -60,6 +61,7 @@

__all__ = [
AnnotateAvgPool1D,
AnnotateConcatRequant,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand Down
97 changes: 97 additions & 0 deletions backends/qualcomm/_passes/annotate_concat_requant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.utils.constants import (
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import get_quant_attrs


EDGE_CAT_OPS = {
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.concat.default,
}


class AnnotateConcatRequant(ExportPass):
"""
Record explicit requantization needs for concat inputs whose concrete
post-calibration qparams do not match concat's output domain.
"""

def __init__(
self,
edge_program: torch.export.ExportedProgram,
skip_advanced_requant: bool = False,
):
super(AnnotateConcatRequant, self).__init__()
self.edge_program = edge_program
self.skip_advanced_requant = skip_advanced_requant

def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]):
if self.skip_advanced_requant:
return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE]

return any(
src_attrs[attr] != dst_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
)

def _annotate_concat_input_requant(self, quant_node: torch.fx.Node) -> None:
cat_node = quant_node.args[0]
if cat_node.target not in EDGE_CAT_OPS:
return

output_q_attrs = get_quant_attrs(self.edge_program, quant_node)
for input_node in cat_node.args[0]:
if input_node.target not in dq_ops:
continue

source_q_node = input_node.args[0]
if source_q_node.target not in q_ops:
continue

source_q_attrs = get_quant_attrs(self.edge_program, source_q_node)
if not self._is_requant_needed(source_q_attrs, output_q_attrs):
continue

source_node = source_q_node.args[0]
if not isinstance(source_node, torch.fx.Node):
continue

requant_attrs = output_q_attrs.copy()
requant_attrs[QCOM_ENCODING] = source_q_attrs[QCOM_ENCODING]
source_node.meta.setdefault(QCOM_REQUANTIZE, {})
source_node.meta[QCOM_REQUANTIZE][cat_node.name] = requant_attrs

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if (
node.target in q_ops
and isinstance(node.args[0], torch.fx.Node)
and node.args[0].target in EDGE_CAT_OPS
):
self._annotate_concat_input_requant(node)
return PassResult(graph_module, True)
38 changes: 16 additions & 22 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:

return last_dq_nodes

def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]):
if self.skip_advanced_requant:
return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE]

return any(
src_attrs[attr] != dst_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
)

def _annotate_requant(self, n):
# Record requant attributes:
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
Expand All @@ -96,28 +111,7 @@ def _annotate_requant(self, n):
# that has multiple outputs that requires quant attributes.

# Determine if requantization is needed based on configuration and attribute mismatch.
is_requant_needed = False
if self.skip_advanced_requant:
# In skip_advanced_requant mode, only consider requant if dtypes differ.
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
is_requant_needed = True
else:
# In full requant mode, consider requant if any key attribute differs.
# This aims to improve accuracy by adjusting scale, zero_point, etc.
# Users can disable this if it causes regressions.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
):
is_requant_needed = True

if is_requant_needed:
if self._is_requant_needed(q_attrs, dq_attrs):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from executorch.backends.qualcomm._passes import (
AnnotateAvgPool1D,
AnnotateConcatRequant,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand Down Expand Up @@ -99,6 +100,7 @@ def get_capture_program_passes():
default_passes_and_setting = [
(AnnotateAvgPool1D, True),
(AnnotateQuantAttrs, True),
(AnnotateConcatRequant, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, False),
Expand Down
10 changes: 9 additions & 1 deletion backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_passes_dependency_for_capture_program():
"""
from executorch.backends.qualcomm._passes import (
AnnotateAvgPool1D,
AnnotateConcatRequant,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand Down Expand Up @@ -89,6 +90,7 @@ def get_passes_dependency_for_capture_program():

return {
AnnotateAvgPool1D: [RemoveRedundancy],
AnnotateConcatRequant: [AnnotateQuantAttrs],
AnnotateQuantAttrs: [
ConvertBmmToMatmul,
RecomposePixelUnshuffle,
Expand All @@ -108,9 +110,15 @@ def get_passes_dependency_for_capture_program():
DecomposeTrunc: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
FoldQDQ: [
AnnotateConcatRequant,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
],
I64toI32: [RemoveRedundancy],
LayoutTransform: [
AnnotateConcatRequant,
AnnotateQuantAttrs,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down
37 changes: 21 additions & 16 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants
import torch

from executorch.backends.qualcomm.quantizer.observers.concat_observer import (
ConcatObserver,
)
Expand Down Expand Up @@ -235,31 +234,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
return

input_qspec_map, input_nodes = {}, node.args[0]
for input in input_nodes:
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
for input_node in input_nodes:
assert isinstance(input_node, Node)
input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None)
qspec = getattr(input_qspec, "output_qspec", None)
# keep shared qspec here for propagation the data range
# without introducing extra requantizations
# Preserve shared upstream qspecs, but derive concat's output domain
# from the merged output range to avoid clipping wider branches.
if isinstance(qspec, SharedQuantizationSpec):
input_qspec_map[input] = SharedQuantizationSpec(input)
input_qspec_map[input_node] = SharedQuantizationSpec(input_node)
else:
input_qspec_map[input] = quantization_config.input_activation
input_qspec_map[input_node] = quantization_config.input_activation

output_qspec = QuantizationSpec(
dtype=quantization_config.output_activation.dtype,
qscheme=quantization_config.output_activation.qscheme,
quant_max=quantization_config.output_activation.quant_max,
quant_min=quantization_config.output_activation.quant_min,
observer_or_fake_quant_ctr=ConcatObserver.with_args(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this change it reverting what this PR is doing: #15162.
The reason #15162 is introduced is because the input[0] could not cover the entire range of values for concat output, so a lot of output values were clipped.

If you have 2 input tensors like:
sample_input = ( torch.tensor([[[[-10.0, 2.0], [3.0, 4.0]]]]), torch.tensor([[[[1.0, 3.0], [8.0, 10.0]]]]), )
and after it goes through cat operation, you will be getting the wrong value with this PR.
[tensor([[[[-9.9798, 1.9849], [ 2.9774, 4.0250], [ 1.0476, 3.0325], [ 4.0802, 4.0802]]]])]
I have a demo PR to reproduce this error, please have a look:
#19182

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @winskuo-quic for the detailed review, i agree that it might've worked for the model but might not work when the ranges skewed like in your example. Let me revert the cat to concatobserver and test the accuracy.

# we need to know the concat node in order to hack all the input observers' data range
# since deep copy of fake tensor (node.meta["val"]) is inhibited
# we could only ship grap & node name and perform postprocess inside observer currently
**{
"node_name": node.name,
"graph": node.graph,
}
node_name=node.name,
graph=node.graph,
),
)

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_qspec,
Expand Down Expand Up @@ -295,6 +291,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
@register_annotator(
[
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.split_with_sizes_copy.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.chunk.default,
],
Expand Down Expand Up @@ -1203,14 +1200,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
)
class PixelShuffle(GeneralOpDef):
pass
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_share_out(node, quantization_config)


@register_annotator(
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
)
class PixelUnshuffle(GeneralOpDef):
pass
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_share_out(node, quantization_config)


@register_annotator(
Expand Down
35 changes: 19 additions & 16 deletions backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants
import torch

from executorch.backends.qualcomm.quantizer.observers.concat_observer import (
ConcatObserver,
)
Expand Down Expand Up @@ -181,31 +180,26 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
return

input_qspec_map, input_nodes = {}, node.args[0]
for input in input_nodes:
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
for input_node in input_nodes:
assert isinstance(input_node, Node)
input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None)
qspec = getattr(input_qspec, "output_qspec", None)
# keep shared qspec here for propagation the data range
# without introducing extra requantizations
if isinstance(qspec, SharedQuantizationSpec):
input_qspec_map[input] = SharedQuantizationSpec(input)
input_qspec_map[input_node] = SharedQuantizationSpec(input_node)
else:
input_qspec_map[input] = quantization_config.input_activation
input_qspec_map[input_node] = quantization_config.input_activation

output_qspec = QuantizationSpec(
dtype=quantization_config.output_activation.dtype,
qscheme=quantization_config.output_activation.qscheme,
quant_max=quantization_config.output_activation.quant_max,
quant_min=quantization_config.output_activation.quant_min,
observer_or_fake_quant_ctr=ConcatObserver.with_args(
# we need to know the concat node in order to hack all the input observers' data range
# since deep copy of fake tensor (node.meta["val"]) is inhibited
# we could only ship grap & node name and perform postprocess inside observer currently
**{
"node_name": node.name,
"graph": node.graph,
}
node_name=node.name,
graph=node.graph,
),
)

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_qspec,
Expand All @@ -223,6 +217,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
@register_annotator(
[
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.split_with_sizes_copy.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.chunk.default,
],
Expand Down Expand Up @@ -705,14 +700,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
)
class PixelShuffle(GeneralOpDef):
pass
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_share_out(node, quantization_config)


@register_annotator(
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
)
class PixelUnshuffle(GeneralOpDef):
pass
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_share_out(node, quantization_config)


@register_annotator(
Expand Down
Loading
Loading