mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community[minor]: Added integrations for ThirdAI's NeuralDB as a Retriever (#17334)
**Description:** Adds ThirdAI NeuralDB retriever integration. NeuralDB is a CPU-friendly and fine-tunable text retrieval engine. We previously added a vector store integration but we think that it will be easier for our customers if they can also find us under under langchain-community/retrievers. --------- Co-authored-by: kartikTAI <129414343+kartikTAI@users.noreply.github.com> Co-authored-by: Kartik Sarangmath <kartik@thirdai.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.retrievers import NeuralDBRetriever
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_csv() -> Generator[str, None, None]:
|
||||
csv = "thirdai-test.csv"
|
||||
with open(csv, "w") as o:
|
||||
o.write("column_1,column_2\n")
|
||||
o.write("column one,column two\n")
|
||||
yield csv
|
||||
os.remove(csv)
|
||||
|
||||
|
||||
def assert_result_correctness(documents: list) -> None:
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
|
||||
retriever = NeuralDBRetriever.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
documents = retriever.get_relevant_documents("column")
|
||||
assert_result_correctness(documents)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
|
||||
checkpoint = "thirdai-test-save.ndb"
|
||||
if os.path.exists(checkpoint):
|
||||
shutil.rmtree(checkpoint)
|
||||
try:
|
||||
retriever = NeuralDBRetriever.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
retriever.save(checkpoint)
|
||||
loaded_retriever = NeuralDBRetriever.from_checkpoint(checkpoint)
|
||||
documents = loaded_retriever.get_relevant_documents("column")
|
||||
assert_result_correctness(documents)
|
||||
finally:
|
||||
if os.path.exists(checkpoint):
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
|
||||
retriever = NeuralDBRetriever.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
# Make sure they don't throw an error.
|
||||
retriever.associate("A", "B")
|
||||
retriever.associate_batch([("A", "B"), ("C", "D")])
|
||||
retriever.upvote("A", 0)
|
||||
retriever.upvote_batch([("A", 0), ("B", 0)])
|
@@ -46,14 +46,6 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
|
||||
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
|
||||
retriever.insert([test_csv])
|
||||
documents = retriever.similarity_search("column")
|
||||
assert_result_correctness(documents)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
|
@@ -40,6 +40,7 @@ EXPECTED_ALL = [
|
||||
"ZepRetriever",
|
||||
"ZillizRetriever",
|
||||
"DocArrayRetriever",
|
||||
"NeuralDBRetriever",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user