community[minor]: Added integrations for ThirdAI's NeuralDB as a Retriever (#17334)

**Description:** Adds ThirdAI NeuralDB retriever integration. NeuralDB
is a CPU-friendly and fine-tunable text retrieval engine. We previously
added a vector store integration but we think that it will be easier for
our customers if they can also find us under under
langchain-community/retrievers.

---------

Co-authored-by: kartikTAI <129414343+kartikTAI@users.noreply.github.com>
Co-authored-by: Kartik Sarangmath <kartik@thirdai.com>
This commit is contained in:
Benito Geordie
2024-04-16 18:36:55 -05:00
committed by GitHub
parent e9fc87aab1
commit 57b226532d
8 changed files with 469 additions and 63 deletions

View File

@@ -212,6 +212,7 @@ _module_lookup = {
"YouRetriever": "langchain_community.retrievers.you",
"ZepRetriever": "langchain_community.retrievers.zep",
"ZillizRetriever": "langchain_community.retrievers.zilliz",
"NeuralDBRetriever": "langchain_community.retrievers.thirdai_neuraldb",
}

View File

@@ -0,0 +1,260 @@
from __future__ import annotations
import importlib
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
class NeuralDBRetriever(BaseRetriever):
"""Document retriever that uses ThirdAI's NeuralDB."""
thirdai_key: SecretStr
"""ThirdAI API Key"""
db: Any = None #: :meta private:
"""NeuralDB instance"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
@staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
try:
from thirdai import licensing
importlib.util.find_spec("thirdai.neural_db")
licensing.activate(thirdai_key or os.getenv("THIRDAI_KEY"))
except ImportError:
raise ModuleNotFoundError(
"Could not import thirdai python package and neuraldb dependencies. "
"Please install it with `pip install thirdai[neural_db]`."
)
@classmethod
def from_scratch(
cls,
thirdai_key: Optional[str] = None,
**model_kwargs: dict,
) -> NeuralDBRetriever:
"""
Create a NeuralDBRetriever from scratch.
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
API key, or pass ``thirdai_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import NeuralDBRetriever
retriever = NeuralDBRetriever.from_scratch(
thirdai_key="your-thirdai-key",
)
retriever.insert([
"/path/to/doc.pdf",
"/path/to/doc.docx",
"/path/to/doc.csv",
])
documents = retriever.get_relevant_documents("AI-driven music therapy")
"""
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB(**model_kwargs))
@classmethod
def from_checkpoint(
cls,
checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None,
) -> NeuralDBRetriever:
"""
Create a NeuralDBRetriever with a base model from a saved checkpoint
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
API key, or pass ``thirdai_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import NeuralDBRetriever
retriever = NeuralDBRetriever.from_checkpoint(
checkpoint="/path/to/checkpoint.ndb",
thirdai_key="your-thirdai-key",
)
retriever.insert([
"/path/to/doc.pdf",
"/path/to/doc.docx",
"/path/to/doc.csv",
])
documents = retriever.get_relevant_documents("AI-driven music therapy")
"""
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint))
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate ThirdAI environment variables."""
values["thirdai_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"thirdai_key",
"THIRDAI_KEY",
)
)
return values
def insert(
self,
sources: List[Any],
train: bool = True,
fast_mode: bool = True,
**kwargs: dict,
) -> None:
"""Inserts files / document sources into the retriever.
Args:
train: When True this means that the underlying model in the
NeuralDB will undergo unsupervised pretraining on the inserted files.
Defaults to True.
fast_mode: Much faster insertion with a slight drop in performance.
Defaults to True.
"""
sources = self._preprocess_sources(sources)
self.db.insert(
sources=sources,
train=train,
fast_approximation=fast_mode,
**kwargs,
)
def _preprocess_sources(self, sources: list) -> list:
"""Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects.
Args:
sources: list of either string paths to PDF, DOCX or CSV files, or
NeuralDB document objects.
"""
from thirdai import neural_db as ndb
if not sources:
return sources
preprocessed_sources = []
for doc in sources:
if not isinstance(doc, str):
preprocessed_sources.append(doc)
else:
if doc.lower().endswith(".pdf"):
preprocessed_sources.append(ndb.PDF(doc))
elif doc.lower().endswith(".docx"):
preprocessed_sources.append(ndb.DOCX(doc))
elif doc.lower().endswith(".csv"):
preprocessed_sources.append(ndb.CSV(doc))
else:
raise RuntimeError(
f"Could not automatically load {doc}. Only files "
"with .pdf, .docx, or .csv extensions can be loaded "
"automatically. For other formats, please use the "
"appropriate document object from the ThirdAI library."
)
return preprocessed_sources
def upvote(self, query: str, document_id: int) -> None:
"""The retriever upweights the score of a document for a specific query.
This is useful for fine-tuning the retriever to user behavior.
Args:
query: text to associate with `document_id`
document_id: id of the document to associate query with.
"""
self.db.text_to_result(query, document_id)
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
"""Given a batch of (query, document id) pairs, the retriever upweights
the scores of the document for the corresponding queries.
This is useful for fine-tuning the retriever to user behavior.
Args:
query_id_pairs: list of (query, document id) pairs. For each pair in
this list, the model will upweight the document id for the query.
"""
self.db.text_to_result_batch(query_id_pairs)
def associate(self, source: str, target: str) -> None:
"""The retriever associates a source phrase with a target phrase.
When the retriever sees the source phrase, it will also consider results
that are relevant to the target phrase.
Args:
source: text to associate to `target`.
target: text to associate `source` to.
"""
self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
"""Given a batch of (source, target) pairs, the retriever associates
each source phrase with the corresponding target phrase.
Args:
text_pairs: list of (source, target) text pairs. For each pair in
this list, the source will be associated with the target.
"""
self.db.associate_batch(text_pairs)
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {top_k} contexts with your retriever for a given query
Args:
query: Query to submit to the model
top_k: The max number of context results to retrieve. Defaults to 10.
"""
try:
if "top_k" not in kwargs:
kwargs["top_k"] = 10
references = self.db.search(query=query, **kwargs)
return [
Document(
page_content=ref.text,
metadata={
"id": ref.id,
"upvote_ids": ref.upvote_ids,
"source": ref.source,
"metadata": ref.metadata,
"score": ref.score,
"context": ref.context(1),
},
)
for ref in references
]
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str) -> None:
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path)
Args:
path: path on disk to save the NeuralDB instance to.
"""
self.db.save(path)

