Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 61b8a57

Browse files
authored
[MLIR][XeGPU] Refactor layout propagation utilities (#179016)
This PR refactors layout propagation into two distinct components: result/anchor layout setup and source layout inference from the result. For operations that require a specific result layout due to semantic or hardware constraints, the propagation logic explicitly sets up the result or anchor layout. Otherwise, it infers the source layout from the backward-propagated consumer layout. The result or anchor layout may differ from the backward-propagated consumer layout; any such discrepancies are resolved via the existing layout-conflict mechanism. **This PR introduces the following utility functions:** Source layout inference: > inferBroadcastSourceLayout() > inferMultiReductionSourceLayout() > inferBitCastSourceLayout() > inferShapeCastSourceLayout() > inferInsertStridedSliceSourceLayout() Result / anchor layout setup: > setupMultiReductionResultLayout() > setupBitCastResultLayout() > setupInsertStridedSliceResultLayout() > setupLoadMatrixAnchorLayout() > setupStoreMatrixAnchorLayout() > setupLoadGatherAnchorLayout() > setupStoreScatterAnchorLayout() Part of subgroup distribution related code changes are separated and created as PR https://github.com/llvm/llvm-project/pull/179018/changes.
1 parent 15a30e3 commit 61b8a57

21 files changed

Lines changed: 1901 additions & 547 deletions

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,31 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
226226
InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
227227
"xegpu::DistributeLayoutAttr",
228228
"setUnitDimData",
229-
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
229+
/*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
230230
InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
231231
"xegpu::DistributeLayoutAttr",
232232
"setUnitDimLayout",
233-
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
233+
/*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
234234
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
235235
indices based on the effective layout level.}],
236236
"FailureOr<SmallVector<Value>>",
237237
"delinearizeId",
238238
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
239+
InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the
240+
specified values for the given dimension. Passing -1 for any parameter
241+
preserves its original value.}],
242+
"xegpu::DistributeLayoutAttr",
243+
"setDimData",
244+
(ins "int64_t": $dim,
245+
"int64_t": $sgData,
246+
"int64_t": $instData,
247+
"int64_t": $laneData)>,
248+
InterfaceMethod<[{Derive a new layout by collapsing dimensions.
249+
`dimGroup` specifies a group of adjacent dimensions that are collapsed into
250+
a single dimension in the derived layout.}],
251+
"xegpu::DistributeLayoutAttr",
252+
"collapseDims",
253+
(ins "SmallVector<int64_t>": $dimGroup)>,
239254
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
240255
assigned to a level identified by linearId. The shape parameter
241256
represents the higher-level problem size. Each level may access
@@ -501,10 +516,20 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
501516
}
502517

503518
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
504-
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
519+
DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
505520

506521
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
507-
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
522+
DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
523+
524+
// Derive a new layout with sg_data, inst_data and lane_data set to the
525+
// specified values for the given dimension. Passing -1 for any parameter
526+
// preserves its original value.
527+
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
528+
529+
// Derive a new layout by collapsing dimensions.
530+
// `dimGroup` specifies a group of adjacent dimensions
531+
// that are collapsed into a single dimension in the derived layout.
532+
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
508533

509534
/// Delinearizes a linear ID into its multidimensional indices
510535
/// based on the effective level of the layout.
@@ -672,10 +697,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
672697
}
673698

674699
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
675-
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
700+
DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
676701

677702
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
678-
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
703+
DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
704+
705+
// Derive a new layout with sg_data, inst_data and lane_data set to the
706+
// specified values for the given dimension. Passing -1 for any parameter
707+
// preserves its original value.
708+
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
709+
710+
// Derive a new layout by collapsing dimensions.
711+
// `dimGroup` specifies a group of adjacent dimensions
712+
// that are collapsed into a single dimension in the derived layout.
713+
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
679714

680715
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
681716
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,6 @@ void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
103103
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
104104
const UnrollOptions &options);
105105

106-
enum class LayoutKind { Lane, InstData, Subgroup };
107-
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
108-
LayoutKind layoutKind, bool printOnly = false);
109-
110-
LogicalResult resolveLayoutConflicts(Operation *target);
111-
112106
} // namespace xegpu
113107
} // namespace mlir
114108

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
//===- XeGPULayoutImpl.h - Layout utility functions ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
10+
#define MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
11+
12+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
13+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
14+
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/OpDefinition.h"
17+
18+
namespace mlir {
19+
20+
class VectorType;
21+
class OpOperand;
22+
class OpResult;
23+
class OpBuilder;
24+
class ValueRange;
25+
class TypeConverter;
26+
class OpFoldResult;
27+
28+
namespace xegpu {
29+
class DistributeLayoutAttr;
30+
class LayoutAttr;
31+
class TensorDescType;
32+
} // namespace xegpu
33+
34+
namespace xegpu {
35+
36+
enum class LayoutKind { Lane, InstData, Subgroup };
37+
38+
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
39+
LayoutKind layoutKind, bool printOnly = false);
40+
41+
LogicalResult resolveLayoutConflicts(Operation *target);
42+
43+
/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
44+
/// OpResult of of the given operation. If the operation contains regions, it is
45+
/// also applied recursively to the contained operations operation.
46+
/// TODO: To be replaced by recoverTemporaryLayouts()
47+
void recoverTemporaryLayoutsDeprecated(Operation *op);
48+
49+
/// Attach layout attributes to all vector-type operands of operations within
50+
/// the given operation's nested region. Reports an error if any vector operand
51+
/// lacks a layout attribute.
52+
bool recoverTemporaryLayouts(Operation *rootOp);
53+
54+
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
55+
template <typename T,
56+
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
57+
std::is_same_v<T, OpResult>>>
58+
void removeLayoutAttr(const T &operandOrResult);
59+
60+
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
61+
/// given operation if they exist. If the operation contains regions, it is also
62+
/// applied recursively to the contained operations
63+
void removeLayoutAttrs(Operation *op);
64+
65+
/// Updates the NamedAttribute sequence by dropping sg-layout and
66+
/// sg-data information from any DistributeLayoutAttr found.
67+
SmallVector<NamedAttribute>
68+
dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
69+
70+
/// Updates the NamedAttribute sequence by dropping inst-data information from
71+
/// any DistributeLayoutAttr found.
72+
SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
73+
74+
/// Infers the source layout attribute for a broadcast operation given the
75+
/// result layout attribute, result shape, and source shape.
76+
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
77+
ArrayRef<int64_t> resShape,
78+
ArrayRef<int64_t> srcShape);
79+
80+
/// Infers the source layout attribute for a reduction operation given the
81+
/// result layout attribute and reduced dims.
82+
DistributeLayoutAttr
83+
inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
84+
SmallVector<int64_t> reduceDims);
85+
86+
/// Infers the source layout attribute for a bitcast operation given the
87+
/// result layout attribute, result element type bitwidth, and source element
88+
/// type bitwidth.
89+
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
90+
int resElemTyBitWidth,
91+
int srcElemTyBitWidth);
92+
93+
/// Infers the source layout attribute for a shape cast operation given the
94+
/// result layout attribute, result shape, and source shape.
95+
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
96+
ArrayRef<int64_t> resShape,
97+
ArrayRef<int64_t> srcShape);
98+
99+
/// Infers the source layout attribute for an insert strided slice operation
100+
/// given the result layout attribute, result shape, and source shape. Removes
101+
/// leading dimensions from the result layout to match the source shape size.
102+
DistributeLayoutAttr
103+
inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
104+
ArrayRef<int64_t> resShape,
105+
ArrayRef<int64_t> srcShape);
106+
107+
/// Sets up layout for reduction operations by creating a SliceAttr for the
108+
/// result.
109+
///
110+
/// This function first attempts to construct a source layout that, when
111+
/// sliced along reduction dimensions, produces a result layout compatible
112+
/// with the consumer's preferred layout. This minimizes data redistribution
113+
/// overhead. The SliceAttr for the result is then created based on the
114+
/// derived source layout and the specified reduction dimensions.
115+
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind,
116+
VectorType srcVectorTy,
117+
DistributeLayoutAttr consumerLayout,
118+
SmallVector<int64_t> reductionDims,
119+
const uArch::uArch *uArch);
120+
121+
/// Setup the result layout attribute for a bitcast operation based on element
122+
/// type bitwidths. This ensures the source layout can always be derived from
123+
/// the result layout.
124+
///
125+
/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
126+
/// resElemTyBitWidth), the result layout's innermost dimension data sizes
127+
/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This maintains
128+
/// the invariant that the source layout can be recovered by adjusting the
129+
/// result layout based on bitwidth ratio of input vs output.
130+
DistributeLayoutAttr setupBitCastResultLayout(
131+
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
132+
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
133+
134+
/// Sets up the result layout for an insert strided slice operation.
135+
/// Creates a result layout based on the specified layout kind (InstData or
136+
/// Lane).
137+
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
138+
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
139+
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
140+
141+
/// Sets up the anchor layout for a load gather operation.
142+
DistributeLayoutAttr
143+
setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
144+
int chunkSize, DistributeLayoutAttr consumerLayout,
145+
const uArch::uArch *uArch);
146+
147+
/// Sets up the anchor layout for load matrix operation.
148+
DistributeLayoutAttr
149+
setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
150+
DistributeLayoutAttr consumerLayout,
151+
const uArch::uArch *uArch);
152+
153+
/// Sets up the anchor layout for a store scatter operation.
154+
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
155+
VectorType vectorTy,
156+
int chunkSize,
157+
const uArch::uArch *uArch);
158+
159+
/// Sets up the anchor layout for a store matrix operation.
160+
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
161+
VectorType vectorTy,
162+
const uArch::uArch *uArch);
163+
164+
} // namespace xegpu
165+
166+
} // namespace mlir
167+
168+
#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ template <typename T>
137137
int getLargestDivisor(T dim, ArrayRef<T> candidates,
138138
ArrayRef<T> candidateMultiples = {});
139139

140-
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
141-
std::string getTemporaryLayoutName(const OpOperand &operand);
142-
143-
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
144-
std::string getTemporaryLayoutName(const OpResult result);
145-
146140
/// Retrieves the DistributeLayoutAttr associated with a given Value. For
147141
/// TensorDescType values, the DistributeLayoutAttr is extracted from the
148142
/// TensorDescType itself. For other values, it is obtained from the attributes
@@ -155,26 +149,6 @@ DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
155149
/// found, it will check the operand itself and its defining op.
156150
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
157151

158-
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
159-
template <typename T,
160-
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
161-
std::is_same_v<T, OpResult>>>
162-
void removeLayoutAttr(const T &operandOrResult);
163-
164-
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
165-
/// given operation if they exist. If the operation contains regions, it is also
166-
/// applied recursively to the contained operations
167-
void removeLayoutAttrs(Operation *op);
168-
169-
/// Updates the NamedAttribute sequence by dropping sg-layout and
170-
/// sg-data information from any DistributeLayoutAttr found.
171-
SmallVector<NamedAttribute>
172-
dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
173-
174-
/// Updates the NamedAttribute sequence by dropping inst-data information from
175-
/// any DistributeLayoutAttr found.
176-
SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
177-
178152
/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
179153
/// user should use setAnchorLayout instead
180154
void setDistributeLayoutAttr(const OpResult &Result,
@@ -185,6 +159,12 @@ void setDistributeLayoutAttr(const OpResult &Result,
185159
void setDistributeLayoutAttr(const OpOperand &opr,
186160
const DistributeLayoutAttr layout);
187161

162+
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
163+
std::string getTemporaryLayoutName(const OpOperand &operand);
164+
165+
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
166+
std::string getTemporaryLayoutName(const OpResult result);
167+
188168
/// get and set distribute layout attribute for non-anchor operations
189169
/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
190170
template <typename T,
@@ -198,17 +178,6 @@ template <typename T,
198178
void setTemporaryLayout(const T &operandOrResult,
199179
const DistributeLayoutAttr layout);
200180

201-
/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
202-
/// OpResult of of the given operation. If the operation contains regions, it is
203-
/// also applied recursively to the contained operations operation.
204-
/// TODO: To be replaced by recoverTemporaryLayouts()
205-
void recoverTemporaryLayoutsDeprecated(Operation *op);
206-
207-
/// Attach layout attributes to all vector-type operands of operations within
208-
/// the given operation's region. Reports an error if any vector operand lacks
209-
/// a layout attribute.
210-
bool recoverTemporaryLayouts(Operation *rootOp);
211-
212181
/// Helper function to check if the layout is packed. Layout is packed if it is
213182
/// 2D and lane_data[0] != 1 (data packed from col dimension).
214183
/// TODO: Move to target info.
@@ -217,6 +186,15 @@ bool requirePacked(const LayoutAttr layout);
217186
/// Helper function to check if the layout requires a transpose effect.
218187
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
219188

189+
// Check if dst shape is an expansion of src shape by inserting unit dimensions.
190+
bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
191+
SmallVector<int64_t> &expandedUnitDims);
192+
193+
// Checks if dst shape is an expansion of src shape where each dimension in src
194+
// is split into one or more consecutive dimensions in dst
195+
bool matchSplitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
196+
SmallVector<SmallVector<int64_t>> &splitDimGroups);
197+
220198
} // namespace xegpu
221199

222200
} // namespace mlir

mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,19 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
216216
};
217217

218218
struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
219-
int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
220-
return 16;
221-
}
219+
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
222220
};
223221

224222
struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
225-
int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
226-
return 16;
227-
}
223+
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
224+
};
225+
226+
struct LoadMatrixInstruction : public LoadMatrixInstructionInterface {
227+
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
228+
};
229+
230+
struct StoreMatrixInstruction : public StoreMatrixInstructionInterface {
231+
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
228232
};
229233

230234
//===----------------------------------------------------------------------===//
@@ -239,9 +243,11 @@ struct PVCuArch final : public Xe2Plus {
239243
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
240244
static const SpirvStoreScatterInstruction storeScatterInst;
241245
static const SpirvLoadGatherInstruction loadGatherInst;
242-
static const Instruction *arr[] = {&dpasInst, &loadNdInst,
243-
&storeNdInst, &prefetchNdInst,
244-
&storeScatterInst, &loadGatherInst};
246+
static const StoreMatrixInstruction storeMatrixInst;
247+
static const LoadMatrixInstruction loadMatrixInst;
248+
static const Instruction *arr[] = {
249+
&dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
250+
&storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
245251
return arr;
246252
}
247253

0 commit comments

Comments
 (0)