mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 06:26:12 +00:00
Accept uuids kwargs for weaviate (#4800)
# Accept uuids kwargs for weaviate Fixes #4791
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Test Weaviate functionality."""
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -85,3 +86,28 @@ class TestWeaviateHybridSearchRetriever:
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
]
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_get_relevant_documents_with_uuids(self, weaviate_url: str) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||
|
||||
client = Client(weaviate_url)
|
||||
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=f"LangChain_{uuid4().hex}",
|
||||
text_key="text",
|
||||
attributes=["page"],
|
||||
)
|
||||
for i, text in enumerate(texts):
|
||||
# hoge
|
||||
retriever.add_documents(
|
||||
[Document(page_content=text, metadata=metadatas[i])], uuids=[uuids[i]]
|
||||
)
|
||||
|
||||
output = retriever.get_relevant_documents("foo")
|
||||
assert len(output) == 1
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Test Weaviate functionality."""
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
@@ -80,6 +81,26 @@ class TestWeaviate:
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_similarity_search_with_uuids(
|
||||
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||
) -> None:
|
||||
"""Test end to end construction and search with uuids."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts]
|
||||
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Weaviate.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
metadatas=metadatas,
|
||||
weaviate_url=weaviate_url,
|
||||
uuids=uuids,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert len(output) == 1
|
||||
|
||||
@pytest.mark.vcr(ignore_localhost=True)
|
||||
def test_max_marginal_relevance_search(
|
||||
self, weaviate_url: str, embedding_openai: OpenAIEmbeddings
|
||||
@@ -181,3 +202,23 @@ class TestWeaviate:
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="foo"),
|
||||
]
|
||||
|
||||
def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
embedding = FakeEmbeddings()
|
||||
uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts]
|
||||
|
||||
docsearch = Weaviate.from_texts(
|
||||
texts,
|
||||
embedding=embedding,
|
||||
weaviate_url=weaviate_url,
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
# Weaviate replaces the object if the UUID already exists
|
||||
docsearch.add_texts(["foo"], uuids=[uuids[0]])
|
||||
output = docsearch.similarity_search_by_vector(
|
||||
embedding.embed_query("foo"), k=2
|
||||
)
|
||||
assert output[0] == Document(page_content="foo")
|
||||
assert output[1] != Document(page_content="foo")
|
||||
|
Reference in New Issue
Block a user