community[minor]: Added integrations for ThirdAI's NeuralDB as a Retriever (#17334)

**Description:** Adds ThirdAI NeuralDB retriever integration. NeuralDB
is a CPU-friendly and fine-tunable text retrieval engine. We previously
added a vector store integration but we think that it will be easier for
our customers if they can also find us under under
langchain-community/retrievers.

---------

Co-authored-by: kartikTAI <129414343+kartikTAI@users.noreply.github.com>
Co-authored-by: Kartik Sarangmath <kartik@thirdai.com>
This commit is contained in:
Benito Geordie
2024-04-16 18:36:55 -05:00
committed by GitHub
parent e9fc87aab1
commit 57b226532d
8 changed files with 469 additions and 63 deletions

View File

@@ -0,0 +1,58 @@
import os
import shutil
from typing import Generator
import pytest
from langchain_community.retrievers import NeuralDBRetriever
@pytest.fixture(scope="session")
def test_csv() -> Generator[str, None, None]:
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
o.write("column one,column two\n")
yield csv
os.remove(csv)
def assert_result_correctness(documents: list) -> None:
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
documents = retriever.get_relevant_documents("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
try:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
retriever.save(checkpoint)
loaded_retriever = NeuralDBRetriever.from_checkpoint(checkpoint)
documents = loaded_retriever.get_relevant_documents("column")
assert_result_correctness(documents)
finally:
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.
retriever.associate("A", "B")
retriever.associate_batch([("A", "B"), ("C", "D")])
retriever.upvote("A", 0)
retriever.upvote_batch([("A", 0), ("B", 0)])

View File

@@ -46,14 +46,6 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
retriever.insert([test_csv])
documents = retriever.similarity_search("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch()

View File

@@ -40,6 +40,7 @@ EXPECTED_ALL = [
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"NeuralDBRetriever",
]