-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[DAGCombine] Simplify partial_reduce_*mla with constant. #138289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag Author: Sander de Smalen (sdesmalen-arm) Changespartial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) Full diff: https://github.com/llvm/llvm-project/pull/138289.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..345cb4f9fb6ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
-
+ auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
+ APInt C;
if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
- unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+ if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
- SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
- return SDValue();
- // Only perform the DAG combine if there is custom lowering provided by the
- // target
- auto *Context = DAG.getContext();
+ // Only perform these combines if the target supports folding
+ // the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
+ // -> partial_reduce_umla(acc, x, C)
+ if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+ APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+ unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+ if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+ (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+ return SDValue();
+
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ }
+
+ unsigned RHSOpcode = RHS->getOpcode();
+ if (!ISD::isExtOpcode(RHSOpcode))
+ return SDValue();
+
+ SDValue RHSExtOp = RHS->getOperand(0);
+ if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+ return SDValue();
// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
ret <vscale x 2 x i16> %partial.reduce
}
-
define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: partial_reduce_only_split_acc:
; CHECK: // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
<vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 4 x i64> %partial.reduce
}
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: sub z0.s, z0.s, z2.s
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sub z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEXT: mov z2.s, #255 // =0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z4.s, z3.h
+; CHECK-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
|
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) Changespartial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) Full diff: https://github.com/llvm/llvm-project/pull/138289.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..345cb4f9fb6ee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
-
+ auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
+ APInt C;
if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
- unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+ if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
- SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
- return SDValue();
- // Only perform the DAG combine if there is custom lowering provided by the
- // target
- auto *Context = DAG.getContext();
+ // Only perform these combines if the target supports folding
+ // the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
+ // -> partial_reduce_umla(acc, x, C)
+ if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+ APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
+ unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
+ if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
+ (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
+ return SDValue();
+
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ }
+
+ unsigned RHSOpcode = RHS->getOpcode();
+ if (!ISD::isExtOpcode(RHSOpcode))
+ return SDValue();
+
+ SDValue RHSExtOp = RHS->getOperand(0);
+ if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+ return SDValue();
// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
@@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..5326bccbbc3d5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1139,7 +1139,6 @@ entry:
ret <vscale x 2 x i16> %partial.reduce
}
-
define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: partial_reduce_only_split_acc:
; CHECK: // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
<vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 4 x i64> %partial.reduce
}
+
+define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: sub z0.s, z0.s, z2.s
+; CHECK-NEXT: sub z0.s, z0.s, z3.s
+; CHECK-NEXT: sub z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: sdot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEXT: mov z2.s, #255 // =0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z4.s, z3.h
+; CHECK-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
+; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: udot_imm_does_not_fit:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
+; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
|
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) -> partial_reduce_*mla(acc, x, C)
f734038
to
6c9e8ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (bar a nit pick)
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) -> partial_reduce_*mla(acc, x, C)
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-> partial_reduce_*mla(acc, x, C)