diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index 45ccf30644920..6c0fe363d5551 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -876,4 +876,32 @@ def UWTableKindEnum : LLVM_EnumAttr< let cppNamespace = "::mlir::LLVM::uwtable"; } +//===----------------------------------------------------------------------===// +// GEPNoWrapFlags +//===----------------------------------------------------------------------===// + +// These values must match llvm::GEPNoWrapFlags ones. +// See llvm/include/llvm/IR/GEPNoWrapFlags.h. +// Since inbounds implies nusw, create an inboundsFlag that represents the +// concept of raw inbounds with no nusw implication and the actual inbounds +// literal will be captured as the combination of inboundsFlag and nusw. + +def GEPNone : I32BitEnumCaseNone<"none">; +def GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">; +def GEPNusw : I32BitEnumCaseBit<"nusw", 1>; +def GEPNuw : I32BitEnumCaseBit<"nuw", 2>; +def GEPInbounds : BitEnumCaseGroup<"inbounds", [GEPInboundsFlag, GEPNusw]>; + +def GEPNoWrapFlags : I32BitEnum< + "GEPNoWrapFlags", + "::mlir::LLVM::GEPNoWrapFlags", + [GEPNone, GEPInboundsFlag, GEPNusw, GEPNuw, GEPInbounds]> { + let cppNamespace = "::mlir::LLVM"; + let printBitEnumPrimaryGroups = 1; +} + +def GEPNoWrapFlagsProp : EnumProp { + let defaultValue = interfaceType # "::none"; +} + #endif // LLVMIR_ENUMS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 5745d370f7268..5315e3994b33d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -291,7 +291,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, TypeAttr:$elem_type, - UnitAttr:$inbounds); + GEPNoWrapFlagsProp:$noWrapFlags); let results = (outs LLVM_ScalarOrVectorOf:$res); let skipDefaultBuilders = 1; @@ -303,8 +303,12 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, as indices. In the case of indexing within a structure, it is required to either use constant indices directly, or supply a constant SSA value. - An optional 'inbounds' attribute specifies the low-level pointer arithmetic + The no-wrap flags can be used to specify the low-level pointer arithmetic overflow behavior that LLVM uses after lowering the operation to LLVM IR. + Valid options include 'inbounds' (pointer arithmetic must be within object + bounds), 'nusw' (no unsigned signed wrap), and 'nuw' (no unsigned wrap). + Note that 'inbounds' implies 'nusw' which is ensured by the enum + definition. The flags can be set individually or in combination. Examples: @@ -323,10 +327,12 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, let builders = [ OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr, - "ValueRange":$indices, CArg<"bool", "false">:$inbounds, + "ValueRange":$indices, + CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr, - "ArrayRef":$indices, CArg<"bool", "false">:$inbounds, + "ArrayRef":$indices, + CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags, CArg<"ArrayRef", "{}">:$attributes)>, ]; let llvmBuilder = [{ @@ -343,10 +349,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, } Type baseElementType = op.getElemType(); llvm::Type *elementType = moduleTranslation.convertType(baseElementType); - $res = builder.CreateGEP(elementType, $base, indices, "", $inbounds); + $res = builder.CreateGEP(elementType, $base, indices, "", + llvm::GEPNoWrapFlags::fromRaw( + static_cast( + op.getNoWrapFlags()))); }]; let assemblyFormat = [{ - (`inbounds` $inbounds^)? + ($noWrapFlags^)? $base `[` custom($dynamicIndices, $rawConstantIndices) `]` attr-dict `:` functional-type(operands, results) `,` $elem_type }]; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 26c3ef1e8b8bf..d0ac39f538a1f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -673,29 +673,29 @@ static void destructureIndices(Type currType, ArrayRef indices, void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, Type elementType, Value basePtr, ArrayRef indices, - bool inbounds, ArrayRef attributes) { + GEPNoWrapFlags noWrapFlags, + ArrayRef attributes) { SmallVector rawConstantIndices; SmallVector dynamicIndices; destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices); result.addTypes(resultType); result.addAttributes(attributes); - result.addAttribute(getRawConstantIndicesAttrName(result.name), - builder.getDenseI32ArrayAttr(rawConstantIndices)); - if (inbounds) { - result.addAttribute(getInboundsAttrName(result.name), - builder.getUnitAttr()); - } - result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); + result.getOrAddProperties().rawConstantIndices = + builder.getDenseI32ArrayAttr(rawConstantIndices); + result.getOrAddProperties().noWrapFlags = noWrapFlags; + result.getOrAddProperties().elem_type = + TypeAttr::get(elementType); result.addOperands(basePtr); result.addOperands(dynamicIndices); } void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, Type elementType, Value basePtr, ValueRange indices, - bool inbounds, ArrayRef attributes) { + GEPNoWrapFlags noWrapFlags, + ArrayRef attributes) { build(builder, result, resultType, elementType, basePtr, - SmallVector(indices), inbounds, attributes); + SmallVector(indices), noWrapFlags, attributes); } static ParseResult @@ -794,6 +794,9 @@ LogicalResult LLVM::GEPOp::verify() { return emitOpError("expected as many dynamic indices as specified in '") << getRawConstantIndicesAttrName().getValue() << "'"; + if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag) + return emitOpError("'inbounds_flag' cannot be used directly."); + return verifyStructIndices(getElemType(), getIndices(), [&] { return emitOpError(); }); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 8640ef28a9e56..bc451f8b028fc 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -891,7 +891,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, auto byteType = IntegerType::get(builder.getContext(), 8); auto newPtr = builder.createOrFold( getLoc(), getResult().getType(), byteType, newSlot.ptr, - ArrayRef(accessInfo->subslotOffset), getInbounds()); + ArrayRef(accessInfo->subslotOffset), getNoWrapFlags()); getResult().replaceAllUsesWith(newPtr); return DeletionKind::Delete; } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 3f80002c15ebb..c350addeeb8bc 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1984,8 +1984,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { } Type type = convertType(inst->getType()); - auto gepOp = builder.create(loc, type, sourceElementType, *basePtr, - indices, gepInst->isInBounds()); + auto gepOp = builder.create( + loc, type, sourceElementType, *basePtr, indices, + static_cast(gepInst->getNoWrapFlags().getRaw())); mapValue(inst, gepOp); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index a3cd9572933ae..d7c47fcd98441 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1803,3 +1803,11 @@ llvm.func @t1() -> !llvm.ptr { ^bb1: llvm.return %0 : !llvm.ptr } + +// ----- + +llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) { + // expected-error@+1 {{'inbounds_flag' cannot be used directly}} + llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 93916f621630d..f30c8f2b16808 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -236,6 +236,16 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64, %ptr2: !llvm.ptr) { llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)> // CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: llvm.getelementptr inbounds|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.getelementptr inbounds | nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: llvm.getelementptr nusw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: llvm.getelementptr nusw|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: llvm.getelementptr nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> llvm.return } diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index c294e1b34f9bb..2098d85c18c3f 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -557,6 +557,25 @@ define void @gep_static_idx(ptr %ptr) { ; // ----- +; CHECK-LABEL: @gep_no_wrap_flags +; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] +define void @gep_no_wrap_flags(ptr %ptr) { + ; CHECK: %[[IDX:.+]] = llvm.mlir.constant(7 : i32) + ; CHECK: llvm.getelementptr inbounds %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %1 = getelementptr inbounds float, ptr %ptr, i32 7 + ; CHECK: llvm.getelementptr nusw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %2 = getelementptr nusw float, ptr %ptr, i32 7 + ; CHECK: llvm.getelementptr nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %3 = getelementptr nuw float, ptr %ptr, i32 7 + ; CHECK: llvm.getelementptr nusw|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %4 = getelementptr nusw nuw float, ptr %ptr, i32 7 + ; CHECK: llvm.getelementptr inbounds|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %5 = getelementptr inbounds nuw float, ptr %ptr, i32 7 + ret void +} + +; // ----- + ; CHECK: @varargs(...) declare void @varargs(...) diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 74fa327809864..4a2447263ae68 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1057,6 +1057,14 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64, llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)> // CHECK: = getelementptr inbounds { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: = getelementptr inbounds nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: = getelementptr nusw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: = getelementptr nusw nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> + // CHECK: = getelementptr nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}} + llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> llvm.return }