diff --git a/10-Retriever/03-EnsembleRetriever.ipynb b/10-Retriever/03-EnsembleRetriever.ipynb new file mode 100644 index 000000000..4fd5ec270 --- /dev/null +++ b/10-Retriever/03-EnsembleRetriever.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2817c9e", + "metadata": {}, + "source": [ + "# Ensemble Retriever\n", + "\n", + "- Author: [3dkids](https://github.com/3dkids)\n", + "- Peer Review: [r14minji](https://github.com/r14minji), [jeongkpa](https://github.com/jeongkpa)\n", + "- This is a part of [LangChain Open Tutorial](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial)\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1a9N74AS8BTPuO5IWdlvAm1AWwTRP9nCH?usp=sharing) [![Open in LangChain Academy](https://cdn.prod.website-files.com/65b8cd72835ceeacd4449a53/66e9eba12c7b7688aa3dbb5e_LCA-badge-green.svg)](https://academy.langchain.com/courses/take/intro-to-langgraph/lessons/58239937-lesson-2-sub-graphs)\n", + "\n", + "## Overview\n", + "\n", + "This notebook explores the creation and use of an EnsembleRetriever in LangChain to improve information retrieval by combining multiple retrieval methods.
\n", + "The EnsembleRetriever integrates the strengths of sparse and dense retrieval algorithms, using weights and runtime configurations for tailored performance.
\n", + "\n", + "**Key Features**\n", + "1. integrate multiple searchers: take different types of searchers as input and combine results.\n", + "2. result re-ranking: uses the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm to re-rank results.\n", + "3. hybrid search: mainly uses a combination of `sparse retriever` (e.g. BM25) and `dense retriever` (e.g. embedding similarity).\n", + "\n", + "**Advantages**\n", + "- Sparse retriever: effective for keyword-based searches\n", + "- Dense retriever: effective for semantic similarity-based searches\n", + "\n", + "Due to these complementary characteristics, `EnsembleRetriever` can provide improved performance in a variety of search scenarios.\n", + "\n", + "For more information, please refer to the [LangChain official documentation](https://python.langchain.com/api_reference/langchain/retrievers.html)\n", + "\n", + "\n", + "\n", + "### Table of Contents\n", + "\n", + "- [Overview](#overview)\n", + "- [Environement Setup](#environment-setup)\n", + "- [Creating and Configuring Ensemble Retrievers](#creating-and-configuring-ensemble-retrievers)\n", + "- [Query Execution](#query-execution)\n", + "- [Change runtime config](#change-runtime-config)\n", + "\n", + "\n", + "### References\n", + "\n", + "- [LangChain: EnsembleRetriever](https://python.langchain.com/api_reference/langchain/retrievers/langchain.retrievers.ensemble.EnsembleRetriever.html#ensembleretriever)\n", + "- [LangChain: BM25Retriever](https://python.langchain.com/api_reference/community/retrievers/langchain_community.retrievers.bm25.BM25Retriever.html)\n", + "- [LangChain: ConfigurableField](https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.utils.ConfigurableField.html)\n", + "----" + ] + }, + { + "cell_type": "markdown", + "id": "76826f0c", + "metadata": {}, + "source": [ + "## Environment Setup\n", + "\n", + "Set up the environment. You may refer to [Environment Setup](https://wikidocs.net/257836) for more details.\n", + "\n", + "**[Note]**\n", + "- `langchain-opentutorial` is a package that provides a set of easy-to-use environment setup, useful functions and utilities for tutorials. \n", + "- You can checkout the [`langchain-opentutorial`](https://github.com/LangChain-OpenTutorial/langchain-opentutorial-pypi) for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "45778437", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture --no-stderr\n", + "!pip install langchain-opentutorial" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b44160ee", + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "from langchain_opentutorial import package\n", + "\n", + "package.install(\n", + " [\n", + " \"langchain_core\", # Core functionality of LangChain\n", + " \"langchain_community\", # Community-supported integrations\n", + " \"langchain_openai\", # OpenAI integration for embeddings and models\n", + " \"rank_bm25\", # BM25 ranking algorithm for information retrieval\n", + " ],\n", + " verbose=False, # Suppress detailed installation logs\n", + " upgrade=False, # Do not upgrade packages if already installed\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4d757a49", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Environment variables have been set successfully.\n" + ] + } + ], + "source": [ + "# Set environment variables\n", + "from langchain_opentutorial import set_env\n", + "\n", + "set_env(\n", + " {\n", + " \"OPENAI_API_KEY\": \"\",\n", + " \"LANGCHAIN_API_KEY\": \"\",\n", + " \"LANGCHAIN_TRACING_V2\": \"true\",\n", + " \"LANGCHAIN_ENDPOINT\": \"https://api.smith.langchain.com\",\n", + " \"LANGCHAIN_PROJECT\": \"Conversation-With-History\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "19224b47", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Configuration file to manage API keys as environment variables\n", + "from dotenv import load_dotenv\n", + "\n", + "# Load API key information\n", + "load_dotenv(override=True)" + ] + }, + { + "cell_type": "markdown", + "id": "405f3d44", + "metadata": {}, + "source": [ + "## Creating and Configuring Ensemble Retrievers\n", + "Initializing an ensemble retriever\n", + "Ensemble retrievers combine two discovery mechanisms\n", + "\n", + "- Sparse search: Uses BM25Retriever for keyword-based matching.\n", + "- Dense search: Uses FAISS with OpenAI embedding for semantic similarity.\n", + "\n", + "- Initialize `EnsembleRetriever` to combine the `BM25Retriever` and `FAISS` searchers. Set the weights for each searcher." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c021c46f", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import BM25Retriever, EnsembleRetriever\n", + "from langchain.vectorstores import FAISS\n", + "from langchain_openai import OpenAIEmbeddings\n", + "\n", + "# list sample documents\n", + "doc_list = [\n", + " \"I like apples\",\n", + " \"I like apple company\",\n", + " \"I like apple's iphone\",\n", + " \"Apple is my favorite company\",\n", + " \"I like apple's ipad\",\n", + " \"I like apple's macbook\",\n", + "]\n", + "\n", + "# Initialize the bm25 retriever and faiss retriever.\n", + "bm25_retriever = BM25Retriever.from_texts(\n", + " doc_list,\n", + ")\n", + "bm25_retriever.k = 1 # Set the number of search results for BM25Retriever to 1.\n", + "\n", + "embedding = OpenAIEmbeddings() # Enable OpenAI embedding.\n", + "\n", + "faiss_vectorstore = FAISS.from_texts(\n", + " doc_list,\n", + " embedding,\n", + ")\n", + "faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 1})\n", + "\n", + "# Initialize the ensemble retriever.\n", + "ensemble_retriever = EnsembleRetriever(\n", + " retrievers=[bm25_retriever, faiss_retriever],\n", + " weights=[0.7, 0.3],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "835eb07f", + "metadata": {}, + "source": [ + "## Query Execution\n", + "Perform retrieval for a given query using ensemble_retriever and compare results across retrievers.\n", + "- Call the `get_relevant_documents()` method of the `ensemble_retriever` object to retrieve relevant documents.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cfacdd8d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Ensemble Retriever]\n", + "Content: Apple is my favorite company\n", + "\n", + "Content: I like apples\n", + "\n", + "[BM25 Retriever]\n", + "Content: Apple is my favorite company\n", + "\n", + "[FAISS Retriever]\n", + "Content: I like apples\n", + "\n" + ] + } + ], + "source": [ + "# Get the search results document.\n", + "query = \"my favorite fruit is apple\"\n", + "ensemble_result = ensemble_retriever.invoke(query)\n", + "bm25_result = bm25_retriever.invoke(query)\n", + "faiss_result = faiss_retriever.invoke(query)\n", + "\n", + "# Output the fetched documents.\n", + "print(\"[Ensemble Retriever]\")\n", + "for doc in ensemble_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()\n", + "\n", + "print(\"[BM25 Retriever]\")\n", + "for doc in bm25_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()\n", + "\n", + "print(\"[FAISS Retriever]\")\n", + "for doc in faiss_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dac4523b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Ensemble Retriever]\n", + "Content: Apple is my favorite company\n", + "\n", + "Content: I like apple's iphone\n", + "\n", + "[BM25 Retriever]\n", + "Content: Apple is my favorite company\n", + "\n", + "[FAISS Retriever]\n", + "Content: I like apple's iphone\n", + "\n" + ] + } + ], + "source": [ + "# Get the search results document.\n", + "query = \"Apple company makes my favorite iphone\"\n", + "ensemble_result = ensemble_retriever.invoke(query)\n", + "bm25_result = bm25_retriever.invoke(query)\n", + "faiss_result = faiss_retriever.invoke(query)\n", + "\n", + "# Output the fetched documents.\n", + "print(\"[Ensemble Retriever]\")\n", + "for doc in ensemble_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()\n", + "\n", + "print(\"[BM25 Retriever]\")\n", + "for doc in bm25_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()\n", + "\n", + "print(\"[FAISS Retriever]\")\n", + "for doc in faiss_result:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "id": "c9d4a5a2", + "metadata": {}, + "source": [ + "## Change runtime config\n", + "\n", + "You can also change the properties of a retriever at runtime. This is possible using the `ConfigurableField` class." + ] + }, + { + "cell_type": "markdown", + "id": "aae1c63e", + "metadata": {}, + "source": [ + "- Define the `weights` parameter as a `ConfigurableField` object.\n", + " - Set the field's ID to “ensemble_weights”.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "84a34e4b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.runnables import ConfigurableField\n", + "\n", + "ensemble_retriever = EnsembleRetriever(\n", + " # Set the list of retrievers. Here we use bm25_retriever and faiss_retriever.\n", + " retrievers=[bm25_retriever, faiss_retriever],\n", + ").configurable_fields(\n", + " weights=ConfigurableField(\n", + " # Set a unique identifier for the search parameter.\n", + " id=\"ensemble_weights\",\n", + " # Set a name for the search parameter.\n", + " name=\"Ensemble Weights\",\n", + " # Write a description of the search parameters.\n", + " description=\"Ensemble Weights\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bb0e4fb1", + "metadata": {}, + "source": [ + "- Specify the search settings via the `config` parameter when searching.\n", + " - Set the weight of the `ensemble_weights` option to [1, 0] so that **all search results are weighted more heavily toward BM25 retriever**." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1bee1000", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(metadata={}, page_content='Apple is my favorite company'),\n", + " Document(id='6280c2a3-b58f-474e-aeb6-d480bb44d49e', metadata={}, page_content='I like apples')]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config = {\"configurable\": {\"ensemble_weights\": [1, 0]}}\n", + "\n", + "# Use the config parameter to specify search settings.\n", + "docs = ensemble_retriever.invoke(\"my favorite fruit is apple\", config=config)\n", + "docs # Print the search result, docs." + ] + }, + { + "cell_type": "markdown", + "id": "ffdad0be", + "metadata": {}, + "source": [ + "This time, we want all search results to be weighted **more heavily in favor of the FAISS retriever**." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5d95922b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(id='6280c2a3-b58f-474e-aeb6-d480bb44d49e', metadata={}, page_content='I like apples'),\n", + " Document(metadata={}, page_content='Apple is my favorite company')]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config = {\"configurable\": {\"ensemble_weights\": [0, 1]}}\n", + "\n", + "# Use the config parameter to specify search settings.\n", + "docs = ensemble_retriever.invoke(\"my favorite fruit is apple\", config=config)\n", + "docs # Print the search result, docs." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain-opentutorial-UZZrLuk2-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}