From b346d4a455fb3854a5c6960ed9ae77b798e080ba Mon Sep 17 00:00:00 2001 From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:30:48 +0100 Subject: [PATCH] Add message to documents (#12552) This adds the response message as a document to the rag retriever so users can choose to use this. Also drops document limit. --------- Co-authored-by: Bagatur --- .../langchain/langchain/chat_models/cohere.py | 14 +++++++++-- .../retrievers/cohere_rag_retriever.py | 24 +++++++++++++------ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/chat_models/cohere.py b/libs/langchain/langchain/chat_models/cohere.py index 19eaffc685a..4b09e06affd 100644 --- a/libs/langchain/langchain/chat_models/cohere.py +++ b/libs/langchain/langchain/chat_models/cohere.py @@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere): if run_manager: await run_manager.on_llm_new_token(delta) + def _get_generation_info(self, response: Any) -> Dict[str, Any]: + """Get the generation info from cohere API response.""" + return { + "documents": response.documents, + "citations": response.citations, + "search_results": response.search_results, + "search_queries": response.search_queries, + "token_count": response.token_count, + } + def _generate( self, messages: List[BaseMessage], @@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere): message = AIMessage(content=response.text) generation_info = None if hasattr(response, "documents"): - generation_info = {"documents": response.documents} + generation_info = self._get_generation_info(response) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere): message = AIMessage(content=response.text) generation_info = None if hasattr(response, "documents"): - generation_info = {"documents": response.documents} + generation_info = self._get_generation_info(response) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) diff --git a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py index 1fd88c093ec..9d79adee69f 100644 --- a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py +++ b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py @@ -15,17 +15,27 @@ if TYPE_CHECKING: def _get_docs(response: Any) -> List[Document]: - return [ + docs = [ Document(page_content=doc["snippet"], metadata=doc) for doc in response.generation_info["documents"] ] + docs.append( + Document( + page_content=response.message.content, + metadata={ + "type": "model_response", + "citations": response.generation_info["citations"], + "search_results": response.generation_info["search_results"], + "search_queries": response.generation_info["search_queries"], + "token_count": response.generation_info["token_count"], + }, + ) + ) + return docs class CohereRagRetriever(BaseRetriever): - """`ChatGPT plugin` retriever.""" - - top_k: int = 3 - """Number of documents to return.""" + """Cohere Chat API with RAG.""" connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}]) """ @@ -55,7 +65,7 @@ class CohereRagRetriever(BaseRetriever): callbacks=run_manager.get_child(), **kwargs, ).generations[0][0] - return _get_docs(res)[: self.top_k] + return _get_docs(res) async def _aget_relevant_documents( self, @@ -73,4 +83,4 @@ class CohereRagRetriever(BaseRetriever): **kwargs, ) ).generations[0][0] - return _get_docs(res)[: self.top_k] + return _get_docs(res)