diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h new file mode 100644 index 0000000000000..91de43fdf60ca --- /dev/null +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -0,0 +1,167 @@ +//===- DataAccessProf.h - Data access profile format support ---------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains support to construct and use data access profiles. +// +// For the original RFC of this pass please see +// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744 +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_ +#define LLVM_PROFILEDATA_DATAACCESSPROF_H_ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/StringSaver.h" + +#include +#include + +namespace llvm { + +namespace data_access_prof { +// The location of data in the source code. +struct DataLocation { + // The filename where the data is located. + StringRef FileName; + // The line number in the source code. + uint32_t Line; +}; + +// The data access profiles for a symbol. +struct DataAccessProfRecord { + DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount, + bool IsStringLiteral) + : SymbolID(SymbolID), AccessCount(AccessCount), + IsStringLiteral(IsStringLiteral) {} + + // Represents a data symbol. The semantic comes in two forms: a symbol index + // for symbol name if `IsStringLiteral` is false, or the hash of a string + // content if `IsStringLiteral` is true. For most of the symbolizable static + // data, the mangled symbol names remain stable relative to the source code + // and therefore used to identify symbols across binary releases. String + // literals have unstable name patterns like `.str.N[.llvm.hash]`, so we use + // the content hash instead. This is a required field. + uint64_t SymbolID; + + // The access count of symbol. Required. + uint64_t AccessCount; + + // True iff this is a record for string literal (symbols with name pattern + // `.str.*` in the symbol table). Required. + bool IsStringLiteral; + + // The locations of data in the source code. Optional. + llvm::SmallVector Locations; +}; + +/// Encapsulates the data access profile data and the methods to operate on it. +/// This class provides profile look-up, serialization and deserialization. +class DataAccessProfData { +public: + // SymbolID is either a string representing symbol name if the symbol has + // stable mangled name relative to source code, or a uint64_t representing the + // content hash of a string literal (with unstable name patterns like + // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object. + using SymbolHandle = std::variant; + using StringToIndexMap = llvm::MapVector; + + DataAccessProfData() : Saver(Allocator) {} + + /// Serialize profile data to the output stream. + /// Storage layout: + /// - Serialized strings. + /// - The encoded hashes. + /// - Records. + Error serialize(ProfOStream &OS) const; + + /// Deserialize this class from the given buffer. + Error deserialize(const unsigned char *&Ptr); + + /// Returns a pointer of profile record for \p SymbolID, or nullptr if there + /// isn't a record. Internally, this function will canonicalize the symbol + /// name before the lookup. + const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const; + + /// Returns true if \p SymID is seen in profiled binaries and cold. + bool isKnownColdSymbol(const SymbolHandle SymID) const; + + /// Methods to set symbolized data access profile. Returns error if duplicated + /// symbol names or content hashes are seen. The user of this class should + /// aggregate counters that correspond to the same symbol name or with the + /// same string literal hash before calling 'set*' methods. + Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount); + /// Similar to the method above, for records with \p Locations representing + /// the `filename:line` where this symbol shows up. Note because of linker's + /// merge of identical symbols (e.g., unnamed_addr string literals), one + /// symbol is likely to have multiple locations. + Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount, + ArrayRef Locations); + Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID); + + /// Returns an iterable StringRef for strings in the order they are added. + /// Each string may be a symbol name or a file name. + auto getStrings() const { + return llvm::make_first_range(StrToIndexMap.getArrayRef()); + } + + /// Returns array reference for various internal data structures. + auto getRecords() const { return Records.getArrayRef(); } + ArrayRef getKnownColdSymbols() const { + return KnownColdSymbols.getArrayRef(); + } + ArrayRef getKnownColdHashes() const { + return KnownColdHashes.getArrayRef(); + } + +private: + /// Serialize the symbol strings into the output stream. + Error serializeSymbolsAndFilenames(ProfOStream &OS) const; + + /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the + /// start of the next payload. + Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr, + const uint64_t NumSampledSymbols, + const uint64_t NumColdKnownSymbols); + + /// Decode the records and increment \p Ptr to the start of the next payload. + Error deserializeRecords(const unsigned char *&Ptr); + + /// A helper function to compute a storage index for \p SymbolID. + uint64_t getEncodedIndex(const SymbolHandle SymbolID) const; + + // Keeps owned copies of the input strings. + // NOTE: Keep `Saver` initialized before other class members that reference + // its string copies and destructed after they are destructed. + llvm::BumpPtrAllocator Allocator; + llvm::UniqueStringSaver Saver; + + // `Records` stores the records. + MapVector Records; + + // Use MapVector to keep input order of strings for serialization and + // deserialization. + StringToIndexMap StrToIndexMap; + llvm::SetVector KnownColdHashes; + llvm::SetVector KnownColdSymbols; +}; + +} // namespace data_access_prof +} // namespace llvm + +#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_ diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index 2d011c89f27cb..33b93ea0a558a 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -357,6 +357,13 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName); /// the duplicated profile variables for Comdat functions. bool needsComdatForCounter(const GlobalObject &GV, const Module &M); +/// \c NameStrings is a string composed of one or more possibly encoded +/// sub-strings. The substrings are separated by `\01` (returned by +/// InstrProf.h:getInstrProfNameSeparator). This method decodes the string and +/// calls `NameCallback` for each substring. +Error readAndDecodeStrings(StringRef NameStrings, + std::function NameCallback); + /// An enum describing the attributes of an instrumented profile. enum class InstrProfKind { Unknown = 0x0, @@ -493,6 +500,11 @@ class InstrProfSymtab { public: using AddrHashMap = std::vector>; + // Returns the canonial name of the given PGOName. In a canonical name, all + // suffixes that begins with "." except ".__uniq." are stripped. + // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`. + static StringRef getCanonicalName(StringRef PGOName); + private: using AddrIntervalMap = IntervalMap>; @@ -528,11 +540,6 @@ class InstrProfSymtab { static StringRef getExternalSymbol() { return "** External Symbol **"; } - // Returns the canonial name of the given PGOName. In a canonical name, all - // suffixes that begins with "." except ".__uniq." are stripped. - // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`. - static StringRef getCanonicalName(StringRef PGOName); - // Add the function into the symbol table, by creating the following // map entries: // name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)} diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt index eb7c2a3c1a28a..67a69d7761b2c 100644 --- a/llvm/lib/ProfileData/CMakeLists.txt +++ b/llvm/lib/ProfileData/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_component_library(LLVMProfileData + DataAccessProf.cpp GCOV.cpp IndexedMemProfData.cpp InstrProf.cpp diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp new file mode 100644 index 0000000000000..d7e67f5f09cbe --- /dev/null +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -0,0 +1,263 @@ +#include "llvm/ProfileData/DataAccessProf.h" +#include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Compression.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace llvm { +namespace data_access_prof { + +// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise, +// creates an owned copy of `Str`, adds a map entry for it and returns the +// iterator. +static MapVector::iterator +saveStringToMap(MapVector &Map, + llvm::UniqueStringSaver &Saver, StringRef Str) { + auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size()); + return Iter; +} + +// Returns the canonical name or error. +static Expected getCanonicalName(StringRef Name) { + if (Name.empty()) + return make_error("Empty symbol name", + llvm::errc::invalid_argument); + return InstrProfSymtab::getCanonicalName(Name); +} + +const DataAccessProfRecord * +DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const { + auto Key = SymbolID; + if (std::holds_alternative(SymbolID)) { + auto NameOrErr = getCanonicalName(std::get(SymbolID)); + // If name canonicalization fails, suppress the error inside. + if (!NameOrErr) { + assert( + std::get(SymbolID).empty() && + "Name canonicalization only fails when stringified string is empty."); + return nullptr; + } + Key = *NameOrErr; + } + + auto It = Records.find(Key); + if (It != Records.end()) + return &It->second; + + return nullptr; +} + +bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const { + if (std::holds_alternative(SymID)) + return KnownColdHashes.contains(std::get(SymID)); + return KnownColdSymbols.contains(std::get(SymID)); +} + +Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol, + uint64_t AccessCount) { + uint64_t RecordID = -1; + bool IsStringLiteral = false; + SymbolHandle Key; + if (std::holds_alternative(Symbol)) { + RecordID = std::get(Symbol); + Key = RecordID; + IsStringLiteral = true; + } else { + auto CanonicalName = getCanonicalName(std::get(Symbol)); + if (!CanonicalName) + return CanonicalName.takeError(); + std::tie(Key, RecordID) = + *saveStringToMap(StrToIndexMap, Saver, *CanonicalName); + IsStringLiteral = false; + } + + auto [Iter, Inserted] = + Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral); + if (!Inserted) + return make_error("Duplicate symbol or string literal added. " + "User of DataAccessProfData should " + "aggregate count for the same symbol. ", + llvm::errc::invalid_argument); + + return Error::success(); +} + +Error DataAccessProfData::setDataAccessProfile( + SymbolHandle SymbolID, uint64_t AccessCount, + ArrayRef Locations) { + if (Error E = setDataAccessProfile(SymbolID, AccessCount)) + return E; + + auto &Record = Records.back().second; + for (const auto &Location : Locations) + Record.Locations.push_back( + {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first, + Location.Line}); + + return Error::success(); +} + +Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) { + if (std::holds_alternative(SymbolID)) { + KnownColdHashes.insert(std::get(SymbolID)); + return Error::success(); + } + auto CanonicalName = getCanonicalName(std::get(SymbolID)); + if (!CanonicalName) + return CanonicalName.takeError(); + KnownColdSymbols.insert(*CanonicalName); + return Error::success(); +} + +Error DataAccessProfData::deserialize(const unsigned char *&Ptr) { + uint64_t NumSampledSymbols = + support::endian::readNext(Ptr); + uint64_t NumColdKnownSymbols = + support::endian::readNext(Ptr); + if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols, + NumColdKnownSymbols)) + return E; + + uint64_t Num = + support::endian::readNext(Ptr); + for (uint64_t I = 0; I < Num; ++I) + KnownColdHashes.insert( + support::endian::readNext(Ptr)); + + return deserializeRecords(Ptr); +} + +Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { + OS.write(StrToIndexMap.size()); + OS.write(KnownColdSymbols.size()); + + std::vector Strs; + Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size()); + for (const auto &Str : StrToIndexMap) + Strs.push_back(Str.first.str()); + for (const auto &Str : KnownColdSymbols) + Strs.push_back(Str.str()); + + std::string CompressedStrings; + if (!Strs.empty()) + if (Error E = collectGlobalObjectNameStrings( + Strs, compression::zlib::isAvailable(), CompressedStrings)) + return E; + const uint64_t CompressedStringLen = CompressedStrings.length(); + // Record the length of compressed string. + OS.write(CompressedStringLen); + // Write the chars in compressed strings. + for (char C : CompressedStrings) + OS.writeByte(static_cast(C)); + // Pad up to a multiple of 8. + // InstrProfReader could read bytes according to 'CompressedStringLen'. + const uint64_t PaddedLength = alignTo(CompressedStringLen, 8); + for (uint64_t K = CompressedStringLen; K < PaddedLength; K++) + OS.writeByte(0); + return Error::success(); +} + +uint64_t +DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const { + if (std::holds_alternative(SymbolID)) + return std::get(SymbolID); + + auto Iter = StrToIndexMap.find(std::get(SymbolID)); + assert(Iter != StrToIndexMap.end() && + "String literals not found in StrToIndexMap"); + return Iter->second; +} + +Error DataAccessProfData::serialize(ProfOStream &OS) const { + if (Error E = serializeSymbolsAndFilenames(OS)) + return E; + OS.write(KnownColdHashes.size()); + for (const auto &Hash : KnownColdHashes) + OS.write(Hash); + OS.write((uint64_t)(Records.size())); + for (const auto &[Key, Rec] : Records) { + OS.write(getEncodedIndex(Rec.SymbolID)); + OS.writeByte(Rec.IsStringLiteral); + OS.write(Rec.AccessCount); + OS.write(Rec.Locations.size()); + for (const auto &Loc : Rec.Locations) { + OS.write(getEncodedIndex(Loc.FileName)); + OS.write32(Loc.Line); + } + } + return Error::success(); +} + +Error DataAccessProfData::deserializeSymbolsAndFilenames( + const unsigned char *&Ptr, const uint64_t NumSampledSymbols, + const uint64_t NumColdKnownSymbols) { + uint64_t Len = + support::endian::readNext(Ptr); + + // The first NumSampledSymbols strings are symbols with samples, and next + // NumColdKnownSymbols strings are known cold symbols. + uint64_t StringCnt = 0; + std::function addName = [&](StringRef Name) { + if (StringCnt < NumSampledSymbols) + saveStringToMap(StrToIndexMap, Saver, Name); + else + KnownColdSymbols.insert(Saver.save(Name)); + ++StringCnt; + return Error::success(); + }; + if (Error E = + readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName)) + return E; + + Ptr += alignTo(Len, 8); + return Error::success(); +} + +Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { + SmallVector Strings = llvm::to_vector(getStrings()); + + uint64_t NumRecords = + support::endian::readNext(Ptr); + + for (uint64_t I = 0; I < NumRecords; ++I) { + uint64_t ID = + support::endian::readNext(Ptr); + + bool IsStringLiteral = + support::endian::readNext(Ptr); + + uint64_t AccessCount = + support::endian::readNext(Ptr); + + SymbolHandle SymbolID; + if (IsStringLiteral) + SymbolID = ID; + else + SymbolID = Strings[ID]; + if (Error E = setDataAccessProfile(SymbolID, AccessCount)) + return E; + + auto &Record = Records.back().second; + + uint64_t NumLocations = + support::endian::readNext(Ptr); + + Record.Locations.reserve(NumLocations); + for (uint64_t J = 0; J < NumLocations; ++J) { + uint64_t FileNameIndex = + support::endian::readNext(Ptr); + uint32_t Line = + support::endian::readNext(Ptr); + Record.Locations.push_back({Strings[FileNameIndex], Line}); + } + } + return Error::success(); +} +} // namespace data_access_prof +} // namespace llvm diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp index 76e8ca6a67590..368e3535fe905 100644 --- a/llvm/lib/ProfileData/InstrProf.cpp +++ b/llvm/lib/ProfileData/InstrProf.cpp @@ -572,12 +572,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable, return Error::success(); } -/// \c NameStrings is a string composed of one of more possibly encoded -/// sub-strings. The substrings are separated by 0 or more zero bytes. This -/// method decodes the string and calls `NameCallback` for each substring. -static Error -readAndDecodeStrings(StringRef NameStrings, - std::function NameCallback) { +Error readAndDecodeStrings(StringRef NameStrings, + std::function NameCallback) { const uint8_t *P = NameStrings.bytes_begin(); const uint8_t *EndP = NameStrings.bytes_end(); while (P < EndP) { diff --git a/llvm/unittests/ProfileData/CMakeLists.txt b/llvm/unittests/ProfileData/CMakeLists.txt index 0a7f7da085950..29b9cb751dabe 100644 --- a/llvm/unittests/ProfileData/CMakeLists.txt +++ b/llvm/unittests/ProfileData/CMakeLists.txt @@ -10,6 +10,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(ProfileDataTests BPFunctionNodeTest.cpp CoverageMappingTest.cpp + DataAccessProfTest.cpp InstrProfDataTest.cpp InstrProfTest.cpp ItaniumManglingCanonicalizerTest.cpp diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp new file mode 100644 index 0000000000000..50c4af49fe76b --- /dev/null +++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp @@ -0,0 +1,181 @@ + +//===- unittests/Support/DataAccessProfTest.cpp +//----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ProfileData/DataAccessProf.h" +#include "llvm/Support/raw_ostream.h" +#include "gmock/gmock-more-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace data_access_prof { +namespace { + +using ::llvm::StringRef; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; + +static std::string ErrorToString(Error E) { + std::string ErrMsg; + llvm::raw_string_ostream OS(ErrMsg); + llvm::logAllUnhandledErrors(std::move(E), OS); + return ErrMsg; +} + +// Test the various scenarios when DataAccessProfData should return error on +// invalid input. +TEST(MemProf, DataAccessProfileError) { + // Returns error if the input symbol name is empty. + DataAccessProfData Data; + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)), + HasSubstr("Empty symbol name")); + + // Returns error when the same symbol gets added twice. + ASSERT_FALSE(Data.setDataAccessProfile("foo", 100)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)), + HasSubstr("Duplicate symbol or string literal added")); + + // Returns error when the same string content hash gets added twice. + ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)), + HasSubstr("Duplicate symbol or string literal added")); +} + +// Test the following operations on DataAccessProfData: +// - Profile record look up. +// - Serialization and de-serialization. +TEST(MemProf, DataAccessProfile) { + DataAccessProfData Data; + + // In the bool conversion, Error is true if it's in a failure state and false + // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error. + ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); + ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123, + { + DataLocation{"file2", 3}, + })); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); + ASSERT_FALSE(Data.setDataAccessProfile( + (uint64_t)135246, 1000, + {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); + + { + // Test that symbol names and file names are stored in the input order. + EXPECT_THAT(llvm::to_vector(Data.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); + EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles. + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(Data.isKnownColdSymbol("sym2")); + EXPECT_TRUE(Data.isKnownColdSymbol("sym1")); + + EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr); + EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr); + + EXPECT_THAT( + *Data.getProfileRecord("foo.llvm.123"), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty()))); + EXPECT_THAT( + *Data.getProfileRecord("bar.__uniq.321"), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3)))))); + EXPECT_THAT( + *Data.getProfileRecord((uint64_t)135246), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2)))))); + } + + // Tests serialization and de-serialization. + DataAccessProfData deserializedData; + { + std::string serializedData; + llvm::raw_string_ostream OS(serializedData); + llvm::ProfOStream POS(OS); + + EXPECT_FALSE(Data.serialize(POS)); + + const unsigned char *p = + reinterpret_cast(serializedData.data()); + ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()), + testing::IsEmpty()); + EXPECT_FALSE(deserializedData.deserialize(p)); + + EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(deserializedData.getKnownColdSymbols(), + ElementsAre("sym2", "sym1")); + EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles after deserialization. + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2")); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1")); + + auto Records = + llvm::to_vector(llvm::make_second_range(deserializedData.getRecords())); + + EXPECT_THAT( + Records, + ElementsAre( + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty())), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3))))), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2))))))); + } +} +} // namespace +} // namespace data_access_prof +} // namespace llvm