-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Optimize AttentionPrepare #26850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request optimizes the AttentionPrepare operation in the WebGPU BERT Attention operator by replacing a custom QKV preparation kernel with a more modular approach using MatMul followed by a dedicated SplitPackedQKV kernel. This refactoring improves performance (from 751.67ms to 128.88ms in phi4-vision model) by leveraging optimized MatMul operations and enhances code maintainability through better separation of concerns and reusability.
Key changes:
- Replaced custom AttentionPrepare kernel with MatMul + SplitPackedQKV approach
- Moved SplitPackedQKV implementation from group_query_attention.cc to attention.cc for broader reuse
- Enhanced SplitPackedQKV with vectorization support and an additional kv_hidden_size parameter
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | Removed SplitPackedQKVProgram class declaration (moved to attention.h) |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Removed SplitPackedQKV function implementation and updated call site to include new kv_hidden_size parameter |
| onnxruntime/contrib_ops/webgpu/bert/attention_common.h | Added SplitPackedQKV function declaration for shared use across attention operators |
| onnxruntime/contrib_ops/webgpu/bert/attention.h | Added SplitPackedQKVProgram class declaration with updated uniform variables including input_size |
| onnxruntime/contrib_ops/webgpu/bert/attention.cc | Implemented new PrepareQKV using MatMul + SplitPackedQKV, added vectorization support, refactored to create Q/K/V in BSD format first before converting to BNSH for non-flash attention path |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This pull request refactors and streamlines the computation of Q, K, V tensors in the WebGPU BERT Attention operator. The main changes include removing a custom QKV preparation kernel in favor of a more modular approach using a MatMul operation followed by a dedicated split kernel, and generalizing the QKV splitting logic for broader reuse. This improves maintainability, code reuse, and performance since we have done many optimization on MatMul op.
With this change, PrepareQKV becomes 128.88 ms from 751.67 ms in phi4-vision model.
Before
After