mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
cohere[patch]: Add additional kwargs support for Cohere SDK params (#19533)
* Adds support for `additional_kwargs` in `get_cohere_chat_request` * This functionality passes in Cohere SDK specific parameters from `BaseMessage` based classes to the API --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
2763d8cbe5
commit
9ea2a9b0c1
@ -47,6 +47,7 @@ def get_role(message: BaseMessage) -> str:
|
|||||||
def get_cohere_chat_request(
|
def get_cohere_chat_request(
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
*,
|
*,
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
connectors: Optional[List[Dict[str, str]]] = None,
|
connectors: Optional[List[Dict[str, str]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -60,24 +61,33 @@ def get_cohere_chat_request(
|
|||||||
Returns:
|
Returns:
|
||||||
The request for the Cohere chat API.
|
The request for the Cohere chat API.
|
||||||
"""
|
"""
|
||||||
documents = (
|
additional_kwargs = messages[-1].additional_kwargs
|
||||||
None
|
|
||||||
if "source_documents" not in kwargs
|
# cohere SDK will fail loudly if both connectors and documents are provided
|
||||||
else [
|
if (
|
||||||
{
|
len(additional_kwargs.get("documents", [])) > 0
|
||||||
"snippet": doc.page_content,
|
and documents
|
||||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
and len(documents) > 0
|
||||||
}
|
):
|
||||||
for i, doc in enumerate(kwargs["source_documents"])
|
raise ValueError(
|
||||||
]
|
"Received documents both as a keyword argument and as an prompt additional"
|
||||||
)
|
"keywword argument. Please choose only one option."
|
||||||
kwargs.pop("source_documents", None)
|
)
|
||||||
maybe_connectors = connectors if documents is None else None
|
|
||||||
|
formatted_docs = [
|
||||||
|
{
|
||||||
|
"text": doc.page_content,
|
||||||
|
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||||
|
}
|
||||||
|
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
||||||
|
] or documents
|
||||||
|
if not formatted_docs:
|
||||||
|
formatted_docs = None
|
||||||
|
|
||||||
# by enabling automatic prompt truncation, the probability of request failure is
|
# by enabling automatic prompt truncation, the probability of request failure is
|
||||||
# reduced with minimal impact on response quality
|
# reduced with minimal impact on response quality
|
||||||
prompt_truncation = (
|
prompt_truncation = (
|
||||||
"AUTO" if documents is not None or connectors is not None else None
|
"AUTO" if formatted_docs is not None or connectors is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
req = {
|
req = {
|
||||||
@ -85,8 +95,8 @@ def get_cohere_chat_request(
|
|||||||
"chat_history": [
|
"chat_history": [
|
||||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||||
],
|
],
|
||||||
"documents": documents,
|
"documents": formatted_docs,
|
||||||
"connectors": maybe_connectors,
|
"connectors": connectors,
|
||||||
"prompt_truncation": prompt_truncation,
|
"prompt_truncation": prompt_truncation,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ from langchain_cohere import ChatCohere
|
|||||||
|
|
||||||
|
|
||||||
def test_stream() -> None:
|
def test_stream() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test streaming tokens from ChatCohere."""
|
||||||
llm = ChatCohere()
|
llm = ChatCohere()
|
||||||
|
|
||||||
for token in llm.stream("I'm Pickle Rick"):
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
@ -11,7 +11,7 @@ def test_stream() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test streaming tokens from ChatCohere."""
|
||||||
llm = ChatCohere()
|
llm = ChatCohere()
|
||||||
|
|
||||||
async for token in llm.astream("I'm Pickle Rick"):
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
|
63
libs/partners/cohere/tests/integration_tests/test_rag.py
Normal file
63
libs/partners/cohere/tests/integration_tests/test_rag.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""Test ChatCohere chat model."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.messages.human import HumanMessage
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain_core.runnables import (
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSerializable,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_cohere import ChatCohere
|
||||||
|
|
||||||
|
|
||||||
|
def test_connectors() -> None:
|
||||||
|
"""Test connectors parameter support from ChatCohere."""
|
||||||
|
llm = ChatCohere().bind(connectors=[{"id": "web-search"}])
|
||||||
|
|
||||||
|
result = llm.invoke("Who directed dune two? reply with just the name.")
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents() -> None:
|
||||||
|
"""Test documents paraneter support from ChatCohere."""
|
||||||
|
docs = [{"text": "The sky is green."}]
|
||||||
|
llm = ChatCohere().bind(documents=docs)
|
||||||
|
prompt = "What color is the sky?"
|
||||||
|
|
||||||
|
result = llm.invoke(prompt)
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
assert len(result.response_metadata["documents"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_chain() -> None:
|
||||||
|
"""Test documents paraneter support from ChatCohere."""
|
||||||
|
llm = ChatCohere()
|
||||||
|
|
||||||
|
def get_documents(_: Any) -> List[Document]:
|
||||||
|
return [Document(page_content="The sky is green.")]
|
||||||
|
|
||||||
|
def format_input_msgs(input: Dict[str, Any]) -> List[HumanMessage]:
|
||||||
|
return [
|
||||||
|
HumanMessage(
|
||||||
|
content=input["message"],
|
||||||
|
additional_kwargs={
|
||||||
|
"documents": input.get("documents", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("input_msgs")])
|
||||||
|
chain: RunnableSerializable[Any, Any] = (
|
||||||
|
{"message": RunnablePassthrough(), "documents": get_documents}
|
||||||
|
| RunnablePassthrough()
|
||||||
|
| {"input_msgs": format_input_msgs}
|
||||||
|
| prompt
|
||||||
|
| llm
|
||||||
|
)
|
||||||
|
|
||||||
|
result = chain.invoke("What color is the sky?")
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
assert len(result.response_metadata["documents"]) == 1
|
Loading…
Reference in New Issue
Block a user