Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
617a477fbb RFC: rich retrieval results 2024-01-22 12:46:33 -08:00
2 changed files with 111 additions and 19 deletions

View File

@@ -3,10 +3,12 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, overload
from langchain_core.documents import Document
from langchain_core.load import Serializable
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import (
Runnable,
RunnableConfig,
@@ -22,8 +24,19 @@ if TYPE_CHECKING:
Callbacks,
)
class DocumentResult(Serializable):
document: Document
metadata: dict = Field(default_factory=dict)
class RetrievalResult(Serializable):
documents: List[DocumentResult]
metadata: dict = Field(default_factory=dict)
RetrieverInput = str
RetrieverOutput = List[Document]
RetrieverOutput = Union[List[Document], RetrievalResult]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
@@ -39,14 +52,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
docs: RetrieverOutput
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(self, query: str) -> RetrieverOutput:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
@@ -116,7 +129,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
) -> RetrieverOutput:
config = ensure_config(config)
return self.get_relevant_documents(
input,
@@ -131,7 +144,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
) -> RetrieverOutput:
config = ensure_config(config)
return await self.aget_relevant_documents(
input,
@@ -153,9 +166,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
List of relevant documents
"""
def _get_relevant_results(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> RetrievalResult:
"""Get results relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
RetrievalResult
"""
raise NotImplementedError()
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
) -> RetrieverOutput:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
@@ -170,16 +195,45 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
run_manager=run_manager.get_sync(),
)
@overload
def get_relevant_documents(
self,
query: str,
*,
docs_only: Literal[True] = True,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
""""""
@overload
def get_relevant_documents(
self,
query: str,
*,
docs_only: Literal[False],
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> RetrievalResult:
""""""
def get_relevant_documents(
self,
query: str,
*,
docs_only: bool = True,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> RetrieverOutput:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
@@ -211,19 +265,27 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
if docs_only:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result: RetrieverOutput = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
else:
result = self._get_relevant_documents(query, **_kwargs)
result = self._get_relevant_results(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
callback_result = (
result
if isinstance(result, list)
else [doc_res.document for doc_res in result.documents]
)
run_manager.on_retriever_end(
result,
callback_result,
**kwargs,
)
return result
@@ -237,7 +299,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
) -> RetrieverOutput:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for

View File

@@ -21,7 +21,7 @@ from typing import (
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, DocumentResult, RetrievalResult
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
@@ -650,22 +650,52 @@ class VectorStoreRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
result = self._get_relevant_results(query, run_manager=run_manager)
return [doc_result.document for doc_result in result.documents]
def _get_relevant_results(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> RetrievalResult:
result_metadata = {
"query": query,
"search_type": self.search_type,
"search_kwargs": self.search_kwargs,
}
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **self.search_kwargs
)
result = RetrievalResult(
documents=[
DocumentResult(document=doc, metadata={"score": score})
for doc, score in docs_and_scores
],
metadata=result_metadata,
)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
result = RetrievalResult(
documents=[
DocumentResult(document=doc, metadata={"relevance_score": score})
for doc, score in docs_and_similarities
],
metadata=result_metadata,
)
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
result = RetrievalResult(
documents=[DocumentResult(document=doc) for doc in docs],
metadata=result_metadata,
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
return result
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun