Skip to content

[WebGPU] Support continuous decoding (RewindTo) with graph capture#2083

Open
qjia7 wants to merge 5 commits intomainfrom
continuous-decoding-graph-capture
Open

[WebGPU] Support continuous decoding (RewindTo) with graph capture#2083
qjia7 wants to merge 5 commits intomainfrom
continuous-decoding-graph-capture

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented Apr 13, 2026

This pull request introduces improvements to the handling of attention masks in both the CUDA and WebGPU backends, focusing on more efficient and correct updates of mask buffers during decoding. The main changes are the implementation of a CPU-side update for static attention masks in CUDA and the addition of a reusable staging buffer for efficient mask updates in WebGPU, with logic to avoid redundant work for single-beam cases.

CUDA backend improvements:

  • Replaced the previous (commented-out and incorrect) CUDA memory set logic in DefaultPositionInputs::RewindMask with a CPU-side update that correctly sets attended and non-attended positions in the attention mask for each batch/beam, followed by a copy back to the device. This ensures the mask is set with 1s for attended tokens and 0s for future tokens, supporting both int32_t and int64_t types.

WebGPU backend improvements:

  • Added a reusable CPU staging buffer (mask_staging_buffer_) to the InterfaceImpl struct for efficient attention mask updates, avoiding repeated allocations and redundant writes.
  • Implemented the UpdateAttentionMask method to efficiently update the mask for single-beam cases by only filling new positions with 1s and copying the relevant portion to the device, falling back to CPU for multi-beam cases. This method handles static update path and supports both int32_t and int64_t mask types.

qjia7 added 2 commits April 13, 2026 21:42
Previously, RewindMask threw 'Static buffer is not supported for continuous
decoding' when graph capture was enabled. The original CUDA implementation
was disabled due to cudaMemsetAsync semantics issues.

Fix: For static mask handling (graph capture), rewind the attention mask
by zeroing positions [index, max_length) in-place on the static buffer.
This correctly resets the mask to reflect the target sequence length without
reallocating or reshaping the buffer.

Tested with phi4-graph-prune (graph capture ON, WebGPU):
- RewindTo(0): produces identical sequences across multiple runs
- RewindTo(10): first 10 tokens preserved, generation continues coherently
  Run 1 (15 tokens): 'The capital of France is Paris. Paris is known for
    its rich history, culture, and landmarks such'
  RewindTo(10) + 5 tokens: 'The capital of France is Paris. Paris is known
    for its rich history, culture'
  First 10 tokens match: True
Add WebGPU-native UpdateAttentionMask for both static and non-static paths
when batch_beam_size == 1. For single-batch no-padding, the mask is always
all 1s for attended positions. Upload total_length elements of 1s directly
from a reusable CPU staging buffer to GPU via CopyTensors.

Works for both update_only=true (graph capture static mask) and
update_only=false (non-static growing mask) since both reduce to writing
total_length 1s to the target buffer. Falls back to CPU for batch_beam_size > 1.
Copilot AI review requested due to automatic review settings April 13, 2026 14:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves continuous decoding support when graph capture/static attention masks are used, by updating attention mask buffers more efficiently and correctly across backends (notably CUDA/static mask rewind handling and WebGPU mask updates).

Changes:

  • Implement CPU-side rewind/update logic for static attention masks used with graph capture (write attended positions to 1 and future positions to 0, then copy back to device).
  • Add a reusable CPU staging buffer and a WebGPU UpdateAttentionMask fast path for single-beam cases to avoid repeated allocations and redundant CPU fills.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/webgpu/interface.cpp Adds a reusable staging buffer and implements WebGPU-side UpdateAttentionMask to upload an all-ones prefix for single-beam decoding.
src/models/position_inputs.cpp Replaces the previous “unsupported”/commented logic with a CPU-written static mask rewind implementation for graph capture.

Comment thread src/webgpu/interface.cpp
Comment thread src/webgpu/interface.cpp Outdated
Comment thread src/models/position_inputs.cpp
…ment

- Fix aliasing UB: change mask_staging_buffer_ from vector<uint8_t> to
  vector<int32_t>. For int64 masks, use std::memcpy to write int64_t(1)
  values without reinterpret_cast aliasing violations.
- Add bounds check in RewindMask: throw if index > max_length to prevent
  unsigned underflow in std::fill_n size calculation.
- Use size_t consistently for max_length in RewindMask to avoid
  signed/unsigned comparison surprises.
- Add threading model comment on mask_staging_buffer_: WebGPU/Dawn is
  single-threaded, document this assumption for future reference.
