-
Notifications
You must be signed in to change notification settings - Fork 632
[TORCH] Transformer encoder decomposition #4381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
2e95de5 to
0de7090
Compare
- Add a dedicated DecomposeTransformerEncoder pass to expand encoder ops into primitive Torch patterns. - Extend shared lowering helpers (ReduceOpVariants.cpp, Utils.h) so the new pass can reuse reduction utilities during decomposition. - Register the pass in the Torch Transform pipeline so it runs as part of the decomposition flow. - Expand e2e coverage with new transformer encoder tests to validate the lowering path. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7
0de7090 to
e446682
Compare
Lallapallooza
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch! I left comments on a few correctness/pipeline-contract and some nits/cleanup.
| RewritePatternSet &patterns, const llvm::StringSet<> &legalOpsSet) { | ||
| MLIRContext *context = patterns.getContext(); | ||
| DecomposeAtenTransformerEncoderLayerFwd pattern(context); | ||
| auto opName = pattern.getRootKind(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
populateTransformerEncoderPatterns(patterns, legalOpsSet) must respect the same "legal ops" contract as the other patterns in this pass. As-is, the transformer rewrite is a torch.operator pattern (semantic op name lives in an attribute), so legality gating must be done by inspecting the operator name attribute, not by getRootKind().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#pragma once is missing
| namespace Torch { | ||
|
|
||
| inline bool isTransformerEncoderOperatorName(llvm::StringRef name) { | ||
| if (!name.consume_front("torch.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use kTorchOpPrefix here.
| } | ||
|
|
||
| bool isSpecializedOperation(Torch::OperatorOp op) { return true; } | ||
| bool isSpecializedOperation(Torch::OperatorOp op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously isSpecializedOperation effectively treated torch.operator as illegal unless explicitly handled. Now it returns true only for flash-attn and returns false for almost everything else, which risks letting unexpected torch.operator ops leak further into the pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Utils.h is an extremely generic filename and easy to confuse with torch-mlir/Dialect/Torch/Utils/Utils.h. Please rename to something domain-specific (e.g. TransformerEncoderUtils.h.
| "StdCorrectionLargeInputModule_basic", | ||
| "TupleModule_basic", | ||
| "ThresholdStaticModule_basic", | ||
| "TransformerEncoderModule_basic", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a reason why it fails.
| "TrilIndicesNegativeOffsetModule_basic", | ||
| "TriuIndicesAllZerosModule_basic", | ||
| "TriuIndicesModule_basic", | ||
| "TransformerEncoderModule_basic", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a reason why it fails.
| void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, | ||
| MLIRContext *context); | ||
|
|
||
| void populateTransformerEncoderPatterns(RewritePatternSet &patterns, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
populateTransformerEncoderPatterns is being added to the public passes header that's consistent with other populate*Pattern(s) declarations here, but please confirm this is intended API surface (vs a private helper), since it's tightly coupled to DecomposeComplexOps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This standalone python script feels redundant/extra maintenance. I'd prefer deleting this file and relying on the lit test + the pt1 e2e transformer coverage.
| Value bias) -> FailureOr<Value> { | ||
| auto inputTensorType = cast<ValueTensorType>(input.getType()); | ||
| Value normalizedShape = createIntList(rewriter, loc, {embedDim}); | ||
| Value cudnnEnable = createBoolConstant(rewriter, loc, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
buildLayerNorm always sets cudnn_enable=true. Please confirm this matches the semantics of _transformer_encoder_layer_fwd for the CPU path, if the fused op can produce cudnn_enable=false, this rewrite could diverge.
Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7