From 1827bb4042eab6d22321f7cbcc44c5e0df3c9602 Mon Sep 17 00:00:00 2001
From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Date: Thu, 1 Aug 2024 23:43:07 +0800
Subject: [PATCH] community[patch]: support bind_tools for ChatMlflow (#24547)
Thank you for contributing to LangChain!
- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
- Example: "community: add foobar LLM"
- **Description:**
Support ChatMlflow.bind_tools method
Tested in Databricks:
- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
---------
Signed-off-by: Serena Ruan
---
docs/docs/integrations/chat/databricks.ipynb | 62 ++-
docs/docs/integrations/tools/databricks.ipynb | 6 +-
libs/community/extended_testing_deps.txt | 1 +
.../langchain_community/chat_models/mlflow.py | 217 ++++++++-
.../unit_tests/chat_models/test_mlflow.py | 423 ++++++++++++++++++
5 files changed, 689 insertions(+), 20 deletions(-)
create mode 100644 libs/community/tests/unit_tests/chat_models/test_mlflow.py
diff --git a/docs/docs/integrations/chat/databricks.ipynb b/docs/docs/integrations/chat/databricks.ipynb
index 1e6e325f928..f0612c90d1a 100644
--- a/docs/docs/integrations/chat/databricks.ipynb
+++ b/docs/docs/integrations/chat/databricks.ipynb
@@ -36,7 +36,7 @@
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
- "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
+ "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"\n",
"### Supported Methods\n",
"\n",
@@ -395,6 +395,66 @@
"chat_model_external.invoke(\"How to use Databricks?\")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Function calling on Databricks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Databricks Function Calling is OpenAI-compatible and is only available during model serving as part of Foundation Model APIs.\n",
+ "\n",
+ "See [Databricks function calling introduction](https://docs.databricks.com/en/machine-learning/model-serving/function-calling.html#supported-models) for supported models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain_community.chat_models.databricks import ChatDatabricks\n",
+ "\n",
+ "llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
+ "tools = [\n",
+ " {\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\n",
+ " \"name\": \"get_current_weather\",\n",
+ " \"description\": \"Get the current weather in a given location\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"location\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
+ " },\n",
+ " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " }\n",
+ "]\n",
+ "\n",
+ "# supported tool_choice values: \"auto\", \"required\", \"none\", function name in string format,\n",
+ "# or a dictionary as {\"type\": \"function\", \"function\": {\"name\": <>}}\n",
+ "model = llm.bind_tools(tools, tool_choice=\"auto\")\n",
+ "\n",
+ "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature of Chicago?\"}]\n",
+ "print(model.invoke(messages))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "See [Databricks Unity Catalog](docs/integrations/tools/databricks.ipynb) about how to use UC functions in chains."
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/docs/docs/integrations/tools/databricks.ipynb b/docs/docs/integrations/tools/databricks.ipynb
index 49e5bc63905..823ab803e1f 100644
--- a/docs/docs/integrations/tools/databricks.ipynb
+++ b/docs/docs/integrations/tools/databricks.ipynb
@@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
- "%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
+ "%pip install --upgrade --quiet databricks-sdk langchain-community mlflow"
]
},
{
@@ -47,9 +47,9 @@
"metadata": {},
"outputs": [],
"source": [
- "from langchain_openai import ChatOpenAI\n",
+ "from langchain_community.chat_models.databricks import ChatDatabricks\n",
"\n",
- "llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
+ "llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")"
]
},
{
diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt
index 79ce46657ac..f94f975d057 100644
--- a/libs/community/extended_testing_deps.txt
+++ b/libs/community/extended_testing_deps.txt
@@ -91,3 +91,4 @@ vdms>=0.0.20
xata>=1.0.0a7,<2
xmltodict>=0.13.0,<0.14
nanopq==0.2.1
+mlflow[genai]>=2.14.0
diff --git a/libs/community/langchain_community/chat_models/mlflow.py b/libs/community/langchain_community/chat_models/mlflow.py
index 5872948a063..101c2cb040c 100644
--- a/libs/community/langchain_community/chat_models/mlflow.py
+++ b/libs/community/langchain_community/chat_models/mlflow.py
@@ -1,5 +1,19 @@
+import json
import logging
-from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Sequence,
+ Type,
+ Union,
+ cast,
+)
from urllib.parse import urlparse
from langchain_core.callbacks import CallbackManagerForLLMRun
@@ -15,15 +29,27 @@ from langchain_core.messages import (
FunctionMessage,
HumanMessage,
HumanMessageChunk,
+ InvalidToolCall,
SystemMessage,
SystemMessageChunk,
+ ToolCall,
+ ToolMessage,
+ ToolMessageChunk,
+)
+from langchain_core.messages.tool import tool_call_chunk
+from langchain_core.output_parsers.openai_tools import (
+ make_invalid_tool_call,
+ parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
+ BaseModel,
Field,
PrivateAttr,
)
-from langchain_core.runnables import RunnableConfig
+from langchain_core.runnables import Runnable, RunnableConfig
+from langchain_core.tools import BaseTool
+from langchain_core.utils.function_calling import convert_to_openai_tool
logger = logging.getLogger(__name__)
@@ -228,11 +254,32 @@ class ChatMlflow(BaseChatModel):
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
- content = _dict["content"]
+ content = cast(str, _dict.get("content"))
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
- return AIMessage(content=content)
+ content = content or ""
+ additional_kwargs: Dict = {}
+ tool_calls = []
+ invalid_tool_calls = []
+ if raw_tool_calls := _dict.get("tool_calls"):
+ additional_kwargs["tool_calls"] = raw_tool_calls
+ for raw_tool_call in raw_tool_calls:
+ try:
+ tool_calls.append(
+ parse_tool_call(raw_tool_call, return_id=True)
+ )
+ except Exception as e:
+ invalid_tool_calls.append(
+ make_invalid_tool_call(raw_tool_call, str(e))
+ )
+ return AIMessage(
+ content=content,
+ additional_kwargs=additional_kwargs,
+ id=_dict.get("id"),
+ tool_calls=tool_calls,
+ invalid_tool_calls=invalid_tool_calls,
+ )
elif role == "system":
return SystemMessage(content=content)
else:
@@ -243,13 +290,38 @@ class ChatMlflow(BaseChatModel):
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
- content = _dict["content"]
+ content = _dict.get("content") or ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
- return AIMessageChunk(content=content)
+ additional_kwargs: Dict = {}
+ tool_call_chunks = []
+ if raw_tool_calls := _dict.get("tool_calls"):
+ additional_kwargs["tool_calls"] = raw_tool_calls
+ try:
+ tool_call_chunks = [
+ tool_call_chunk(
+ name=rtc["function"].get("name"),
+ args=rtc["function"].get("arguments"),
+ id=rtc.get("id"),
+ index=rtc["index"],
+ )
+ for rtc in raw_tool_calls
+ ]
+ except KeyError:
+ pass
+ return AIMessageChunk(
+ content=content,
+ additional_kwargs=additional_kwargs,
+ id=_dict.get("id"),
+ tool_call_chunks=tool_call_chunks,
+ )
elif role == "system":
return SystemMessageChunk(content=content)
+ elif role == "tool":
+ return ToolMessageChunk(
+ content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
+ )
else:
return ChatMessageChunk(content=content, role=role)
@@ -262,14 +334,47 @@ class ChatMlflow(BaseChatModel):
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict = {"content": message.content}
+ if (name := message.name or message.additional_kwargs.get("name")) is not None:
+ message_dict["name"] = name
if isinstance(message, ChatMessage):
- message_dict = {"role": message.role, "content": message.content}
+ message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
- message_dict = {"role": "user", "content": message.content}
+ message_dict["role"] = "user"
elif isinstance(message, AIMessage):
- message_dict = {"role": "assistant", "content": message.content}
+ message_dict["role"] = "assistant"
+ if message.tool_calls or message.invalid_tool_calls:
+ message_dict["tool_calls"] = [
+ _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
+ ] + [
+ _lc_invalid_tool_call_to_openai_tool_call(tc)
+ for tc in message.invalid_tool_calls
+ ] # type: ignore[assignment]
+ elif "tool_calls" in message.additional_kwargs:
+ message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
+ tool_call_supported_props = {"id", "type", "function"}
+ message_dict["tool_calls"] = [
+ {
+ k: v
+ for k, v in tool_call.items() # type: ignore[union-attr]
+ if k in tool_call_supported_props
+ }
+ for tool_call in message_dict["tool_calls"]
+ ]
+ else:
+ pass
+ # If tool calls present, content null value should be None not empty string.
+ if "tool_calls" in message_dict:
+ message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
elif isinstance(message, SystemMessage):
- message_dict = {"role": "system", "content": message.content}
+ message_dict["role"] = "system"
+ elif isinstance(message, ToolMessage):
+ message_dict["role"] = "tool"
+ message_dict["tool_call_id"] = message.tool_call_id
+ supported_props = {"content", "role", "tool_call_id"}
+ message_dict = {
+ k: v for k, v in message_dict.items() if k in supported_props
+ }
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by Databricks. Please"
@@ -280,12 +385,6 @@ class ChatMlflow(BaseChatModel):
if "function_call" in message.additional_kwargs:
ChatMlflow._raise_functions_not_supported()
- if message.additional_kwargs:
- logger.warning(
- "Additional message arguments are unsupported by Databricks"
- " and will be ignored: %s",
- message.additional_kwargs,
- )
return message_dict
@staticmethod
@@ -302,3 +401,89 @@ class ChatMlflow(BaseChatModel):
usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)
+
+ def bind_tools(
+ self,
+ tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
+ *,
+ tool_choice: Optional[
+ Union[dict, str, Literal["auto", "none", "required", "any"], bool]
+ ] = None,
+ **kwargs: Any,
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
+ """Bind tool-like objects to this chat model.
+
+ Assumes model is compatible with OpenAI tool-calling API.
+
+ Args:
+ tools: A list of tool definitions to bind to this chat model.
+ Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
+ models, callables, and BaseTools will be automatically converted to
+ their schema dictionary representation.
+ tool_choice: Which tool to require the model to call.
+ Options are:
+ name of the tool (str): calls corresponding tool;
+ "auto": automatically selects a tool (including no tool);
+ "none": model does not generate any tool calls and instead must
+ generate a standard assistant message;
+ "required": the model picks the most relevant tool in tools and
+ must generate a tool call;
+
+ or a dict of the form:
+ {"type": "function", "function": {"name": <>}}.
+ **kwargs: Any additional parameters to pass to the
+ :class:`~langchain.runnable.Runnable` constructor.
+ """
+ formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
+ if tool_choice:
+ if isinstance(tool_choice, str):
+ # tool_choice is a tool/function name
+ if tool_choice not in ("auto", "none", "required"):
+ tool_choice = {
+ "type": "function",
+ "function": {"name": tool_choice},
+ }
+ elif isinstance(tool_choice, dict):
+ tool_names = [
+ formatted_tool["function"]["name"]
+ for formatted_tool in formatted_tools
+ ]
+ if not any(
+ tool_name == tool_choice["function"]["name"]
+ for tool_name in tool_names
+ ):
+ raise ValueError(
+ f"Tool choice {tool_choice} was specified, but the only "
+ f"provided tools were {tool_names}."
+ )
+ else:
+ raise ValueError(
+ f"Unrecognized tool_choice type. Expected str, bool or dict. "
+ f"Received: {tool_choice}"
+ )
+ kwargs["tool_choice"] = tool_choice
+ return super().bind(tools=formatted_tools, **kwargs)
+
+
+def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
+ return {
+ "type": "function",
+ "id": tool_call["id"],
+ "function": {
+ "name": tool_call["name"],
+ "arguments": json.dumps(tool_call["args"]),
+ },
+ }
+
+
+def _lc_invalid_tool_call_to_openai_tool_call(
+ invalid_tool_call: InvalidToolCall,
+) -> dict:
+ return {
+ "type": "function",
+ "id": invalid_tool_call["id"],
+ "function": {
+ "name": invalid_tool_call["name"],
+ "arguments": invalid_tool_call["args"],
+ },
+ }
diff --git a/libs/community/tests/unit_tests/chat_models/test_mlflow.py b/libs/community/tests/unit_tests/chat_models/test_mlflow.py
new file mode 100644
index 00000000000..d526086c490
--- /dev/null
+++ b/libs/community/tests/unit_tests/chat_models/test_mlflow.py
@@ -0,0 +1,423 @@
+import json
+from typing import Any, Dict, List
+from unittest.mock import MagicMock
+
+import pytest
+from langchain.agents import AgentExecutor, create_tool_calling_agent
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ ChatMessage,
+ ChatMessageChunk,
+ FunctionMessage,
+ HumanMessage,
+ HumanMessageChunk,
+ SystemMessage,
+ SystemMessageChunk,
+ ToolCallChunk,
+ ToolMessageChunk,
+)
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel
+from langchain_core.tools import StructuredTool
+
+from langchain_community.chat_models.mlflow import ChatMlflow
+
+
+@pytest.fixture
+def llm() -> ChatMlflow:
+ return ChatMlflow(
+ endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
+ )
+
+
+@pytest.fixture
+def model_input() -> List[BaseMessage]:
+ data = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": "36939 * 8922.4"},
+ ]
+ return [ChatMlflow._convert_dict_to_message(value) for value in data]
+
+
+@pytest.fixture
+def mock_prediction() -> dict:
+ return {
+ "id": "chatcmpl_id",
+ "object": "chat.completion",
+ "created": 1721875529,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "To calculate the result of 36939 multiplied by 8922.4, "
+ "I get:\n\n36939 x 8922.4 = 329,511,111.6",
+ },
+ "finish_reason": "stop",
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
+ }
+
+
+@pytest.fixture
+def mock_predict_stream_result() -> List[dict]:
+ return [
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": "36939"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
+ },
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": "x"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
+ },
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": "8922.4"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
+ },
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": " = "},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
+ },
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": "329,511,111.6"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
+ },
+ {
+ "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
+ "object": "chat.completion.chunk",
+ "created": 1721877054,
+ "model": "meta-llama-3.1-70b-instruct-072424",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": ""},
+ "finish_reason": "stop",
+ "logprobs": None,
+ }
+ ],
+ "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
+ },
+ ]
+
+
+@pytest.mark.requires("mlflow")
+def test_chat_mlflow_predict(
+ llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict
+) -> None:
+ mock_client = MagicMock()
+ llm._client = mock_client
+
+ def mock_predict(*args: Any, **kwargs: Any) -> Any:
+ return mock_prediction
+
+ mock_client.predict = mock_predict
+ res = llm.invoke(model_input)
+ assert res.content == mock_prediction["choices"][0]["message"]["content"]
+
+
+@pytest.mark.requires("mlflow")
+def test_chat_mlflow_stream(
+ llm: ChatMlflow,
+ model_input: List[BaseMessage],
+ mock_predict_stream_result: List[dict],
+) -> None:
+ mock_client = MagicMock()
+ llm._client = mock_client
+
+ def mock_stream(*args: Any, **kwargs: Any) -> Any:
+ yield from mock_predict_stream_result
+
+ mock_client.predict_stream = mock_stream
+ for i, res in enumerate(llm.stream(model_input)):
+ assert (
+ res.content
+ == mock_predict_stream_result[i]["choices"][0]["delta"]["content"]
+ )
+
+
+@pytest.mark.requires("mlflow")
+@pytest.mark.skipif(
+ _PYDANTIC_MAJOR_VERSION < 2,
+ reason="The tool mock is not compatible with pydantic 1.x",
+)
+def test_chat_mlflow_bind_tools(
+ llm: ChatMlflow, mock_predict_stream_result: List[dict]
+) -> None:
+ mock_client = MagicMock()
+ llm._client = mock_client
+
+ def mock_stream(*args: Any, **kwargs: Any) -> Any:
+ yield from mock_predict_stream_result
+
+ mock_client.predict_stream = mock_stream
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "You are a helpful assistant. Make sure to use tool for information.",
+ ),
+ ("placeholder", "{chat_history}"),
+ ("human", "{input}"),
+ ("placeholder", "{agent_scratchpad}"),
+ ]
+ )
+
+ def mock_func(*args: Any, **kwargs: Any) -> str:
+ return "36939 x 8922.4 = 329,511,111.6"
+
+ tools = [
+ StructuredTool(
+ name="name",
+ description="description",
+ args_schema=BaseModel,
+ func=mock_func,
+ )
+ ]
+ agent = create_tool_calling_agent(llm, tools, prompt)
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type]
+ result = agent_executor.invoke({"input": "36939 * 8922.4"})
+ assert result["output"] == "36939x8922.4 = 329,511,111.6"
+
+
+def test_convert_dict_to_message_human() -> None:
+ message = {"role": "user", "content": "foo"}
+ result = ChatMlflow._convert_dict_to_message(message)
+ expected_output = HumanMessage(content="foo")
+ assert result == expected_output
+
+
+def test_convert_dict_to_message_ai() -> None:
+ message = {"role": "assistant", "content": "foo"}
+ result = ChatMlflow._convert_dict_to_message(message)
+ expected_output = AIMessage(content="foo")
+ assert result == expected_output
+
+ tool_calls = [
+ {
+ "id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
+ "type": "function",
+ "function": {
+ "name": "main__test__python_exec",
+ "arguments": '{"code": "result = 36939 * 8922.4" }',
+ },
+ }
+ ]
+ message_with_tools: Dict[str, Any] = {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": tool_calls,
+ }
+ result = ChatMlflow._convert_dict_to_message(message_with_tools)
+ expected_output = AIMessage(
+ content="",
+ additional_kwargs={"tool_calls": tool_calls},
+ id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
+ tool_calls=[
+ {
+ "name": tool_calls[0]["function"]["name"], # type: ignore[index]
+ "args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
+ "id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
+ "type": "tool_call",
+ }
+ ],
+ )
+
+
+def test_convert_dict_to_message_system() -> None:
+ message = {"role": "system", "content": "foo"}
+ result = ChatMlflow._convert_dict_to_message(message)
+ expected_output = SystemMessage(content="foo")
+ assert result == expected_output
+
+
+def test_convert_dict_to_message_chat() -> None:
+ message = {"role": "any_role", "content": "foo"}
+ result = ChatMlflow._convert_dict_to_message(message)
+ expected_output = ChatMessage(content="foo", role="any_role")
+ assert result == expected_output
+
+
+def test_convert_delta_to_message_chunk_ai() -> None:
+ delta = {"role": "assistant", "content": "foo"}
+ result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
+ expected_output = AIMessageChunk(content="foo")
+ assert result == expected_output
+
+ delta_with_tools: Dict[str, Any] = {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
+ }
+ result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role")
+ expected_output = AIMessageChunk(
+ content="",
+ additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
+ id=None,
+ tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
+ )
+ assert result == expected_output
+
+
+def test_convert_delta_to_message_chunk_tool() -> None:
+ delta = {
+ "role": "tool",
+ "content": "foo",
+ "tool_call_id": "tool_call_id",
+ "id": "some_id",
+ }
+ result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
+ expected_output = ToolMessageChunk(
+ content="foo", id="some_id", tool_call_id="tool_call_id"
+ )
+ assert result == expected_output
+
+
+def test_convert_delta_to_message_chunk_human() -> None:
+ delta = {
+ "role": "user",
+ "content": "foo",
+ }
+ result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
+ expected_output = HumanMessageChunk(content="foo")
+ assert result == expected_output
+
+
+def test_convert_delta_to_message_chunk_system() -> None:
+ delta = {
+ "role": "system",
+ "content": "foo",
+ }
+ result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
+ expected_output = SystemMessageChunk(content="foo")
+ assert result == expected_output
+
+
+def test_convert_delta_to_message_chunk_chat() -> None:
+ delta = {
+ "role": "any_role",
+ "content": "foo",
+ }
+ result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
+ expected_output = ChatMessageChunk(content="foo", role="any_role")
+ assert result == expected_output
+
+
+def test_convert_message_to_dict_human() -> None:
+ human_message = HumanMessage(content="foo")
+ result = ChatMlflow._convert_message_to_dict(human_message)
+ expected_output = {"role": "user", "content": "foo"}
+ assert result == expected_output
+
+
+def test_convert_message_to_dict_system() -> None:
+ system_message = SystemMessage(content="foo")
+ result = ChatMlflow._convert_message_to_dict(system_message)
+ expected_output = {"role": "system", "content": "foo"}
+ assert result == expected_output
+
+
+def test_convert_message_to_dict_ai() -> None:
+ ai_message = AIMessage(content="foo")
+ result = ChatMlflow._convert_message_to_dict(ai_message)
+ expected_output = {"role": "assistant", "content": "foo"}
+ assert result == expected_output
+
+ ai_message = AIMessage(
+ content="",
+ tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}],
+ )
+ result = ChatMlflow._convert_message_to_dict(ai_message)
+ expected_output_with_tools: Dict[str, Any] = {
+ "content": None,
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": "id",
+ "function": {"name": "name", "arguments": "{}"},
+ }
+ ],
+ }
+ assert result == expected_output_with_tools
+
+
+def test_convert_message_to_dict_tool() -> None:
+ tool_message = ToolMessageChunk(
+ content="foo", id="some_id", tool_call_id="tool_call_id"
+ )
+ result = ChatMlflow._convert_message_to_dict(tool_message)
+ expected_output = {
+ "role": "tool",
+ "content": "foo",
+ "tool_call_id": "tool_call_id",
+ }
+ assert result == expected_output
+
+
+def test_convert_message_to_dict_function() -> None:
+ with pytest.raises(ValueError):
+ ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))