From c421997caa8f95e57ca473815ecd10f032a0aa65 Mon Sep 17 00:00:00 2001 From: Eric Pinzur <2641606+epinzur@users.noreply.github.com> Date: Fri, 8 Nov 2024 20:04:57 +0100 Subject: [PATCH] community[patch]: Added type hinting to OpenSearch clients (#27946) Description: * When working with OpenSearchVectorSearch to make OpenSearchGraphVectorStore (coming soon), I noticed that there wasn't type hinting for the underlying OpenSearch clients. This fixes that issue. * Confirmed tests are still passing with code changes. Note that there is some additional code duplication now, but I think this approach is cleaner overall. --- .../vectorstores/opensearch_vector_search.py | 108 +++++++----------- 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py index f1329659eec..3e5bb280035 100644 --- a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py +++ b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid import warnings -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np from langchain_core.documents import Document @@ -23,57 +23,18 @@ SCRIPT_SCORING_SEARCH = "script_scoring" PAINLESS_SCRIPTING_SEARCH = "painless_scripting" MATCH_ALL_QUERY = {"match_all": {}} # type: Dict - -def _import_opensearch() -> Any: - """Import OpenSearch if available, otherwise raise error.""" - try: - from opensearchpy import OpenSearch - except ImportError: - raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) - return OpenSearch +if TYPE_CHECKING: + from opensearchpy import AsyncOpenSearch, OpenSearch -def _import_async_opensearch() -> Any: - """Import AsyncOpenSearch if available, otherwise raise error.""" - try: - from opensearchpy import AsyncOpenSearch - except ImportError: - raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR) - return AsyncOpenSearch - - -def _import_bulk() -> Any: - """Import bulk if available, otherwise raise error.""" - try: - from opensearchpy.helpers import bulk - except ImportError: - raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) - return bulk - - -def _import_async_bulk() -> Any: - """Import async_bulk if available, otherwise raise error.""" - try: - from opensearchpy.helpers import async_bulk - except ImportError: - raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR) - return async_bulk - - -def _import_not_found_error() -> Any: - """Import not found error if available, otherwise raise error.""" - try: - from opensearchpy.exceptions import NotFoundError - except ImportError: - raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) - return NotFoundError - - -def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: +def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch: """Get OpenSearch client from the opensearch_url, otherwise raise error.""" try: - opensearch = _import_opensearch() - client = opensearch(opensearch_url, **kwargs) + from opensearchpy import OpenSearch + + client = OpenSearch(opensearch_url, **kwargs) + except ImportError: + raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) except ValueError as e: raise ImportError( f"OpenSearch client string provided is not in proper format. " @@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: return client -def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: +def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch: """Get AsyncOpenSearch client from the opensearch_url, otherwise raise error.""" try: - async_opensearch = _import_async_opensearch() - client = async_opensearch(opensearch_url, **kwargs) + from opensearchpy import AsyncOpenSearch + + client = AsyncOpenSearch(opensearch_url, **kwargs) + except ImportError: + raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR) except ValueError as e: raise ImportError( f"AsyncOpenSearch client string provided is not in proper format. " @@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool: def _bulk_ingest_embeddings( - client: Any, + client: OpenSearch, index_name: str, embeddings: List[List[float]], texts: Iterable[str], @@ -142,16 +106,19 @@ def _bulk_ingest_embeddings( """Bulk Ingest Embeddings into given index.""" if not mapping: mapping = dict() + try: + from opensearchpy.exceptions import NotFoundError + from opensearchpy.helpers import bulk + except ImportError: + raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) - bulk = _import_bulk() - not_found_error = _import_not_found_error() requests = [] return_ids = [] mapping = mapping try: client.indices.get(index=index_name) - except not_found_error: + except NotFoundError: client.indices.create(index=index_name, body=mapping) for i, text in enumerate(texts): @@ -177,7 +144,7 @@ def _bulk_ingest_embeddings( async def _abulk_ingest_embeddings( - client: Any, + client: AsyncOpenSearch, index_name: str, embeddings: List[List[float]], texts: Iterable[str], @@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings( if not mapping: mapping = dict() - async_bulk = _import_async_bulk() - not_found_error = _import_not_found_error() + try: + from opensearchpy.exceptions import NotFoundError + from opensearchpy.helpers import async_bulk + except ImportError: + raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR) + requests = [] return_ids = [] try: await client.indices.get(index=index_name) - except not_found_error: + except NotFoundError: await client.indices.create(index=index_name, body=mapping) for i, text in enumerate(texts): @@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings( def _default_scripting_text_mapping( dim: int, vector_field: str = "vector_field", -) -> Dict: +) -> Dict[str, Any]: """For Painless Scripting or Script Scoring,the default mapping to create index.""" return { "mappings": { @@ -249,7 +220,7 @@ def _default_text_mapping( ef_construction: int = 512, m: int = 16, vector_field: str = "vector_field", -) -> Dict: +) -> Dict[str, Any]: """For Approximate k-NN Search, this is the default mapping to create index.""" return { "settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}}, @@ -275,7 +246,7 @@ def _default_approximate_search_query( k: int = 4, vector_field: str = "vector_field", score_threshold: Optional[float] = 0.0, -) -> Dict: +) -> Dict[str, Any]: """For Approximate k-NN Search, this is the default query.""" return { "size": k, @@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter( vector_field: str = "vector_field", subquery_clause: str = "must", score_threshold: Optional[float] = 0.0, -) -> Dict: +) -> Dict[str, Any]: """For Approximate k-NN Search, with Boolean Filter.""" return { "size": k, @@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter( k: int = 4, vector_field: str = "vector_field", score_threshold: Optional[float] = 0.0, -) -> Dict: +) -> Dict[str, Any]: """For Approximate k-NN Search, with Efficient Filter for Lucene and Faiss Engines.""" search_query = _default_approximate_search_query( @@ -330,7 +301,7 @@ def _default_script_query( pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", score_threshold: Optional[float] = 0.0, -) -> Dict: +) -> Dict[str, Any]: """For Script Scoring Search, this is the default query.""" if not pre_filter: @@ -376,7 +347,7 @@ def _default_painless_scripting_query( pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", score_threshold: Optional[float] = 0.0, -) -> Dict: +) -> Dict[str, Any]: """For Painless Scripting Search, this is the default query.""" if not pre_filter: @@ -692,7 +663,10 @@ class OpenSearchVectorSearch(VectorStore): refresh_indices: Whether to refresh the index after deleting documents. Defaults to True. """ - bulk = _import_bulk() + try: + from opensearchpy.helpers import bulk + except ImportError: + raise ImportError(IMPORT_OPENSEARCH_PY_ERROR) body = []