diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 88b1e44d15af0..35ddb906c366a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -55,7 +55,6 @@ static unsigned typeToAddressSpace(const Type *Ty) { reportFatalInternalError("Unable to convert LLVM type to SPIRVType"); } -#ifndef NDEBUG static bool storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) { switch (SC) { @@ -87,7 +86,6 @@ storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) { } llvm_unreachable("Unknown SPIRV::StorageClass enum"); } -#endif SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) : PointerSize(PointerSize), Bound(0) {} @@ -837,13 +835,31 @@ static std::string buildSpirvTypeName(const SPIRVType *Type, } case SPIRV::OpTypeStruct: { std::string TypeName = "{"; - for (uint32_t I = 2; I < Type->getNumOperands(); ++I) { + for (uint32_t I = 1; I < Type->getNumOperands(); ++I) { SPIRVType *MemberType = GR.getSPIRVTypeForVReg(Type->getOperand(I).getReg()); - TypeName = '_' + buildSpirvTypeName(MemberType, MIRBuilder, GR); + TypeName += '_' + buildSpirvTypeName(MemberType, MIRBuilder, GR); } return TypeName + "}"; } + case SPIRV::OpTypeVector: { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register ElementTypeReg = Type->getOperand(1).getReg(); + auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg); + uint32_t VectorSize = GR.getScalarOrVectorComponentCount(Type); + return (buildSpirvTypeName(ElementType, MIRBuilder, GR) + Twine("[") + + Twine(VectorSize) + Twine("]")) + .str(); + } + case SPIRV::OpTypeRuntimeArray: { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register ElementTypeReg = Type->getOperand(1).getReg(); + auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg); + uint32_t ArraySize = 0; + return (buildSpirvTypeName(ElementType, MIRBuilder, GR) + Twine("[") + + Twine(ArraySize) + Twine("]")) + .str(); + } default: llvm_unreachable("Trying to the the name of an unknown type."); } @@ -885,30 +901,41 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding( return VarReg; } +// TODO: Double check the calls to getOpTypeArray to make sure that `ElemType` +// is explicitly laid out when required. SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, MachineIRBuilder &MIRBuilder, + bool ExplicitLayoutRequired, bool EmitIR) { assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && "Invalid array element type"); SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder); - + SPIRVType *ArrayType = nullptr; if (NumElems != 0) { Register NumElementsVReg = buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR); - return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { + ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { return MIRBuilder.buildInstr(SPIRV::OpTypeArray) .addDef(createTypeVReg(MIRBuilder)) .addUse(getSPIRVTypeID(ElemType)) .addUse(NumElementsVReg); }); + } else { + ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)); + }); } - return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { - return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray) - .addDef(createTypeVReg(MIRBuilder)) - .addUse(getSPIRVTypeID(ElemType)); - }); + if (ExplicitLayoutRequired && !isResourceType(ElemType)) { + Type *ET = const_cast(getTypeForSPIRVType(ElemType)); + addArrayStrideDecorations(ArrayType->defs().begin()->getReg(), ET, + MIRBuilder); + } + + return ArrayType; } SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, @@ -926,7 +953,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct( const StructType *Ty, MachineIRBuilder &MIRBuilder, - SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { + SPIRV::AccessQualifier::AccessQualifier AccQual, + bool ExplicitLayoutRequired, bool EmitIR) { SmallVector FieldTypes; constexpr unsigned MaxWordCount = UINT16_MAX; const size_t NumElements = Ty->getNumElements(); @@ -940,8 +968,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct( } for (const auto &Elem : Ty->elements()) { - SPIRVType *ElemTy = - findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR); + SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, + ExplicitLayoutRequired, EmitIR); assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && "Invalid struct element type"); FieldTypes.push_back(getSPIRVTypeID(ElemTy)); @@ -952,18 +980,27 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct( if (Ty->isPacked()) buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); - return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { - auto MIBStruct = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); - for (size_t I = 0; I < SPIRVStructNumElements; ++I) - MIBStruct.addUse(FieldTypes[I]); - for (size_t I = SPIRVStructNumElements; I < NumElements; - I += MaxNumElements) { - auto MIBCont = MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL); - for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J) - MIBCont.addUse(FieldTypes[I]); - } - return MIBStruct; - }); + SPIRVType *SPVType = + createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { + auto MIBStruct = + MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); + for (size_t I = 0; I < SPIRVStructNumElements; ++I) + MIBStruct.addUse(FieldTypes[I]); + for (size_t I = SPIRVStructNumElements; I < NumElements; + I += MaxNumElements) { + auto MIBCont = + MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL); + for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J) + MIBCont.addUse(FieldTypes[I]); + } + return MIBStruct; + }); + + if (ExplicitLayoutRequired) + addStructOffsetDecorations(SPVType->defs().begin()->getReg(), + const_cast(Ty), MIRBuilder); + + return SPVType; } SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( @@ -1013,22 +1050,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( const Type *Ty, SPIRVType *RetType, const SmallVectorImpl &ArgTypes, MachineIRBuilder &MIRBuilder) { - if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF())) + if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF())) return MI; const MachineInstr *NewMI = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); - add(Ty, NewMI); + add(Ty, false, NewMI); return finishCreatingSPIRVType(Ty, NewMI); } SPIRVType *SPIRVGlobalRegistry::findSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, - SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { + SPIRV::AccessQualifier::AccessQualifier AccQual, + bool ExplicitLayoutRequired, bool EmitIR) { Ty = adjustIntTypeByWidth(Ty); - if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF())) + // TODO: findMI needs to know if a layout is required. + if (const MachineInstr *MI = + findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF())) return MI; if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end()) return It->second; - return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); + return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, ExplicitLayoutRequired, + EmitIR); } Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { @@ -1062,11 +1103,13 @@ const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const { SPIRVType *SPIRVGlobalRegistry::createSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, - SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { + SPIRV::AccessQualifier::AccessQualifier AccQual, + bool ExplicitLayoutRequired, bool EmitIR) { if (isSpecialOpaqueType(Ty)) return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); - if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF())) + if (const MachineInstr *MI = + findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF())) return MI; if (auto IType = dyn_cast(Ty)) { @@ -1079,27 +1122,31 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( if (Ty->isVoidTy()) return getOpTypeVoid(MIRBuilder); if (Ty->isVectorTy()) { - SPIRVType *El = findSPIRVType(cast(Ty)->getElementType(), - MIRBuilder, AccQual, EmitIR); + SPIRVType *El = + findSPIRVType(cast(Ty)->getElementType(), MIRBuilder, + AccQual, ExplicitLayoutRequired, EmitIR); return getOpTypeVector(cast(Ty)->getNumElements(), El, MIRBuilder); } if (Ty->isArrayTy()) { - SPIRVType *El = - findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR); - return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder, + AccQual, ExplicitLayoutRequired, EmitIR); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, + ExplicitLayoutRequired, EmitIR); } if (auto SType = dyn_cast(Ty)) { if (SType->isOpaque()) return getOpTypeOpaque(SType, MIRBuilder); - return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR); + return getOpTypeStruct(SType, MIRBuilder, AccQual, ExplicitLayoutRequired, + EmitIR); } if (auto FType = dyn_cast(Ty)) { - SPIRVType *RetTy = - findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR); + SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder, + AccQual, ExplicitLayoutRequired, EmitIR); SmallVector ParamTypes; for (const auto &ParamTy : FType->params()) - ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR)); + ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, + ExplicitLayoutRequired, EmitIR)); return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); } @@ -1114,44 +1161,50 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( const SPIRVSubtarget *ST = static_cast(&MIRBuilder.getMF().getSubtarget()); auto SC = addressSpaceToStorageClass(AddrSpace, *ST); - // Null pointer means we have a loop in type definitions, make and - // return corresponding OpTypeForwardPointer. - if (SpvElementType == nullptr) { - auto [It, Inserted] = ForwardPointerTypes.try_emplace(Ty); - if (Inserted) - It->second = getOpTypeForwardPointer(SC, MIRBuilder); - return It->second; + + Type *ElemTy = ::getPointeeType(Ty); + if (!ElemTy) { + ElemTy = Type::getInt8Ty(MIRBuilder.getContext()); } + // If we have forward pointer associated with this type, use its register // operand to create OpTypePointer. if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end()) { Register Reg = getSPIRVTypeID(It->second); + // TODO: what does getOpTypePointer do? return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); } - return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC); + return getOrCreateSPIRVPointerType(ElemTy, MIRBuilder, SC); } SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, - SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { + SPIRV::AccessQualifier::AccessQualifier AccessQual, + bool ExplicitLayoutRequired, bool EmitIR) { + // TODO: Could this create a problem if one requires an explicit layout, and + // the next time it does not? if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty)) return nullptr; TypesInProcessing.insert(Ty); - SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); + SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, + ExplicitLayoutRequired, EmitIR); TypesInProcessing.erase(Ty); VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + + // TODO: We could end up with two SPIR-V types pointing to the same llvm type. + // Is that a problem? SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty); if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer || - findMI(Ty, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty)) + findMI(Ty, false, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty)) return SpirvType; if (auto *ExtTy = dyn_cast(Ty); ExtTy && isTypedPointerWrapper(ExtTy)) add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), SpirvType); else if (!isPointerTy(Ty)) - add(Ty, SpirvType); + add(Ty, ExplicitLayoutRequired, SpirvType); else if (isTypedPointerTy(Ty)) add(cast(Ty)->getElementType(), getPointerAddressSpace(Ty), SpirvType); @@ -1183,14 +1236,15 @@ SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg, SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, - SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { + SPIRV::AccessQualifier::AccessQualifier AccessQual, + bool ExplicitLayoutRequired, bool EmitIR) { const MachineFunction *MF = &MIRBuilder.getMF(); Register Reg; if (auto *ExtTy = dyn_cast(Ty); ExtTy && isTypedPointerWrapper(ExtTy)) Reg = find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), MF); else if (!isPointerTy(Ty)) - Reg = find(Ty = adjustIntTypeByWidth(Ty), MF); + Reg = find(Ty = adjustIntTypeByWidth(Ty), ExplicitLayoutRequired, MF); else if (isTypedPointerTy(Ty)) Reg = find(cast(Ty)->getElementType(), getPointerAddressSpace(Ty), MF); @@ -1201,15 +1255,20 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( return getSPIRVTypeForVReg(Reg); TypesInProcessing.clear(); - SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); + SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, + ExplicitLayoutRequired, EmitIR); // Create normal pointer types for the corresponding OpTypeForwardPointers. for (auto &CU : ForwardPointerTypes) { + // Pointer type themselves do not require an explicit layout. The types + // they pointer to might, but that is taken care of when creating the type. + bool PtrNeedsLayout = false; const Type *Ty2 = CU.first; SPIRVType *STy2 = CU.second; - if ((Reg = find(Ty2, MF)).isValid()) + if ((Reg = find(Ty2, PtrNeedsLayout, MF)).isValid()) STy2 = getSPIRVTypeForVReg(Reg); else - STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); + STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, PtrNeedsLayout, + EmitIR); if (Ty == Ty2) STy = STy2; } @@ -1238,6 +1297,19 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, return false; } +bool SPIRVGlobalRegistry::isResourceType(SPIRVType *Type) const { + switch (Type->getOpcode()) { + case SPIRV::OpTypeImage: + case SPIRV::OpTypeSampler: + case SPIRV::OpTypeSampledImage: + return true; + case SPIRV::OpTypeStruct: + return hasBlockDecoration(Type); + default: + return false; + } + return false; +} unsigned SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const { return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg)); @@ -1362,16 +1434,16 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType( if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF())) return MI; - // TODO(134119): The SPIRVType for `ElemType` will not have an explicit - // layout. This generates invalid SPIR-V. + bool ExplicitLayoutRequired = storageClassRequiresExplictLayout(SC); + // We need to get the SPIR-V type for the element here, so we can add the + // decoration to it. auto *T = StructType::create(ElemType); auto *BlockType = - getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None, EmitIr); + getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None, + ExplicitLayoutRequired, EmitIr); buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder, SPIRV::Decoration::Block, {}); - buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder, - SPIRV::Decoration::Offset, 0, {0}); if (!IsWritable) { buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder, @@ -1480,7 +1552,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType, const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns, uint32_t Use, bool EmitIR) { - if (const MachineInstr *MI = findMI(ExtensionType, &MIRBuilder.getMF())) + if (const MachineInstr *MI = + findMI(ExtensionType, false, &MIRBuilder.getMF())) return MI; const MachineInstr *NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { @@ -1493,26 +1566,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, EmitIR)) .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, EmitIR)); }); - add(ExtensionType, NewMI); + add(ExtensionType, false, NewMI); return NewMI; } SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { - if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF())) + if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF())) return MI; const MachineInstr *NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { return MIRBuilder.buildInstr(Opcode).addDef(createTypeVReg(MIRBuilder)); }); - add(Ty, NewMI); + add(Ty, false, NewMI); return NewMI; } SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType( const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode, const ArrayRef Operands) { - if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF())) + if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF())) return MI; Register ResVReg = createTypeVReg(MIRBuilder); const MachineInstr *NewMI = @@ -1529,7 +1602,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType( } return MIB; }); - add(Ty, NewMI); + add(Ty, false, NewMI); return NewMI; } @@ -1545,7 +1618,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( if (hasBuiltinTypePrefix(TypeStr)) return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType( TypeStr.str(), MIRBuilder.getContext()), - MIRBuilder, AQ, true); + MIRBuilder, AQ, false, true); // Parse type name in either "typeN" or "type vector[N]" format, where // N is the number of elements of the vector. @@ -1556,7 +1629,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( // Unable to recognize SPIRV type name return nullptr; - auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ, true); + auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ, false, true); // Handle "type*" or "type* vector[N]". if (TypeStr.starts_with("*")) { @@ -1585,7 +1658,7 @@ SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder) { return getOrCreateSPIRVType( IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), - MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true); + MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, true); } SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, @@ -1601,7 +1674,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth, const SPIRVInstrInfo &TII, unsigned SPIRVOPcode, Type *Ty) { - if (const MachineInstr *MI = findMI(Ty, CurMF)) + if (const MachineInstr *MI = findMI(Ty, false, CurMF)) return MI; MachineBasicBlock &DepMBB = I.getMF()->front(); MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI()); @@ -1613,7 +1686,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth, .addImm(BitWidth) .addImm(0); }); - add(Ty, NewMI); + add(Ty, false, NewMI); return finishCreatingSPIRVType(Ty, NewMI); } @@ -1654,14 +1727,14 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder, bool EmitIR) { return getOrCreateSPIRVType( IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), - MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR); } SPIRVType * SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, const SPIRVInstrInfo &TII) { Type *Ty = IntegerType::get(CurMF->getFunction().getContext(), 1); - if (const MachineInstr *MI = findMI(Ty, CurMF)) + if (const MachineInstr *MI = findMI(Ty, false, CurMF)) return MI; MachineBasicBlock &DepMBB = I.getMF()->front(); MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI()); @@ -1671,7 +1744,7 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, MIRBuilder.getDL(), TII.get(SPIRV::OpTypeBool)) .addDef(createTypeVReg(CurMF->getRegInfo())); }); - add(Ty, NewMI); + add(Ty, false, NewMI); return finishCreatingSPIRVType(Ty, NewMI); } @@ -1681,7 +1754,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( return getOrCreateSPIRVType( FixedVectorType::get(const_cast(getTypeForSPIRVType(BaseType)), NumElements), - MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR); } SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( @@ -1689,7 +1762,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( const SPIRVInstrInfo &TII) { Type *Ty = FixedVectorType::get( const_cast(getTypeForSPIRVType(BaseType)), NumElements); - if (const MachineInstr *MI = findMI(Ty, CurMF)) + if (const MachineInstr *MI = findMI(Ty, false, CurMF)) return MI; MachineInstr *DepMI = const_cast(BaseType); MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator()); @@ -1701,30 +1774,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( .addUse(getSPIRVTypeID(BaseType)) .addImm(NumElements); }); - add(Ty, NewMI); - return finishCreatingSPIRVType(Ty, NewMI); -} - -SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( - SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, - const SPIRVInstrInfo &TII) { - Type *Ty = ArrayType::get(const_cast(getTypeForSPIRVType(BaseType)), - NumElements); - if (const MachineInstr *MI = findMI(Ty, CurMF)) - return MI; - SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII); - Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII); - MachineBasicBlock &DepMBB = I.getMF()->front(); - MachineIRBuilder MIRBuilder(DepMBB, getInsertPtValidEnd(&DepMBB)); - const MachineInstr *NewMI = - createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { - return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(), - MIRBuilder.getDL(), TII.get(SPIRV::OpTypeArray)) - .addDef(createTypeVReg(CurMF->getRegInfo())) - .addUse(getSPIRVTypeID(BaseType)) - .addUse(Len); - }); - add(Ty, NewMI); + add(Ty, false, NewMI); return finishCreatingSPIRVType(Ty, NewMI); } @@ -1738,8 +1788,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC) { + // TODO: Need to check if EmitIr should always be true. SPIRVType *SpirvBaseType = getOrCreateSPIRVType( - BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true); + BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, + storageClassRequiresExplictLayout(SC), true); + assert(SpirvBaseType); return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC); } @@ -2006,3 +2059,33 @@ void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg, addDeducedElementType(AssignCI, ElemTy); addDeducedElementType(Arg, ElemTy); } + +void SPIRVGlobalRegistry::addStructOffsetDecorations( + Register Reg, StructType *Ty, MachineIRBuilder &MIRBuilder) { + ArrayRef Offsets = + DataLayout().getStructLayout(Ty)->getMemberOffsets(); + for (uint32_t I = 0; I < Ty->getNumElements(); ++I) { + buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I, + {static_cast(Offsets[I])}); + } +} + +void SPIRVGlobalRegistry::addArrayStrideDecorations( + Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) { + uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8; + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride, + {SizeInBytes}); +} + +bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const { + Register Def = getSPIRVTypeID(Type); + for (const MachineInstr &Use : + Type->getMF()->getRegInfo().use_instructions(Def)) { + if (Use.getOpcode() != SPIRV::OpDecorate) + continue; + + if (Use.getOperand(1).getImm() == SPIRV::Decoration::Block) + return true; + } + return false; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index b05896fb7174c..7338e805956d6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -90,14 +90,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { // Add a new OpTypeXXX instruction without checking for duplicates. SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ, - bool EmitIR); + bool ExplicitLayoutRequired, bool EmitIR); SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier accessQual, - bool EmitIR); + bool ExplicitLayoutRequired, bool EmitIR); SPIRVType * restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual, - bool EmitIR); + bool ExplicitLayoutRequired, bool EmitIR); // Internal function creating the an OpType at the correct position in the // function by tweaking the passed "MIRBuilder" insertion point and restoring @@ -298,10 +298,19 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes) // because this method may be called from InstructionSelector and we don't // want to emit extra IR instructions there. + SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineInstr &I, + SPIRV::AccessQualifier::AccessQualifier AQ, + bool EmitIR) { + MachineIRBuilder MIRBuilder(I); + return getOrCreateSPIRVType(Type, MIRBuilder, AQ, EmitIR); + } + SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ, - bool EmitIR); + bool EmitIR) { + return getOrCreateSPIRVType(Type, MIRBuilder, AQ, false, EmitIR); + } const Type *getTypeForSPIRVType(const SPIRVType *Ty) const { auto Res = SPIRVToLLVMType.find(Ty); @@ -364,6 +373,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const; + // Returns true if `Type` is a resource type. This could be an image type + // or a struct for a buffer decorated with the block decoration. + bool isResourceType(SPIRVType *Type) const; + // Return number of elements in a vector if the argument is associated with // a vector type. Return 1 for a scalar type, and 0 for a missing type. unsigned getScalarOrVectorComponentCount(Register VReg) const; @@ -414,6 +427,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { const Type *adjustIntTypeByWidth(const Type *Ty) const; unsigned adjustOpTypeIntWidth(unsigned Width) const; + SPIRVType *getOrCreateSPIRVType(const Type *Type, + MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier::AccessQualifier AQ, + bool ExplicitLayoutRequired, bool EmitIR); + SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder, bool IsSigned = false); @@ -425,14 +443,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { MachineIRBuilder &MIRBuilder); SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, - MachineIRBuilder &MIRBuilder, bool EmitIR); + MachineIRBuilder &MIRBuilder, + bool ExplicitLayoutRequired, bool EmitIR); SPIRVType *getOpTypeOpaque(const StructType *Ty, MachineIRBuilder &MIRBuilder); SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccQual, - bool EmitIR); + bool ExplicitLayoutRequired, bool EmitIR); SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, MachineIRBuilder &MIRBuilder, @@ -475,6 +494,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC); + void addStructOffsetDecorations(Register Reg, StructType *Ty, + MachineIRBuilder &MIRBuilder); + void addArrayStrideDecorations(Register Reg, Type *ElementType, + MachineIRBuilder &MIRBuilder); + bool hasBlockDecoration(SPIRVType *Type) const; + public: Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, @@ -545,9 +570,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, const SPIRVInstrInfo &TII); - SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType, - unsigned NumElements, MachineInstr &I, - const SPIRVInstrInfo &TII); // Returns a pointer to a SPIR-V pointer type with the given base type and // storage class. The base type will be translated to a SPIR-V type, and the diff --git a/llvm/lib/Target/SPIRV/SPIRVIRMapping.h b/llvm/lib/Target/SPIRV/SPIRVIRMapping.h index 9c9c099bc5fc4..a329fd5ed9d29 100644 --- a/llvm/lib/Target/SPIRV/SPIRVIRMapping.h +++ b/llvm/lib/Target/SPIRV/SPIRVIRMapping.h @@ -66,6 +66,7 @@ enum SpecialTypeKind { STK_Value, STK_MachineInstr, STK_VkBuffer, + STK_ExplictLayoutType, STK_Last = -1 }; @@ -150,6 +151,11 @@ inline IRHandle irhandle_vkbuffer(const Type *ElementType, SpecialTypeKind::STK_VkBuffer); } +inline IRHandle irhandle_explict_layout_type(const Type *Ty) { + const Type *WrpTy = unifyPtrType(Ty); + return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType); +} + inline IRHandle handle(const Type *Ty) { const Type *WrpTy = unifyPtrType(Ty); return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type); @@ -163,6 +169,10 @@ inline IRHandle handle(const MachineInstr *KeyMI) { return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr); } +inline bool type_has_layout_decoration(const Type *T) { + return (isa(T) || isa(T)); +} + } // namespace SPIRV // Bi-directional mappings between LLVM entities and (v-reg, machine function) @@ -238,14 +248,49 @@ class SPIRVIRMapping { return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF); } - template bool add(const T *Obj, const MachineInstr *MI) { + bool add(const Value *V, const MachineInstr *MI) { + return add(SPIRV::handle(V), MI); + } + + bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) { + if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) { + return add(SPIRV::irhandle_explict_layout_type(T), MI); + } + return add(SPIRV::handle(T), MI); + } + + bool add(const MachineInstr *Obj, const MachineInstr *MI) { return add(SPIRV::handle(Obj), MI); } - template Register find(const T *Obj, const MachineFunction *MF) { - return find(SPIRV::handle(Obj), MF); + + Register find(const Value *V, const MachineFunction *MF) { + return find(SPIRV::handle(V), MF); + } + + Register find(const Type *T, bool RequiresExplicitLayout, + const MachineFunction *MF) { + if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) + return find(SPIRV::irhandle_explict_layout_type(T), MF); + return find(SPIRV::handle(T), MF); + } + + Register find(const MachineInstr *MI, const MachineFunction *MF) { + return find(SPIRV::handle(MI), MF); + } + + const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) { + return findMI(SPIRV::handle(Obj), MF); + } + + const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout, + const MachineFunction *MF) { + if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) + return findMI(SPIRV::irhandle_explict_layout_type(T), MF); + return findMI(SPIRV::handle(T), MF); } - template - const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) { + + const MachineInstr *findMI(const MachineInstr *Obj, + const MachineFunction *MF) { return findMI(SPIRV::handle(Obj), MF); } }; diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 216c3e26be1bf..8a873426e78d8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -25,6 +25,42 @@ using namespace llvm; +// Returns true of the types logically match, as defined in +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical. +static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, + SPIRVGlobalRegistry &GR) { + if (Ty1->getOpcode() != Ty2->getOpcode()) + return false; + + if (Ty1->getNumOperands() != Ty2->getNumOperands()) + return false; + + if (Ty1->getOpcode() == SPIRV::OpTypeArray) { + // Array must have the same size. + if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg()) + return false; + + SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg()); + SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg()); + return ElemType1 == ElemType2 || + typesLogicallyMatch(ElemType1, ElemType2, GR); + } + + if (Ty1->getOpcode() == SPIRV::OpTypeStruct) { + for (unsigned I = 1; I < Ty1->getNumOperands(); I++) { + SPIRVType *ElemType1 = + GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg()); + SPIRVType *ElemType2 = + GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg()); + if (ElemType1 != ElemType2 && + !typesLogicallyMatch(ElemType1, ElemType2, GR)) + return false; + } + return true; + } + return false; +} + unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( LLVMContext &Context, CallingConv::ID CC, EVT VT) const { // This code avoids CallLowering fail inside getVectorTypeBreakdown @@ -374,6 +410,9 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { // implies that %Op is a pointer to case SPIRV::OpLoad: // OpLoad , ptr %Op implies that %Op is a pointer to + if (enforcePtrTypeCompatibility(MI, 2, 0)) + break; + validatePtrTypes(STI, MRI, GR, MI, 2, GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg())); break; @@ -531,3 +570,58 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { ProcessedMF.insert(&MF); TargetLowering::finalizeLowering(MF); } + +// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of +// PtrOpIdx matches the type for operand OpIdx. Returns true if they already +// match or if the instruction was modified to make them match. +bool SPIRVTargetLowering::enforcePtrTypeCompatibility( + MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const { + SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); + SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg()); + SPIRVType *PointeeType = GR.getPointeeType(PtrType); + SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg()); + + if (PointeeType == OpType) + return true; + + if (typesLogicallyMatch(PointeeType, OpType, GR)) { + // Apply OpCopyLogical to OpIdx. + if (I.getOperand(OpIdx).isDef() && + insertLogicalCopyOnResult(I, PointeeType)) { + return true; + } + + llvm_unreachable("Unable to add OpCopyLogical yet."); + return false; + } + + return false; +} + +bool SPIRVTargetLowering::insertLogicalCopyOnResult( + MachineInstr &I, SPIRVType *NewResultType) const { + MachineRegisterInfo *MRI = &I.getMF()->getRegInfo(); + SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); + + Register NewResultReg = + createVirtualRegister(NewResultType, &GR, MRI, *I.getMF()); + Register NewTypeReg = GR.getSPIRVTypeID(NewResultType); + + assert(std::distance(I.defs().begin(), I.defs().end()) == 1 && + "Expected only one def"); + MachineOperand &OldResult = *I.defs().begin(); + Register OldResultReg = OldResult.getReg(); + MachineOperand &OldType = *I.uses().begin(); + Register OldTypeReg = OldType.getReg(); + + OldResult.setReg(NewResultReg); + OldType.setReg(NewTypeReg); + + MachineIRBuilder MIB(*I.getNextNode()); + return MIB.buildInstr(SPIRV::OpCopyLogical) + .addDef(OldResultReg) + .addUse(OldTypeReg) + .addUse(NewResultReg) + .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), + *STI.getRegBankInfo()); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h index eb78299b72f04..9025e6eb0842e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -71,6 +71,11 @@ class SPIRVTargetLowering : public TargetLowering { EVT ConditionVT) const override { return ConditionVT.getSimpleVT(); } + + bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, + unsigned OpIdx) const; + bool insertLogicalCopyOnResult(MachineInstr &I, + SPIRVType *NewResultType) const; }; } // namespace llvm diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll index fc8faa7300534..f539fdefa3fa2 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll @@ -11,17 +11,18 @@ declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handle ; CHECK: OpDecorate [[BufferVar:%.+]] DescriptorSet 0 ; CHECK: OpDecorate [[BufferVar]] Binding 0 -; CHECK: OpDecorate [[BufferType:%.+]] Block -; CHECK: OpMemberDecorate [[BufferType]] 0 Offset 0 +; CHECK: OpMemberDecorate [[BufferType:%.+]] 0 Offset 0 +; CHECK: OpDecorate [[BufferType]] Block ; CHECK: OpMemberDecorate [[BufferType]] 0 NonWritable ; CHECK: OpDecorate [[RWBufferVar:%.+]] DescriptorSet 0 ; CHECK: OpDecorate [[RWBufferVar]] Binding 1 -; CHECK: OpDecorate [[RWBufferType:%.+]] Block -; CHECK: OpMemberDecorate [[RWBufferType]] 0 Offset 0 +; CHECK: OpDecorate [[ArrayType:%.+]] ArrayStride 4 +; CHECK: OpMemberDecorate [[RWBufferType:%.+]] 0 Offset 0 +; CHECK: OpDecorate [[RWBufferType]] Block ; CHECK: [[int:%[0-9]+]] = OpTypeInt 32 0 -; CHECK: [[ArrayType:%.+]] = OpTypeRuntimeArray +; CHECK: [[ArrayType]] = OpTypeRuntimeArray ; CHECK: [[RWBufferType]] = OpTypeStruct [[ArrayType]] ; CHECK: [[RWBufferPtrType:%.+]] = OpTypePointer StorageBuffer [[RWBufferType]] ; CHECK: [[BufferType]] = OpTypeStruct [[ArrayType]] diff --git a/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll b/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll new file mode 100644 index 0000000000000..7303471c9929c --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll @@ -0,0 +1,149 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val %} + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" + +; CHECK-DAG: OpName [[ScalarBlock_var:%[0-9]+]] "__resource_p_12_{_u32[0]}_0_0" +; CHECK-DAG: OpName [[buffer_var:%[0-9]+]] "__resource_p_12_{_{_{_u32_f32[3]}[10]}[0]}_0_0" +; CHECK-DAG: OpName [[array_buffer_var:%[0-9]+]] "__resource_p_12_{_{_{_u32_f32[3]}[10]}[0]}[10]_0_0" + +; CHECK-DAG: OpMemberDecorate [[ScalarBlock:%[0-9]+]] 0 Offset 0 +; CHECK-DAG: OpDecorate [[ScalarBlock]] Block +; CHECK-DAG: OpMemberDecorate [[ScalarBlock]] 0 NonWritable +; CHECK-DAG: OpMemberDecorate [[T_explicit:%[0-9]+]] 0 Offset 0 +; CHECK-DAG: OpMemberDecorate [[T_explicit]] 1 Offset 16 +; CHECK-DAG: OpDecorate [[T_array_explicit:%[0-9]+]] ArrayStride 32 +; CHECK-DAG: OpMemberDecorate [[S_explicit:%[0-9]+]] 0 Offset 0 +; CHECK-DAG: OpDecorate [[S_array_explicit:%[0-9]+]] ArrayStride 320 +; CHECK-DAG: OpMemberDecorate [[block:%[0-9]+]] 0 Offset 0 +; CHECK-DAG: OpDecorate [[block]] Block +; CHECK-DAG: OpMemberDecorate [[block]] 0 NonWritable + +; CHECK-DAG: [[float:%[0-9]+]] = OpTypeFloat 32 +; CHECK-DAG: [[v3f:%[0-9]+]] = OpTypeVector [[float]] 3 +; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0 +; CHECK-DAG: [[T:%[0-9]+]] = OpTypeStruct [[uint]] [[v3f]] +; CHECK-DAG: [[T_explicit]] = OpTypeStruct [[uint]] [[v3f]] +%struct.T = type { i32, <3 x float> } + +; CHECK-DAG: [[zero:%[0-9]+]] = OpConstant [[uint]] 0{{$}} +; CHECK-DAG: [[one:%[0-9]+]] = OpConstant [[uint]] 1{{$}} +; CHECK-DAG: [[ten:%[0-9]+]] = OpConstant [[uint]] 10 +; CHECK-DAG: [[T_array:%[0-9]+]] = OpTypeArray [[T]] [[ten]] +; CHECK-DAG: [[S:%[0-9]+]] = OpTypeStruct [[T_array]] +; CHECK-DAG: [[T_array_explicit]] = OpTypeArray [[T_explicit]] [[ten]] +; CHECK-DAG: [[S_explicit]] = OpTypeStruct [[T_array_explicit]] +%struct.S = type { [10 x %struct.T] } + +; CHECK-DAG: [[private_S_ptr:%[0-9]+]] = OpTypePointer Private [[S]] +; CHECK-DAG: [[private_var:%[0-9]+]] = OpVariable [[private_S_ptr]] Private +@private = internal addrspace(10) global %struct.S poison + +; CHECK-DAG: [[storagebuffer_S_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[S_explicit]] +; CHECK-DAG: [[storage_buffer:%[0-9]+]] = OpVariable [[storagebuffer_S_ptr]] StorageBuffer +@storage_buffer = internal addrspace(11) global %struct.S poison + +; CHECK-DAG: [[storagebuffer_int_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[uint]] +; CHECK-DAG: [[ScalarBlock_array:%[0-9]+]] = OpTypeRuntimeArray [[uint]] +; CHECK-DAG: [[ScalarBlock]] = OpTypeStruct [[ScalarBlock_array]] +; CHECK-DAG: [[ScalarBlock_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[ScalarBlock]] +; CHECK-DAG: [[ScalarBlock_var]] = OpVariable [[ScalarBlock_ptr]] StorageBuffer + + +; CHECK-DAG: [[S_array_explicit]] = OpTypeRuntimeArray [[S_explicit]] +; CHECK-DAG: [[block]] = OpTypeStruct [[S_array_explicit]] +; CHECK-DAG: [[buffer_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[block]] +; CHECK-DAG: [[buffer_var]] = OpVariable [[buffer_ptr]] StorageBuffer + +; CHECK-DAG: [[array_buffer:%[0-9]+]] = OpTypeArray [[block]] [[ten]] +; CHECK-DAG: [[array_buffer_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[array_buffer]] +; CHECK-DAG: [[array_buffer_var]] = OpVariable [[array_buffer_ptr]] StorageBuffer + +; CHECK: OpFunction [[uint]] None +define external i32 @scalar_vulkan_buffer_load() { +; CHECK-NEXT: OpLabel +entry: +; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[ScalarBlock_ptr]] [[ScalarBlock_var]] + %handle = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false) + +; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_int_ptr]] [[handle]] [[zero]] [[one]] + %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x i32], 12, 0) %handle, i32 1) + +; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[uint]] [[ptr]] Aligned 4 + %1 = load i32, ptr addrspace(11) %0, align 4 + +; CHECK-NEXT: OpReturnValue [[ld]] + ret i32 %1 + +; CHECK-NEXT: OpFunctionEnd +} + +; CHECK: OpFunction [[S]] None +define external %struct.S @private_load() { +; CHECK-NEXT: OpLabel +entry: + +; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S]] [[private_var]] Aligned 4 + %1 = load %struct.S, ptr addrspace(10) @private, align 4 + +; CHECK-NEXT: OpReturnValue [[ld]] + ret %struct.S %1 + +; CHECK-NEXT: OpFunctionEnd +} + +; CHECK: OpFunction [[S]] None +define external %struct.S @storage_buffer_load() { +; CHECK-NEXT: OpLabel +entry: + +; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[storage_buffer]] Aligned 4 +; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]] + %1 = load %struct.S, ptr addrspace(11) @storage_buffer, align 4 + +; CHECK-NEXT: OpReturnValue [[copy]] + ret %struct.S %1 + +; CHECK-NEXT: OpFunctionEnd +} + +; CHECK: OpFunction [[S]] None +define external %struct.S @vulkan_buffer_load() { +; CHECK-NEXT: OpLabel +entry: +; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[buffer_ptr]] [[buffer_var]] + %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false) + +; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_S_ptr]] [[handle]] [[zero]] [[one]] + %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) %handle, i32 1) + +; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[ptr]] Aligned 4 +; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]] + %1 = load %struct.S, ptr addrspace(11) %0, align 4 + +; CHECK-NEXT: OpReturnValue [[copy]] + ret %struct.S %1 + +; CHECK-NEXT: OpFunctionEnd +} + +; CHECK: OpFunction [[S]] None +define external %struct.S @array_of_vulkan_buffers_load() { +; CHECK-NEXT: OpLabel +entry: +; CHECK-NEXT: [[h:%[0-9]+]] = OpAccessChain [[buffer_ptr]] [[array_buffer_var]] [[one]] +; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[buffer_ptr]] [[h]] + %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 1, i1 false) + +; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_S_ptr]] [[handle]] [[zero]] [[one]] + %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) %handle, i32 1) + +; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[ptr]] Aligned 4 +; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]] + %1 = load %struct.S, ptr addrspace(11) %0, align 4 + +; CHECK-NEXT: OpReturnValue [[copy]] + ret %struct.S %1 + +; CHECK-NEXT: OpFunctionEnd +}