mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
Add Cohere retrieval augmented generation to retrievers (#11483)
Add Cohere retrieval augmented generation to retrievers --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
committed by
GitHub
parent
0a24ac7388
commit
f4742dce50
@@ -32,6 +32,44 @@ def get_role(message: BaseMessage) -> str:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def get_cohere_chat_request(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
documents = (
|
||||
None
|
||||
if "source_documents" not in kwargs
|
||||
else [
|
||||
{
|
||||
"snippet": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(kwargs["source_documents"])
|
||||
]
|
||||
)
|
||||
kwargs.pop("source_documents", None)
|
||||
maybe_connectors = connectors if documents is None else None
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
prompt_truncation = (
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
return {
|
||||
"message": messages[0].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||
],
|
||||
"documents": documents,
|
||||
"connectors": maybe_connectors,
|
||||
"prompt_truncation": prompt_truncation,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
|
||||
@@ -73,18 +111,6 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
def get_cohere_chat_request(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"message": messages[0].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||
],
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@@ -92,7 +118,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
@@ -109,7 +135,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
@@ -132,11 +158,18 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@@ -151,11 +184,18 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request, stream=False)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
|
@@ -24,6 +24,7 @@ from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetr
|
||||
from langchain.retrievers.bm25 import BM25Retriever
|
||||
from langchain.retrievers.chaindesk import ChaindeskRetriever
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.cohere_rag_retriever import CohereRagRetriever
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.docarray import DocArrayRetriever
|
||||
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
|
||||
@@ -77,6 +78,7 @@ __all__ = [
|
||||
"ChatGPTPluginRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
"ChaindeskRetriever",
|
||||
"CohereRagRetriever",
|
||||
"ElasticSearchBM25Retriever",
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
|
76
libs/langchain/langchain/retrievers/cohere_rag_retriever.py
Normal file
76
libs/langchain/langchain/retrievers/cohere_rag_retriever.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import BaseRetriever, Document, HumanMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
def _get_docs(response: Any) -> List[Document]:
|
||||
return [
|
||||
Document(page_content=doc["snippet"], metadata=doc)
|
||||
for doc in response.generation_info["documents"]
|
||||
]
|
||||
|
||||
|
||||
class CohereRagRetriever(BaseRetriever):
|
||||
"""`ChatGPT plugin` retriever."""
|
||||
|
||||
top_k: int = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
|
||||
"""
|
||||
When specified, the model's reply will be enriched with information found by
|
||||
querying each of the connectors (RAG). These will be returned as langchain
|
||||
documents.
|
||||
|
||||
Currently only accepts {"id": "web-search"}.
|
||||
"""
|
||||
|
||||
llm: BaseChatModel
|
||||
"""Cohere ChatModel to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
"""Allow arbitrary types."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = self.llm.generate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
).generations[0][0]
|
||||
return _get_docs(res)[: self.top_k]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = (
|
||||
await self.llm.agenerate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
)
|
||||
).generations[0][0]
|
||||
return _get_docs(res)[: self.top_k]
|
Reference in New Issue
Block a user