mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
community[minor]: Rememberizer retriever (#20052)
**Description:** This pull request introduces a new feature for LangChain: the integration with the Rememberizer API through a custom retriever. This enables LangChain applications to allow users to load and sync their data from Dropbox, Google Drive, Slack, their hard drive into a vector database that LangChain can query. Queries involve sending text chunks generated within LangChain and retrieving a collection of semantically relevant user data for inclusion in LLM prompts. User knowledge dramatically improved AI applications. The Rememberizer integration will also allow users to access general purpose vectorized data such as Reddit channel discussions and US patents. **Issue:** N/A **Dependencies:** N/A **Twitter handle:** https://twitter.com/Rememberizer
This commit is contained in:
@@ -202,6 +202,7 @@ _module_lookup = {
|
||||
"PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501
|
||||
"PubMedRetriever": "langchain_community.retrievers.pubmed",
|
||||
"QdrantSparseVectorRetriever": "langchain_community.retrievers.qdrant_sparse_vector_retriever", # noqa: E501
|
||||
"RememberizerRetriever": "langchain_community.retrievers.rememberizer",
|
||||
"RemoteLangChainRetriever": "langchain_community.retrievers.remote_retriever",
|
||||
"SVMRetriever": "langchain_community.retrievers.svm",
|
||||
"TFIDFRetriever": "langchain_community.retrievers.tfidf",
|
||||
|
@@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
|
||||
|
||||
|
||||
class RememberizerRetriever(BaseRetriever, RememberizerAPIWrapper):
|
||||
"""`Rememberizer` retriever.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all RememberizerAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
@@ -265,6 +265,7 @@ _module_lookup = {
|
||||
"PowerBIDataset": "langchain_community.utilities.powerbi",
|
||||
"PubMedAPIWrapper": "langchain_community.utilities.pubmed",
|
||||
"PythonREPL": "langchain_community.utilities.python",
|
||||
"RememberizerAPIWrapper": "langchain_community.utilities.rememberizer",
|
||||
"Requests": "langchain_community.utilities.requests",
|
||||
"RequestsWrapper": "langchain_community.utilities.requests",
|
||||
"RivaASR": "langchain_community.utilities.nvidia_riva",
|
||||
|
48
libs/community/langchain_community/utilities/rememberizer.py
Normal file
48
libs/community/langchain_community/utilities/rememberizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Wrapper for Rememberizer APIs."""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class RememberizerAPIWrapper(BaseModel):
|
||||
"""Wrapper for Rememberizer APIs."""
|
||||
|
||||
top_k_results: int = 10
|
||||
rememberizer_api_key: Optional[str] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key in environment."""
|
||||
rememberizer_api_key = get_from_dict_or_env(
|
||||
values, "rememberizer_api_key", "REMEMBERIZER_API_KEY"
|
||||
)
|
||||
values["rememberizer_api_key"] = rememberizer_api_key
|
||||
|
||||
return values
|
||||
|
||||
def search(self, query: str) -> dict:
|
||||
"""Search for a query in the Rememberizer API."""
|
||||
url = f"https://api.rememberizer.ai/api/v1/documents/search?q={query}&n={self.top_k_results}"
|
||||
response = requests.get(url, headers={"x-api-key": self.rememberizer_api_key})
|
||||
data = response.json()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"API Error: {data}")
|
||||
|
||||
matched_chunks = data.get("matched_chunks", [])
|
||||
return matched_chunks
|
||||
|
||||
def load(self, query: str) -> List[Document]:
|
||||
matched_chunks = self.search(query)
|
||||
docs = []
|
||||
for matched_chunk in matched_chunks:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=matched_chunk["matched_content"],
|
||||
metadata=matched_chunk["document"],
|
||||
)
|
||||
)
|
||||
return docs
|
@@ -29,6 +29,7 @@ EXPECTED_ALL = [
|
||||
"PubMedRetriever",
|
||||
"QdrantSparseVectorRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"RememberizerRetriever",
|
||||
"SVMRetriever",
|
||||
"TavilySearchAPIRetriever",
|
||||
"TFIDFRetriever",
|
||||
|
@@ -42,6 +42,7 @@ EXPECTED_ALL = [
|
||||
"PythonREPL",
|
||||
"Requests",
|
||||
"RequestsWrapper",
|
||||
"RememberizerAPIWrapper",
|
||||
"SQLDatabase",
|
||||
"SceneXplainAPIWrapper",
|
||||
"SearchApiAPIWrapper",
|
||||
|
@@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import responses
|
||||
|
||||
from langchain_community.utilities import RememberizerAPIWrapper
|
||||
|
||||
|
||||
class TestRememberizerAPIWrapper(unittest.TestCase):
|
||||
@responses.activate
|
||||
def test_search_successful(self) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"https://api.rememberizer.ai/api/v1/documents/search?q=test&n=10",
|
||||
json={
|
||||
"matched_chunks": [
|
||||
{
|
||||
"chunk_id": "chunk",
|
||||
"matched_content": "content",
|
||||
"document": {"id": "id", "name": "name"},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
result = wrapper.search("test")
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
"chunk_id": "chunk",
|
||||
"matched_content": "content",
|
||||
"document": {"id": "id", "name": "name"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@responses.activate
|
||||
def test_search_fail(self) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"https://api.rememberizer.ai/api/v1/documents/search?q=test&n=10",
|
||||
status=400,
|
||||
json={"detail": "Incorrect authentication credentials."},
|
||||
)
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
wrapper.search("test")
|
||||
self.assertEqual(
|
||||
str(e.exception),
|
||||
"API Error: {'detail': 'Incorrect authentication credentials.'}",
|
||||
)
|
||||
|
||||
@patch("langchain_community.utilities.rememberizer.RememberizerAPIWrapper.search")
|
||||
def test_load(self, mock_search: Any) -> None:
|
||||
mock_search.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"matched_content": "content1",
|
||||
"document": {"id": "id1", "name": "name1"},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"matched_content": "content2",
|
||||
"document": {"id": "id2", "name": "name2"},
|
||||
},
|
||||
]
|
||||
wrapper = RememberizerAPIWrapper(rememberizer_api_key="dummy_key", n=10)
|
||||
result = wrapper.load("test")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].page_content, "content1")
|
||||
self.assertEqual(result[0].metadata, {"id": "id1", "name": "name1"})
|
||||
self.assertEqual(result[1].page_content, "content2")
|
||||
self.assertEqual(result[1].metadata, {"id": "id2", "name": "name2"})
|
Reference in New Issue
Block a user