mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
[ OpenSearch ] : Add AOSS Support to OpenSearch (#8256)
### Description This PR includes the following changes: - Adds AOSS (Amazon OpenSearch Service Serverless) support to OpenSearch. Please refer to the documentation on how to use it. - While creating an index, AOSS only supports Approximate Search with `nmslib` and `faiss` engines. During Search, only Approximate Search and Script Scoring (on doc values) are supported. - This PR also adds support to `efficient_filter` which can be used with `faiss` and `lucene` engines. - The `lucene_filter` is deprecated. Instead please use the `efficient_filter` for the lucene engine. Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
parent
7a00f17033
commit
9cbefcc56c
@ -315,6 +315,101 @@
|
|||||||
" metadata_field=\"message_metadata\",\n",
|
" metadata_field=\"message_metadata\",\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Using AOSS (Amazon OpenSearch Service Serverless)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is just an example to show how to use AOSS with faiss engine and efficient_filter, you need to set proper values.\n",
|
||||||
|
"\n",
|
||||||
|
"service = 'aoss' # must set the service as 'aoss'\n",
|
||||||
|
"region = 'us-east-2'\n",
|
||||||
|
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
|
||||||
|
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
|
||||||
|
"\n",
|
||||||
|
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||||
|
" docs,\n",
|
||||||
|
" embeddings,\n",
|
||||||
|
" opensearch_url=\"host url\",\n",
|
||||||
|
" http_auth=awsauth,\n",
|
||||||
|
" timeout = 300,\n",
|
||||||
|
" use_ssl = True,\n",
|
||||||
|
" verify_certs = True,\n",
|
||||||
|
" connection_class = RequestsHttpConnection,\n",
|
||||||
|
" index_name=\"test-index-using-aoss\",\n",
|
||||||
|
" engine=\"faiss\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"docs = docsearch.similarity_search(\n",
|
||||||
|
" \"What is feature selection\",\n",
|
||||||
|
" efficient_filter=filter,\n",
|
||||||
|
" k=200,\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Using AOS (Amazon OpenSearch Service)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is just an example to show how to use AOS , you need to set proper values.\n",
|
||||||
|
"\n",
|
||||||
|
"service = 'es' # must set the service as 'es'\n",
|
||||||
|
"region = 'us-east-2'\n",
|
||||||
|
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
|
||||||
|
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
|
||||||
|
"\n",
|
||||||
|
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||||
|
" docs,\n",
|
||||||
|
" embeddings,\n",
|
||||||
|
" opensearch_url=\"host url\",\n",
|
||||||
|
" http_auth=awsauth,\n",
|
||||||
|
" timeout = 300,\n",
|
||||||
|
" use_ssl = True,\n",
|
||||||
|
" verify_certs = True,\n",
|
||||||
|
" connection_class = RequestsHttpConnection,\n",
|
||||||
|
" index_name=\"test-index\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"docs = docsearch.similarity_search(\n",
|
||||||
|
" \"What is feature selection\",\n",
|
||||||
|
" k=200,\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -71,6 +72,26 @@ def _validate_embeddings_and_bulk_size(embeddings_length: int, bulk_size: int) -
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None:
|
||||||
|
"""Validate AOSS with the engine."""
|
||||||
|
if is_aoss and engine != "nmslib" and engine != "faiss":
|
||||||
|
raise ValueError(
|
||||||
|
"Amazon OpenSearch Service Serverless only "
|
||||||
|
"supports `nmslib` or `faiss` engines"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_aoss_enabled(http_auth: Any) -> bool:
|
||||||
|
"""Check if the service is http_auth is set as `aoss`."""
|
||||||
|
if (
|
||||||
|
http_auth is not None
|
||||||
|
and http_auth.service is not None
|
||||||
|
and http_auth.service == "aoss"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _bulk_ingest_embeddings(
|
def _bulk_ingest_embeddings(
|
||||||
client: Any,
|
client: Any,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
@ -82,6 +103,7 @@ def _bulk_ingest_embeddings(
|
|||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
mapping: Optional[Dict] = None,
|
mapping: Optional[Dict] = None,
|
||||||
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
|
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
|
||||||
|
is_aoss: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Bulk Ingest Embeddings into given index."""
|
"""Bulk Ingest Embeddings into given index."""
|
||||||
if not mapping:
|
if not mapping:
|
||||||
@ -107,12 +129,16 @@ def _bulk_ingest_embeddings(
|
|||||||
vector_field: embeddings[i],
|
vector_field: embeddings[i],
|
||||||
text_field: text,
|
text_field: text,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"_id": _id,
|
|
||||||
}
|
}
|
||||||
|
if is_aoss:
|
||||||
|
request["id"] = _id
|
||||||
|
else:
|
||||||
|
request["_id"] = _id
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
return_ids.append(_id)
|
return_ids.append(_id)
|
||||||
bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
|
bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
|
||||||
client.indices.refresh(index=index_name)
|
if not is_aoss:
|
||||||
|
client.indices.refresh(index=index_name)
|
||||||
return return_ids
|
return return_ids
|
||||||
|
|
||||||
|
|
||||||
@ -192,17 +218,18 @@ def _approximate_search_query_with_boolean_filter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _approximate_search_query_with_lucene_filter(
|
def _approximate_search_query_with_efficient_filter(
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
lucene_filter: Dict,
|
efficient_filter: Dict,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""For Approximate k-NN Search, with Lucene Filter."""
|
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
|
||||||
|
Faiss Engines."""
|
||||||
search_query = _default_approximate_search_query(
|
search_query = _default_approximate_search_query(
|
||||||
query_vector, k=k, vector_field=vector_field
|
query_vector, k=k, vector_field=vector_field
|
||||||
)
|
)
|
||||||
search_query["query"]["knn"][vector_field]["filter"] = lucene_filter
|
search_query["query"]["knn"][vector_field]["filter"] = efficient_filter
|
||||||
return search_query
|
return search_query
|
||||||
|
|
||||||
|
|
||||||
@ -309,11 +336,13 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
opensearch_url: str,
|
opensearch_url: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
embedding_function: Embeddings,
|
embedding_function: Embeddings,
|
||||||
|
is_aoss: bool,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
|
self.is_aoss = is_aoss
|
||||||
self.client = _get_opensearch_client(opensearch_url, **kwargs)
|
self.client = _get_opensearch_client(opensearch_url, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -358,6 +387,8 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
||||||
|
|
||||||
|
_validate_aoss_with_engines(self.is_aoss, engine)
|
||||||
|
|
||||||
mapping = _default_text_mapping(
|
mapping = _default_text_mapping(
|
||||||
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
||||||
)
|
)
|
||||||
@ -373,6 +404,7 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
text_field=text_field,
|
text_field=text_field,
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
max_chunk_bytes=max_chunk_bytes,
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
is_aoss=self.is_aoss,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
@ -404,14 +436,18 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
Optional Args for Approximate Search:
|
Optional Args for Approximate Search:
|
||||||
search_type: "approximate_search"; default: "approximate_search"
|
search_type: "approximate_search"; default: "approximate_search"
|
||||||
|
|
||||||
boolean_filter: A Boolean filter consists of a Boolean query that
|
boolean_filter: A Boolean filter is a post filter consists of a Boolean
|
||||||
contains a k-NN query and a filter.
|
query that contains a k-NN query and a filter.
|
||||||
|
|
||||||
subquery_clause: Query clause on the knn vector field; default: "must"
|
subquery_clause: Query clause on the knn vector field; default: "must"
|
||||||
|
|
||||||
lucene_filter: the Lucene algorithm decides whether to perform an exact
|
lucene_filter: the Lucene algorithm decides whether to perform an exact
|
||||||
k-NN search with pre-filtering or an approximate search with modified
|
k-NN search with pre-filtering or an approximate search with modified
|
||||||
post-filtering.
|
post-filtering. (deprecated, use `efficient_filter`)
|
||||||
|
|
||||||
|
efficient_filter: the Lucene Engine or Faiss Engine decides whether to
|
||||||
|
perform an exact k-NN search with pre-filtering or an approximate search
|
||||||
|
with modified post-filtering.
|
||||||
|
|
||||||
Optional Args for Script Scoring Search:
|
Optional Args for Script Scoring Search:
|
||||||
search_type: "script_scoring"; default: "approximate_search"
|
search_type: "script_scoring"; default: "approximate_search"
|
||||||
@ -494,15 +530,41 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
||||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.is_aoss
|
||||||
|
and search_type != "approximate_search"
|
||||||
|
and search_type != SCRIPT_SCORING_SEARCH
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Amazon OpenSearch Service Serverless only "
|
||||||
|
"supports `approximate_search` and `script_scoring`"
|
||||||
|
)
|
||||||
|
|
||||||
if search_type == "approximate_search":
|
if search_type == "approximate_search":
|
||||||
boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {})
|
boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {})
|
||||||
subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must")
|
subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must")
|
||||||
|
efficient_filter = _get_kwargs_value(kwargs, "efficient_filter", {})
|
||||||
|
# `lucene_filter` is deprecated, added for Backwards Compatibility
|
||||||
lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {})
|
lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {})
|
||||||
if boolean_filter != {} and lucene_filter != {}:
|
|
||||||
|
if boolean_filter != {} and efficient_filter != {}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Both `boolean_filter` and `lucene_filter` are provided which "
|
"Both `boolean_filter` and `efficient_filter` are provided which "
|
||||||
"is invalid"
|
"is invalid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if lucene_filter != {} and efficient_filter != {}:
|
||||||
|
raise ValueError(
|
||||||
|
"Both `lucene_filter` and `efficient_filter` are provided which "
|
||||||
|
"is invalid. `lucene_filter` is deprecated"
|
||||||
|
)
|
||||||
|
|
||||||
|
if lucene_filter != {} and boolean_filter != {}:
|
||||||
|
raise ValueError(
|
||||||
|
"Both `lucene_filter` and `boolean_filter` are provided which "
|
||||||
|
"is invalid. `lucene_filter` is deprecated"
|
||||||
|
)
|
||||||
|
|
||||||
if boolean_filter != {}:
|
if boolean_filter != {}:
|
||||||
search_query = _approximate_search_query_with_boolean_filter(
|
search_query = _approximate_search_query_with_boolean_filter(
|
||||||
embedding,
|
embedding,
|
||||||
@ -511,8 +573,16 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
vector_field=vector_field,
|
vector_field=vector_field,
|
||||||
subquery_clause=subquery_clause,
|
subquery_clause=subquery_clause,
|
||||||
)
|
)
|
||||||
|
elif efficient_filter != {}:
|
||||||
|
search_query = _approximate_search_query_with_efficient_filter(
|
||||||
|
embedding, efficient_filter, k=k, vector_field=vector_field
|
||||||
|
)
|
||||||
elif lucene_filter != {}:
|
elif lucene_filter != {}:
|
||||||
search_query = _approximate_search_query_with_lucene_filter(
|
warnings.warn(
|
||||||
|
"`lucene_filter` is deprecated. Please use the keyword argument"
|
||||||
|
" `efficient_filter`"
|
||||||
|
)
|
||||||
|
search_query = _approximate_search_query_with_efficient_filter(
|
||||||
embedding, lucene_filter, k=k, vector_field=vector_field
|
embedding, lucene_filter, k=k, vector_field=vector_field
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -659,6 +729,7 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
"ef_construction",
|
"ef_construction",
|
||||||
"m",
|
"m",
|
||||||
"max_chunk_bytes",
|
"max_chunk_bytes",
|
||||||
|
"is_aoss",
|
||||||
]
|
]
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
||||||
@ -672,6 +743,15 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||||
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
||||||
|
http_auth = _get_kwargs_value(kwargs, "http_auth", None)
|
||||||
|
is_aoss = _is_aoss_enabled(http_auth=http_auth)
|
||||||
|
|
||||||
|
if is_aoss and not is_appx_search:
|
||||||
|
raise ValueError(
|
||||||
|
"Amazon OpenSearch Service Serverless only "
|
||||||
|
"supports `approximate_search`"
|
||||||
|
)
|
||||||
|
|
||||||
if is_appx_search:
|
if is_appx_search:
|
||||||
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
|
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
|
||||||
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
||||||
@ -679,6 +759,8 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
|
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
|
||||||
m = _get_kwargs_value(kwargs, "m", 16)
|
m = _get_kwargs_value(kwargs, "m", 16)
|
||||||
|
|
||||||
|
_validate_aoss_with_engines(is_aoss, engine)
|
||||||
|
|
||||||
mapping = _default_text_mapping(
|
mapping = _default_text_mapping(
|
||||||
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
||||||
)
|
)
|
||||||
@ -697,5 +779,6 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
text_field=text_field,
|
text_field=text_field,
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
max_chunk_bytes=max_chunk_bytes,
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
is_aoss=is_aoss,
|
||||||
)
|
)
|
||||||
return cls(opensearch_url, index_name, embedding, **kwargs)
|
return cls(opensearch_url, index_name, embedding, is_aoss, **kwargs)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""Test OpenSearch functionality."""
|
"""Test OpenSearch functionality."""
|
||||||
|
|
||||||
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
|
from opensearchpy import AWSV4SignerAuth
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.opensearch_vector_search import (
|
from langchain.vectorstores.opensearch_vector_search import (
|
||||||
@ -213,3 +215,95 @@ def test_opensearch_with_custom_field_name_appx_false() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.similarity_search("add", k=1)
|
output = docsearch.similarity_search("add", k=1)
|
||||||
assert output == [Document(page_content="add")]
|
assert output == [Document(page_content="add")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_opensearch_serverless_with_scripting_search_indexing_throws_error() -> None:
|
||||||
|
"""Test to validate indexing using Serverless without Approximate Search."""
|
||||||
|
region = "test-region"
|
||||||
|
service = "aoss"
|
||||||
|
credentials = boto3.Session().get_credentials()
|
||||||
|
auth = AWSV4SignerAuth(credentials, region, service)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenSearchVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||||
|
is_appx_search=False,
|
||||||
|
http_auth=auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_opensearch_serverless_with_lucene_engine_throws_error() -> None:
|
||||||
|
"""Test to validate indexing using lucene engine with Serverless."""
|
||||||
|
region = "test-region"
|
||||||
|
service = "aoss"
|
||||||
|
credentials = boto3.Session().get_credentials()
|
||||||
|
auth = AWSV4SignerAuth(credentials, region, service)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenSearchVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||||
|
engine="lucene",
|
||||||
|
http_auth=auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_appx_search_with_efficient_and_bool_filter_throws_error() -> None:
|
||||||
|
"""Test Approximate Search with Efficient and Bool Filter throws Error."""
|
||||||
|
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||||
|
boolean_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.similarity_search(
|
||||||
|
"foo",
|
||||||
|
k=3,
|
||||||
|
efficient_filter=efficient_filter_val,
|
||||||
|
boolean_filter=boolean_filter_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_appx_search_with_efficient_and_lucene_filter_throws_error() -> None:
|
||||||
|
"""Test Approximate Search with Efficient and Lucene Filter throws Error."""
|
||||||
|
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||||
|
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.similarity_search(
|
||||||
|
"foo",
|
||||||
|
k=3,
|
||||||
|
efficient_filter=efficient_filter_val,
|
||||||
|
lucene_filter=lucene_filter_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_appx_search_with_boolean_and_lucene_filter_throws_error() -> None:
|
||||||
|
"""Test Approximate Search with Boolean and Lucene Filter throws Error."""
|
||||||
|
boolean_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||||
|
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.similarity_search(
|
||||||
|
"foo",
|
||||||
|
k=3,
|
||||||
|
boolean_filter=boolean_filter_val,
|
||||||
|
lucene_filter=lucene_filter_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_appx_search_with_faiss_efficient_filter() -> None:
|
||||||
|
"""Test Approximate Search with Faiss Efficient Filter."""
|
||||||
|
efficient_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="faiss"
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=3, efficient_filter=efficient_filter_val
|
||||||
|
)
|
||||||
|
assert output == [Document(page_content="bar")]
|
||||||
|
Loading…
Reference in New Issue
Block a user