diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1a8e8737f7fed..71aaee931d543 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); +/// Returns the number of successor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block); + +/// Returns `pos`-th successor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block, + intptr_t pos); + +/// Returns the number of predecessor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block); + +/// Returns `pos`-th predecessor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block, + intptr_t pos); + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cbd35f2974ae9..c12f036352b30 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2626,6 +2626,84 @@ class PyOpSuccessors : public Sliceable { PyOperationRef operation; }; +/// A list of block successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation and block whose successors these are, and thus +/// extends the lifetime of this operation and block. +class PyBlockSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockSuccessors"; + + PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumSuccessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumSuccessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyBlockSuccessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of block predecessors. The (returned) predecessor list is +/// associated with the operation and block whose predecessors these are, and +/// thus extends the lifetime of this operation and block. +class PyBlockPredecessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockPredecessors"; + + PyBlockPredecessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumPredecessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumPredecessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockPredecessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + /// A list of operation attributes. Can be indexed by name, producing /// attributes, or by index, producing named attributes. class PyOpAttributeMap { @@ -3655,7 +3733,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("operation"), "Appends an operation to this block. If the operation is currently " - "in another block, it will be moved."); + "in another block, it will be moved.") + .def_prop_ro( + "successors", + [](PyBlock &self) { + return PyBlockSuccessors(self, self.getParentOperation()); + }, + "Returns the list of Block successors.") + .def_prop_ro( + "predecessors", + [](PyBlock &self) { + return PyBlockPredecessors(self, self.getParentOperation()); + }, + "Returns the list of Block predecessors."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. @@ -4099,6 +4189,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); PyBlockList::bind(m); + PyBlockSuccessors::bind(m); + PyBlockPredecessors::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e0e386d55ede1..fbc66bcf5c2d0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, unwrap(block)->print(stream); } +intptr_t mlirBlockGetNumSuccessors(MlirBlock block) { + return static_cast(unwrap(block)->getNumSuccessors()); +} + +MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) { + return wrap(unwrap(block)->getSuccessor(static_cast(pos))); +} + +intptr_t mlirBlockGetNumPredecessors(MlirBlock block) { + Block *b = unwrap(block); + return static_cast(std::distance(b->pred_begin(), b->pred_end())); +} + +MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) { + Block *b = unwrap(block); + Block::pred_iterator it = b->pred_begin(); + std::advance(it, pos); + return wrap(*it); +} + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 68da79f69cc0a..ed7ee58c97930 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2440,6 +2440,52 @@ void testDiagnostics(void) { mlirContextDestroy(ctx); } +int testBlockPredecessorsSuccessors(MlirContext ctx) { + // CHECK-LABEL: @testBlockPredecessorsSuccessors + fprintf(stderr, "@testBlockPredecessorsSuccessors\n"); + + const char *moduleString = "module {\n" + " func.func @test(%arg0: i32, %arg1: i16) {\n" + " cf.br ^bb1(%arg1 : i16)\n" + " ^bb1(%0: i16): // pred: ^bb0\n" + " cf.br ^bb2(%arg0 : i32)\n" + " ^bb2(%1: i32): // pred: ^bb1\n" + " return\n" + " }\n" + "}\n"; + + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + + MlirOperation moduleOp = mlirModuleGetOperation(module); + MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0); + MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion); + MlirOperation function = mlirBlockGetFirstOperation(moduleBlock); + MlirRegion funcRegion = mlirOperationGetRegion(function, 0); + MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion); + MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock); + MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock); + + assert(mlirBlockGetNumPredecessors(entryBlock) == 0); + + assert(mlirBlockGetNumSuccessors(entryBlock) == 1); + assert(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0))); + assert(mlirBlockGetNumPredecessors(middleBlock) == 1); + assert(mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0))); + + assert(mlirBlockGetNumSuccessors(middleBlock) == 1); + assert(mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0))); + assert(mlirBlockGetNumPredecessors(successorBlock) == 1); + assert( + mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0))); + + assert(mlirBlockGetNumSuccessors(successorBlock) == 0); + + mlirModuleDestroy(module); + + return 0; +} + int main(void) { MlirContext ctx = mlirContextCreate(); registerAllUpstreamDialects(ctx); @@ -2486,6 +2532,9 @@ int main(void) { testExplicitThreadPools(); testDiagnostics(); + if (testBlockPredecessorsSuccessors(ctx)) + return 17; + // CHECK: DESTROY MAIN CONTEXT // CHECK: reportResourceDelete: resource_i64_blob fprintf(stderr, "DESTROY MAIN CONTEXT\n"); diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py index 70ccaeeb5435b..ced5fce434728 100644 --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -1,12 +1,11 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import io -import itertools -from mlir.ir import * + from mlir.dialects import builtin from mlir.dialects import cf from mlir.dialects import func +from mlir.ir import * def run(f): @@ -54,10 +53,25 @@ def testBlockCreation(): with InsertionPoint(middle_block) as middle_ip: assert middle_ip.block == middle_block cf.BranchOp([i32_arg], dest=successor_block) + module.print(enable_debug_info=True) # Ensure region back references are coherent. assert entry_block.region == middle_block.region == successor_block.region + assert len(entry_block.predecessors) == 0 + + assert len(entry_block.successors) == 1 + assert middle_block == entry_block.successors[0] + assert len(middle_block.predecessors) == 1 + assert entry_block == middle_block.predecessors[0] + + assert len(middle_block.successors) == 1 + assert successor_block == middle_block.successors[0] + assert len(successor_block.predecessors) == 1 + assert middle_block == successor_block.predecessors[0] + + assert len(successor_block.successors) == 0 + # CHECK-LABEL: TEST: testBlockCreationArgLocs @run