feat: Google Vertex AI Search Retriever - Add support for Website Data Stores (#11736)

- Only works for Data stores with Advanced Website Indexing
-
https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features
- Minor restructuring - Follow up to #10513
- Remove outdated docs (readded in
https://github.com/langchain-ai/langchain/pull/11620)
  - Move legacy class into new py file to clean up the directory
- Shouldn't cause backwards compatibility issues as the import works the
same way for users
This commit is contained in:
Holt Skinner
2023-10-19 01:41:48 -05:00
committed by GitHub
parent 4b6fdd7bf0
commit 2661dc94f3
7 changed files with 149 additions and 333 deletions

View File

@@ -32,10 +32,8 @@ from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_documentai_warehouse import (
GoogleDocumentAIWarehouseRetriever,
)
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)

View File

@@ -1,22 +0,0 @@
"""Retriever wrapper for Google Vertex AI Search.
DEPRECATED: Maintained for backwards compatibility.
"""
from typing import Any
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)

View File

@@ -25,10 +25,21 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
"""Vertex AI Search data store ID."""
location_id: str = "global"
"""Vertex AI Search data store location."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
engine_data_type: int = Field(default=0, ge=0, le=2)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
2 - Website data (with Advanced Website Indexing)
"""
_serving_config: str
"""Full path of serving config."""
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
@@ -144,6 +155,47 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
return documents
def _convert_website_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
doc_metadata["source"] = derived_struct_data.get("link", "")
chunk_type = "extractive_answers"
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
if not documents:
print(
f"No {chunk_type} could be found.\n"
"Make sure that your data store is using Advanced Website Indexing.\n"
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
)
return documents
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search` retriever.
@@ -153,8 +205,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
@@ -188,15 +238,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
Search will be based on the corrected query if found.
"""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
@@ -260,11 +302,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
)
elif self.engine_data_type == 1:
content_search_spec = None
elif self.engine_data_type == 2:
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
"Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website with Advanced Indexing) are supported currently."
+ f" Got {self.engine_data_type}"
)
@@ -305,11 +352,12 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
elif self.engine_data_type == 2:
documents = self._convert_website_search_response(response.results)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
"Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website with Advanced Indexing) are supported currently."
+ f" Got {self.engine_data_type}"
)
@@ -321,6 +369,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
):
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
conversation_id: str = "-"
"""Vertex AI Search Conversation ID."""
_client: ConversationalSearchServiceClient
class Config:
@@ -340,6 +391,20 @@ class GoogleVertexAIMultiTurnSearchRetriever(
credentials=self.credentials, client_options=self.client_options
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
if self.engine_data_type == 1:
raise NotImplementedError(
"Data store type 1 (Structured)"
"is not currently supported for multi-turn search."
+ f" Got {self.engine_data_type}"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
@@ -351,11 +416,35 @@ class GoogleVertexAIMultiTurnSearchRetriever(
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id, self.location_id, self.data_store_id, "-"
self.project_id,
self.location_id,
self.data_store_id,
self.conversation_id,
),
serving_config=self._serving_config,
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
if self.engine_data_type == 2:
return self._convert_website_search_response(response.search_results)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)

View File

@@ -15,10 +15,8 @@ import os
import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)