Skip to content

[torch.compile][PyTorch] Prepare linear for torch compile#2967

Merged
pggPL merged 12 commits into
NVIDIA:mainfrom
pggPL:prepare_linear_for_torch_compile
May 11, 2026
Merged

[torch.compile][PyTorch] Prepare linear for torch compile#2967
pggPL merged 12 commits into
NVIDIA:mainfrom
pggPL:prepare_linear_for_torch_compile

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 7, 2026

Description

Refactor of transformer_engine/pytorch/module/linear.py to lift the
Linear module into a shape that can be wrapped in a
torch.library.custom_op (and matching backward op) in a follow-up PR.

Type of change

  • Code refactoring

Changes

  • Pack forward/backward state into LinearFwdArgs / LinearBwdArgs
    dataclasses.
    _linear_forward_impl, _linear_setup_ctx,
    _linear_backward and the _Linear.{forward,backward} autograd
    methods all take a single structured argument instead of 25+
    positional ones. A custom op requires a fully-declared signature on
    both sides; the previous pattern of writing arbitrary
    ctx.something = ... attributes scattered throughout forward made
    it impossible to tell from the call site what state backward
    actually consumes. The dataclasses make the read/write contract
    explicit and grep-able. Concretely, things that used to be re-queried
    from tensor objects (input.requires_grad, weight.requires_grad,
    bias.requires_grad) are now captured up front as
    input_requires_grad / weight_requires_grad /
    bias_requires_grad and consumed as requires_dgrad /
    requires_wgrad in backward — backward no longer has to assume the
    Python tensor objects survive the op boundary.
  • Move prepare_for_saving / ctx.save_for_backward to the autograd
    boundary.
    _linear_forward_impl returns the raw tensors it
    produced (along with tensors_to_save_from_forward aliases) and
    _linear_setup_ctx returns the raw merged tensor list it wants
    saved; _Linear.forward is the only place that actually calls
    prepare_for_saving(*tensors_to_save_from_setup) and
    ctx.save_for_backward(...). This shape fell out of the compile-path
    experiments: under torch.library.register_autograd, the
    setup_context callback is the only legal place to call
    save_for_backward, and the helper has to hand back tensors rather
    than mutate the autograd ctx itself. Same contract is now used in
    eager so a single helper serves both modes.
  • Deduplicate saved tensors that alias forward inputs. Save-slots
    that would alias inp / weight / bias are emitted as None and
    reconstructed in _linear_setup_ctx from the original refs. An
    opaque custom op cannot return aliases of its inputs (the tracer
    has no way to reason about the aliasing).
  • Minimize the ctx_attrs blob plumbed from forward to backward
    setup.
    Anything that can be re-derived from LinearFwdArgs
    (weight_quantizer, is_fsdp2, owns_input) is recomputed in
    _linear_setup_ctx. The compile path needs the fake forward impl
    to return a structurally identical ctx_attrs, so a smaller surface
    is a smaller cross-impl contract.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 10 commits May 8, 2026 00:17
Three small refactors that make the module easier to reason about
and pave the way for the dataclass / saved-tensor refactors:

- Add a TensorOrQuantized type alias (Union[Tensor, QuantizedTensorStorage])
  used pervasively in helper signatures.
- Hoist the conditional bias argument into a local linear_bias_tensor
  variable instead of an inline expression at the linear_fn() call site.
- Only forward self.wgrad_store into the autograd Function when it is
  actually active (delay_wgrad_compute() is True); pass None otherwise so
  the autograd graph does not carry an unused Python object.

Pure rename / hoisting; no behavioural change.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Replace the loosely typed ``non_tensor_args`` tuple and the ad-hoc
``ctx.<attr>`` plumbing with two dataclasses, ``LinearFwdArgs`` and
``LinearBwdArgs``, that act as the single argument to every helper
in the forward/backward pipeline.

What changes:

* ``LinearFwdArgs`` carries the (positional) tensors ``weight``, ``inp``
  and ``bias`` plus all quantizers, ``requires_grad`` flags, the cached
  ``weight_workspace`` and every former ``non_tensor_args`` knob.
  ``_Linear.forward`` still takes ``weight/inp/bias`` as positional
  Tensor inputs so autograd tracks them, then immediately re-attaches
  them to ``fwd_args`` so every downstream helper has a single-argument
  signature.
* ``LinearBwdArgs`` mirrors that on the backward side: it owns the
  saved tensors (``inputmat``, ``weight_fp8``, ``saved_weight``,
  ``bias``), the per-call quantizers, every flag previously stored
  directly on ``ctx`` and a ``setup_saved_tensors(saved_tensors,
  tensor_objects)`` helper that rehydrates the saved-tensor fields.
