mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
added a multiturn search based on Vertex AI Search (#11885)
Replace this entire comment with: - **Description:** Added a retriever based on multi-turn Vertex AI Search - **Twitter handle:** lkuligin
This commit is contained in:
parent
38ed55245f
commit
d269dd2e2f
@ -161,7 +161,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.retrievers import GoogleVertexAISearchRetriever\n",
|
"from langchain.retrievers import GoogleVertexAISearchRetriever, GoogleVertexAIMultiTurnSearchRetriever\n",
|
||||||
"\n",
|
"\n",
|
||||||
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
|
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
|
||||||
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
|
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
|
||||||
@ -247,6 +247,37 @@
|
|||||||
"for doc in result:\n",
|
"for doc in result:\n",
|
||||||
" print(doc)"
|
" print(doc)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Configure and use the retrieve for multi-turn search"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = GoogleVertexAIMultiTurnSearchRetriever(\n",
|
||||||
|
" project_id=PROJECT_ID,\n",
|
||||||
|
" location_id=LOCATION_ID,\n",
|
||||||
|
" data_store_id=DATA_STORE_ID\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"result = retriever.get_relevant_documents(query)\n",
|
||||||
|
"for doc in result:\n",
|
||||||
|
" print(doc)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -35,6 +35,7 @@ from langchain.retrievers.google_cloud_enterprise_search import (
|
|||||||
GoogleCloudEnterpriseSearchRetriever,
|
GoogleCloudEnterpriseSearchRetriever,
|
||||||
)
|
)
|
||||||
from langchain.retrievers.google_vertex_ai_search import (
|
from langchain.retrievers.google_vertex_ai_search import (
|
||||||
|
GoogleVertexAIMultiTurnSearchRetriever,
|
||||||
GoogleVertexAISearchRetriever,
|
GoogleVertexAISearchRetriever,
|
||||||
)
|
)
|
||||||
from langchain.retrievers.kay import KayAiRetriever
|
from langchain.retrievers.kay import KayAiRetriever
|
||||||
@ -79,6 +80,7 @@ __all__ = [
|
|||||||
"ElasticSearchBM25Retriever",
|
"ElasticSearchBM25Retriever",
|
||||||
"GoogleDocumentAIWarehouseRetriever",
|
"GoogleDocumentAIWarehouseRetriever",
|
||||||
"GoogleCloudEnterpriseSearchRetriever",
|
"GoogleCloudEnterpriseSearchRetriever",
|
||||||
|
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||||
"GoogleVertexAISearchRetriever",
|
"GoogleVertexAISearchRetriever",
|
||||||
"KayAiRetriever",
|
"KayAiRetriever",
|
||||||
"KNNRetriever",
|
"KNNRetriever",
|
||||||
|
@ -4,88 +4,32 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from google.api_core.client_options import ClientOptions
|
||||||
from google.cloud.discoveryengine_v1beta import (
|
from google.cloud.discoveryengine_v1beta import (
|
||||||
|
ConversationalSearchServiceClient,
|
||||||
SearchRequest,
|
SearchRequest,
|
||||||
SearchResult,
|
SearchResult,
|
||||||
SearchServiceClient,
|
SearchServiceClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GoogleVertexAISearchRetriever(BaseRetriever):
|
class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||||
"""`Google Vertex AI Search` retriever.
|
|
||||||
|
|
||||||
For a detailed explanation of the Vertex AI Search concepts
|
|
||||||
and configuration parameters, refer to the product documentation.
|
|
||||||
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
|
|
||||||
"""
|
|
||||||
|
|
||||||
project_id: str
|
project_id: str
|
||||||
"""Google Cloud Project ID."""
|
"""Google Cloud Project ID."""
|
||||||
data_store_id: str
|
data_store_id: str
|
||||||
"""Vertex AI Search data store ID."""
|
"""Vertex AI Search data store ID."""
|
||||||
serving_config_id: str = "default_config"
|
|
||||||
"""Vertex AI Search serving config ID."""
|
|
||||||
location_id: str = "global"
|
location_id: str = "global"
|
||||||
"""Vertex AI Search data store location."""
|
"""Vertex AI Search data store location."""
|
||||||
filter: Optional[str] = None
|
|
||||||
"""Filter expression."""
|
|
||||||
get_extractive_answers: bool = False
|
|
||||||
"""If True return Extractive Answers, otherwise return Extractive Segments."""
|
|
||||||
max_documents: int = Field(default=5, ge=1, le=100)
|
|
||||||
"""The maximum number of documents to return."""
|
|
||||||
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
|
|
||||||
"""The maximum number of extractive answers returned in each search result.
|
|
||||||
At most 5 answers will be returned for each SearchResult.
|
|
||||||
"""
|
|
||||||
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
|
|
||||||
"""The maximum number of extractive segments returned in each search result.
|
|
||||||
Currently one segment will be returned for each SearchResult.
|
|
||||||
"""
|
|
||||||
query_expansion_condition: int = Field(default=1, ge=0, le=2)
|
|
||||||
"""Specification to determine under which conditions query expansion should occur.
|
|
||||||
0 - Unspecified query expansion condition. In this case, server behavior defaults
|
|
||||||
to disabled
|
|
||||||
1 - Disabled query expansion. Only the exact search query is used, even if
|
|
||||||
SearchResponse.total_size is zero.
|
|
||||||
2 - Automatic query expansion built by the Search API.
|
|
||||||
"""
|
|
||||||
spell_correction_mode: int = Field(default=2, ge=0, le=2)
|
|
||||||
"""Specification to determine under which conditions query expansion should occur.
|
|
||||||
0 - Unspecified spell correction mode. In this case, server behavior defaults
|
|
||||||
to auto.
|
|
||||||
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
|
|
||||||
and put in the `SearchResponse.corrected_query`.
|
|
||||||
The spell suggestion will not be used as the search query.
|
|
||||||
2 - Automatic spell correction built by the Search API.
|
|
||||||
Search will be based on the corrected query if found.
|
|
||||||
"""
|
|
||||||
credentials: Any = None
|
credentials: Any = None
|
||||||
"""The default custom credentials (google.auth.credentials.Credentials) to use
|
"""The default custom credentials (google.auth.credentials.Credentials) to use
|
||||||
when making API calls. If not provided, credentials will be ascertained from
|
when making API calls. If not provided, credentials will be ascertained from
|
||||||
the environment."""
|
the environment."""
|
||||||
|
|
||||||
# TODO: Add extra data type handling for type website
|
|
||||||
engine_data_type: int = Field(default=0, ge=0, le=1)
|
|
||||||
""" Defines the Vertex AI Search data type
|
|
||||||
0 - Unstructured data
|
|
||||||
1 - Structured data
|
|
||||||
"""
|
|
||||||
|
|
||||||
_client: SearchServiceClient
|
|
||||||
_serving_config: str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.ignore
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
underscore_attrs_are_private = True
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validates the environment."""
|
"""Validates the environment."""
|
||||||
@ -94,9 +38,9 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
|
|||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"google.cloud.discoveryengine is not installed."
|
"google.cloud.discoveryengine is not installed."
|
||||||
"Please install it with pip install google-cloud-discoveryengine"
|
"Please install it with pip install "
|
||||||
|
"google-cloud-discoveryengine>=0.11.0"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google.api_core.exceptions import InvalidArgument # noqa: F401
|
from google.api_core.exceptions import InvalidArgument # noqa: F401
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
@ -130,87 +74,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def __init__(self, **data: Any) -> None:
|
@property
|
||||||
"""Initializes private fields."""
|
def client_options(self) -> "ClientOptions":
|
||||||
try:
|
|
||||||
from google.cloud.discoveryengine_v1beta import SearchServiceClient
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"google.cloud.discoveryengine is not installed."
|
|
||||||
"Please install it with pip install google-cloud-discoveryengine"
|
|
||||||
) from exc
|
|
||||||
try:
|
|
||||||
from google.api_core.client_options import ClientOptions
|
from google.api_core.client_options import ClientOptions
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"google.api_core.client_options is not installed."
|
|
||||||
"Please install it with pip install google-api-core"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
super().__init__(**data)
|
return ClientOptions(
|
||||||
|
api_endpoint=f"{self.location_id}-discoveryengine.googleapis.com"
|
||||||
# For more information, refer to:
|
if self.location_id != "global"
|
||||||
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
|
else None
|
||||||
api_endpoint = (
|
|
||||||
"discoveryengine.googleapis.com"
|
|
||||||
if self.location_id == "global"
|
|
||||||
else f"{self.location_id}-discoveryengine.googleapis.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._client = SearchServiceClient(
|
|
||||||
credentials=self.credentials,
|
|
||||||
client_options=ClientOptions(api_endpoint=api_endpoint),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._serving_config = self._client.serving_config_path(
|
|
||||||
project=self.project_id,
|
|
||||||
location=self.location_id,
|
|
||||||
data_store=self.data_store_id,
|
|
||||||
serving_config=self.serving_config_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_unstructured_search_response(
|
|
||||||
self, results: Sequence[SearchResult]
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
|
||||||
from google.protobuf.json_format import MessageToDict
|
|
||||||
|
|
||||||
documents: List[Document] = []
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
document_dict = MessageToDict(
|
|
||||||
result.document._pb, preserving_proto_field_name=True
|
|
||||||
)
|
|
||||||
derived_struct_data = document_dict.get("derived_struct_data")
|
|
||||||
if not derived_struct_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
doc_metadata = document_dict.get("struct_data", {})
|
|
||||||
doc_metadata["id"] = document_dict["id"]
|
|
||||||
|
|
||||||
chunk_type = (
|
|
||||||
"extractive_answers"
|
|
||||||
if self.get_extractive_answers
|
|
||||||
else "extractive_segments"
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunk_type not in derived_struct_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for chunk in derived_struct_data[chunk_type]:
|
|
||||||
doc_metadata["source"] = derived_struct_data.get("link", "")
|
|
||||||
|
|
||||||
if chunk_type == "extractive_answers":
|
|
||||||
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
|
|
||||||
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
page_content=chunk.get("content", ""), metadata=doc_metadata
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def _convert_structured_search_response(
|
def _convert_structured_search_response(
|
||||||
self, results: Sequence[SearchResult]
|
self, results: Sequence[SearchResult]
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -235,6 +108,128 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def _convert_unstructured_search_response(
|
||||||
|
self, results: Sequence[SearchResult], chunk_type: str
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
|
||||||
|
documents: List[Document] = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
document_dict = MessageToDict(
|
||||||
|
result.document._pb, preserving_proto_field_name=True
|
||||||
|
)
|
||||||
|
derived_struct_data = document_dict.get("derived_struct_data")
|
||||||
|
if not derived_struct_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc_metadata = document_dict.get("struct_data", {})
|
||||||
|
doc_metadata["id"] = document_dict["id"]
|
||||||
|
|
||||||
|
if chunk_type not in derived_struct_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for chunk in derived_struct_data[chunk_type]:
|
||||||
|
doc_metadata["source"] = derived_struct_data.get("link", "")
|
||||||
|
|
||||||
|
if chunk_type == "extractive_answers":
|
||||||
|
doc_metadata["source"] += f":{chunk.get('pageNumber', '')}"
|
||||||
|
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
page_content=chunk.get("content", ""), metadata=doc_metadata
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
|
||||||
|
"""`Google Vertex AI Search` retriever.
|
||||||
|
|
||||||
|
For a detailed explanation of the Vertex AI Search concepts
|
||||||
|
and configuration parameters, refer to the product documentation.
|
||||||
|
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
|
||||||
|
"""
|
||||||
|
|
||||||
|
serving_config_id: str = "default_config"
|
||||||
|
"""Vertex AI Search serving config ID."""
|
||||||
|
filter: Optional[str] = None
|
||||||
|
"""Filter expression."""
|
||||||
|
get_extractive_answers: bool = False
|
||||||
|
"""If True return Extractive Answers, otherwise return Extractive Segments."""
|
||||||
|
max_documents: int = Field(default=5, ge=1, le=100)
|
||||||
|
"""The maximum number of documents to return."""
|
||||||
|
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
|
||||||
|
"""The maximum number of extractive answers returned in each search result.
|
||||||
|
At most 5 answers will be returned for each SearchResult.
|
||||||
|
"""
|
||||||
|
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
|
||||||
|
"""The maximum number of extractive segments returned in each search result.
|
||||||
|
Currently one segment will be returned for each SearchResult.
|
||||||
|
"""
|
||||||
|
query_expansion_condition: int = Field(default=1, ge=0, le=2)
|
||||||
|
"""Specification to determine under which conditions query expansion should occur.
|
||||||
|
0 - Unspecified query expansion condition. In this case, server behavior defaults
|
||||||
|
to disabled
|
||||||
|
1 - Disabled query expansion. Only the exact search query is used, even if
|
||||||
|
SearchResponse.total_size is zero.
|
||||||
|
2 - Automatic query expansion built by the Search API.
|
||||||
|
"""
|
||||||
|
spell_correction_mode: int = Field(default=2, ge=0, le=2)
|
||||||
|
"""Specification to determine under which conditions query expansion should occur.
|
||||||
|
0 - Unspecified spell correction mode. In this case, server behavior defaults
|
||||||
|
to auto.
|
||||||
|
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
|
||||||
|
and put in the `SearchResponse.corrected_query`.
|
||||||
|
The spell suggestion will not be used as the search query.
|
||||||
|
2 - Automatic spell correction built by the Search API.
|
||||||
|
Search will be based on the corrected query if found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Add extra data type handling for type website
|
||||||
|
engine_data_type: int = Field(default=0, ge=0, le=1)
|
||||||
|
""" Defines the Vertex AI Search data type
|
||||||
|
0 - Unstructured data
|
||||||
|
1 - Structured data
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client: SearchServiceClient
|
||||||
|
_serving_config: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initializes private fields."""
|
||||||
|
try:
|
||||||
|
from google.cloud.discoveryengine_v1beta import SearchServiceClient
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"google.cloud.discoveryengine is not installed."
|
||||||
|
"Please install it with pip install google-cloud-discoveryengine"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# For more information, refer to:
|
||||||
|
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
|
||||||
|
self._client = SearchServiceClient(
|
||||||
|
credentials=self.credentials, client_options=self.client_options
|
||||||
|
)
|
||||||
|
|
||||||
|
self._serving_config = self._client.serving_config_path(
|
||||||
|
project=self.project_id,
|
||||||
|
location=self.location_id,
|
||||||
|
data_store=self.data_store_id,
|
||||||
|
serving_config=self.serving_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_search_request(self, query: str) -> SearchRequest:
|
def _create_search_request(self, query: str) -> SearchRequest:
|
||||||
"""Prepares a SearchRequest object."""
|
"""Prepares a SearchRequest object."""
|
||||||
from google.cloud.discoveryengine_v1beta import SearchRequest
|
from google.cloud.discoveryengine_v1beta import SearchRequest
|
||||||
@ -300,7 +295,14 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.engine_data_type == 0:
|
if self.engine_data_type == 0:
|
||||||
documents = self._convert_unstructured_search_response(response.results)
|
chunk_type = (
|
||||||
|
"extractive_answers"
|
||||||
|
if self.get_extractive_answers
|
||||||
|
else "extractive_segments"
|
||||||
|
)
|
||||||
|
documents = self._convert_unstructured_search_response(
|
||||||
|
response.results, chunk_type
|
||||||
|
)
|
||||||
elif self.engine_data_type == 1:
|
elif self.engine_data_type == 1:
|
||||||
documents = self._convert_structured_search_response(response.results)
|
documents = self._convert_structured_search_response(response.results)
|
||||||
else:
|
else:
|
||||||
@ -312,3 +314,46 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleVertexAIMultiTurnSearchRetriever(
|
||||||
|
BaseRetriever, _BaseGoogleVertexAISearchRetriever
|
||||||
|
):
|
||||||
|
_client: ConversationalSearchServiceClient
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
underscore_attrs_are_private = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
from google.cloud.discoveryengine_v1beta import (
|
||||||
|
ConversationalSearchServiceClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = ConversationalSearchServiceClient(
|
||||||
|
credentials=self.credentials, client_options=self.client_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get documents relevant for a query."""
|
||||||
|
from google.cloud.discoveryengine_v1beta import (
|
||||||
|
ConverseConversationRequest,
|
||||||
|
TextInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = ConverseConversationRequest(
|
||||||
|
name=self._client.conversation_path(
|
||||||
|
self.project_id, self.location_id, self.data_store_id, "-"
|
||||||
|
),
|
||||||
|
query=TextInput(input=query),
|
||||||
|
)
|
||||||
|
response = self._client.converse_conversation(request)
|
||||||
|
return self._convert_unstructured_search_response(
|
||||||
|
response.search_results, "extractive_answers"
|
||||||
|
)
|
||||||
|
@ -7,8 +7,8 @@ google_vertex_ai_search.ipynb
|
|||||||
to set up the app and configure authentication.
|
to set up the app and configure authentication.
|
||||||
|
|
||||||
Set the following environment variables before the tests:
|
Set the following environment variables before the tests:
|
||||||
PROJECT_ID - set to your Google Cloud project ID
|
export PROJECT_ID=... - set to your Google Cloud project ID
|
||||||
DATA_STORE_ID - the ID of the search engine to use for the test
|
export DATA_STORE_ID=... - the ID of the search engine to use for the test
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -18,7 +18,10 @@ import pytest
|
|||||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||||
GoogleCloudEnterpriseSearchRetriever,
|
GoogleCloudEnterpriseSearchRetriever,
|
||||||
)
|
)
|
||||||
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
|
from langchain.retrievers.google_vertex_ai_search import (
|
||||||
|
GoogleVertexAIMultiTurnSearchRetriever,
|
||||||
|
GoogleVertexAISearchRetriever,
|
||||||
|
)
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +38,19 @@ def test_google_vertex_ai_search_get_relevant_documents() -> None:
|
|||||||
assert doc.metadata["source"]
|
assert doc.metadata["source"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("google_api_core")
|
||||||
|
def test_google_vertex_ai_multiturnsearch_get_relevant_documents() -> None:
|
||||||
|
"""Test the get_relevant_documents() method."""
|
||||||
|
retriever = GoogleVertexAIMultiTurnSearchRetriever()
|
||||||
|
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
|
||||||
|
assert len(documents) > 0
|
||||||
|
for doc in documents:
|
||||||
|
assert isinstance(doc, Document)
|
||||||
|
assert doc.page_content
|
||||||
|
assert doc.metadata["id"]
|
||||||
|
assert doc.metadata["source"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("google_api_core")
|
@pytest.mark.requires("google_api_core")
|
||||||
def test_google_vertex_ai_search_enterprise_search_deprecation() -> None:
|
def test_google_vertex_ai_search_enterprise_search_deprecation() -> None:
|
||||||
"""Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""
|
"""Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""
|
||||||
|
Loading…
Reference in New Issue
Block a user