mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 00:58:32 +00:00
feat: Google Vertex AI Search Retriever - Add support for Website Data Stores (#11736)
- Only works for Data stores with Advanced Website Indexing - https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features - Minor restructuring - Follow up to #10513 - Remove outdated docs (readded in https://github.com/langchain-ai/langchain/pull/11620) - Move legacy class into new py file to clean up the directory - Shouldn't cause backwards compatibility issues as the import works the same way for users
This commit is contained in:
@@ -32,10 +32,8 @@ from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
from langchain.retrievers.google_cloud_documentai_warehouse import (
|
||||
GoogleDocumentAIWarehouseRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
|
@@ -1,22 +0,0 @@
|
||||
"""Retriever wrapper for Google Vertex AI Search.
|
||||
DEPRECATED: Maintained for backwards compatibility.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
|
||||
|
||||
|
||||
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex Search API` retriever alias for backwards compatibility.
|
||||
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
@@ -25,10 +25,21 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||
"""Vertex AI Search data store ID."""
|
||||
location_id: str = "global"
|
||||
"""Vertex AI Search data store location."""
|
||||
serving_config_id: str = "default_config"
|
||||
"""Vertex AI Search serving config ID."""
|
||||
credentials: Any = None
|
||||
"""The default custom credentials (google.auth.credentials.Credentials) to use
|
||||
when making API calls. If not provided, credentials will be ascertained from
|
||||
the environment."""
|
||||
engine_data_type: int = Field(default=0, ge=0, le=2)
|
||||
""" Defines the Vertex AI Search data type
|
||||
0 - Unstructured data
|
||||
1 - Structured data
|
||||
2 - Website data (with Advanced Website Indexing)
|
||||
"""
|
||||
|
||||
_serving_config: str
|
||||
"""Full path of serving config."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -144,6 +155,47 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||
|
||||
return documents
|
||||
|
||||
def _convert_website_search_response(
|
||||
self, results: Sequence[SearchResult]
|
||||
) -> List[Document]:
|
||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
for result in results:
|
||||
document_dict = MessageToDict(
|
||||
result.document._pb, preserving_proto_field_name=True
|
||||
)
|
||||
derived_struct_data = document_dict.get("derived_struct_data")
|
||||
if not derived_struct_data:
|
||||
continue
|
||||
|
||||
doc_metadata = document_dict.get("struct_data", {})
|
||||
doc_metadata["id"] = document_dict["id"]
|
||||
doc_metadata["source"] = derived_struct_data.get("link", "")
|
||||
|
||||
chunk_type = "extractive_answers"
|
||||
|
||||
if chunk_type not in derived_struct_data:
|
||||
continue
|
||||
|
||||
for chunk in derived_struct_data[chunk_type]:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=chunk.get("content", ""), metadata=doc_metadata
|
||||
)
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"No {chunk_type} could be found.\n"
|
||||
"Make sure that your data store is using Advanced Website Indexing.\n"
|
||||
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex AI Search` retriever.
|
||||
@@ -153,8 +205,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
|
||||
"""
|
||||
|
||||
serving_config_id: str = "default_config"
|
||||
"""Vertex AI Search serving config ID."""
|
||||
filter: Optional[str] = None
|
||||
"""Filter expression."""
|
||||
get_extractive_answers: bool = False
|
||||
@@ -188,15 +238,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
Search will be based on the corrected query if found.
|
||||
"""
|
||||
|
||||
# TODO: Add extra data type handling for type website
|
||||
engine_data_type: int = Field(default=0, ge=0, le=1)
|
||||
""" Defines the Vertex AI Search data type
|
||||
0 - Unstructured data
|
||||
1 - Structured data
|
||||
"""
|
||||
|
||||
_client: SearchServiceClient
|
||||
_serving_config: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -260,11 +302,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
content_search_spec = None
|
||||
elif self.engine_data_type == 2:
|
||||
content_search_spec = SearchRequest.ContentSearchSpec(
|
||||
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
|
||||
max_extractive_answer_count=self.max_extractive_answer_count,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO: Add extra data type handling for type website
|
||||
raise NotImplementedError(
|
||||
"Only engine data type 0 (Unstructured) or 1 (Structured)"
|
||||
+ " are supported currently."
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"or 2 (Website with Advanced Indexing) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
@@ -305,11 +352,12 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
documents = self._convert_structured_search_response(response.results)
|
||||
elif self.engine_data_type == 2:
|
||||
documents = self._convert_website_search_response(response.results)
|
||||
else:
|
||||
# TODO: Add extra data type handling for type website
|
||||
raise NotImplementedError(
|
||||
"Only engine data type 0 (Unstructured) or 1 (Structured)"
|
||||
+ " are supported currently."
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"or 2 (Website with Advanced Indexing) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
@@ -321,6 +369,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
):
|
||||
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
|
||||
|
||||
conversation_id: str = "-"
|
||||
"""Vertex AI Search Conversation ID."""
|
||||
|
||||
_client: ConversationalSearchServiceClient
|
||||
|
||||
class Config:
|
||||
@@ -340,6 +391,20 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
credentials=self.credentials, client_options=self.client_options
|
||||
)
|
||||
|
||||
self._serving_config = self._client.serving_config_path(
|
||||
project=self.project_id,
|
||||
location=self.location_id,
|
||||
data_store=self.data_store_id,
|
||||
serving_config=self.serving_config_id,
|
||||
)
|
||||
|
||||
if self.engine_data_type == 1:
|
||||
raise NotImplementedError(
|
||||
"Data store type 1 (Structured)"
|
||||
"is not currently supported for multi-turn search."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
@@ -351,11 +416,35 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
|
||||
request = ConverseConversationRequest(
|
||||
name=self._client.conversation_path(
|
||||
self.project_id, self.location_id, self.data_store_id, "-"
|
||||
self.project_id,
|
||||
self.location_id,
|
||||
self.data_store_id,
|
||||
self.conversation_id,
|
||||
),
|
||||
serving_config=self._serving_config,
|
||||
query=TextInput(input=query),
|
||||
)
|
||||
response = self._client.converse_conversation(request)
|
||||
|
||||
if self.engine_data_type == 2:
|
||||
return self._convert_website_search_response(response.search_results)
|
||||
|
||||
return self._convert_unstructured_search_response(
|
||||
response.search_results, "extractive_answers"
|
||||
)
|
||||
|
||||
|
||||
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex Search API` retriever alias for backwards compatibility.
|
||||
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
|
@@ -15,10 +15,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
|
Reference in New Issue
Block a user