mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 01:54:14 +00:00
cohere[patch]: misc fixs tool use agent and cohere chat (#19705)
Bug fixes in this PR: * allows for other params such as "message" not just the input param to the prompt for the cohere tools agent * fixes to documents kwarg from messages * fixes to tool_calls API call --------- Co-authored-by: Harry M <127103098+harry-cohere@users.noreply.github.com>
This commit is contained in:
237
libs/partners/cohere/docs/cohere_agent.ipynb
Normal file
237
libs/partners/cohere/docs/cohere_agent.ipynb
Normal file
@@ -0,0 +1,237 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_position: 0\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cohere Tools\n",
|
||||
"\n",
|
||||
"The following notebook goes over how to use the Cohere tools agent:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prerequisites for this notebook:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: langchain in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.13)\n",
|
||||
"Requirement already satisfied: langchain-cohere in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.0rc2)\n",
|
||||
"Requirement already satisfied: PyYAML>=5.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (6.0.1)\n",
|
||||
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.0.27)\n",
|
||||
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (3.9.3)\n",
|
||||
"Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.6.4)\n",
|
||||
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.33)\n",
|
||||
"Requirement already satisfied: langchain-community<0.1,>=0.0.29 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.29)\n",
|
||||
"Requirement already satisfied: langchain-core<0.2.0,>=0.1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.35)\n",
|
||||
"Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.1)\n",
|
||||
"Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.31)\n",
|
||||
"Requirement already satisfied: numpy<2,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.24.4)\n",
|
||||
"Requirement already satisfied: pydantic<3,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.6.4)\n",
|
||||
"Requirement already satisfied: requests<3,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.31.0)\n",
|
||||
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (8.2.3)\n",
|
||||
"Requirement already satisfied: cohere<6.0.0,>=5.1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-cohere) (5.1.4)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
|
||||
"Requirement already satisfied: httpx>=0.21.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (0.27.0)\n",
|
||||
"Requirement already satisfied: typing_extensions>=4.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (4.10.0)\n",
|
||||
"Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.20.2)\n",
|
||||
"Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
|
||||
"Requirement already satisfied: jsonpointer>=1.9 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
|
||||
"Requirement already satisfied: packaging<24.0,>=23.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.33->langchain) (23.2)\n",
|
||||
"Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.9.15)\n",
|
||||
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
|
||||
"Requirement already satisfied: pydantic-core==2.16.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.3.2)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.6)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2.2.1)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2024.2.2)\n",
|
||||
"Requirement already satisfied: anyio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (4.3.0)\n",
|
||||
"Requirement already satisfied: httpcore==1.* in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.0.4)\n",
|
||||
"Requirement already satisfied: sniffio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.3.1)\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (0.14.0)\n",
|
||||
"Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
|
||||
"Requirement already satisfied: wikipedia in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (1.4.0)\n",
|
||||
"Requirement already satisfied: beautifulsoup4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (4.12.3)\n",
|
||||
"Requirement already satisfied: requests<3.0.0,>=2.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (2.31.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.3.2)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.6)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2.2.1)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2024.2.2)\n",
|
||||
"Requirement already satisfied: soupsieve>1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from beautifulsoup4->wikipedia) (2.5)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# install package\n",
|
||||
"!pip install langchain langchain-cohere\n",
|
||||
"!pip install wikipedia"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import AgentExecutor\n",
|
||||
"from langchain.retrievers import WikipediaRetriever\n",
|
||||
"from langchain.tools.retriever import create_retriever_tool\n",
|
||||
"from langchain_cohere import create_cohere_tools_agent\n",
|
||||
"from langchain_cohere.chat_models import ChatCohere\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next we create the prompt template and cohere model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the prompt\n",
|
||||
"prompt = ChatPromptTemplate.from_template(\n",
|
||||
" \"Write all output in capital letters. {input}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create the Cohere chat model\n",
|
||||
"chat = ChatCohere(cohere_api_key=\"API_KEY\", model=\"command-r\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this example we use a Wikipedia retrieval tool "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = WikipediaRetriever()\n",
|
||||
"retriever_tool = create_retriever_tool(\n",
|
||||
" retriever,\n",
|
||||
" \"wikipedia\",\n",
|
||||
" \"Search for information on Wikipedia\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, create the cohere tool agent and call with the input"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mwikipedia\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': 'Who founded Cohere?',\n",
|
||||
" 'text': 'COHERE WAS FOUNDED BY AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.',\n",
|
||||
" 'additional_info': {'documents': [{'answer': '',\n",
|
||||
" 'id': 'wikipedia:0:0',\n",
|
||||
" 'tool_name': 'wikipedia'}],\n",
|
||||
" 'citations': [ChatCitation(start=22, end=63, text='AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.', document_ids=['wikipedia:0:0'])],\n",
|
||||
" 'search_results': None,\n",
|
||||
" 'search_queries': None,\n",
|
||||
" 'is_search_required': None,\n",
|
||||
" 'generation_id': '3b7e96be-8aad-4fa0-9ae3-7a38e800c289',\n",
|
||||
" 'token_count': {'prompt_tokens': 740,\n",
|
||||
" 'response_tokens': 27,\n",
|
||||
" 'total_tokens': 767,\n",
|
||||
" 'billed_tokens': 48}}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent = create_cohere_tools_agent(\n",
|
||||
" llm=chat,\n",
|
||||
" tools=[retriever_tool],\n",
|
||||
" prompt=prompt,\n",
|
||||
")\n",
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=[retriever_tool], verbose=True)\n",
|
||||
"agent_executor.invoke({\"input\": \"Who founded Cohere?\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@@ -81,25 +81,22 @@ def get_cohere_chat_request(
|
||||
additional_kwargs = messages[-1].additional_kwargs
|
||||
|
||||
# cohere SDK will fail loudly if both connectors and documents are provided
|
||||
if (
|
||||
len(additional_kwargs.get("documents", [])) > 0
|
||||
and documents
|
||||
and len(documents) > 0
|
||||
):
|
||||
if additional_kwargs.get("documents", []) and documents and len(documents) > 0:
|
||||
raise ValueError(
|
||||
"Received documents both as a keyword argument and as an prompt additional"
|
||||
"keywword argument. Please choose only one option."
|
||||
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
||||
)
|
||||
|
||||
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
||||
if additional_kwargs.get("documents"):
|
||||
formatted_docs = [
|
||||
{
|
||||
"text": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
||||
] or documents
|
||||
if not formatted_docs:
|
||||
formatted_docs = None
|
||||
]
|
||||
elif documents:
|
||||
formatted_docs = documents
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
|
@@ -1,6 +1,12 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
||||
|
||||
from cohere.types import Tool, ToolParameterDefinitionsValue
|
||||
from cohere.types import (
|
||||
ChatRequestToolResultsItem,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolParameterDefinitionsValue,
|
||||
)
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
@@ -30,9 +36,7 @@ def create_cohere_tools_agent(
|
||||
RunnablePassthrough.assign(
|
||||
# Intermediate steps are in tool results.
|
||||
# Edit below to change the prompt parameters.
|
||||
input=lambda x: prompt.format_messages(
|
||||
input=x["input"], agent_scratchpad=[]
|
||||
),
|
||||
input=lambda x: prompt.format_messages(**x, agent_scratchpad=[]),
|
||||
tools=lambda x: _format_to_cohere_tools(tools),
|
||||
tool_results=lambda x: _format_to_cohere_tools_messages(
|
||||
x["intermediate_steps"]
|
||||
@@ -52,20 +56,35 @@ def _format_to_cohere_tools(
|
||||
|
||||
def _format_to_cohere_tools_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> list:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
||||
if len(intermediate_steps) == 0:
|
||||
return []
|
||||
tool_results = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
# agent_action.tool_input can be a dict, serialised dict, or string.
|
||||
# Cohere API only accepts a dict.
|
||||
tool_call_parameters: Dict[str, Any]
|
||||
if isinstance(agent_action.tool_input, dict):
|
||||
# tool_input is a dict, use as-is.
|
||||
tool_call_parameters = agent_action.tool_input
|
||||
else:
|
||||
try:
|
||||
# tool_input is serialised dict.
|
||||
tool_call_parameters = json.loads(agent_action.tool_input)
|
||||
if not isinstance(tool_call_parameters, dict):
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
# tool_input is a string, last ditch attempt at having something useful.
|
||||
tool_call_parameters = {"input": agent_action.tool_input}
|
||||
tool_results.append(
|
||||
{
|
||||
"call": {
|
||||
"name": agent_action.tool,
|
||||
"parameters": agent_action.tool_input,
|
||||
},
|
||||
"outputs": [{"answer": observation}],
|
||||
}
|
||||
ChatRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=agent_action.tool,
|
||||
parameters=tool_call_parameters,
|
||||
),
|
||||
outputs=[{"answer": observation}],
|
||||
).dict()
|
||||
)
|
||||
|
||||
return tool_results
|
||||
@@ -143,7 +162,7 @@ class _CohereToolsAgentOutputParser(
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
||||
if result[0].message.additional_kwargs["tool_calls"]:
|
||||
if "tool_calls" in result[0].message.additional_kwargs:
|
||||
actions = []
|
||||
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
||||
function = tool.get("function", {})
|
||||
|
@@ -1,9 +1,14 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.tools import BaseModel, BaseTool, Field
|
||||
|
||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||
from langchain_cohere.cohere_agent import (
|
||||
_format_to_cohere_tools,
|
||||
_format_to_cohere_tools_messages,
|
||||
)
|
||||
|
||||
expected_test_tool_definition = {
|
||||
"description": "test_tool description",
|
||||
@@ -80,3 +85,56 @@ def test_format_to_cohere_tools(
|
||||
actual = _format_to_cohere_tools([tool])
|
||||
|
||||
assert [expected_test_tool_definition] == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"intermediate_step,expected",
|
||||
[
|
||||
pytest.param(
|
||||
(
|
||||
AgentAction(tool="tool_name", tool_input={"arg1": "value1"}, log=""),
|
||||
"result",
|
||||
),
|
||||
{
|
||||
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
|
||||
"outputs": [{"answer": "result"}],
|
||||
},
|
||||
id="tool_input as dict",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
AgentAction(
|
||||
tool="tool_name", tool_input=json.dumps({"arg1": "value1"}), log=""
|
||||
),
|
||||
"result",
|
||||
),
|
||||
{
|
||||
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
|
||||
"outputs": [{"answer": "result"}],
|
||||
},
|
||||
id="tool_input as serialized dict",
|
||||
),
|
||||
pytest.param(
|
||||
(AgentAction(tool="tool_name", tool_input="foo", log=""), "result"),
|
||||
{
|
||||
"call": {"name": "tool_name", "parameters": {"input": "foo"}},
|
||||
"outputs": [{"answer": "result"}],
|
||||
},
|
||||
id="tool_input as string",
|
||||
),
|
||||
pytest.param(
|
||||
(AgentAction(tool="tool_name", tool_input="['foo']", log=""), "result"),
|
||||
{
|
||||
"call": {"name": "tool_name", "parameters": {"input": "['foo']"}},
|
||||
"outputs": [{"answer": "result"}],
|
||||
},
|
||||
id="tool_input unrelated JSON",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_format_to_cohere_tools_messages(
|
||||
intermediate_step: Tuple[AgentAction, str], expected: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
actual = _format_to_cohere_tools_messages(intermediate_steps=[intermediate_step])
|
||||
|
||||
assert [expected] == actual
|
||||
|
Reference in New Issue
Block a user