community[patch]: Add Blended Search Support to GoogleVertexAISearchRetriever (#19082)

https://cloud.google.com/generative-ai-app-builder/docs/create-data-store-es#multi-data-stores

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Holt Skinner 2024-03-15 17:39:31 -05:00 committed by GitHub
parent 0ddfe7fc9d
commit cee03630d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 41 deletions

View File

@ -30,7 +30,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install --upgrade --quiet google-cloud-discoveryengine" "%pip install --upgrade --quiet google-cloud-discoveryengine"
] ]
}, },
{ {
@ -115,10 +115,12 @@
" - `global` (default)\n", " - `global` (default)\n",
" - `us`\n", " - `us`\n",
" - `eu`\n", " - `eu`\n",
"- `data_store_id` - The ID of the data store you want to use.\n",
" - Note: This was called `search_engine_id` in previous versions of the retriever.\n",
"\n", "\n",
"The `project_id` and `data_store_id` parameters can be provided explicitly in the retriever's constructor or through the environment variables - `PROJECT_ID` and `DATA_STORE_ID`.\n", "One of:\n",
"- `search_engine_id` - The ID of the search app you want to use. (Required for Blended Search)\n",
"- `data_store_id` - The ID of the data store you want to use.\n",
"\n",
"The `project_id`, `search_engine_id` and `data_store_id` parameters can be provided explicitly in the retriever's constructor or through the environment variables - `PROJECT_ID`, `SEARCH_ENGINE_ID` and `DATA_STORE_ID`.\n",
"\n", "\n",
"You can also configure a number of optional parameters, including:\n", "You can also configure a number of optional parameters, including:\n",
"\n", "\n",
@ -137,17 +139,17 @@
"- `engine_data_type` - Defines the Vertex AI Search data type\n", "- `engine_data_type` - Defines the Vertex AI Search data type\n",
" - `0` - Unstructured data\n", " - `0` - Unstructured data\n",
" - `1` - Structured data\n", " - `1` - Structured data\n",
" - `2` - Website data with [Advanced Website Indexing](https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing)\n", " - `2` - Website data\n",
" - `3` - [Blended search](https://cloud.google.com/generative-ai-app-builder/docs/create-data-store-es#multi-data-stores)\n",
"\n", "\n",
"### Migration guide for `GoogleCloudEnterpriseSearchRetriever`\n", "### Migration guide for `GoogleCloudEnterpriseSearchRetriever`\n",
"\n", "\n",
"In previous versions, this retriever was called `GoogleCloudEnterpriseSearchRetriever`. Some backwards-incompatible changes had to be made to the retriever after the General Availability launch due to changes in the product behavior.\n", "In previous versions, this retriever was called `GoogleCloudEnterpriseSearchRetriever`.\n",
"\n", "\n",
"To update to the new retriever, make the following changes:\n", "To update to the new retriever, make the following changes:\n",
"\n", "\n",
"- Change the import from: `from langchain.retrievers import GoogleCloudEnterpriseSearchRetriever` -> `from langchain.retrievers import GoogleVertexAISearchRetriever`.\n", "- Change the import from: `from langchain.retrievers import GoogleCloudEnterpriseSearchRetriever` -> `from langchain.retrievers import GoogleVertexAISearchRetriever`.\n",
"- Change all class references from `GoogleCloudEnterpriseSearchRetriever` -> `GoogleVertexAISearchRetriever`.\n", "- Change all class references from `GoogleCloudEnterpriseSearchRetriever` -> `GoogleVertexAISearchRetriever`.\n"
"- Upon class initialization, change the `search_engine_id` parameter name to `data_store_id`.\n"
] ]
}, },
{ {
@ -170,6 +172,7 @@
"\n", "\n",
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n", "PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n", "LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
"SEARCH_ENGINE_ID = \"<YOUR SEARCH APP ID>\" # Set to your search app ID\n",
"DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID" "DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID"
] ]
}, },
@ -281,6 +284,32 @@
" print(doc)" " print(doc)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever for **blended** data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleVertexAISearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" location_id=LOCATION_ID,\n",
" search_engine_id=SEARCH_ENGINE_ID,\n",
" max_documents=3,\n",
" engine_data_type=3,\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -322,7 +351,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.11.0"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -25,8 +25,10 @@ if TYPE_CHECKING:
class _BaseGoogleVertexAISearchRetriever(BaseModel): class _BaseGoogleVertexAISearchRetriever(BaseModel):
project_id: str project_id: str
"""Google Cloud Project ID.""" """Google Cloud Project ID."""
data_store_id: str data_store_id: Optional[str] = None
"""Vertex AI Search data store ID.""" """Vertex AI Search data store ID."""
search_engine_id: Optional[str] = None
"""Vertex AI Search app ID."""
location_id: str = "global" location_id: str = "global"
"""Vertex AI Search data store location.""" """Vertex AI Search data store location."""
serving_config_id: str = "default_config" serving_config_id: str = "default_config"
@ -35,11 +37,12 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
"""The default custom credentials (google.auth.credentials.Credentials) to use """The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from when making API calls. If not provided, credentials will be ascertained from
the environment.""" the environment."""
engine_data_type: int = Field(default=0, ge=0, le=2) engine_data_type: int = Field(default=0, ge=0, le=3)
""" Defines the Vertex AI Search data type """ Defines the Vertex AI Search app data type
0 - Unstructured data 0 - Unstructured data
1 - Structured data 1 - Structured data
2 - Website data 2 - Website data
3 - Blended search
""" """
@root_validator(pre=True) @root_validator(pre=True)
@ -51,7 +54,7 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
raise ImportError( raise ImportError(
"google.cloud.discoveryengine is not installed." "google.cloud.discoveryengine is not installed."
"Please install it with pip install " "Please install it with pip install "
"google-cloud-discoveryengine>=0.11.0" "google-cloud-discoveryengine>=0.11.10"
) from exc ) from exc
try: try:
from google.api_core.exceptions import InvalidArgument # noqa: F401 from google.api_core.exceptions import InvalidArgument # noqa: F401
@ -64,26 +67,15 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID") values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
try: try:
# For backwards compatibility values["data_store_id"] = get_from_dict_or_env(
search_engine_id = get_from_dict_or_env( values, "data_store_id", "DATA_STORE_ID"
)
values["search_engine_id"] = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID" values, "search_engine_id", "SEARCH_ENGINE_ID"
) )
except Exception:
if search_engine_id:
import warnings
warnings.warn(
"The `search_engine_id` parameter is deprecated. Use `data_store_id` instead.", # noqa: E501
DeprecationWarning,
)
values["data_store_id"] = search_engine_id
except: # noqa: E722
pass pass
values["data_store_id"] = get_from_dict_or_env(
values, "data_store_id", "DATA_STORE_ID"
)
return values return values
@property @property
@ -273,12 +265,24 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
client_info=get_client_info(module="vertex-ai-search"), client_info=get_client_info(module="vertex-ai-search"),
) )
self._serving_config = self._client.serving_config_path( if self.engine_data_type == 3 and not self.search_engine_id:
project=self.project_id, raise ValueError(
location=self.location_id, "search_engine_id must be specified for blended search apps."
data_store=self.data_store_id, )
serving_config=self.serving_config_id,
) if self.search_engine_id:
self._serving_config = f"projects/{self.project_id}/locations/{self.location_id}/collections/default_collection/engines/{self.search_engine_id}/servingConfigs/default_config" # noqa: E501
elif self.data_store_id:
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
else:
raise ValueError(
"Either data_store_id or search_engine_id must be specified."
)
def _create_search_request(self, query: str) -> SearchRequest: def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object.""" """Prepares a SearchRequest object."""
@ -310,7 +314,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
) )
elif self.engine_data_type == 1: elif self.engine_data_type == 1:
content_search_spec = None content_search_spec = None
elif self.engine_data_type == 2: elif self.engine_data_type in (2, 3):
content_search_spec = SearchRequest.ContentSearchSpec( content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec( extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count, max_extractive_answer_count=self.max_extractive_answer_count,
@ -322,7 +326,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured)," "Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website) are supported currently." "2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}" + f" Got {self.engine_data_type}"
) )
@ -363,7 +367,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
) )
elif self.engine_data_type == 1: elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results) documents = self._convert_structured_search_response(response.results)
elif self.engine_data_type == 2: elif self.engine_data_type in (2, 3):
chunk_type = ( chunk_type = (
"extractive_answers" if self.get_extractive_answers else "snippets" "extractive_answers" if self.get_extractive_answers else "snippets"
) )
@ -373,7 +377,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured)," "Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website) are supported currently." "2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}" + f" Got {self.engine_data_type}"
) )
@ -410,6 +414,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
client_info=get_client_info(module="vertex-ai-search"), client_info=get_client_info(module="vertex-ai-search"),
) )
if not self.data_store_id:
raise ValueError("data_store_id is required for MultiTurnSearchRetriever.")
self._serving_config = self._client.serving_config_path( self._serving_config = self._client.serving_config_path(
project=self.project_id, project=self.project_id,
location=self.location_id, location=self.location_id,
@ -417,9 +424,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
serving_config=self.serving_config_id, serving_config=self.serving_config_id,
) )
if self.engine_data_type == 1: if self.engine_data_type == 1 or self.engine_data_type == 3:
raise NotImplementedError( raise NotImplementedError(
"Data store type 1 (Structured)" "Data store type 1 (Structured) and 3 (Blended)"
"is not currently supported for multi-turn search." "is not currently supported for multi-turn search."
+ f" Got {self.engine_data_type}" + f" Got {self.engine_data_type}"
) )