View File

@@ -86,48 +86,6 @@ class NeuralDBVectorStore(VectorStore):
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
@classmethod
def from_bazaar( # type: ignore[no-untyped-def]
cls,
base: str,
bazaar_cache: Optional[str] = None,
thirdai_key: Optional[str] = None,
):
"""
Create a NeuralDBVectorStore with a base model from the ThirdAI
model bazaar.
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
API key, or pass ``thirdai_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.vectorstores import NeuralDBVectorStore
vectorstore = NeuralDBVectorStore.from_bazaar(
base="General QnA",
thirdai_key="your-thirdai-key",
)
vectorstore.insert([
"/path/to/doc.pdf",
"/path/to/doc.docx",
"/path/to/doc.csv",
])
documents = vectorstore.similarity_search("AI-driven music therapy")
"""
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
cache = bazaar_cache or str(Path(os.getcwd()) / "model_bazaar")
if not os.path.exists(cache):
os.mkdir(cache)
model_bazaar = ndb.Bazaar(cache)
model_bazaar.fetch()
return cls(db=model_bazaar.get_model(base)) # type: ignore[call-arg]
@classmethod
def from_checkpoint( # type: ignore[no-untyped-def]
cls,

View File

@@ -0,0 +1,58 @@
import os
import shutil
from typing import Generator
import pytest
from langchain_community.retrievers import NeuralDBRetriever
@pytest.fixture(scope="session")
def test_csv() -> Generator[str, None, None]:
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
o.write("column one,column two\n")
yield csv
os.remove(csv)
def assert_result_correctness(documents: list) -> None:
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
documents = retriever.get_relevant_documents("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
try:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
retriever.save(checkpoint)
loaded_retriever = NeuralDBRetriever.from_checkpoint(checkpoint)
documents = loaded_retriever.get_relevant_documents("column")
assert_result_correctness(documents)
finally:
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.
retriever.associate("A", "B")
retriever.associate_batch([("A", "B"), ("C", "D")])
retriever.upvote("A", 0)
retriever.upvote_batch([("A", 0), ("B", 0)])

View File

@@ -46,14 +46,6 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
retriever.insert([test_csv])
documents = retriever.similarity_search("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch()

View File

@@ -40,6 +40,7 @@ EXPECTED_ALL = [
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"NeuralDBRetriever",
]