Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
b03a616cf5 poc 2023-09-19 23:10:16 -07:00

View File

@@ -3,6 +3,7 @@ from typing import List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, BaseStore, Document
from langchain.schema.runnable import RunnableLambda
from langchain.vectorstores import VectorStore
@@ -18,6 +19,17 @@ class MultiVectorRetriever(BaseRetriever):
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
def _similarity_search(self, query: str) -> List[Document]:
"""Search for similar documents to a query.
Args:
args: A dictionary with the following keys:
query: The query to search for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
return self.vectorstore.similarity_search(query, **self.search_kwargs)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
@@ -28,7 +40,9 @@ class MultiVectorRetriever(BaseRetriever):
Returns:
List of relevant documents
"""
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
sub_docs = RunnableLambda(self._similarity_search).invoke(
query, {"callbacks": run_manager.get_child()}
)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs: