Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 72fc0c3

Browse files
[mlir][acc] Improve implicit deviceptr detection for alias (#195934)
The ACCImplicitData automatically is able to use deviceptr clause when variable is detected as being device data. However, it was missing check for own `acc declare deviceptr` attribute.
1 parent 8a86aab commit 72fc0c3

6 files changed

Lines changed: 148 additions & 23 deletions

File tree

flang/test/Transforms/OpenACC/acc-implicit-data.fir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,37 @@ func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32,
397397

398398
// -----
399399

400+
// Test argument mapped with deviceptr but used not via data mapping.
401+
func.func @test_fir_declare_deviceptr_arg_in_parallel(%arg0: !fir.ref<!fir.array<10xf64>>) {
402+
%c10 = arith.constant 10 : index
403+
%shape = fir.shape %c10 : (index) -> !fir.shape<1>
404+
%arr_decl = fir.declare %arg0(%shape) {acc.declare = #acc.declare<dataClause = acc_deviceptr>, uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xf64>>
405+
%arr_box = fir.embox %arr_decl(%shape) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
406+
%devptr = acc.deviceptr var(%arr_box : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
407+
%token = acc.declare_enter dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
408+
acc.parallel {
409+
%addr = fir.box_addr %arr_box : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
410+
%elem = fir.array_coor %arr_decl(%shape) %c10 : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>, index) -> !fir.ref<f64>
411+
acc.yield
412+
}
413+
acc.declare_exit token(%token) dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
414+
return
415+
}
416+
417+
// CHECK-LABEL: func.func @test_fir_declare_deviceptr_arg_in_parallel
418+
// CHECK: %[[DECL:.*]] = fir.declare %{{.*}}{{.*}}{acc.declare = #acc.declare<dataClause = acc_deviceptr>{{.*}}
419+
// CHECK: %[[BOX:.*]] = fir.embox %[[DECL]]
420+
// CHECK: %[[DEVPTR:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
421+
// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
422+
// CHECK: %[[IMPLICIT_BOX:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {implicit = true, name = "a"}
423+
// CHECK: %[[IMPLICIT_REF:.*]] = acc.deviceptr varPtr(%[[DECL]] : !fir.ref<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>> {implicit = true, name = "a"}
424+
// CHECK: acc.parallel dataOperands(%[[IMPLICIT_BOX]], %[[IMPLICIT_REF]] : !fir.box<!fir.array<10xf64>>, !fir.ref<!fir.array<10xf64>>) {
425+
// CHECK: fir.box_addr %[[IMPLICIT_BOX]] : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
426+
// CHECK: fir.array_coor %[[IMPLICIT_REF]]
427+
// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
428+
429+
// -----
430+
400431
// Test that acc.serial inside acc.data deviceptr generates implicit deviceptr
401432
// (not copyin) when the deviceptr clause variable is derived from the ref used
402433
// by the serial construct (here, via fir.embox wrapping the declared ref).

mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
7878

7979
/// Check if a value represents device data.
8080
/// This checks if the value represents device data via the
81-
/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces.
81+
/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces,
82+
/// and whether the defining operation carries `acc.declare` with the deviceptr
83+
/// clause.
8284
/// \param val The value to check
8385
/// \return true if the value is device data, false otherwise
8486
bool isDeviceValue(mlir::Value val);

mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,13 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
312312
// Only accept clauses that guarantee that the alias is present.
313313
if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
314314
acc::DevicePtrOp>(dataClauseOp))
315-
if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
315+
if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust()) {
316+
LLVM_DEBUG(llvm::dbgs()
317+
<< "Using existing data clause:\n\t" << *dataClauseOp
318+
<< "\n\tas reference when processing var:\n\t" << var
319+
<< "\n";);
316320
return dataClauseOp;
321+
}
317322
}
318323
}
319324
return nullptr;
@@ -452,6 +457,15 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
452457
typeCategory, acc::VariableTypeCategory::aggregate);
453458
Location loc = computeConstructOp->getLoc();
454459

