mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 20:49:17 +00:00
Harrison/opensearch logic (#3631)
Co-authored-by: engineer-matsuo <95115586+engineer-matsuo@users.noreply.github.com>
This commit is contained in:
parent
cf384dcb7f
commit
ab749fa1bb
@ -35,6 +35,15 @@ def _import_bulk() -> Any:
|
|||||||
return bulk
|
return bulk
|
||||||
|
|
||||||
|
|
||||||
|
def _import_not_found_error() -> Any:
|
||||||
|
"""Import not found error if available, otherwise raise error."""
|
||||||
|
try:
|
||||||
|
from opensearchpy.exceptions import NotFoundError
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||||
|
return NotFoundError
|
||||||
|
|
||||||
|
|
||||||
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
||||||
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
|
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
|
||||||
try:
|
try:
|
||||||
@ -67,11 +76,20 @@ def _bulk_ingest_embeddings(
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
|
mapping: Dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Bulk Ingest Embeddings into given index."""
|
"""Bulk Ingest Embeddings into given index."""
|
||||||
bulk = _import_bulk()
|
bulk = _import_bulk()
|
||||||
|
not_found_error = _import_not_found_error()
|
||||||
requests = []
|
requests = []
|
||||||
ids = []
|
ids = []
|
||||||
|
mapping = mapping
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.indices.get(index=index_name)
|
||||||
|
except not_found_error:
|
||||||
|
client.indices.create(index=index_name, body=mapping)
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
metadata = metadatas[i] if metadatas else {}
|
metadata = metadatas[i] if metadatas else {}
|
||||||
_id = str(uuid.uuid4())
|
_id = str(uuid.uuid4())
|
||||||
@ -311,8 +329,19 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
"""
|
"""
|
||||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||||
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
||||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
|
||||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
|
||||||
|
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
||||||
|
ef_search = _get_kwargs_value(kwargs, "ef_search", 512)
|
||||||
|
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
|
||||||
|
m = _get_kwargs_value(kwargs, "m", 16)
|
||||||
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
|
|
||||||
|
mapping = _default_text_mapping(
|
||||||
|
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
||||||
|
)
|
||||||
|
|
||||||
return _bulk_ingest_embeddings(
|
return _bulk_ingest_embeddings(
|
||||||
self.client,
|
self.client,
|
||||||
self.index_name,
|
self.index_name,
|
||||||
@ -321,6 +350,7 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
metadatas,
|
metadatas,
|
||||||
vector_field,
|
vector_field,
|
||||||
text_field,
|
text_field,
|
||||||
|
mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
@ -532,8 +562,14 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
|
|
||||||
[kwargs.pop(key, None) for key in keys_list]
|
[kwargs.pop(key, None) for key in keys_list]
|
||||||
client = _get_opensearch_client(opensearch_url, **kwargs)
|
client = _get_opensearch_client(opensearch_url, **kwargs)
|
||||||
client.indices.create(index=index_name, body=mapping)
|
|
||||||
_bulk_ingest_embeddings(
|
_bulk_ingest_embeddings(
|
||||||
client, index_name, embeddings, texts, metadatas, vector_field, text_field
|
client,
|
||||||
|
index_name,
|
||||||
|
embeddings,
|
||||||
|
texts,
|
||||||
|
metadatas,
|
||||||
|
vector_field,
|
||||||
|
text_field,
|
||||||
|
mapping,
|
||||||
)
|
)
|
||||||
return cls(opensearch_url, index_name, embedding, **kwargs)
|
return cls(opensearch_url, index_name, embedding, **kwargs)
|
||||||
|
@ -174,3 +174,26 @@ def test_appx_search_with_lucene_filter() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=3, lucene_filter=lucene_filter_val)
|
output = docsearch.similarity_search("foo", k=3, lucene_filter=lucene_filter_val)
|
||||||
assert output == [Document(page_content="bar")]
|
assert output == [Document(page_content="bar")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_opensearch_with_custom_field_name_appx_true() -> None:
|
||||||
|
"""Test Approximate Search with custom field name appx true."""
|
||||||
|
text_input = ["test", "add", "text", "method"]
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
text_input,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||||
|
is_appx_search=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("add", k=1)
|
||||||
|
assert output == [Document(page_content="add")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_opensearch_with_custom_field_name_appx_false() -> None:
|
||||||
|
"""Test Approximate Search with custom field name appx true."""
|
||||||
|
text_input = ["test", "add", "text", "method"]
|
||||||
|
docsearch = OpenSearchVectorSearch.from_texts(
|
||||||
|
text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("add", k=1)
|
||||||
|
assert output == [Document(page_content="add")]
|
||||||
|
Loading…
Reference in New Issue
Block a user