Skip to content

Conversation

@catcor01
Copy link
Contributor

  • Add a dedicated DecomposeTransformerEncoder pass to expand encoder ops into primitive Torch patterns.
  • Extend shared lowering helpers (ReduceOpVariants.cpp, Utils.h) so the new pass can reuse reduction utilities during decomposition.
  • Register the pass in the Torch Transform pipeline so it runs as part of the decomposition flow.
  • Expand e2e coverage with new transformer encoder tests to validate the lowering path.

Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7

@catcor01 catcor01 force-pushed the transformer_encoder branch 2 times, most recently from 2e95de5 to 0de7090 Compare November 26, 2025 09:29
- Add a dedicated DecomposeTransformerEncoder pass to expand encoder ops into primitive Torch patterns.
- Extend shared lowering helpers (ReduceOpVariants.cpp, Utils.h)
  so the new pass can reuse reduction utilities during decomposition.
- Register the pass in the Torch Transform pipeline so it runs as part of the decomposition flow.
- Expand e2e coverage with new transformer encoder tests to validate the lowering path.

Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7
@catcor01 catcor01 force-pushed the transformer_encoder branch from 0de7090 to e446682 Compare November 26, 2025 09:37
Copy link
Collaborator

@Lallapallooza Lallapallooza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch! I left comments on a few correctness/pipeline-contract and some nits/cleanup.

RewritePatternSet &patterns, const llvm::StringSet<> &legalOpsSet) {
MLIRContext *context = patterns.getContext();
DecomposeAtenTransformerEncoderLayerFwd pattern(context);
auto opName = pattern.getRootKind();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

populateTransformerEncoderPatterns(patterns, legalOpsSet) must respect the same "legal ops" contract as the other patterns in this pass. As-is, the transformer rewrite is a torch.operator pattern (semantic op name lives in an attribute), so legality gating must be done by inspecting the operator name attribute, not by getRootKind().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#pragma once is missing

namespace Torch {

inline bool isTransformerEncoderOperatorName(llvm::StringRef name) {
if (!name.consume_front("torch."))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use kTorchOpPrefix here.

}

bool isSpecializedOperation(Torch::OperatorOp op) { return true; }
bool isSpecializedOperation(Torch::OperatorOp op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously isSpecializedOperation effectively treated torch.operator as illegal unless explicitly handled. Now it returns true only for flash-attn and returns false for almost everything else, which risks letting unexpected torch.operator ops leak further into the pipeline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.h is an extremely generic filename and easy to confuse with torch-mlir/Dialect/Torch/Utils/Utils.h. Please rename to something domain-specific (e.g. TransformerEncoderUtils.h.

"StdCorrectionLargeInputModule_basic",
"TupleModule_basic",
"ThresholdStaticModule_basic",
"TransformerEncoderModule_basic",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a reason why it fails.

"TrilIndicesNegativeOffsetModule_basic",
"TriuIndicesAllZerosModule_basic",
"TriuIndicesModule_basic",
"TransformerEncoderModule_basic",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a reason why it fails.

void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
MLIRContext *context);

void populateTransformerEncoderPatterns(RewritePatternSet &patterns,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

populateTransformerEncoderPatterns is being added to the public passes header that's consistent with other populate*Pattern(s) declarations here, but please confirm this is intended API surface (vs a private helper), since it's tightly coupled to DecomposeComplexOps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This standalone python script feels redundant/extra maintenance. I'd prefer deleting this file and relying on the lit test + the pt1 e2e transformer coverage.

Value bias) -> FailureOr<Value> {
auto inputTensorType = cast<ValueTensorType>(input.getType());
Value normalizedShape = createIntList(rewriter, loc, {embedDim});
Value cudnnEnable = createBoolConstant(rewriter, loc, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buildLayerNorm always sets cudnn_enable=true. Please confirm this matches the semantics of _transformer_encoder_layer_fwd for the CPU path, if the fused op can produce cudnn_enable=false, this rewrite could diverge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants