From 15be439719183df0c323d92791bdd5dca8b5590b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 15 May 2024 13:08:52 -0700 Subject: [PATCH] Harrison/move flashrank rerank (#21448) third party integration, should be in community --- .../document_compressors/__init__.py | 6 +- .../document_compressors/flashrank_rerank.py | 76 ++++++++++++++++ .../document_compressors/test_imports.py | 7 +- .../document_compressors/__init__.py | 19 +++- .../document_compressors/flashrank_rerank.py | 89 ++++--------------- 5 files changed, 122 insertions(+), 75 deletions(-) create mode 100644 libs/community/langchain_community/document_compressors/flashrank_rerank.py diff --git a/libs/community/langchain_community/document_compressors/__init__.py b/libs/community/langchain_community/document_compressors/__init__.py index 1dffb0a032d..b3241a9b114 100644 --- a/libs/community/langchain_community/document_compressors/__init__.py +++ b/libs/community/langchain_community/document_compressors/__init__.py @@ -2,6 +2,9 @@ import importlib from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from langchain_community.document_compressors.flashrank_rerank import ( + FlashrankRerank, + ) from langchain_community.document_compressors.jina_rerank import ( JinaRerank, # noqa: F401 ) @@ -12,12 +15,13 @@ if TYPE_CHECKING: OpenVINOReranker, ) -__all__ = ["LLMLinguaCompressor", "OpenVINOReranker"] +__all__ = ["LLMLinguaCompressor", "OpenVINOReranker", "FlashrankRerank"] _module_lookup = { "LLMLinguaCompressor": "langchain_community.document_compressors.llmlingua_filter", "OpenVINOReranker": "langchain_community.document_compressors.openvino_rerank", "JinaRerank": "langchain_community.document_compressors.jina_rerank", + "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank", } diff --git a/libs/community/langchain_community/document_compressors/flashrank_rerank.py b/libs/community/langchain_community/document_compressors/flashrank_rerank.py new file mode 100644 index 00000000000..dd3307b43e6 --- /dev/null +++ b/libs/community/langchain_community/document_compressors/flashrank_rerank.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Sequence + +from langchain_core.callbacks.manager import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.pydantic_v1 import Extra, root_validator + +if TYPE_CHECKING: + from flashrank import Ranker, RerankRequest +else: + # Avoid pydantic annotation issues when actually instantiating + # while keeping this import optional + try: + from flashrank import Ranker, RerankRequest + except ImportError: + pass + +DEFAULT_MODEL_NAME = "ms-marco-MultiBERT-L-12" + + +class FlashrankRerank(BaseDocumentCompressor): + """Document compressor using Flashrank interface.""" + + client: Ranker + """Flashrank client to use for compressing documents""" + top_n: int = 3 + """Number of documents to return.""" + model: Optional[str] = None + """Model to use for reranking.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + from flashrank import Ranker + except ImportError: + raise ImportError( + "Could not import flashrank python package. " + "Please install it with `pip install flashrank`." + ) + + values["model"] = values.get("model", DEFAULT_MODEL_NAME) + values["client"] = Ranker(model_name=values["model"]) + return values + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + passages = [ + {"id": i, "text": doc.page_content, "meta": doc.metadata} + for i, doc in enumerate(documents) + ] + + rerank_request = RerankRequest(query=query, passages=passages) + rerank_response = self.client.rerank(rerank_request)[: self.top_n] + final_results = [] + + for r in rerank_response: + metadata = r["meta"] + metadata["relevance_score"] = r["score"] + doc = Document( + page_content=r["text"], + metadata=metadata, + ) + final_results.append(doc) + return final_results diff --git a/libs/community/tests/unit_tests/document_compressors/test_imports.py b/libs/community/tests/unit_tests/document_compressors/test_imports.py index d6451cdd88e..c0b08904bf9 100644 --- a/libs/community/tests/unit_tests/document_compressors/test_imports.py +++ b/libs/community/tests/unit_tests/document_compressors/test_imports.py @@ -1,6 +1,11 @@ from langchain_community.document_compressors import __all__, _module_lookup -EXPECTED_ALL = ["LLMLinguaCompressor", "OpenVINOReranker", "JinaRerank"] +EXPECTED_ALL = [ + "LLMLinguaCompressor", + "OpenVINOReranker", + "JinaRerank", + "FlashrankRerank", +] def test_all_imports() -> None: diff --git a/libs/langchain/langchain/retrievers/document_compressors/__init__.py b/libs/langchain/langchain/retrievers/document_compressors/__init__.py index 6a32c453a3c..03f66977113 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/__init__.py +++ b/libs/langchain/langchain/retrievers/document_compressors/__init__.py @@ -1,3 +1,6 @@ +import importlib +from typing import Any + from langchain.retrievers.document_compressors.base import DocumentCompressorPipeline from langchain.retrievers.document_compressors.chain_extract import ( LLMChainExtractor, @@ -12,7 +15,18 @@ from langchain.retrievers.document_compressors.cross_encoder_rerank import ( from langchain.retrievers.document_compressors.embeddings_filter import ( EmbeddingsFilter, ) -from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank + +_module_lookup = { + "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank", +} + + +def __getattr__(name: str) -> Any: + if name in _module_lookup: + module = importlib.import_module(_module_lookup[name]) + return getattr(module, name) + raise AttributeError(f"module {__name__} has no attribute {name}") + __all__ = [ "DocumentCompressorPipeline", @@ -21,5 +35,4 @@ __all__ = [ "LLMChainFilter", "CohereRerank", "CrossEncoderReranker", - "FlashrankRerank", -] +] + list(_module_lookup.keys()) diff --git a/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py index f89cfa344f6..f2196fa6250 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py @@ -1,78 +1,27 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING, Dict, Optional, Sequence - -from langchain_core.callbacks.manager import Callbacks -from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, root_validator - -from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain._api import create_importer if TYPE_CHECKING: - from flashrank import Ranker, RerankRequest -else: - # Avoid pydantic annotation issues when actually instantiating - # while keeping this import optional - try: - from flashrank import Ranker, RerankRequest - except ImportError: - pass + from langchain_community.document_compressors.flashrank_rerank import ( + FlashrankRerank, + ) -DEFAULT_MODEL_NAME = "ms-marco-MultiBERT-L-12" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank" +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class FlashrankRerank(BaseDocumentCompressor): - """Document compressor using Flashrank interface.""" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - client: Ranker - """Flashrank client to use for compressing documents""" - top_n: int = 3 - """Number of documents to return.""" - model: Optional[str] = None - """Model to use for reranking.""" - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - try: - from flashrank import Ranker - except ImportError: - raise ImportError( - "Could not import flashrank python package. " - "Please install it with `pip install flashrank`." - ) - - values["model"] = values.get("model", DEFAULT_MODEL_NAME) - values["client"] = Ranker(model_name=values["model"]) - return values - - def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - passages = [ - {"id": i, "text": doc.page_content, "meta": doc.metadata} - for i, doc in enumerate(documents) - ] - - rerank_request = RerankRequest(query=query, passages=passages) - rerank_response = self.client.rerank(rerank_request)[: self.top_n] - final_results = [] - - for r in rerank_response: - metadata = r["meta"] - metadata["relevance_score"] = r["score"] - doc = Document( - page_content=r["text"], - metadata=metadata, - ) - final_results.append(doc) - return final_results +__all__ = [ + "FlashrankRerank", +]