added a multiturn search based on Vertex AI Search (#11885)

Replace this entire comment with:
- **Description:** Added a retriever based on multi-turn Vertex AI
Search
  - **Twitter handle:** lkuligin
This commit is contained in:
Leonid Kuligin 2023-10-17 02:05:12 +02:00 committed by GitHub
parent 38ed55245f
commit d269dd2e2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 239 additions and 145 deletions

View File

@ -161,7 +161,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.retrievers import GoogleVertexAISearchRetriever\n", "from langchain.retrievers import GoogleVertexAISearchRetriever, GoogleVertexAIMultiTurnSearchRetriever\n",
"\n", "\n",
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n", "PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n", "LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
@ -247,6 +247,37 @@
"for doc in result:\n", "for doc in result:\n",
" print(doc)" " print(doc)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retrieve for multi-turn search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleVertexAIMultiTurnSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" location_id=LOCATION_ID,\n",
" data_store_id=DATA_STORE_ID\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -35,6 +35,7 @@ from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever, GoogleCloudEnterpriseSearchRetriever,
) )
from langchain.retrievers.google_vertex_ai_search import ( from langchain.retrievers.google_vertex_ai_search import (
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever, GoogleVertexAISearchRetriever,
) )
from langchain.retrievers.kay import KayAiRetriever from langchain.retrievers.kay import KayAiRetriever
@ -79,6 +80,7 @@ __all__ = [
"ElasticSearchBM25Retriever", "ElasticSearchBM25Retriever",
"GoogleDocumentAIWarehouseRetriever", "GoogleDocumentAIWarehouseRetriever",
"GoogleCloudEnterpriseSearchRetriever", "GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever", "GoogleVertexAISearchRetriever",
"KayAiRetriever", "KayAiRetriever",
"KNNRetriever", "KNNRetriever",

View File

@ -4,88 +4,32 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING: if TYPE_CHECKING:
from google.api_core.client_options import ClientOptions
from google.cloud.discoveryengine_v1beta import ( from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
SearchRequest, SearchRequest,
SearchResult, SearchResult,
SearchServiceClient, SearchServiceClient,
) )
class GoogleVertexAISearchRetriever(BaseRetriever): class _BaseGoogleVertexAISearchRetriever(BaseModel):
"""`Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
project_id: str project_id: str
"""Google Cloud Project ID.""" """Google Cloud Project ID."""
data_store_id: str data_store_id: str
"""Vertex AI Search data store ID.""" """Vertex AI Search data store ID."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
location_id: str = "global" location_id: str = "global"
"""Vertex AI Search data store location.""" """Vertex AI Search data store location."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
credentials: Any = None credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use """The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from when making API calls. If not provided, credentials will be ascertained from
the environment.""" the environment."""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@root_validator(pre=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment.""" """Validates the environment."""
@ -94,9 +38,9 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
except ImportError as exc: except ImportError as exc:
raise ImportError( raise ImportError(
"google.cloud.discoveryengine is not installed." "google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine" "Please install it with pip install "
"google-cloud-discoveryengine>=0.11.0"
) from exc ) from exc
try: try:
from google.api_core.exceptions import InvalidArgument # noqa: F401 from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc: except ImportError as exc:
@ -130,87 +74,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return values return values
def __init__(self, **data: Any) -> None: @property
"""Initializes private fields.""" def client_options(self) -> "ClientOptions":
try: from google.api_core.client_options import ClientOptions
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.client_options import ClientOptions
except ImportError as exc:
raise ImportError(
"google.api_core.client_options is not installed."
"Please install it with pip install google-api-core"
) from exc
super().__init__(**data) return ClientOptions(
api_endpoint=f"{self.location_id}-discoveryengine.googleapis.com"
# For more information, refer to: if self.location_id != "global"
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store else None
api_endpoint = (
"discoveryengine.googleapis.com"
if self.location_id == "global"
else f"{self.location_id}-discoveryengine.googleapis.com"
) )
self._client = SearchServiceClient(
credentials=self.credentials,
client_options=ClientOptions(api_endpoint=api_endpoint),
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
return documents
def _convert_structured_search_response( def _convert_structured_search_response(
self, results: Sequence[SearchResult] self, results: Sequence[SearchResult]
) -> List[Document]: ) -> List[Document]:
@ -235,6 +108,128 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return documents return documents
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
return documents
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
super().__init__(**kwargs)
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
self._client = SearchServiceClient(
credentials=self.credentials, client_options=self.client_options
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
def _create_search_request(self, query: str) -> SearchRequest: def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object.""" """Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest from google.cloud.discoveryengine_v1beta import SearchRequest
@ -300,7 +295,14 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
) )
if self.engine_data_type == 0: if self.engine_data_type == 0:
documents = self._convert_unstructured_search_response(response.results) chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
documents = self._convert_unstructured_search_response(
response.results, chunk_type
)
elif self.engine_data_type == 1: elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results) documents = self._convert_structured_search_response(response.results)
else: else:
@ -312,3 +314,46 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
) )
return documents return documents
class GoogleVertexAIMultiTurnSearchRetriever(
BaseRetriever, _BaseGoogleVertexAISearchRetriever
):
_client: ConversationalSearchServiceClient
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
)
self._client = ConversationalSearchServiceClient(
credentials=self.credentials, client_options=self.client_options
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.cloud.discoveryengine_v1beta import (
ConverseConversationRequest,
TextInput,
)
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id, self.location_id, self.data_store_id, "-"
),
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)

View File

@ -7,8 +7,8 @@ google_vertex_ai_search.ipynb
to set up the app and configure authentication. to set up the app and configure authentication.
Set the following environment variables before the tests: Set the following environment variables before the tests:
PROJECT_ID - set to your Google Cloud project ID export PROJECT_ID=... - set to your Google Cloud project ID
DATA_STORE_ID - the ID of the search engine to use for the test export DATA_STORE_ID=... - the ID of the search engine to use for the test
""" """
import os import os
@ -18,7 +18,10 @@ import pytest
from langchain.retrievers.google_cloud_enterprise_search import ( from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever, GoogleCloudEnterpriseSearchRetriever,
) )
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever from langchain.retrievers.google_vertex_ai_search import (
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)
from langchain.schema import Document from langchain.schema import Document
@ -35,6 +38,19 @@ def test_google_vertex_ai_search_get_relevant_documents() -> None:
assert doc.metadata["source"] assert doc.metadata["source"]
@pytest.mark.requires("google_api_core")
def test_google_vertex_ai_multiturnsearch_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method."""
retriever = GoogleVertexAIMultiTurnSearchRetriever()
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]
@pytest.mark.requires("google_api_core") @pytest.mark.requires("google_api_core")
def test_google_vertex_ai_search_enterprise_search_deprecation() -> None: def test_google_vertex_ai_search_enterprise_search_deprecation() -> None:
"""Test the deprecation of GoogleCloudEnterpriseSearchRetriever.""" """Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""