Skip to content

[xnnpack] Conv2d on valid C=1 channels_last input fails with static tensor resize mismatch #19153

@wuyii8941

Description

@wuyii8941

🐛 Describe the bug

Bug

ExecuTorch XNNPACK lowers a minimal Conv2d(1, 5, kernel_size=1) model successfully, but fails at runtime when the input tensor uses PyTorch channels_last memory format and has a singleton channel dimension.

No reduction, flatten, or portable fallback op is required to trigger the failure.

Reproducer

import importlib.metadata

import torch
from torch.export import export

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(3001)
        self.conv = torch.nn.Conv2d(1, 5, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


print("torch:", torch.__version__)
print("executorch:", importlib.metadata.version("executorch"))

model = Model().eval()
x = torch.randn(1, 1, 7, 5).to(memory_format=torch.channels_last)

print("shape:", tuple(x.shape))
print("stride:", tuple(x.stride()))
print("is_contiguous:", x.is_contiguous())
print("is_channels_last:", x.is_contiguous(memory_format=torch.channels_last))

ref = model(x)
ep = export(model, (x,), strict=True)
et = to_edge_transform_and_lower(
    ep,
    partitioner=[XnnpackPartitioner()],
).to_executorch()

mod = _load_for_executorch_from_buffer(et.buffer)
out = mod.run_method("forward", (x,))[0]
print("max diff:", (out - ref).abs().max().item())

Actual Behavior

Lowering succeeds, but runtime execution fails:

[tensor_impl.cpp:105] Attempted to resize a static tensor.
Expected shape (1, 5, 7, 5), but received (1, 5, 1, 5).
[XNNExecutor.cpp:239] Failed to resize output tensor for XNNExecutor
[method.cpp:1432] CALL_DELEGATE execute failed at instruction 0: 0x10
RuntimeError: Failed to execute method forward, error: 0x10

Expected Behavior

The program should either:

  1. run successfully and match PyTorch eager output,
  2. fail during XNNPACK lowering / validation with a clear unsupported-layout diagnostic.

Control Case

Changing only the input stride layout to the default contiguous stride makes the same model pass:

x = torch.randn(1, 1, 7, 5)

Observed result:

runtime: ok
output shape: (1, 5, 7, 5)
max_diff: 0.0

Notes

The failing input is valid in PyTorch:

shape: (1, 1, 7, 5)
stride: (35, 1, 5, 1)
is_contiguous: True
is_channels_last: True

This is an ambiguous layout case because the channel dimension is 1, so the same tensor is both contiguous and channels_last according to PyTorch. However, the exported/lowered program succeeds and then the XNNPACK runtime appears to resize the output as if the height dimension were 1:

expected: (1, 5, 7, 5)
received: (1, 5, 1, 5)

Local scouting suggests the failure occurs when C=1, H>1, and W>1. Cases with H=1 or W=1 passed, and C=2 channels_last controls passed.

Versions

Environment

  • torch: 2.11.0+cu130
  • executorch: 1.2.0
  • Python: 3.11
  • Platform: Linux x86_64
  • Backend: XNNPACK
  • API path:
    • torch.export.export
    • executorch.exir.to_edge_transform_and_lower(..., partitioner=[XnnpackPartitioner()])
    • .to_executorch()
    • Python runtime _load_for_executorch_from_buffer

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions