mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Added type hinting to OpenSearch clients (#27946)
Description: * When working with OpenSearchVectorSearch to make OpenSearchGraphVectorStore (coming soon), I noticed that there wasn't type hinting for the underlying OpenSearch clients. This fixes that issue. * Confirmed tests are still passing with code changes. Note that there is some additional code duplication now, but I think this approach is cleaner overall.
This commit is contained in:
parent
4c2392e55c
commit
c421997caa
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -23,57 +23,18 @@ SCRIPT_SCORING_SEARCH = "script_scoring"
|
|||||||
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
|
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
|
||||||
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
|
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
def _import_opensearch() -> Any:
|
from opensearchpy import AsyncOpenSearch, OpenSearch
|
||||||
"""Import OpenSearch if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
from opensearchpy import OpenSearch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
|
||||||
return OpenSearch
|
|
||||||
|
|
||||||
|
|
||||||
def _import_async_opensearch() -> Any:
|
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch:
|
||||||
"""Import AsyncOpenSearch if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
from opensearchpy import AsyncOpenSearch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
|
||||||
return AsyncOpenSearch
|
|
||||||
|
|
||||||
|
|
||||||
def _import_bulk() -> Any:
|
|
||||||
"""Import bulk if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
from opensearchpy.helpers import bulk
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
|
||||||
return bulk
|
|
||||||
|
|
||||||
|
|
||||||
def _import_async_bulk() -> Any:
|
|
||||||
"""Import async_bulk if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
from opensearchpy.helpers import async_bulk
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
|
||||||
return async_bulk
|
|
||||||
|
|
||||||
|
|
||||||
def _import_not_found_error() -> Any:
|
|
||||||
"""Import not found error if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
from opensearchpy.exceptions import NotFoundError
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
|
||||||
return NotFoundError
|
|
||||||
|
|
||||||
|
|
||||||
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
|
||||||
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
|
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
|
||||||
try:
|
try:
|
||||||
opensearch = _import_opensearch()
|
from opensearchpy import OpenSearch
|
||||||
client = opensearch(opensearch_url, **kwargs)
|
|
||||||
|
client = OpenSearch(opensearch_url, **kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"OpenSearch client string provided is not in proper format. "
|
f"OpenSearch client string provided is not in proper format. "
|
||||||
@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch:
|
||||||
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
|
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
|
||||||
try:
|
try:
|
||||||
async_opensearch = _import_async_opensearch()
|
from opensearchpy import AsyncOpenSearch
|
||||||
client = async_opensearch(opensearch_url, **kwargs)
|
|
||||||
|
client = AsyncOpenSearch(opensearch_url, **kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"AsyncOpenSearch client string provided is not in proper format. "
|
f"AsyncOpenSearch client string provided is not in proper format. "
|
||||||
@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _bulk_ingest_embeddings(
|
def _bulk_ingest_embeddings(
|
||||||
client: Any,
|
client: OpenSearch,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
embeddings: List[List[float]],
|
embeddings: List[List[float]],
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -142,16 +106,19 @@ def _bulk_ingest_embeddings(
|
|||||||
"""Bulk Ingest Embeddings into given index."""
|
"""Bulk Ingest Embeddings into given index."""
|
||||||
if not mapping:
|
if not mapping:
|
||||||
mapping = dict()
|
mapping = dict()
|
||||||
|
try:
|
||||||
|
from opensearchpy.exceptions import NotFoundError
|
||||||
|
from opensearchpy.helpers import bulk
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||||
|
|
||||||
bulk = _import_bulk()
|
|
||||||
not_found_error = _import_not_found_error()
|
|
||||||
requests = []
|
requests = []
|
||||||
return_ids = []
|
return_ids = []
|
||||||
mapping = mapping
|
mapping = mapping
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.indices.get(index=index_name)
|
client.indices.get(index=index_name)
|
||||||
except not_found_error:
|
except NotFoundError:
|
||||||
client.indices.create(index=index_name, body=mapping)
|
client.indices.create(index=index_name, body=mapping)
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
@ -177,7 +144,7 @@ def _bulk_ingest_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
async def _abulk_ingest_embeddings(
|
async def _abulk_ingest_embeddings(
|
||||||
client: Any,
|
client: AsyncOpenSearch,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
embeddings: List[List[float]],
|
embeddings: List[List[float]],
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings(
|
|||||||
if not mapping:
|
if not mapping:
|
||||||
mapping = dict()
|
mapping = dict()
|
||||||
|
|
||||||
async_bulk = _import_async_bulk()
|
try:
|
||||||
not_found_error = _import_not_found_error()
|
from opensearchpy.exceptions import NotFoundError
|
||||||
|
from opensearchpy.helpers import async_bulk
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
return_ids = []
|
return_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.indices.get(index=index_name)
|
await client.indices.get(index=index_name)
|
||||||
except not_found_error:
|
except NotFoundError:
|
||||||
await client.indices.create(index=index_name, body=mapping)
|
await client.indices.create(index=index_name, body=mapping)
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings(
|
|||||||
def _default_scripting_text_mapping(
|
def _default_scripting_text_mapping(
|
||||||
dim: int,
|
dim: int,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Painless Scripting or Script Scoring,the default mapping to create index."""
|
"""For Painless Scripting or Script Scoring,the default mapping to create index."""
|
||||||
return {
|
return {
|
||||||
"mappings": {
|
"mappings": {
|
||||||
@ -249,7 +220,7 @@ def _default_text_mapping(
|
|||||||
ef_construction: int = 512,
|
ef_construction: int = 512,
|
||||||
m: int = 16,
|
m: int = 16,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Approximate k-NN Search, this is the default mapping to create index."""
|
"""For Approximate k-NN Search, this is the default mapping to create index."""
|
||||||
return {
|
return {
|
||||||
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
|
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
|
||||||
@ -275,7 +246,7 @@ def _default_approximate_search_query(
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
score_threshold: Optional[float] = 0.0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Approximate k-NN Search, this is the default query."""
|
"""For Approximate k-NN Search, this is the default query."""
|
||||||
return {
|
return {
|
||||||
"size": k,
|
"size": k,
|
||||||
@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter(
|
|||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
subquery_clause: str = "must",
|
subquery_clause: str = "must",
|
||||||
score_threshold: Optional[float] = 0.0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Approximate k-NN Search, with Boolean Filter."""
|
"""For Approximate k-NN Search, with Boolean Filter."""
|
||||||
return {
|
return {
|
||||||
"size": k,
|
"size": k,
|
||||||
@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter(
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
score_threshold: Optional[float] = 0.0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
|
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
|
||||||
Faiss Engines."""
|
Faiss Engines."""
|
||||||
search_query = _default_approximate_search_query(
|
search_query = _default_approximate_search_query(
|
||||||
@ -330,7 +301,7 @@ def _default_script_query(
|
|||||||
pre_filter: Optional[Dict] = None,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
score_threshold: Optional[float] = 0.0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Script Scoring Search, this is the default query."""
|
"""For Script Scoring Search, this is the default query."""
|
||||||
|
|
||||||
if not pre_filter:
|
if not pre_filter:
|
||||||
@ -376,7 +347,7 @@ def _default_painless_scripting_query(
|
|||||||
pre_filter: Optional[Dict] = None,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
score_threshold: Optional[float] = 0.0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
"""For Painless Scripting Search, this is the default query."""
|
"""For Painless Scripting Search, this is the default query."""
|
||||||
|
|
||||||
if not pre_filter:
|
if not pre_filter:
|
||||||
@ -692,7 +663,10 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
refresh_indices: Whether to refresh the index
|
refresh_indices: Whether to refresh the index
|
||||||
after deleting documents. Defaults to True.
|
after deleting documents. Defaults to True.
|
||||||
"""
|
"""
|
||||||
bulk = _import_bulk()
|
try:
|
||||||
|
from opensearchpy.helpers import bulk
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||||
|
|
||||||
body = []
|
body = []
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user