diff --git a/docs/extras/modules/data_connection/retrievers/integrations/bm25.ipynb b/docs/extras/modules/data_connection/retrievers/integrations/bm25.ipynb new file mode 100644 index 00000000000..ad2c5e27abe --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/integrations/bm25.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ab66dd43", + "metadata": {}, + "source": [ + "# BM25\n", + "\n", + "[BM25](https://en.wikipedia.org/wiki/Okapi_BM25) also known as the Okapi BM25, is a ranking function used in information retrieval systems to estimate the relevance of documents to a given search query.\n", + "\n", + "This notebook goes over how to use a retriever that under the hood uses BM25 using [`rank_bm25`](https://github.com/dorianbrown/rank_bm25) package.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a801b57c", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install rank_bm25" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "393ac030", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/langchain/.venv/lib/python3.10/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.10) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from langchain.retrievers import BM25Retriever" + ] + }, + { + "cell_type": "markdown", + "id": "aaf80e7f", + "metadata": {}, + "source": [ + "## Create New Retriever with Texts" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "98b1c017", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "retriever = BM25Retriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])" + ] + }, + { + "cell_type": "markdown", + "id": "c016b266", + "metadata": {}, + "source": [ + "## Create a New Retriever with Documents\n", + "\n", + "You can now create a new retriever with the documents you created." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "53af4f00", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import Document\n", + "\n", + "retriever = BM25Retriever.from_documents(\n", + " [\n", + " Document(page_content=\"foo\"),\n", + " Document(page_content=\"bar\"),\n", + " Document(page_content=\"world\"),\n", + " Document(page_content=\"hello\"),\n", + " Document(page_content=\"foo bar\"),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "08437fa2", + "metadata": {}, + "source": [ + "## Use Retriever\n", + "\n", + "We can now use the retriever!" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c0455218", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result = retriever.get_relevant_documents(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7dfa5c29", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={}),\n", + " Document(page_content='foo bar', metadata={}),\n", + " Document(page_content='hello', metadata={}),\n", + " Document(page_content='world', metadata={})]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "997aaa8d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 87ef899fdd7..9d665407f60 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,5 +1,6 @@ from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever +from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.chaindesk import ChaindeskRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever @@ -51,6 +52,7 @@ __all__ = [ "SVMRetriever", "SelfQueryRetriever", "TFIDFRetriever", + "BM25Retriever", "TimeWeightedVectorStoreRetriever", "VespaRetriever", "WeaviateHybridSearchRetriever", diff --git a/langchain/retrievers/bm25.py b/langchain/retrievers/bm25.py new file mode 100644 index 00000000000..4487654140e --- /dev/null +++ b/langchain/retrievers/bm25.py @@ -0,0 +1,86 @@ +""" +BM25 Retriever without elastic search +""" + + +from __future__ import annotations + +from typing import Any, Callable, Dict, Iterable, List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.schema import BaseRetriever, Document + + +def default_preprocessing_func(text: str) -> List[str]: + return text.split() + + +class BM25Retriever(BaseRetriever): + vectorizer: Any + docs: List[Document] + k: int = 4 + preprocess_func: Callable[[str], List[str]] = default_preprocessing_func + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @classmethod + def from_texts( + cls, + texts: Iterable[str], + metadatas: Optional[Iterable[dict]] = None, + bm25_params: Optional[Dict[str, Any]] = None, + preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, + **kwargs: Any, + ) -> BM25Retriever: + try: + from rank_bm25 import BM25Okapi + except ImportError: + raise ImportError( + "Could not import rank_bm25, please install with `pip install " + "rank_bm25`." + ) + + texts_processed = [preprocess_func(t) for t in texts] + bm25_params = bm25_params or {} + vectorizer = BM25Okapi(texts_processed, **bm25_params) + metadatas = metadatas or ({} for _ in texts) + docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] + return cls( + vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs + ) + + @classmethod + def from_documents( + cls, + documents: Iterable[Document], + *, + bm25_params: Optional[Dict[str, Any]] = None, + preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, + **kwargs: Any, + ) -> BM25Retriever: + texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) + return cls.from_texts( + texts=texts, + bm25_params=bm25_params, + metadatas=metadatas, + preprocess_func=preprocess_func, + **kwargs, + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + processed_query = self.preprocess_func(query) + return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) + return return_docs + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + raise NotImplementedError diff --git a/poetry.lock b/poetry.lock index 437eb410d36..3561062e374 100644 --- a/poetry.lock +++ b/poetry.lock @@ -641,12 +641,16 @@ category = "main" optional = true python-versions = ">=3.7" files = [ + {file = "awadb-0.3.6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:d90318d2d388aa1bb740b0b7e641cb7da00e6ab5700ce97564163c88a1927ed4"}, {file = "awadb-0.3.6-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6154f73aab9996aefe8c8f8bf754f7182d109d6b60302c9f31666c7f50cc7aca"}, {file = "awadb-0.3.6-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:9d7e9dff353517595ecc8c9395a2367acdcfc83c68a64dd4785c8d366eed3f40"}, + {file = "awadb-0.3.6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:6f6d10d1e885fa1d64eeb8ffda2de470c3a7508d57a9489213b8649bcddcd31e"}, {file = "awadb-0.3.6-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:475af75d2ffbbe970999d93fbabdf7281797390c66fe852f6a6989e706b90c94"}, {file = "awadb-0.3.6-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:304be1de63daec1555f0fe9de9a18cdf16a467687a35a6ccf3405cd400fefb48"}, {file = "awadb-0.3.6-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:176cc27d1afc4aad758515d5f8fb435f555c9ba827a9e84d6f28b1c6ac568965"}, + {file = "awadb-0.3.6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:36138b754c990143d0314fd7a9293c96f7ba549860244bda728e3f51b73e0f6e"}, {file = "awadb-0.3.6-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:211d7f6b0f7c3c3d7518d424f0f3dfac5f45f9e5d7bbf397fdae861ff0dc46fd"}, + {file = "awadb-0.3.6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:b1f9e9a7ba2fa58bce55fcca784d5b3e159712962aaee2156f6317c5993f4277"}, {file = "awadb-0.3.6-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:b935ab4ffaa3bcbcc9a381fce91ace5940143b527ffdb467dd4bc630cd94afab"}, ] @@ -9219,6 +9223,24 @@ packaging = "*" [package.extras] test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] +[[package]] +name = "rank-bm25" +version = "0.2.2" +description = "Various BM25 algorithms for document ranking" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, + {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +dev = ["pytest"] + [[package]] name = "rapidfuzz" version = "3.1.1" @@ -12826,7 +12848,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "mwparserfromhell", "mwxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "zep-python"] javascript = ["esprima"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -12836,4 +12858,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "eb68f4732f6230c96b48e5f6ba87494cdbd0f710886c644ecf714610e684c98b" +content-hash = "9115ed1af430453f1ae4a188df7c45933a53491c9d600f438c22d289048bf4a9" diff --git a/pyproject.toml b/pyproject.toml index dd979aa2a26..fe86a3f5d49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ rdflib = {version = "^6.3.2", optional = true} sympy = {version = "^1.12", optional = true} rapidfuzz = {version = "^3.1.1", optional = true} langsmith = "^0.0.7" +rank-bm25 = {version = "^0.2.2", optional = true} [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" @@ -361,6 +362,7 @@ extended_testing = [ "openai", "sympy", "rapidfuzz", + "rank_bm25", ] [[tool.poetry.source]] diff --git a/tests/unit_tests/retrievers/test_bm25.py b/tests/unit_tests/retrievers/test_bm25.py new file mode 100644 index 00000000000..f021d708e7e --- /dev/null +++ b/tests/unit_tests/retrievers/test_bm25.py @@ -0,0 +1,34 @@ +import pytest + +from langchain.retrievers.bm25 import BM25Retriever +from langchain.schema import Document + + +@pytest.mark.requires("rank_bm25") +def test_from_texts() -> None: + input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] + bm25_retriever = BM25Retriever.from_texts(texts=input_texts) + assert len(bm25_retriever.docs) == 3 + assert bm25_retriever.vectorizer.doc_len == [4, 5, 4] + + +@pytest.mark.requires("rank_bm25") +def test_from_texts_with_bm25_params() -> None: + input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] + bm25_retriever = BM25Retriever.from_texts( + texts=input_texts, bm25_params={"epsilon": 10} + ) + # should count only multiple words (have, pan) + assert bm25_retriever.vectorizer.epsilon == 10 + + +@pytest.mark.requires("rank_bm25") +def test_from_documents() -> None: + input_docs = [ + Document(page_content="I have a pen."), + Document(page_content="Do you have a pen?"), + Document(page_content="I have a bag."), + ] + bm25_retriever = BM25Retriever.from_documents(documents=input_docs) + assert len(bm25_retriever.docs) == 3 + assert bm25_retriever.vectorizer.doc_len == [4, 5, 4]