From 2aba9ab47ec0a91eb124723f70ba05360fae7eed Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Thu, 12 Oct 2023 04:08:53 +0200 Subject: [PATCH] Retriever based on GCP DocAI Warehouse (#11400) - **Description:** implements a retriever on top of DocAI Warehouse (to interact with existing enterprise documents) https://cloud.google.com/document-ai-warehouse?hl=en - **Issue:** new functionality @baskaryan --------- Co-authored-by: Bagatur --- docs/docs/integrations/platforms/google.mdx | 17 +++ .../langchain/retrievers/__init__.py | 4 + .../google_cloud_documentai_warehouse.py | 118 ++++++++++++++++++ .../test_google_docai_warehoure_retriever.py | 25 ++++ 4 files changed, 164 insertions(+) create mode 100644 libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py create mode 100644 libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py diff --git a/docs/docs/integrations/platforms/google.mdx b/docs/docs/integrations/platforms/google.mdx index f39c266a6c1..bcfb812686e 100644 --- a/docs/docs/integrations/platforms/google.mdx +++ b/docs/docs/integrations/platforms/google.mdx @@ -152,6 +152,23 @@ See a [usage example](/docs/integrations/retrievers/google_vertex_ai_search). from langchain.retrievers import GoogleVertexAISearchRetriever ``` +### Document AI Warehouse +> [Google Cloud Document AI Warehouse](https://cloud.google.com/document-ai-warehouse) +> allows enterprises to search, store, govern, and manage documents and their AI-extracted +> data and metadata in a single platform. Documents should be uploaded outside of Langchain, +> + +```python +from langchain.retrievers import GoogleDocumentAIWarehouseRetriever +docai_wh_retriever = GoogleDocumentAIWarehouseRetriever( + project_number=... +) +query = ... +documents = docai_wh_retriever.get_relevant_documents( + query, user_ldap=... +) +``` + ## Tools ### Google Search diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index 9abdb145fd2..e1c3fa3da86 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -28,6 +28,9 @@ from langchain.retrievers.contextual_compression import ContextualCompressionRet from langchain.retrievers.docarray import DocArrayRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever +from langchain.retrievers.google_cloud_documentai_warehouse import ( + GoogleDocumentAIWarehouseRetriever, +) from langchain.retrievers.google_cloud_enterprise_search import ( GoogleCloudEnterpriseSearchRetriever, ) @@ -74,6 +77,7 @@ __all__ = [ "ContextualCompressionRetriever", "ChaindeskRetriever", "ElasticSearchBM25Retriever", + "GoogleDocumentAIWarehouseRetriever", "GoogleCloudEnterpriseSearchRetriever", "GoogleVertexAISearchRetriever", "KayAiRetriever", diff --git a/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py new file mode 100644 index 00000000000..760f8362daa --- /dev/null +++ b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py @@ -0,0 +1,118 @@ +"""Retriever wrapper for Google Cloud Document AI Warehouse.""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.docstore.document import Document +from langchain.pydantic_v1 import root_validator +from langchain.schema import BaseRetriever +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + from google.cloud.contentwarehouse_v1 import ( + DocumentServiceClient, + RequestMetadata, + SearchDocumentsRequest, + ) + from google.cloud.contentwarehouse_v1.services.document_service.pagers import ( + SearchDocumentsPager, + ) + + +class GoogleDocumentAIWarehouseRetriever(BaseRetriever): + """A retriever based on Document AI Warehouse. + + Documents should be created and documents should be uploaded + in a separate flow, and this retriever uses only Document AI + schema_id provided to search for revelant documents. + + More info: https://cloud.google.com/document-ai-warehouse. + """ + + location: str = "us" + "GCP location where DocAI Warehouse is placed." + project_number: str + "GCP project number, should contain digits only." + schema_id: Optional[str] = None + "DocAI Warehouse schema to queary against. If nothing is provided, all documents " + "in the project will be searched." + qa_size_limit: int = 5 + "The limit on the number of documents returned." + client: "DocumentServiceClient" = None #: :meta private: + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validates the environment.""" + try: # noqa: F401 + from google.cloud.contentwarehouse_v1 import ( + DocumentServiceClient, + ) + except ImportError as exc: + raise ImportError( + "google.cloud.contentwarehouse is not installed." + "Please install it with pip install google-cloud-contentwarehouse" + ) from exc + + values["project_number"] = get_from_dict_or_env( + values, "project_number", "PROJECT_NUMBER" + ) + values["client"] = DocumentServiceClient() + return values + + def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata": + from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo + + user_info = UserInfo(id=f"user:{user_ldap}") + return RequestMetadata(user_info=user_info) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + ) -> List[Document]: + request = self._prepare_search_request(query, **kwargs) + response = self.client.search_documents(request=request) + return self._parse_search_response(response=response) + + def _prepare_search_request( + self, query: str, **kwargs: Any + ) -> "SearchDocumentsRequest": + from google.cloud.contentwarehouse_v1 import ( + DocumentQuery, + SearchDocumentsRequest, + ) + + try: + user_ldap = kwargs["user_ldap"] + except KeyError: + raise ValueError("Argument user_ldap should be provided!") + + request_metadata = self._prepare_request_metadata(user_ldap=user_ldap) + schemas = [] + if self.schema_id: + schemas.append( + self.client.document_schema_path( + project=self.project_number, + location=self.location, + document_schema=self.schema_id, + ) + ) + return SearchDocumentsRequest( + parent=self.client.common_location_path(self.project_number, self.location), + request_metadata=request_metadata, + document_query=DocumentQuery( + query=query, is_nl_query=True, document_schema_names=schemas + ), + qa_size_limit=self.qa_size_limit, + ) + + def _parse_search_response( + self, response: "SearchDocumentsPager" + ) -> List[Document]: + documents = [] + for doc in response.matching_documents: + metadata = { + "title": doc.document.title, + "source": doc.document.raw_document_path, + } + documents.append( + Document(page_content=doc.search_text_snippet, metadata=metadata) + ) + return documents diff --git a/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py b/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py new file mode 100644 index 00000000000..490c22e3668 --- /dev/null +++ b/libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py @@ -0,0 +1,25 @@ +"""Test Google Cloud Document AI Warehouse retriever.""" +import os + +from langchain.retrievers import GoogleDocumentAIWarehouseRetriever +from langchain.schema import Document + + +def test_google_documentai_warehoure_retriever() -> None: + """In order to run this test, you should provide a project_id and user_ldap. + + Example: + export USER_LDAP=... + export PROJECT_NUMBER=... + """ + project_number = os.environ["PROJECT_NUMBER"] + user_ldap = os.environ["USER_LDAP"] + docai_wh_retriever = GoogleDocumentAIWarehouseRetriever( + project_number=project_number + ) + documents = docai_wh_retriever.get_relevant_documents( + "What are Alphabet's Other Bets?", user_ldap=user_ldap + ) + assert len(documents) > 0 + for doc in documents: + assert isinstance(doc, Document)