fix: Update Google Cloud Enterprise Search to Vertex AI Search (#10513)

- Description: Google Cloud Enterprise Search was renamed to Vertex AI
Search
-
https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-search-and-conversation-is-now-generally-available
- This PR updates the documentation and Retriever class to use the new
terminology.
- Changed retriever class from `GoogleCloudEnterpriseSearchRetriever` to
`GoogleVertexAISearchRetriever`
- Updated documentation to specify that `extractive_segments` requires
the new [Enterprise
edition](https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#enterprise-features)
to be enabled.
  - Fixed spelling errors in documentation.
- Change parameter for Retriever from `search_engine_id` to
`data_store_id`
- When this retriever was originally implemented, there was no
distinction between a data store and search engine, but now these have
been split.
- Fixed an issue blocking some users where the api_endpoint can't be set
This commit is contained in:
Holt Skinner
2023-10-05 12:47:47 -05:00
committed by GitHub
parent 1d678f805f
commit 9f73fec057
11 changed files with 694 additions and 577 deletions

View File

@@ -30,6 +30,9 @@ from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleVertexAISearchRetriever,
)
from langchain.retrievers.kay import KayAiRetriever
from langchain.retrievers.kendra import AmazonKendraRetriever
from langchain.retrievers.knn import KNNRetriever
@@ -70,6 +73,7 @@ __all__ = [
"ChaindeskRetriever",
"ElasticSearchBM25Retriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAISearchRetriever",
"KayAiRetriever",
"KNNRetriever",
"LlamaIndexGraphRetriever",

View File

@@ -1,275 +1,22 @@
"""Retriever wrapper for Google Cloud Enterprise Search on Gen App Builder."""
from __future__ import annotations
"""Retriever wrapper for Google Vertex AI Search.
DEPRECATED: Maintained for backwards compatibility.
"""
from typing import Any
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from google.cloud.discoveryengine_v1beta import (
SearchRequest,
SearchResult,
SearchServiceClient,
)
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
"""`Google Cloud Enterprise Search API` retriever.
For a detailed explanation of the Enterprise Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
project_id: str
"""Google Cloud Project ID."""
search_engine_id: str
"""Enterprise Search engine ID."""
serving_config_id: str = "default_config"
"""Enterprise Search serving config ID."""
location_id: str = "global"
"""Enterprise Search engine location."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
def __init__(self, **data: Any):
import warnings
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the enterprise search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
raise ImportError(
"google.api_core.exceptions is not installed. "
"Please install it with pip install google-api-core"
) from exc
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
values["search_engine_id"] = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID"
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
return values
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
)
super().__init__(**data)
self._client = SearchServiceClient(credentials=self.credentials)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.search_engine_id,
serving_config=self.serving_config_id,
)
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
return documents
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
)
spell_correction_spec = SearchRequest.SpellCorrectionSpec(
mode=self.spell_correction_mode
)
if self.engine_data_type == 0:
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec
)
elif self.engine_data_type == 1:
content_search_spec = None
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
return SearchRequest(
query=query,
filter=self.filter,
serving_config=self._serving_config,
page_size=self.max_documents,
content_search_spec=content_search_spec,
query_expansion_spec=query_expansion_spec,
spell_correction_spec=spell_correction_spec,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.api_core.exceptions import InvalidArgument
search_request = self._create_search_request(query)
try:
response = self._client.search(search_request)
except InvalidArgument as e:
raise type(e)(
e.message + " This might be due to engine_data_type not set correctly."
)
if self.engine_data_type == 0:
documents = self._convert_unstructured_search_response(response.results)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
return documents

View File

@@ -0,0 +1,314 @@
"""Retriever wrapper for Google Vertex AI Search."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from google.cloud.discoveryengine_v1beta import (
SearchRequest,
SearchResult,
SearchServiceClient,
)
class GoogleVertexAISearchRetriever(BaseRetriever):
"""`Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
project_id: str
"""Google Cloud Project ID."""
data_store_id: str
"""Vertex AI Search data store ID."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
location_id: str = "global"
"""Vertex AI Search data store location."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
raise ImportError(
"google.api_core.exceptions is not installed. "
"Please install it with pip install google-api-core"
) from exc
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
try:
# For backwards compatibility
search_engine_id = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID"
)
if search_engine_id:
import warnings
warnings.warn(
"The `search_engine_id` parameter is deprecated. Use `data_store_id` instead.", # noqa: E501
DeprecationWarning,
)
values["data_store_id"] = search_engine_id
except: # noqa: E722
pass
values["data_store_id"] = get_from_dict_or_env(
values, "data_store_id", "DATA_STORE_ID"
)
return values
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.client_options import ClientOptions
except ImportError as exc:
raise ImportError(
"google.api_core.client_options is not installed."
"Please install it with pip install google-api-core"
) from exc
super().__init__(**data)
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
api_endpoint = (
"discoveryengine.googleapis.com"
if self.location_id == "global"
else f"{self.location_id}-discoveryengine.googleapis.com"
)
self._client = SearchServiceClient(
credentials=self.credentials,
client_options=ClientOptions(api_endpoint=api_endpoint),
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
return documents
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
)
spell_correction_spec = SearchRequest.SpellCorrectionSpec(
mode=self.spell_correction_mode
)
if self.engine_data_type == 0:
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec
)
elif self.engine_data_type == 1:
content_search_spec = None
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
return SearchRequest(
query=query,
filter=self.filter,
serving_config=self._serving_config,
page_size=self.max_documents,
content_search_spec=content_search_spec,
query_expansion_spec=query_expansion_spec,
spell_correction_spec=spell_correction_spec,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.api_core.exceptions import InvalidArgument
search_request = self._create_search_request(query)
try:
response = self._client.search(search_request)
except InvalidArgument as exc:
raise type(exc)(
exc.message
+ " This might be due to engine_data_type not set correctly."
)
if self.engine_data_type == 0:
documents = self._convert_unstructured_search_response(response.results)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
return documents

View File

@@ -1,32 +0,0 @@
"""Test Google Cloud Enterprise Search retriever.
You need to create a Gen App Builder search app and populate it
with data to run the integration tests.
Follow the instructions in the example notebook:
google_cloud_enterprise_search.ipynb
to set up the app and configure authentication.
Set the following environment variables before the tests:
PROJECT_ID - set to your Google Cloud project ID
SEARCH_ENGINE_ID - the ID of the search engine to use for the test
"""
import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.schema import Document
@pytest.mark.requires("google_api_core")
def test_google_cloud_enterprise_search_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method."""
retriever = GoogleCloudEnterpriseSearchRetriever()
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]

View File

@@ -0,0 +1,61 @@
"""Test Google Vertex AI Search retriever.
You need to create a Vertex AI Search app and populate it
with data to run the integration tests.
Follow the instructions in the example notebook:
google_vertex_ai_search.ipynb
to set up the app and configure authentication.
Set the following environment variables before the tests:
PROJECT_ID - set to your Google Cloud project ID
DATA_STORE_ID - the ID of the search engine to use for the test
"""
import os
import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
from langchain.schema import Document
@pytest.mark.requires("google_api_core")
def test_google_vertex_ai_search_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method."""
retriever = GoogleVertexAISearchRetriever()
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]
@pytest.mark.requires("google_api_core")
def test_google_vertex_ai_search_enterprise_search_deprecation() -> None:
"""Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""
with pytest.warns(
DeprecationWarning,
match="GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
):
retriever = GoogleCloudEnterpriseSearchRetriever()
os.environ["SEARCH_ENGINE_ID"] = os.getenv("DATA_STORE_ID", "data_store_id")
with pytest.warns(
DeprecationWarning,
match="The `search_engine_id` parameter is deprecated. Use `data_store_id` instead.", # noqa: E501
):
retriever = GoogleCloudEnterpriseSearchRetriever()
# Check that mapped methods still work.
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]