Add message to documents (#12552)

This adds the response message as a document to the rag retriever so
users can choose to use this. Also drops document limit.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
billytrend-cohere 2023-11-09 16:30:48 +01:00 committed by GitHub
parent 5f38770161
commit b346d4a455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 9 deletions

View File

@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere):
if run_manager: if run_manager:
await run_manager.on_llm_new_token(delta) await run_manager.on_llm_new_token(delta)
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
}
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage(content=response.text) message = AIMessage(content=response.text)
generation_info = None generation_info = None
if hasattr(response, "documents"): if hasattr(response, "documents"):
generation_info = {"documents": response.documents} generation_info = self._get_generation_info(response)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration(message=message, generation_info=generation_info) ChatGeneration(message=message, generation_info=generation_info)
@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage(content=response.text) message = AIMessage(content=response.text)
generation_info = None generation_info = None
if hasattr(response, "documents"): if hasattr(response, "documents"):
generation_info = {"documents": response.documents} generation_info = self._get_generation_info(response)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration(message=message, generation_info=generation_info) ChatGeneration(message=message, generation_info=generation_info)

View File

@ -15,17 +15,27 @@ if TYPE_CHECKING:
def _get_docs(response: Any) -> List[Document]: def _get_docs(response: Any) -> List[Document]:
return [ docs = [
Document(page_content=doc["snippet"], metadata=doc) Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"] for doc in response.generation_info["documents"]
] ]
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs
class CohereRagRetriever(BaseRetriever): class CohereRagRetriever(BaseRetriever):
"""`ChatGPT plugin` retriever.""" """Cohere Chat API with RAG."""
top_k: int = 3
"""Number of documents to return."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}]) connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
""" """
@ -55,7 +65,7 @@ class CohereRagRetriever(BaseRetriever):
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(),
**kwargs, **kwargs,
).generations[0][0] ).generations[0][0]
return _get_docs(res)[: self.top_k] return _get_docs(res)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
@ -73,4 +83,4 @@ class CohereRagRetriever(BaseRetriever):
**kwargs, **kwargs,
) )
).generations[0][0] ).generations[0][0]
return _get_docs(res)[: self.top_k] return _get_docs(res)