* ``ctx.backward_objects = bwd_args`` is now the single attribute the
  autograd context needs (besides ``saved_tensors``/``tensor_objects``).
* ``weight_workspace`` is no longer a positional Tensor arg of the
  autograd Function; it is read from ``fwd_args.weight_workspace`` and
  the freshly produced workspace is returned alongside ``out`` so the
  module can refresh its cache without autograd tracking the cache.
* ``prepare_for_saving`` now lives at the autograd boundary in
  ``_Linear.forward``; ``_linear_setup_ctx`` only returns the merged
  list of tensors that should be saved.
* ``grad_output_preprocess`` is invoked with ``bwd_args`` directly
  (it is duck-typed on the same attribute names) so backward never
  reaches into ``ctx.<attr>`` for non-tensor state.

Behaviour preserved (verified numerically against ``torch.nn.Linear``
and on FP8 + workspace-cache paths).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
When ``saved_inputmat is inp``, ``wt_save is weight`` or ``bias`` is the
exact bias passed in, there is no point asking ``prepare_for_saving`` to
serialize the same Python object twice. Make ``_linear_forward_impl``
emit ``None`` in those slots (and a parallel ``saved_tensor_aliases``
tuple in ``ctx_attrs`` describing which slot points where), and have
``_linear_setup_ctx`` rebuild the tuple with the original references
before handing it to ``prepare_for_saving``.

Saves a Python ref per alias in eager and, more importantly, keeps the
forward helper from "returning" a tensor that aliases its own inputs --
a pattern ``torch.compile`` would otherwise need to reason about when
the helper is wrapped in an opaque op.

Numerically equivalent (validated against ``torch.nn.Linear`` and on a
multi-iteration FP8 path with workspace caching).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Follow-up cleanups on top of the dataclass refactor:

* Sort ``LinearFwdArgs`` / ``LinearBwdArgs`` fields into labelled groups
  (tensors, requires_grad flags, quantizers, dtype/numerical config,
  parallelism, userbuffers, FSDP, wgrad scheduling, misc) and mirror that
  ordering in their construction sites.
* Add ``slots=True`` to both dataclasses so typos in
  ``fwd_args.X`` / ``bwd_args.X`` raise ``AttributeError`` immediately
  instead of silently creating a new attribute.
* Inline single-use ``args.X`` aliases in ``_linear_forward_impl``
  (``weight_workspace``, ``fp8_calibration``, ``tp_size``,
  ``tensor_parallel``, ``cache_weight``, ``skip_fp8_weight_update``,
  ``custom``, ``backward_input_needs_gather``) so the prelude only keeps
  aliases that are actually reused.
* Shrink ``ctx_attrs`` to ``{fsdp_shapes, saved_tensor_aliases}``:
  ``weight_quantizer`` is re-derived in ``_linear_setup_ctx`` from
  ``fwd_args.weight`` (matching the resolution done in forward),
  ``is_fsdp2`` already lives on ``fwd_args``, and ``owns_input`` is
  equivalent to ``saved_tensor_aliases[0] != "inp"``.
* Replace ``setup_saved_tensors(saved_tensors, tensor_objects)`` with
  ``setup_saved_tensors(ctx)`` backed by ``restore_from_func_ctx``,
  matching ``layernorm_mlp`` / ``layernorm_linear`` /
  ``grouped_linear`` and dropping the manual
  ``ctx.tensor_objects = None`` cleanup.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
After packing the Linear backward state into ``LinearBwdArgs`` the
attributes the test was reading (``backward_override``, ``fp8``,
``grad_output_quantizer``, ``reduce_and_update_bwd_fp8_tensors``) no
longer live directly on ``grad_fn``. Read them from
``grad_fn.backward_objects`` when present, falling back to ``grad_fn``
for the linear-like modules that have not been refactored yet
(``layernorm_linear``, ``ops_linear``).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Restore the one-line class docstrings dropped during the field
reorganization so pylint stops warning about C0115.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL marked this pull request as ready for review May 8, 2026 11:42
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR is a large refactor of transformer_engine/pytorch/module/linear.py in preparation for torch.compile support. It packs all forward/backward state into LinearFwdArgs / LinearBwdArgs dataclasses, moves prepare_for_saving / ctx.save_for_backward to the autograd boundary, deduplicates saved tensors that alias forward inputs via an alias map, and shrinks the ctx_attrs blob crossing the forward/backward boundary.

  • LinearFwdArgs / LinearBwdArgs dataclasses replace the ~25-element positional-argument tuples, making the forward→backward data contract explicit and static; the _Linear.forward signature shrinks to (weight, inp, bias, fwd_args).
  • Saved-tensor deduplication uses a saved_tensor_aliases tuple to mark which save slots alias forward inputs; _linear_setup_ctx reconstructs those refs instead of saving them, which is required for the opaque-custom-op tracer.
  • ctx.backward_objects stores the populated LinearBwdArgs on the autograd ctx so backward can retrieve it; the reference is nulled after use to prevent tensor lifetime extension under retain_graph.

