[WebGPU] Support continuous decoding (RewindTo) with graph capture#2083
[WebGPU] Support continuous decoding (RewindTo) with graph capture#2083
Conversation
Previously, RewindMask threw 'Static buffer is not supported for continuous
decoding' when graph capture was enabled. The original CUDA implementation
was disabled due to cudaMemsetAsync semantics issues.
Fix: For static mask handling (graph capture), rewind the attention mask
by zeroing positions [index, max_length) in-place on the static buffer.
This correctly resets the mask to reflect the target sequence length without
reallocating or reshaping the buffer.
Tested with phi4-graph-prune (graph capture ON, WebGPU):
- RewindTo(0): produces identical sequences across multiple runs
- RewindTo(10): first 10 tokens preserved, generation continues coherently
Run 1 (15 tokens): 'The capital of France is Paris. Paris is known for
its rich history, culture, and landmarks such'
RewindTo(10) + 5 tokens: 'The capital of France is Paris. Paris is known
for its rich history, culture'
First 10 tokens match: True
Add WebGPU-native UpdateAttentionMask for both static and non-static paths when batch_beam_size == 1. For single-batch no-padding, the mask is always all 1s for attended positions. Upload total_length elements of 1s directly from a reusable CPU staging buffer to GPU via CopyTensors. Works for both update_only=true (graph capture static mask) and update_only=false (non-static growing mask) since both reduce to writing total_length 1s to the target buffer. Falls back to CPU for batch_beam_size > 1.
There was a problem hiding this comment.
Pull request overview
This PR improves continuous decoding support when graph capture/static attention masks are used, by updating attention mask buffers more efficiently and correctly across backends (notably CUDA/static mask rewind handling and WebGPU mask updates).
Changes:
- Implement CPU-side rewind/update logic for static attention masks used with graph capture (write attended positions to 1 and future positions to 0, then copy back to device).
- Add a reusable CPU staging buffer and a WebGPU
UpdateAttentionMaskfast path for single-beam cases to avoid repeated allocations and redundant CPU fills.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/webgpu/interface.cpp | Adds a reusable staging buffer and implements WebGPU-side UpdateAttentionMask to upload an all-ones prefix for single-beam decoding. |
| src/models/position_inputs.cpp | Replaces the previous “unsupported”/commented logic with a CPU-written static mask rewind implementation for graph capture. |
…ment - Fix aliasing UB: change mask_staging_buffer_ from vector<uint8_t> to vector<int32_t>. For int64 masks, use std::memcpy to write int64_t(1) values without reinterpret_cast aliasing violations. - Add bounds check in RewindMask: throw if index > max_length to prevent unsigned underflow in std::fill_n size calculation. - Use size_t consistently for max_length in RewindMask to avoid signed/unsigned comparison surprises. - Add threading model comment on mask_staging_buffer_: WebGPU/Dawn is single-threaded, document this assumption for future reference.
| @@ -381,19 +381,30 @@ void DefaultPositionInputs::InitializeSequenceLengths(std::array<int64_t, 2> sha | |||
|
|
|||
| void DefaultPositionInputs::RewindMask(size_t index) { | |||
There was a problem hiding this comment.
From Copilot:
RewindMask() is incomplete and currently only does anything when use_graph_capture is enabled. In the non-graph-capture path it silently returns without rewinding the mask, which makes DefaultPositionInputs::RewindTo() effectively a no-op for the attention mask in the common “dynamic mask” mode.
Here’s a focused review of the implementation in your file:
void DefaultPositionInputs::RewindMask(size_t index) {
if (state_.params_->use_graph_capture) {
// Static mask layout: [batch_beam_size, max_length]
// Rewind to index: write 1s for [0, index), 0s for [index, max_length)
...
byte_span.CopyCpuToDevice();
return;
}
}What’s good
- Correct intent for static mask: for graph capture they keep a fixed
[batch_beam_size, max_length]tensor and overwrite it to represent the rewound length ([0, index)= 1, rest 0). That matches how a “static mask” should be updated. - Bounds check: throws if
index > max_length. - Handles both int32/int64: consistent with the earlier
type_validation.
Issues / risks
-
No behavior when
use_graph_capture == false- The function just falls off the end with no action.
- But
DefaultPositionInputs::RewindTo()callsRewindMask(index)whenhas_mask_input_andbatch_size == 1(see below), so callers will think rewind succeeded when it didn’t. - This is especially problematic because the dynamic mask path (
ShouldUseStaticMaskHandling() == false) stores a mask whose second dimension tracks the current length, and rewinding should either:- shrink/recreate the tensor to
[batch_beam, index], or - at least ensure future updates rebuild from the rewound length.
- shrink/recreate the tensor to
-
Mismatch between caller condition and mask mode
RewindTo()decides to callRewindMask()only based onuse_graph_capture? No—it checksattention_mask_shape_[0] == 1, then callsRewindMask(index)regardless of whether the mask handling is static/dynamic.- But
RewindMask()only updates foruse_graph_capture. - Net:
RewindTo(index>0)for batch=1 in non-graph-capture mode does nothing.
-
Potentially confusing memory pathway
- It uses
attention_mask_->GetByteSpan()thenCpuSpan()andreinterpret_casttoint32_t*/int64_t*. - This can be OK if
CpuSpan()is guaranteed to be host-accessible backing storage for that tensor (e.g., unified/pinned staging). IfCpuSpan()is just a view into an internal CPU buffer that requires an explicitCopyDeviceToCpu()first, then this would be wrong. (In other parts of the file, the dynamic update path explicitly doesCopyDeviceToCpu()before mutation.)
- It uses
-
max_lentype mixingmax_lenissize_t, but later arithmetic usesintfor the outer loop andstd::fill_ntakes asize_tcount. This is mostly fine, but keep it consistent to avoid signed/unsigned pitfalls.
What I’d change (minimum fix)
- Make
RewindMask()handle both static and dynamic mask modes:- If
ShouldUseStaticMaskHandling()(not justuse_graph_capture), do the current “overwrite [0,index)=1 else 0” behavior with max_length. - Else (dynamic), either:
- Recreate
attention_mask_to shape[batch_beam_size, index]and fill with ones (since batch_size==1 implies no padding), then updatestate_.inputs_[mask_input_index_], or - Set
is_first_update_ = trueand adjustattention_mask_shape_so the nextUpdate()will rebuild properly (but that may not be sufficient if downstream expects the mask tensor immediately after rewind).
- Recreate
- If
…sistency - Use ShouldUseStaticMaskHandling() instead of use_graph_capture to gate the static mask path, consistent with UpdateAttentionMask. - Add dynamic mask fallback: set attention_mask_shape_[1] = index so the next Update() creates a correctly sized tensor. For batch_beam_size == 1 the CPU UpdateAttentionMask fills the entire mask with 1s, so no data fixup is needed. - Fix signed/unsigned type mixing: use size_t consistently for max_len, batch_beam_size, and loop indices in the static mask path. Tested both paths on WebGPU: - Static mask (phi4-graph-prune, enableGraphCapture=1): PASS - Dynamic mask (phi4-prune, past_present_share_buffer=false): PASS
|
Addressed the review comments in 80be21c: 1.
2. Fixed Tested on WebGPU (NVIDIA RTX):
Both tests: generate 10 tokens -> |
Add C++ and Python tests covering: - RewindTo(0): full rewind and regenerate, verify identical output - Multiple sequential RewindTo: rewind to 7, 5, then 0 in succession - RewindTo with divergent tokens: rewind, append different tokens, then rewind again and recover original output All tests use the bundled tiny-random-gpt2-fp32 model (CPU, no GPU needed).
This pull request introduces improvements to the handling of attention masks in both the CUDA and WebGPU backends, focusing on more efficient and correct updates of mask buffers during decoding. The main changes are the implementation of a CPU-side update for static attention masks in CUDA and the addition of a reusable staging buffer for efficient mask updates in WebGPU, with logic to avoid redundant work for single-beam cases.
CUDA backend improvements:
DefaultPositionInputs::RewindMaskwith a CPU-side update that correctly sets attended and non-attended positions in the attention mask for each batch/beam, followed by a copy back to the device. This ensures the mask is set with 1s for attended tokens and 0s for future tokens, supporting bothint32_tandint64_ttypes.WebGPU backend improvements:
mask_staging_buffer_) to theInterfaceImplstruct for efficient attention mask updates, avoiding repeated allocations and redundant writes.UpdateAttentionMaskmethod to efficiently update the mask for single-beam cases by only filling new positions with 1s and copying the relevant portion to the device, falling back to CPU for multi-beam cases. This method handles static update path and supports bothint32_tandint64_tmask types.