cohere[patch]: Add additional kwargs support for Cohere SDK params (#19533)

* Adds support for `additional_kwargs` in `get_cohere_chat_request`
* This functionality passes in Cohere SDK specific parameters from
`BaseMessage` based classes to the API

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Giannis 2024-03-26 14:30:37 -04:00 committed by GitHub
parent 2763d8cbe5
commit 9ea2a9b0c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 18 deletions

View File

@ -47,6 +47,7 @@ def get_role(message: BaseMessage) -> str:
def get_cohere_chat_request( def get_cohere_chat_request(
messages: List[BaseMessage], messages: List[BaseMessage],
*, *,
documents: Optional[List[Dict[str, str]]] = None,
connectors: Optional[List[Dict[str, str]]] = None, connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -60,24 +61,33 @@ def get_cohere_chat_request(
Returns: Returns:
The request for the Cohere chat API. The request for the Cohere chat API.
""" """
documents = ( additional_kwargs = messages[-1].additional_kwargs
None
if "source_documents" not in kwargs # cohere SDK will fail loudly if both connectors and documents are provided
else [ if (
{ len(additional_kwargs.get("documents", [])) > 0
"snippet": doc.page_content, and documents
"id": doc.metadata.get("id") or f"doc-{str(i)}", and len(documents) > 0
} ):
for i, doc in enumerate(kwargs["source_documents"]) raise ValueError(
] "Received documents both as a keyword argument and as an prompt additional"
) "keywword argument. Please choose only one option."
kwargs.pop("source_documents", None) )
maybe_connectors = connectors if documents is None else None
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
] or documents
if not formatted_docs:
formatted_docs = None
# by enabling automatic prompt truncation, the probability of request failure is # by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality # reduced with minimal impact on response quality
prompt_truncation = ( prompt_truncation = (
"AUTO" if documents is not None or connectors is not None else None "AUTO" if formatted_docs is not None or connectors is not None else None
) )
req = { req = {
@ -85,8 +95,8 @@ def get_cohere_chat_request(
"chat_history": [ "chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1] {"role": get_role(x), "message": x.content} for x in messages[:-1]
], ],
"documents": documents, "documents": formatted_docs,
"connectors": maybe_connectors, "connectors": connectors,
"prompt_truncation": prompt_truncation, "prompt_truncation": prompt_truncation,
**kwargs, **kwargs,
} }

View File

@ -3,7 +3,7 @@ from langchain_cohere import ChatCohere
def test_stream() -> None: def test_stream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from ChatCohere."""
llm = ChatCohere() llm = ChatCohere()
for token in llm.stream("I'm Pickle Rick"): for token in llm.stream("I'm Pickle Rick"):
@ -11,7 +11,7 @@ def test_stream() -> None:
async def test_astream() -> None: async def test_astream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from ChatCohere."""
llm = ChatCohere() llm = ChatCohere()
async for token in llm.astream("I'm Pickle Rick"): async for token in llm.astream("I'm Pickle Rick"):

View File

@ -0,0 +1,63 @@
"""Test ChatCohere chat model."""
from typing import Any, Dict, List
from langchain_core.documents import Document
from langchain_core.messages.human import HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import (
RunnablePassthrough,
RunnableSerializable,
)
from langchain_cohere import ChatCohere
def test_connectors() -> None:
"""Test connectors parameter support from ChatCohere."""
llm = ChatCohere().bind(connectors=[{"id": "web-search"}])
result = llm.invoke("Who directed dune two? reply with just the name.")
assert isinstance(result.content, str)
def test_documents() -> None:
"""Test documents paraneter support from ChatCohere."""
docs = [{"text": "The sky is green."}]
llm = ChatCohere().bind(documents=docs)
prompt = "What color is the sky?"
result = llm.invoke(prompt)
assert isinstance(result.content, str)
assert len(result.response_metadata["documents"]) == 1
def test_documents_chain() -> None:
"""Test documents paraneter support from ChatCohere."""
llm = ChatCohere()
def get_documents(_: Any) -> List[Document]:
return [Document(page_content="The sky is green.")]
def format_input_msgs(input: Dict[str, Any]) -> List[HumanMessage]:
return [
HumanMessage(
content=input["message"],
additional_kwargs={
"documents": input.get("documents", None),
},
)
]
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("input_msgs")])
chain: RunnableSerializable[Any, Any] = (
{"message": RunnablePassthrough(), "documents": get_documents}
| RunnablePassthrough()
| {"input_msgs": format_input_msgs}
| prompt
| llm
)
result = chain.invoke("What color is the sky?")
assert isinstance(result.content, str)
assert len(result.response_metadata["documents"]) == 1