@@ -381,19 +381,30 @@ void DefaultPositionInputs::InitializeSequenceLengths(std::array<int64_t, 2> sha

void DefaultPositionInputs::RewindMask(size_t index) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Copilot:


RewindMask() is incomplete and currently only does anything when use_graph_capture is enabled. In the non-graph-capture path it silently returns without rewinding the mask, which makes DefaultPositionInputs::RewindTo() effectively a no-op for the attention mask in the common “dynamic mask” mode.

Here’s a focused review of the implementation in your file:

void DefaultPositionInputs::RewindMask(size_t index) {
  if (state_.params_->use_graph_capture) {
    // Static mask layout: [batch_beam_size, max_length]
    // Rewind to index: write 1s for [0, index), 0s for [index, max_length)
    ...
    byte_span.CopyCpuToDevice();
    return;
  }
}

What’s good

  • Correct intent for static mask: for graph capture they keep a fixed [batch_beam_size, max_length] tensor and overwrite it to represent the rewound length ([0, index) = 1, rest 0). That matches how a “static mask” should be updated.
  • Bounds check: throws if index > max_length.
  • Handles both int32/int64: consistent with the earlier type_ validation.

Issues / risks

  1. No behavior when use_graph_capture == false

    • The function just falls off the end with no action.
    • But DefaultPositionInputs::RewindTo() calls RewindMask(index) when has_mask_input_ and batch_size == 1 (see below), so callers will think rewind succeeded when it didn’t.
    • This is especially problematic because the dynamic mask path (ShouldUseStaticMaskHandling() == false) stores a mask whose second dimension tracks the current length, and rewinding should either:
      • shrink/recreate the tensor to [batch_beam, index], or
      • at least ensure future updates rebuild from the rewound length.
  2. Mismatch between caller condition and mask mode

    • RewindTo() decides to call RewindMask() only based on use_graph_capture? No—it checks attention_mask_shape_[0] == 1, then calls RewindMask(index) regardless of whether the mask handling is static/dynamic.
    • But RewindMask() only updates for use_graph_capture.
    • Net: RewindTo(index>0) for batch=1 in non-graph-capture mode does nothing.
  3. Potentially confusing memory pathway

    • It uses attention_mask_->GetByteSpan() then CpuSpan() and reinterpret_cast to int32_t*/int64_t*.
    • This can be OK if CpuSpan() is guaranteed to be host-accessible backing storage for that tensor (e.g., unified/pinned staging). If CpuSpan() is just a view into an internal CPU buffer that requires an explicit CopyDeviceToCpu() first, then this would be wrong. (In other parts of the file, the dynamic update path explicitly does CopyDeviceToCpu() before mutation.)
  4. max_len type mixing

    • max_len is size_t, but later arithmetic uses int for the outer loop and std::fill_n takes a size_t count. This is mostly fine, but keep it consistent to avoid signed/unsigned pitfalls.

What I’d change (minimum fix)

  • Make RewindMask() handle both static and dynamic mask modes:
    • If ShouldUseStaticMaskHandling() (not just use_graph_capture), do the current “overwrite [0,index)=1 else 0” behavior with max_length.
    • Else (dynamic), either:
      1. Recreate attention_mask_ to shape [batch_beam_size, index] and fill with ones (since batch_size==1 implies no padding), then update state_.inputs_[mask_input_index_], or
      2. Set is_first_update_ = true and adjust attention_mask_shape_ so the next Update() will rebuild properly (but that may not be sufficient if downstream expects the mask tensor immediately after rewind).

…sistency

- Use ShouldUseStaticMaskHandling() instead of use_graph_capture to gate
  the static mask path, consistent with UpdateAttentionMask.
- Add dynamic mask fallback: set attention_mask_shape_[1] = index so the
  next Update() creates a correctly sized tensor. For batch_beam_size == 1
  the CPU UpdateAttentionMask fills the entire mask with 1s, so no data
  fixup is needed.
- Fix signed/unsigned type mixing: use size_t consistently for max_len,
  batch_beam_size, and loop indices in the static mask path.

Tested both paths on WebGPU:
- Static mask (phi4-graph-prune, enableGraphCapture=1): PASS
- Dynamic mask (phi4-prune, past_present_share_buffer=false): PASS
@qjia7
Copy link
Copy Markdown
Contributor Author

qjia7 commented Apr 15, 2026

Addressed the review comments in 80be21c:

1. RewindMask() now handles both static and dynamic mask paths

  • Static path (ShouldUseStaticMaskHandling() — graph capture or NvTensorRtRtx with shared buffers): overwrites the fixed [batch_beam_size, max_length] tensor with 1s for [0, index) and 0s for [index, max_length), then copies to device. Same as before but now gated on ShouldUseStaticMaskHandling() instead of just use_graph_capture, consistent with UpdateAttentionMask().

  • Dynamic path (everything else): sets attention_mask_shape_[1] = index. The next Update() call will:

    1. CreateNextAttentionMaskTensor(total_length) — creates a correctly sized tensor using the adjusted shape
    2. CPU UpdateAttentionMask for batch_beam_size == 1 — fills the entire new mask with 1s (ignores old data)
    3. Replaces attention_mask_ with attention_mask_next_

    No tensor recreation or is_first_update_ reset needed — RewindTo already restricts to batch_size == 1 (no padding), so a shape-only adjustment is sufficient.

2. Fixed max_len type mixingbatch_beam_size, max_len, and loop indices now all use size_t consistently.

Tested on WebGPU (NVIDIA RTX):

  • Static mask (phi4-graph-prune, enableGraphCapture=1): ✅
  • Dynamic mask (phi4-prune, past_present_share_buffer=false): ✅

Both tests: generate 10 tokens -> rewind_to(prompt_len + 5) -> generate 5 more tokens. Output is consistent and coherent after rewind.

@qjia7 qjia7 requested a review from kunal-vaishnavi April 15, 2026 08:26
Add C++ and Python tests covering:
- RewindTo(0): full rewind and regenerate, verify identical output
- Multiple sequential RewindTo: rewind to 7, 5, then 0 in succession
- RewindTo with divergent tokens: rewind, append different tokens,
  then rewind again and recover original output

All tests use the bundled tiny-random-gpt2-fp32 model (CPU, no GPU needed).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants