mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Add message to documents (#12552)
This adds the response message as a document to the rag retriever so users can choose to use this. Also drops document limit. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5f38770161
commit
b346d4a455
@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
|
||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||
"""Get the generation info from cohere API response."""
|
||||
return {
|
||||
"documents": response.documents,
|
||||
"citations": response.citations,
|
||||
"search_results": response.search_results,
|
||||
"search_queries": response.search_queries,
|
||||
"token_count": response.token_count,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
|
@ -15,17 +15,27 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _get_docs(response: Any) -> List[Document]:
|
||||
return [
|
||||
docs = [
|
||||
Document(page_content=doc["snippet"], metadata=doc)
|
||||
for doc in response.generation_info["documents"]
|
||||
]
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=response.message.content,
|
||||
metadata={
|
||||
"type": "model_response",
|
||||
"citations": response.generation_info["citations"],
|
||||
"search_results": response.generation_info["search_results"],
|
||||
"search_queries": response.generation_info["search_queries"],
|
||||
"token_count": response.generation_info["token_count"],
|
||||
},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
class CohereRagRetriever(BaseRetriever):
|
||||
"""`ChatGPT plugin` retriever."""
|
||||
|
||||
top_k: int = 3
|
||||
"""Number of documents to return."""
|
||||
"""Cohere Chat API with RAG."""
|
||||
|
||||
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
|
||||
"""
|
||||
@ -55,7 +65,7 @@ class CohereRagRetriever(BaseRetriever):
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
).generations[0][0]
|
||||
return _get_docs(res)[: self.top_k]
|
||||
return _get_docs(res)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
@ -73,4 +83,4 @@ class CohereRagRetriever(BaseRetriever):
|
||||
**kwargs,
|
||||
)
|
||||
).generations[0][0]
|
||||
return _get_docs(res)[: self.top_k]
|
||||
return _get_docs(res)
|
||||
|
Loading…
Reference in New Issue
Block a user