diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h index e8e6226460ac7..b63c0883c6c15 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension { /// analyzed. DenseMap analyzedFuncOps; + /// A collection of cached SymbolTables used for faster function lookup. + mutable mlir::SymbolTableCollection symbolTable; + /// This function is called right before analyzing the given FuncOp. It /// initializes the data structures for the FuncOp in this state object. void startFunctionAnalysis(FuncOp funcOp); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..1edf5c9190a3d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -76,23 +76,34 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, } /// Return the FuncOp called by `callOp`. -static FuncOp getCalledFunction(CallOpInterface callOp) { +static FuncOp getCalledFunction(CallOpInterface callOp, + mlir::SymbolTableCollection &symbolTable) { SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + symbolTable.lookupNearestSymbolFrom(callOp, sym)); } -/// Get FuncAnalysisState. +/// Get or create FuncAnalysisState. static const FuncAnalysisState & -getFuncAnalysisState(const AnalysisState &state) { +getOrCreateFuncAnalysisState(const AnalysisState &state) { assert(isa(state) && "expected OneShotAnalysisState"); - auto *result = static_cast(state) - .getExtension(); - assert(result && "FuncAnalysisState does not exist"); - return *result; + + // Unfortunately, at the moment the BufferizableOpInterface methods do provide + // a const reference to the AnalysisState class, and the only way to + // dynamically add an extension is to const_cast it to a non-const reference. + // Should the const qualifier be dropped from the interface? + auto &oneShotAnalysisState = + static_cast(const_cast(state)); + + auto *result = oneShotAnalysisState.getExtension(); + + if (result) + return *result; + + return oneShotAnalysisState.addExtension(); } /// Return the state (phase) of analysis of the FuncOp. @@ -135,14 +146,14 @@ struct CallOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); + FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is read. return true; - const FuncAnalysisState &funcState = getFuncAnalysisState(state); return funcState.readBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -150,14 +161,14 @@ struct CallOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); + FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is written. return true; - const FuncAnalysisState &funcState = getFuncAnalysisState(state); return funcState.writtenBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -165,14 +176,14 @@ struct CallOpInterface AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); + FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Any OpResult may be aliasing. return detail::unknownGetAliasingValues(opOperand); // Get aliasing results from state. - const FuncAnalysisState &funcState = getFuncAnalysisState(state); auto aliasingReturnVals = funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); @@ -199,7 +210,11 @@ struct CallOpInterface getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + + // TODO Avoid recomputing the symbol tables every time. + mlir::SymbolTableCollection symbolTable; + + FuncOp funcOp = getCalledFunction(callOp, symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); // If the callee was already bufferized, we can directly take the type from @@ -243,7 +258,11 @@ struct CallOpInterface // 2. Rewrite tensor operands as memrefs based on type of the already // bufferized callee. SmallVector newOperands; - FuncOp funcOp = getCalledFunction(callOp); + + // TODO Avoid recomputing the symbol tables every time. + mlir::SymbolTableCollection symbolTable; + + FuncOp funcOp = getCalledFunction(callOp, symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); FunctionType funcType = funcOp.getFunctionType(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index edd6bcf84f460..a025da8635135 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) { } /// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(func::CallOp callOp) { +static func::FuncOp +getCalledFunction(func::CallOp callOp, + mlir::SymbolTableCollection &symbolTable) { SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + symbolTable.lookupNearestSymbolFrom(callOp, sym)); } /// Return "true" if the given function signature has tensor semantics. @@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; + + // TODO Avoid recomputing the symbol tables every time. + mlir::SymbolTableCollection symbolTable; + for (func::FuncOp funcOp : moduleOp.getOps()) { // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp); + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable); assert(calledFunction && "could not retrieved called func::FuncOp"); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller.