Add Cohere retrieval augmented generation to retrievers (#11483)

Add Cohere retrieval augmented generation to retrievers

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
billytrend-cohere
2023-10-17 19:51:04 +01:00
committed by GitHub
parent 0a24ac7388
commit f4742dce50
4 changed files with 359 additions and 18 deletions

View File

@@ -32,6 +32,44 @@ def get_role(message: BaseMessage) -> str:
raise ValueError(f"Got unknown type {message}")
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
documents = (
None
if "source_documents" not in kwargs
else [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(kwargs["source_documents"])
]
)
kwargs.pop("source_documents", None)
maybe_connectors = connectors if documents is None else None
# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
prompt_truncation = (
"AUTO" if documents is not None or connectors is not None else None
)
return {
"message": messages[0].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[1:]
],
"documents": documents,
"connectors": maybe_connectors,
"prompt_truncation": prompt_truncation,
**kwargs,
}
class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
@@ -73,18 +111,6 @@ class ChatCohere(BaseChatModel, BaseCohere):
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
def get_cohere_chat_request(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
return {
"message": messages[0].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[1:]
],
**self._default_params,
**kwargs,
}
def _stream(
self,
messages: List[BaseMessage],
@@ -92,7 +118,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self.get_cohere_chat_request(messages, **kwargs)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)
for data in stream:
@@ -109,7 +135,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self.get_cohere_chat_request(messages, **kwargs)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
async for data in stream:
@@ -132,11 +158,18 @@ class ChatCohere(BaseChatModel, BaseCohere):
)
return _generate_from_stream(stream_iter)
request = self.get_cohere_chat_request(messages, **kwargs)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
return ChatResult(generations=[ChatGeneration(message=message)])
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
async def _agenerate(
self,
@@ -151,11 +184,18 @@ class ChatCohere(BaseChatModel, BaseCohere):
)
return await _agenerate_from_stream(stream_iter)
request = self.get_cohere_chat_request(messages, **kwargs)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
message = AIMessage(content=response.text)
return ChatResult(generations=[ChatGeneration(message=message)])
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""

View File

@@ -24,6 +24,7 @@ from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetr
from langchain.retrievers.bm25 import BM25Retriever
from langchain.retrievers.chaindesk import ChaindeskRetriever
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.cohere_rag_retriever import CohereRagRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
@@ -77,6 +78,7 @@ __all__ = [
"ChatGPTPluginRetriever",
"ContextualCompressionRetriever",
"ChaindeskRetriever",
"CohereRagRetriever",
"ElasticSearchBM25Retriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleCloudEnterpriseSearchRetriever",

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document, HumanMessage
if TYPE_CHECKING:
from langchain.schema.messages import BaseMessage
def _get_docs(response: Any) -> List[Document]:
return [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
class CohereRagRetriever(BaseRetriever):
"""`ChatGPT plugin` retriever."""
top_k: int = 3
"""Number of documents to return."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
When specified, the model's reply will be enriched with information found by
querying each of the connectors (RAG). These will be returned as langchain
documents.
Currently only accepts {"id": "web-search"}.
"""
llm: BaseChatModel
"""Cohere ChatModel to use."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
"""Allow arbitrary types."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)[: self.top_k]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
)
).generations[0][0]
return _get_docs(res)[: self.top_k]