diff --git a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py index bb3508eaa1c..19241e476e6 100644 --- a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py +++ b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py @@ -306,13 +306,6 @@ def _default_painless_scripting_query( } -def _get_kwargs_value(kwargs: Any, key: str, default_value: Any) -> Any: - """Get the value of the key if present. Else get the default_value.""" - if key in kwargs: - return kwargs.get(key) - return default_value - - class OpenSearchVectorSearch(VectorStore): """`Amazon OpenSearch Vector Engine` vector store. @@ -338,10 +331,10 @@ class OpenSearchVectorSearch(VectorStore): """Initialize with necessary components.""" self.embedding_function = embedding_function self.index_name = index_name - http_auth = _get_kwargs_value(kwargs, "http_auth", None) + http_auth = kwargs.get("http_auth") self.is_aoss = _is_aoss_enabled(http_auth=http_auth) self.client = _get_opensearch_client(opensearch_url, **kwargs) - self.engine = _get_kwargs_value(kwargs, "engine", None) + self.engine = kwargs.get("engine") @property def embeddings(self) -> Embeddings: @@ -357,16 +350,16 @@ class OpenSearchVectorSearch(VectorStore): **kwargs: Any, ) -> List[str]: _validate_embeddings_and_bulk_size(len(embeddings), bulk_size) - index_name = _get_kwargs_value(kwargs, "index_name", self.index_name) - text_field = _get_kwargs_value(kwargs, "text_field", "text") + index_name = kwargs.get("index_name", self.index_name) + text_field = kwargs.get("text_field", "text") dim = len(embeddings[0]) - engine = _get_kwargs_value(kwargs, "engine", "nmslib") - space_type = _get_kwargs_value(kwargs, "space_type", "l2") - ef_search = _get_kwargs_value(kwargs, "ef_search", 512) - ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512) - m = _get_kwargs_value(kwargs, "m", 16) - vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") - max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024) + engine = kwargs.get("engine", "nmslib") + space_type = kwargs.get("space_type", "l2") + ef_search = kwargs.get("ef_search", 512) + ef_construction = kwargs.get("ef_construction", 512) + m = kwargs.get("m", 16) + vector_field = kwargs.get("vector_field", "vector_field") + max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024) _validate_aoss_with_engines(self.is_aoss, engine) @@ -542,8 +535,8 @@ class OpenSearchVectorSearch(VectorStore): same as `similarity_search` """ - text_field = _get_kwargs_value(kwargs, "text_field", "text") - metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata") + text_field = kwargs.get("text_field", "text") + metadata_field = kwargs.get("metadata_field", "metadata") hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs) @@ -581,10 +574,10 @@ class OpenSearchVectorSearch(VectorStore): same as `similarity_search` """ embedding = self.embedding_function.embed_query(query) - search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") - vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") - index_name = _get_kwargs_value(kwargs, "index_name", self.index_name) - filter = _get_kwargs_value(kwargs, "filter", {}) + search_type = kwargs.get("search_type", "approximate_search") + vector_field = kwargs.get("vector_field", "vector_field") + index_name = kwargs.get("index_name", self.index_name) + filter = kwargs.get("filter", {}) if ( self.is_aoss @@ -597,11 +590,11 @@ class OpenSearchVectorSearch(VectorStore): ) if search_type == "approximate_search": - boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {}) - subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must") - efficient_filter = _get_kwargs_value(kwargs, "efficient_filter", {}) + boolean_filter = kwargs.get("boolean_filter", {}) + subquery_clause = kwargs.get("subquery_clause", "must") + efficient_filter = kwargs.get("efficient_filter", {}) # `lucene_filter` is deprecated, added for Backwards Compatibility - lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {}) + lucene_filter = kwargs.get("lucene_filter", {}) if boolean_filter != {} and efficient_filter != {}: raise ValueError( @@ -657,14 +650,14 @@ class OpenSearchVectorSearch(VectorStore): embedding, k=k, vector_field=vector_field ) elif search_type == SCRIPT_SCORING_SEARCH: - space_type = _get_kwargs_value(kwargs, "space_type", "l2") - pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) + space_type = kwargs.get("space_type", "l2") + pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY) search_query = _default_script_query( embedding, k, space_type, pre_filter, vector_field ) elif search_type == PAINLESS_SCRIPTING_SEARCH: - space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared") - pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) + space_type = kwargs.get("space_type", "l2Squared") + pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY) search_query = _default_painless_scripting_query( embedding, k, space_type, pre_filter, vector_field ) @@ -701,9 +694,9 @@ class OpenSearchVectorSearch(VectorStore): List of Documents selected by maximal marginal relevance. """ - vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") - text_field = _get_kwargs_value(kwargs, "text_field", "text") - metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata") + vector_field = kwargs.get("vector_field", "vector_field") + text_field = kwargs.get("text_field", "text") + metadata_field = kwargs.get("metadata_field", "metadata") # Get embedding of the user query embedding = self.embedding_function.embed_query(query) @@ -874,11 +867,11 @@ class OpenSearchVectorSearch(VectorStore): index_name = get_from_dict_or_env( kwargs, "index_name", "OPENSEARCH_INDEX_NAME", default=uuid.uuid4().hex ) - is_appx_search = _get_kwargs_value(kwargs, "is_appx_search", True) - vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") - text_field = _get_kwargs_value(kwargs, "text_field", "text") - max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024) - http_auth = _get_kwargs_value(kwargs, "http_auth", None) + is_appx_search = kwargs.get("is_appx_search", True) + vector_field = kwargs.get("vector_field", "vector_field") + text_field = kwargs.get("text_field", "text") + max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024) + http_auth = kwargs.get("http_auth") is_aoss = _is_aoss_enabled(http_auth=http_auth) engine = None @@ -889,11 +882,11 @@ class OpenSearchVectorSearch(VectorStore): ) if is_appx_search: - engine = _get_kwargs_value(kwargs, "engine", "nmslib") - space_type = _get_kwargs_value(kwargs, "space_type", "l2") - ef_search = _get_kwargs_value(kwargs, "ef_search", 512) - ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512) - m = _get_kwargs_value(kwargs, "m", 16) + engine = kwargs.get("engine", "nmslib") + space_type = kwargs.get("space_type", "l2") + ef_search = kwargs.get("ef_search", 512) + ef_construction = kwargs.get("ef_construction", 512) + m = kwargs.get("m", 16) _validate_aoss_with_engines(is_aoss, engine)