Confidence Score: 4/5

The refactor is internally consistent and the new dataclass contract is well-structured, but the backward path has a known crash when the same graph node is traversed a second time.

The forward→backward data handoff via LinearFwdArgs/LinearBwdArgs is clean and the alias-deduplication logic is correct. However, ctx.backward_objects is nulled after the first backward call; a second backward on the same node (retain_graph=True) immediately dereferences None and raises AttributeError, which is a real behavioral regression from the prior flat-ctx approach.

transformer_engine/pytorch/module/linear.py — specifically the _Linear.backward method and the ctx.backward_objects lifecycle under retain_graph.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Core refactor file: introduces LinearFwdArgs/LinearBwdArgs dataclasses, restructures forward/backward plumbing, and changes weight_quantizer derivation in _linear_setup_ctx to use weight._quantizer for QuantizedTensor weights. The retain_graph second-backward crash (ctx.backward_objects=None after first backward) remains present.
tests/pytorch/test_backward_override.py Test helper updated to look up backward state via grad_fn.backward_objects (LinearBwdArgs) with fallback to grad_fn; getattr default does not trigger when backward_objects exists but is None, so a post-backward call yields a confusing error instead of the fallback.

Sequence Diagram

sequenceDiagram
    participant L as Linear.forward
    participant FA as LinearFwdArgs
    participant FI as _linear_forward_impl
    participant SC as _linear_setup_ctx
    participant BA as LinearBwdArgs
    participant CTX as autograd ctx
    participant BW as _Linear.backward
    participant BI as _linear_backward

    L->>FA: construct(weight, inp, bias, ...)
    L->>FI: _linear_forward_impl(fwd_args)
    FI-->>L: out, new_wks, tensors_to_save_from_forward, ctx_attrs
    L->>BA: LinearBwdArgs()
    L->>SC: _linear_setup_ctx(bwd_args, fwd_args, out, ctx_attrs, tensors)
    SC-->>BA: populate all backward fields
    SC-->>L: (saved_inputmat, wt_save, saved_weight, saved_bias)
    L->>CTX: "prepare_for_saving(*tensors)"
    L->>CTX: ctx.save_for_backward(...)
    L->>CTX: "ctx.backward_objects = bwd_args"

    BW->>CTX: "bwd_args = ctx.backward_objects"
    BW->>BA: "bwd_args.grad_output = grad_output"
    BW->>BA: bwd_args.setup_saved_tensors(ctx)
    BW->>BI: _linear_backward(bwd_args)
    BI-->>BW: (wgrad, dgrad, grad_bias)
    BW->>CTX: "ctx.backward_objects = None"
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into prepare_linear_..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/linear.py
Comment thread transformer_engine/pytorch/module/linear.py
Saved tensors, quantizers, weakrefs and main_grad closures referenced
from LinearBwdArgs survived until ctx GC, extending peak GPU memory
under retain_graph=True. Null out ctx.backward_objects right after
_linear_backward so they are released as soon as backward returns.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 8, 2026

/te-ci pytorch L1

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Most of my comments are just me talking through the design.

Comment on lines +1325 to +1328
# Drop all references held by bwd_args (saved tensors, quantizers, weakrefs,
# main_grad closure) so they don't outlive backward via ctx under retain_graph.
ctx.backward_objects = None
del bwd_args
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we destroy the cached state, we should also mark backward with function.once_differentiable.



@dataclass(slots=True)
class LinearFwdArgs:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design is nicer than exposing a long list of positional args or passing in a tuple of args, but it is less nice than exposing kwargs. However, it has some advantages for this case:

  • torch.compile infrastructure will only need to handle one non-tensor arg to the autograd function.
  • Autograd functions do some processing on each arg, so minimizing the number of non-tensor args reduces CPU overhead.
  • _Linear.forward calls _linear_forward_impl and it doesn't need to be aware of the impl specifics.



@dataclass(slots=True)
class LinearBwdArgs:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In typical usage, LinearBwdArgs is somewhat redundant with the autograd context class. But we need it so that the forward-context-saving and backward are less entangled with autograd, which will allow code reuse when we implement an alternate torch.compile code path.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 11, 2026

/te-ci pytorch L1

@pggPL pggPL merged commit 282b4fb into NVIDIA:main May 11, 2026
25 of 31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants