mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
cohere[patch]: Add additional kwargs support for Cohere SDK params (#19533)
* Adds support for `additional_kwargs` in `get_cohere_chat_request` * This functionality passes in Cohere SDK specific parameters from `BaseMessage` based classes to the API --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
2763d8cbe5
commit
9ea2a9b0c1
@ -47,6 +47,7 @@ def get_role(message: BaseMessage) -> str:
|
||||
def get_cohere_chat_request(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
@ -60,24 +61,33 @@ def get_cohere_chat_request(
|
||||
Returns:
|
||||
The request for the Cohere chat API.
|
||||
"""
|
||||
documents = (
|
||||
None
|
||||
if "source_documents" not in kwargs
|
||||
else [
|
||||
{
|
||||
"snippet": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(kwargs["source_documents"])
|
||||
]
|
||||
)
|
||||
kwargs.pop("source_documents", None)
|
||||
maybe_connectors = connectors if documents is None else None
|
||||
additional_kwargs = messages[-1].additional_kwargs
|
||||
|
||||
# cohere SDK will fail loudly if both connectors and documents are provided
|
||||
if (
|
||||
len(additional_kwargs.get("documents", [])) > 0
|
||||
and documents
|
||||
and len(documents) > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Received documents both as a keyword argument and as an prompt additional"
|
||||
"keywword argument. Please choose only one option."
|
||||
)
|
||||
|
||||
formatted_docs = [
|
||||
{
|
||||
"text": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
||||
] or documents
|
||||
if not formatted_docs:
|
||||
formatted_docs = None
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
prompt_truncation = (
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
"AUTO" if formatted_docs is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
req = {
|
||||
@ -85,8 +95,8 @@ def get_cohere_chat_request(
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||
],
|
||||
"documents": documents,
|
||||
"connectors": maybe_connectors,
|
||||
"documents": formatted_docs,
|
||||
"connectors": connectors,
|
||||
"prompt_truncation": prompt_truncation,
|
||||
**kwargs,
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ from langchain_cohere import ChatCohere
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
"""Test streaming tokens from ChatCohere."""
|
||||
llm = ChatCohere()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
@ -11,7 +11,7 @@ def test_stream() -> None:
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
"""Test streaming tokens from ChatCohere."""
|
||||
llm = ChatCohere()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
|
63
libs/partners/cohere/tests/integration_tests/test_rag.py
Normal file
63
libs/partners/cohere/tests/integration_tests/test_rag.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Test ChatCohere chat model."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import (
|
||||
RunnablePassthrough,
|
||||
RunnableSerializable,
|
||||
)
|
||||
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
|
||||
def test_connectors() -> None:
|
||||
"""Test connectors parameter support from ChatCohere."""
|
||||
llm = ChatCohere().bind(connectors=[{"id": "web-search"}])
|
||||
|
||||
result = llm.invoke("Who directed dune two? reply with just the name.")
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_documents() -> None:
|
||||
"""Test documents paraneter support from ChatCohere."""
|
||||
docs = [{"text": "The sky is green."}]
|
||||
llm = ChatCohere().bind(documents=docs)
|
||||
prompt = "What color is the sky?"
|
||||
|
||||
result = llm.invoke(prompt)
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.response_metadata["documents"]) == 1
|
||||
|
||||
|
||||
def test_documents_chain() -> None:
|
||||
"""Test documents paraneter support from ChatCohere."""
|
||||
llm = ChatCohere()
|
||||
|
||||
def get_documents(_: Any) -> List[Document]:
|
||||
return [Document(page_content="The sky is green.")]
|
||||
|
||||
def format_input_msgs(input: Dict[str, Any]) -> List[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(
|
||||
content=input["message"],
|
||||
additional_kwargs={
|
||||
"documents": input.get("documents", None),
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("input_msgs")])
|
||||
chain: RunnableSerializable[Any, Any] = (
|
||||
{"message": RunnablePassthrough(), "documents": get_documents}
|
||||
| RunnablePassthrough()
|
||||
| {"input_msgs": format_input_msgs}
|
||||
| prompt
|
||||
| llm
|
||||
)
|
||||
|
||||
result = chain.invoke("What color is the sky?")
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.response_metadata["documents"]) == 1
|
Loading…
Reference in New Issue
Block a user