- Notifications
You must be signed in to change notification settings - Fork 15.2k
[IR2Vec] Restrict caching only to Flow-Aware computation #162559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesRemoved all the caching maps (BB, Inst) in OTOH, Flow-Aware embeddings would benefit from instruction level caching, as computing the embedding for an instruction would depend on the embeddings of other instructions in a function. So, retained instruction embedding caching logic only for Flow-Aware computation. This also necessitates an Full diff: https://github.com/llvm/llvm-project/pull/162559.diff 6 Files Affected:
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst index 965a21b8c84b8..bf3de11a2640e 100644 --- a/llvm/docs/MLGO.rst +++ b/llvm/docs/MLGO.rst @@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance. .. code-block:: c++ - const ir2vec::Embedding &FuncVector = Emb->getFunctionVector(); + ir2vec::Embedding FuncVector = Emb->getFunctionVector(); Currently, ``Embedder`` can generate embeddings at three levels: Instructions, Basic Blocks, and Functions. Appropriate getters are provided to access the diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 81409df7337c5..8420746761c5e 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -533,21 +533,18 @@ class Embedder { /// in the IR instructions to generate the vector representation. const float OpcWeight, TypeWeight, ArgWeight; - // Utility maps - these are used to store the vector representations of - // instructions, basic blocks and functions. - mutable Embedding FuncVector; - mutable BBEmbeddingsMap BBVecMap; - mutable InstEmbeddingsMap InstVecMap; - LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); /// Function to compute embeddings. It generates embeddings for all /// the instructions and basic blocks in the function F. - void computeEmbeddings() const; + Embedding computeEmbeddings() const; /// Function to compute the embedding for a given basic block. - /// Specific to the kind of embeddings being computed. - virtual void computeEmbeddings(const BasicBlock &BB) const = 0; + Embedding computeEmbeddings(const BasicBlock &BB) const; + + /// Function to compute the embedding for a given instruction. Specific to the + /// kind of embeddings being computed. + virtual Embedding computeEmbeddings(const Instruction &I) const = 0; public: virtual ~Embedder() = default; @@ -556,23 +553,21 @@ class Embedder { LLVM_ABI static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab); - /// Returns a map containing instructions and the corresponding embeddings for - /// the function F if it has been computed. If not, it computes the embeddings - /// for the function and returns the map. - LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const; + /// Returns the embedding for a given instruction in the function F + LLVM_ABI Embedding getInstVector(const Instruction &I) const; - /// Returns a map containing basic block and the corresponding embeddings for - /// the function F if it has been computed. If not, it computes the embeddings - /// for the function and returns the map. - LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const; + /// Returns the embedding for a given basic block in the function F + LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const; - /// Returns the embedding for a given basic block in the function F if it has - /// been computed. If not, it computes the embedding for the basic block and - /// returns it. - LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const; + /// Returns the embedding for the current function. + LLVM_ABI Embedding getFunctionVector() const; - /// Computes and returns the embedding for the current function. - LLVM_ABI const Embedding &getFunctionVector() const; + /// Invalidate embeddings if cached. The embeddings may not be relevant + /// anymore when the IR changes due to transformations. In such cases, the + /// cached embeddings should be invalidated to ensure + /// correctness/recomputation. This is a no-op for SymbolicEmbedder but + /// removes all the cached entries in FlowAwareEmbedder. + virtual void invalidateEmbeddings() {} }; /// Class for computing the Symbolic embeddings of IR2Vec. @@ -580,7 +575,7 @@ class Embedder { /// representations obtained from the Vocabulary. class LLVM_ABI SymbolicEmbedder : public Embedder { private: - void computeEmbeddings(const BasicBlock &BB) const override; + Embedding computeEmbeddings(const Instruction &I) const override; public: SymbolicEmbedder(const Function &F, const Vocabulary &Vocab) @@ -592,11 +587,17 @@ class LLVM_ABI SymbolicEmbedder : public Embedder { /// embeddings, and additionally capture the flow information in the IR. class LLVM_ABI FlowAwareEmbedder : public Embedder { private: - void computeEmbeddings(const BasicBlock &BB) const override; + // Utility map for caching - needed for flow-aware dependencies + mutable InstEmbeddingsMap InstVecMap; + + Embedding computeEmbeddings(const Instruction &I) const override; public: FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab) : Embedder(F, Vocab) {} + + /// Override to invalidate all cached instruction embeddings + void invalidateEmbeddings() override; }; } // namespace ir2vec diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 688535161d4b9..d6b1ae98b9977 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -155,8 +155,8 @@ void Embedding::print(raw_ostream &OS) const { Embedder::Embedder(const Function &F, const Vocabulary &Vocab) : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), - OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight), - FuncVector(Embedding(Dimension)) {} + OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) { +} std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab) { @@ -169,101 +169,104 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, return nullptr; } -const InstEmbeddingsMap &Embedder::getInstVecMap() const { - if (InstVecMap.empty()) - computeEmbeddings(); - return InstVecMap; +Embedding Embedder::getInstVector(const Instruction &I) const { + return computeEmbeddings(I); } -const BBEmbeddingsMap &Embedder::getBBVecMap() const { - if (BBVecMap.empty()) - computeEmbeddings(); - return BBVecMap; +Embedding Embedder::getBBVector(const BasicBlock &BB) const { + return computeEmbeddings(BB); } -const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { - auto It = BBVecMap.find(&BB); - if (It != BBVecMap.end()) - return It->second; - computeEmbeddings(BB); - return BBVecMap[&BB]; -} - -const Embedding &Embedder::getFunctionVector() const { +Embedding Embedder::getFunctionVector() const { // Currently, we always (re)compute the embeddings for the function. // This is cheaper than caching the vector. - computeEmbeddings(); - return FuncVector; + return computeEmbeddings(); } -void Embedder::computeEmbeddings() const { +Embedding Embedder::computeEmbeddings() const { if (F.isDeclaration()) - return; + return Embedding(Dimension, 0.0); - FuncVector = Embedding(Dimension, 0.0); + Embedding FuncVector(Dimension, 0.0); // Consider only the basic blocks that are reachable from entry for (const BasicBlock *BB : depth_first(&F)) { - computeEmbeddings(*BB); - FuncVector += BBVecMap[BB]; + FuncVector += computeEmbeddings(*BB); } + return FuncVector; } -void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { +Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const { Embedding BBVector(Dimension, 0); // We consider only the non-debug and non-pseudo instructions for (const auto &I : BB.instructionsWithoutDebug()) { - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) - ArgEmb += Vocab[*Op]; - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - if (const auto *IC = dyn_cast<CmpInst>(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; + BBVector += computeEmbeddings(I); } - BBVecMap[&BB] = BBVector; + return BBVector; } -void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { - Embedding BBVector(Dimension, 0); +Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const { + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) + ArgEmb += Vocab[*Op]; + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; + return InstVector; +} - // We consider only the non-debug and non-pseudo instructions - for (const auto &I : BB.instructionsWithoutDebug()) { - // TODO: Handle call instructions differently. - // For now, we treat them like other instructions - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) { - // If the operand is defined elsewhere, we use its embedding - if (const auto *DefInst = dyn_cast<Instruction>(Op)) { - auto DefIt = InstVecMap.find(DefInst); - assert(DefIt != InstVecMap.end() && - "Instruction should have been processed before its operands"); +Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const { + // If we have already computed the embedding for this instruction, return it + auto It = InstVecMap.find(&I); + if (It != InstVecMap.end()) + return It->second; + + // TODO: Handle call instructions differently. + // For now, we treat them like other instructions + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) { + // If the operand is defined elsewhere, we use its embedding + if (const auto *DefInst = dyn_cast<Instruction>(Op)) { + auto DefIt = InstVecMap.find(DefInst); + // Fixme (#159171): Ideally we should never miss an instruction + // embedding here. + // But when we have cyclic dependencies (e.g., phi + // nodes), we might miss the embedding. In such cases, we fall back to + // using the vocabulary embedding. This can be fixed by iterating to a + // fixed-point, or by using a simple solver for the set of simultaneous + // equations. + // Another case when we might miss an instruction embedding is when + // the operand instruction is in a different basic block that has not + // been processed yet. This can be fixed by processing the basic blocks + // in a topological order. + if (DefIt != InstVecMap.end()) ArgEmb += DefIt->second; - continue; - } - // If the operand is not defined by an instruction, we use the vocabulary - else { - LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " - << *Op << "=" << Vocab[*Op][0] << "\n"); + else ArgEmb += Vocab[*Op]; - } } - // Create the instruction vector by combining opcode, type, and arguments - // embeddings - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - // Add compare predicate embedding as an additional operand if applicable - if (const auto *IC = dyn_cast<CmpInst>(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; + // If the operand is not defined by an instruction, we use the + // vocabulary + else { + LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " + << *Op << "=" << Vocab[*Op][0] << "\n"); + ArgEmb += Vocab[*Op]; + } } - BBVecMap[&BB] = BBVector; + // Create the instruction vector by combining opcode, type, and arguments + // embeddings + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + // Add compare predicate embedding as an additional operand if applicable + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; + InstVecMap[&I] = InstVector; + return InstVector; } +void FlowAwareEmbedder::invalidateEmbeddings() { InstVecMap.clear(); } + // ==----------------------------------------------------------------------===// // VocabStorage //===----------------------------------------------------------------------===// @@ -684,25 +687,19 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, Emb->getFunctionVector().print(OS); OS << "Basic block vectors:\n"; - const auto &BBMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { - auto It = BBMap.find(&BB); - if (It != BBMap.end()) { - OS << "Basic block: " << BB.getName() << ":\n"; - It->second.print(OS); - } + auto BBVector = Emb->getBBVector(BB); + OS << "Basic block: " << BB.getName() << ":\n"; + BBVector.print(OS); } OS << "Instruction vectors:\n"; - const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { - auto It = InstMap.find(&I); - if (It != InstMap.end()) { - OS << "Instruction: "; - I.print(OS); - It->second.print(OS); - } + auto InstVector = Emb->getInstVector(I); + OS << "Instruction: "; + I.print(OS); + InstVector.print(OS); } } } diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll index 9be0ee1c2de7a..661f2ad158b10 100644 --- a/llvm/test/Analysis/IR2Vec/unreachable.ll +++ b/llvm/test/Analysis/IR2Vec/unreachable.ll @@ -31,12 +31,4 @@ return: ; preds = %if.else, %if.then ret i32 %4 } -; CHECK: Basic block vectors: -; CHECK-NEXT: Basic block: entry: -; CHECK-NEXT: [ 816.20 825.20 834.20 ] -; CHECK-NEXT: Basic block: if.then: -; CHECK-NEXT: [ 195.00 198.00 201.00 ] -; CHECK-NEXT: Basic block: if.else: -; CHECK-NEXT: [ 195.00 198.00 201.00 ] -; CHECK-NEXT: Basic block: return: -; CHECK-NEXT: [ 95.00 97.00 99.00 ] +; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ] diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 434449c7c5117..1031932116c1e 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -253,25 +253,17 @@ class IR2VecTool { break; } case BasicBlockLevel: { - const auto &BBVecMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { - auto It = BBVecMap.find(&BB); - if (It != BBVecMap.end()) { - OS << BB.getName() << ":"; - It->second.print(OS); - } + OS << BB.getName() << ":"; + Emb->getBBVector(BB).print(OS); } break; } case InstructionLevel: { - const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { - auto It = InstMap.find(&I); - if (It != InstMap.end()) { - I.print(OS); - It->second.print(OS); - } + I.print(OS); + Emb->getInstVector(I).print(OS); } } break; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 40b4aa21f2b46..d46dae1df1c7d 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -30,7 +30,9 @@ namespace { class TestableEmbedder : public Embedder { public: TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {} - void computeEmbeddings(const BasicBlock &BB) const override {} + Embedding computeEmbeddings(const Instruction &I) const override { + return Embedding(); + } }; TEST(EmbeddingTest, ConstructorsAndAccessors) { @@ -321,72 +323,6 @@ class IR2VecTestFixture : public ::testing::Test { } }; -TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { - auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - const auto &AddEmb = InstMap.at(AddInst); - const auto &RetEmb = InstMap.at(RetInst); - EXPECT_EQ(AddEmb.size(), 2u); - EXPECT_EQ(RetEmb.size(), 2u); - - EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5))); - EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5))); -} - -TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { - auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - EXPECT_EQ(InstMap.at(AddInst).size(), 2u); - EXPECT_EQ(InstMap.at(RetInst).size(), 2u); - - EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5))); - EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { - auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); - - // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = - // {41.0, 41.0} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { - auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); - - // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = - // {58.1, 58.1} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1))); -} - TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); ASSERT_TRUE(static_cast<bool>(Emb)); @@ -445,16 +381,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2)); EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } } TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { @@ -472,16 +398,6 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2)); EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } } static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; |
…-ir2vec-fa-refactor
Removed all the caching maps (BB, Inst) in
Embedder
as we don't want to cache embeddings in general. Our earlier experiments on Symbolic embeddings show recomputation of embeddings is cheaper than cache lookups.OTOH, Flow-Aware embeddings would benefit from instruction level caching, as computing the embedding for an instruction would depend on the embeddings of other instructions in a function. So, retained instruction embedding caching logic only for Flow-Aware computation. This also necessitates an
invalidate
method that would clean up the cache when the embeddings would become invalid due to transformations.