From c262cef1fbdecd3e805bba04a85018989edc5fb3 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 25 Apr 2024 17:32:32 -0400 Subject: [PATCH] update SelfQueryRetriever --- libs/core/langchain_core/load/mapping.py | 12 +++++ .../langchain/retrievers/self_query/base.py | 50 +++++++++++++++---- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index 417d54e35e5..190dc9decb6 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -157,6 +157,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { "base", "Document", ), + ("langchain", "schema", "document_search_hit", "DocumentSearchHit"): ( + "langchain_core", + "documents", + "base", + "DocumentSearchHit", + ), ("langchain", "output_parsers", "fix", "OutputFixingParser"): ( "langchain", "output_parsers", @@ -666,6 +672,12 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { "base", "Document", ), + ("langchain_core", "documents", "base", "DocumentSearchHit"): ( + "langchain_core", + "documents", + "base", + "DocumentSearchHit", + ), ("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): ( "langchain_core", "prompts", diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 9d1e79eb61c..ba8dd964f1b 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain_core.documents import Document +from langchain_core.documents import Document, DocumentSearchHit from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever @@ -192,19 +192,43 @@ class SelfQueryRetriever(BaseRetriever): return new_query, search_kwargs def _get_docs_with_query( - self, query: str, search_kwargs: Dict[str, Any] + self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False ) -> List[Document]: - docs = self.vectorstore.search(query, self.search_type, **search_kwargs) + if include_score: + docs_and_scores = self.vectorstore.similarity_search_with_score( + query, **search_kwargs + ) + return [ + DocumentSearchHit(page_content=doc.page_content, score=score) + for doc, score in docs_and_scores + ] + else: + docs = self.vectorstore.search(query, self.search_type, **search_kwargs) return docs async def _aget_docs_with_query( - self, query: str, search_kwargs: Dict[str, Any] + self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False ) -> List[Document]: - docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) + if include_score: + docs_and_scores = await self.vectorstore.asimilarity_search_with_score( + query, **search_kwargs + ) + return [ + DocumentSearchHit(page_content=doc.page_content, score=score) + for doc, score in docs_and_scores + ] + else: + docs = await self.vectorstore.asearch( + query, self.search_type, **search_kwargs + ) return docs def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + include_score: bool = False, ) -> List[Document]: """Get documents relevant for a query. @@ -220,11 +244,17 @@ class SelfQueryRetriever(BaseRetriever): if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) - docs = self._get_docs_with_query(new_query, search_kwargs) + docs = self._get_docs_with_query( + new_query, search_kwargs, include_score=include_score + ) return docs async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + include_score: bool = False, ) -> List[Document]: """Get documents relevant for a query. @@ -240,7 +270,9 @@ class SelfQueryRetriever(BaseRetriever): if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) - docs = await self._aget_docs_with_query(new_query, search_kwargs) + docs = await self._aget_docs_with_query( + new_query, search_kwargs, include_score=include_score + ) return docs @classmethod