460+
if (acc::isDeviceValue(var)) {
461+
// If the variable is device data, use deviceptr clause.
462+
LLVM_DEBUG(llvm::dbgs() << "Using deviceptr clause because variable is "
463+
"device data\n");
464+
return acc::DevicePtrOp::create(builder, loc, var,
465+
/*structured=*/true, /*implicit=*/true,
466+
accSupport.getVariableName(var));
467+
}
468+
455469
Operation *op = nullptr;
456470
op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
457471
dominatingDataClauses);
@@ -476,13 +490,6 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
476490
acc::getBounds(op));
477491
}
478492

479-
if (acc::isDeviceValue(var)) {
480-
// If the variable is device data, use deviceptr clause.
481-
return acc::DevicePtrOp::create(builder, loc, var,
482-
/*structured=*/true, /*implicit=*/true,
483-
accSupport.getVariableName(var));
484-
}
485-
486493
if (isScalar) {
487494
if (enableImplicitReductionCopy &&
488495
acc::isOnlyUsedByReductionClauses(var,

mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -245,23 +245,33 @@ bool mlir::acc::isDeviceValue(mlir::Value val) {
245245
if (pointerLikeTy.isDeviceData(val))
246246
return true;
247247

248+
mlir::Operation *defOp = val.getDefiningOp();
249+
if (!defOp)
250+
return false;
251+
252+
// `acc.declare` with deviceptr marks data that is already associated with
253+
// the device.
254+
if (auto declareAttr = defOp->getAttrOfType<mlir::acc::DeclareAttr>(
255+
mlir::acc::getDeclareAttrName()))
256+
if (declareAttr.getDataClause().getValue() ==
257+
mlir::acc::DataClause::acc_deviceptr)
258+
return true;
259+
248260
// Handle operations that access a partial entity - check if the base entity
249261
// is device data.
250-
if (auto *defOp = val.getDefiningOp()) {
251-
if (auto partialAccess =
252-
dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
253-
if (mlir::Value base = partialAccess.getBaseEntity())
254-
return isDeviceValue(base);
255-
}
262+
if (auto partialAccess =
263+
dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
264+
if (mlir::Value base = partialAccess.getBaseEntity())
265+
return isDeviceValue(base);
266+
}
256267

257-
// Handle address_of - check if the referenced global is device data.
258-
if (auto addrOfIface =
259-
dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
260-
auto symbol = addrOfIface.getSymbol();
261-
if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
262-
mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
263-
return global.isDeviceData();
264-
}
268+
// Handle address_of - check if the referenced global is device data.
269+
if (auto addrOfIface =
270+
dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
271+
auto symbol = addrOfIface.getSymbol();
272+
if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
273+
mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
274+
return global.isDeviceData();
265275
}
266276

267277
return false;

mlir/test/Dialect/OpenACC/acc-implicit-data.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,31 @@ func.func @test_device_global_in_parallel() {
259259
// CHECK: acc.deviceptr varPtr({{.*}} : memref<10xf32, #gpu.address_space<global>>) -> memref<10xf32, #gpu.address_space<global>> {implicit = true, name = ""}
260260
// CHECK-NOT: acc.copyin
261261
// CHECK-NOT: acc.copyout
262+
263+
// -----
264+
265+
// Test memref.view tagged with acc.declare deviceptr and used directly in region.
266+
func.func @test_declare_deviceptr_arg_in_parallel(%arg0: memref<?xi8>) {
267+
%c0 = arith.constant 0 : index
268+
%view = memref.view %arg0[%c0][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
269+
%devptr = acc.deviceptr varPtr(%view : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
270+
%token = acc.declare_enter dataOperands(%devptr : memref<10xf32>)
271+
acc.parallel {
272+
%c0_1 = arith.constant 0 : index
273+
%load = memref.load %arg0[%c0_1] : memref<?xi8>
274+
acc.yield
275+
}
276+
acc.declare_exit token(%token) dataOperands(%devptr : memref<10xf32>)
277+
return
278+
}
279+
280+
// CHECK-LABEL: func.func @test_declare_deviceptr_arg_in_parallel
281+
// CHECK: %[[VIEW:.*]] = memref.view %{{.*}}[{{.*}}][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
282+
// CHECK: %[[DEVPTR:.*]] = acc.deviceptr varPtr(%[[VIEW]] : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
283+
// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : memref<10xf32>)
284+
// CHECK: %[[IMPLICIT_DEVPTR:.*]] = acc.deviceptr varPtr(%{{.*}} : memref<?xi8>) -> memref<?xi8> {implicit = true, name = ""}
285+
// CHECK: acc.parallel dataOperands(%[[IMPLICIT_DEVPTR]] : memref<?xi8>) {
286+
// CHECK: memref.load %[[IMPLICIT_DEVPTR]][{{.*}}] : memref<?xi8>
287+
// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : memref<10xf32>)
288+
// CHECK-NOT: acc.copyin
289+
// CHECK-NOT: acc.copyout

mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,33 @@ TEST_F(OpenACCUtilsTest, getDominatingDataClausesEmpty) {
13981398
// isDeviceValue Tests
13991399
//===----------------------------------------------------------------------===//
14001400

1401+
namespace {
1402+
static Value memrefViewFromBlockArgWithDeclare(OpBuilder &builder, Location loc,
1403+
MLIRContext *ctx,
1404+
DataClause clause,
1405+
ModuleOp module,
1406+
StringRef funcName) {
1407+
OpBuilder::InsertionGuard guard(builder);
1408+
builder.setInsertionPointToStart(module.getBody());
1409+
1410+
auto i8BufTy = MemRefType::get({40}, builder.getI8Type());
1411+
auto viewTy = MemRefType::get({10}, builder.getI32Type());
1412+
auto funcType = builder.getFunctionType({i8BufTy}, {});
1413+
func::FuncOp funcOp = func::FuncOp::create(builder, loc, funcName, funcType);
1414+
Block *entry = funcOp.addEntryBlock();
1415+
1416+
builder.setInsertionPointToStart(entry);
1417+
Value buf = entry->getArgument(0);
1418+
Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
1419+
memref::ViewOp viewOp =
1420+
memref::ViewOp::create(builder, loc, viewTy, buf, c0, ValueRange{});
1421+
viewOp->setAttr(getDeclareAttrName(),
1422+
DeclareAttr::get(ctx, DataClauseAttr::get(ctx, clause)));
1423+
func::ReturnOp::create(builder, loc);
1424+
return viewOp.getResult();
1425+
}
1426+
} // namespace
1427+
14011428
TEST_F(OpenACCUtilsTest, isDeviceValueMemrefGlobalAddressSpace) {
14021429
// Test that a memref with GPU global address space is considered device data
14031430
auto gpuAddressSpace =
@@ -1525,6 +1552,26 @@ TEST_F(OpenACCUtilsTest, isDeviceValueGlobalWithoutGPUAddressSpace) {
15251552
EXPECT_FALSE(isDeviceValue(val));
15261553
}
15271554

1555+
TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareDeviceptr) {
1556+
OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
1557+
OpBuilder::InsertionGuard guard(b);
1558+
b.setInsertionPointToStart(module->getBody());
1559+
Value val = memrefViewFromBlockArgWithDeclare(
1560+
b, loc, &context, DataClause::acc_deviceptr, module.get(),
1561+
"test_memref_view_declare_devptr");
1562+
EXPECT_TRUE(isDeviceValue(val));
1563+
}
1564+
1565+
TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareNonDeviceptr) {
1566+
OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
1567+
OpBuilder::InsertionGuard guard(b);
1568+
b.setInsertionPointToStart(module->getBody());
1569+
Value val = memrefViewFromBlockArgWithDeclare(
1570+
b, loc, &context, DataClause::acc_copyin, module.get(),
1571+
"test_memref_view_declare_copyin");
1572+
EXPECT_FALSE(isDeviceValue(val));
1573+
}
1574+
15281575
//===----------------------------------------------------------------------===//
15291576
// isValidValueUse Tests
15301577
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)