From 1827bb4042eab6d22321f7cbcc44c5e0df3c9602 Mon Sep 17 00:00:00 2001 From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com> Date: Thu, 1 Aug 2024 23:43:07 +0800 Subject: [PATCH] community[patch]: support bind_tools for ChatMlflow (#24547) Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - **Description:** Support ChatMlflow.bind_tools method Tested in Databricks: image - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Serena Ruan --- docs/docs/integrations/chat/databricks.ipynb | 62 ++- docs/docs/integrations/tools/databricks.ipynb | 6 +- libs/community/extended_testing_deps.txt | 1 + .../langchain_community/chat_models/mlflow.py | 217 ++++++++- .../unit_tests/chat_models/test_mlflow.py | 423 ++++++++++++++++++ 5 files changed, 689 insertions(+), 20 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_mlflow.py diff --git a/docs/docs/integrations/chat/databricks.ipynb b/docs/docs/integrations/chat/databricks.ipynb index 1e6e325f928..f0612c90d1a 100644 --- a/docs/docs/integrations/chat/databricks.ipynb +++ b/docs/docs/integrations/chat/databricks.ipynb @@ -36,7 +36,7 @@ "### Model features\n", "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n", + "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n", "\n", "### Supported Methods\n", "\n", @@ -395,6 +395,66 @@ "chat_model_external.invoke(\"How to use Databricks?\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Function calling on Databricks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Databricks Function Calling is OpenAI-compatible and is only available during model serving as part of Foundation Model APIs.\n", + "\n", + "See [Databricks function calling introduction](https://docs.databricks.com/en/machine-learning/model-serving/function-calling.html#supported-models) for supported models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.databricks import ChatDatabricks\n", + "\n", + "llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", + " },\n", + " },\n", + " },\n", + " }\n", + "]\n", + "\n", + "# supported tool_choice values: \"auto\", \"required\", \"none\", function name in string format,\n", + "# or a dictionary as {\"type\": \"function\", \"function\": {\"name\": <>}}\n", + "model = llm.bind_tools(tools, tool_choice=\"auto\")\n", + "\n", + "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature of Chicago?\"}]\n", + "print(model.invoke(messages))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See [Databricks Unity Catalog](docs/integrations/tools/databricks.ipynb) about how to use UC functions in chains." + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/docs/integrations/tools/databricks.ipynb b/docs/docs/integrations/tools/databricks.ipynb index 49e5bc63905..823ab803e1f 100644 --- a/docs/docs/integrations/tools/databricks.ipynb +++ b/docs/docs/integrations/tools/databricks.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai" + "%pip install --upgrade --quiet databricks-sdk langchain-community mlflow" ] }, { @@ -47,9 +47,9 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_openai import ChatOpenAI\n", + "from langchain_community.chat_models.databricks import ChatDatabricks\n", "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo\")" + "llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")" ] }, { diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 79ce46657ac..f94f975d057 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -91,3 +91,4 @@ vdms>=0.0.20 xata>=1.0.0a7,<2 xmltodict>=0.13.0,<0.14 nanopq==0.2.1 +mlflow[genai]>=2.14.0 diff --git a/libs/community/langchain_community/chat_models/mlflow.py b/libs/community/langchain_community/chat_models/mlflow.py index 5872948a063..101c2cb040c 100644 --- a/libs/community/langchain_community/chat_models/mlflow.py +++ b/libs/community/langchain_community/chat_models/mlflow.py @@ -1,5 +1,19 @@ +import json import logging -from typing import Any, Dict, Iterator, List, Mapping, Optional, cast +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Type, + Union, + cast, +) from urllib.parse import urlparse from langchain_core.callbacks import CallbackManagerForLLMRun @@ -15,15 +29,27 @@ from langchain_core.messages import ( FunctionMessage, HumanMessage, HumanMessageChunk, + InvalidToolCall, SystemMessage, SystemMessageChunk, + ToolCall, + ToolMessage, + ToolMessageChunk, +) +from langchain_core.messages.tool import tool_call_chunk +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import ( + BaseModel, Field, PrivateAttr, ) -from langchain_core.runnables import RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool logger = logging.getLogger(__name__) @@ -228,11 +254,32 @@ class ChatMlflow(BaseChatModel): @staticmethod def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] - content = _dict["content"] + content = cast(str, _dict.get("content")) if role == "user": return HumanMessage(content=content) elif role == "assistant": - return AIMessage(content=content) + content = content or "" + additional_kwargs: Dict = {} + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append( + parse_tool_call(raw_tool_call, return_id=True) + ) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + id=_dict.get("id"), + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) elif role == "system": return SystemMessage(content=content) else: @@ -243,13 +290,38 @@ class ChatMlflow(BaseChatModel): _dict: Mapping[str, Any], default_role: str ) -> BaseMessageChunk: role = _dict.get("role", default_role) - content = _dict["content"] + content = _dict.get("content") or "" if role == "user": return HumanMessageChunk(content=content) elif role == "assistant": - return AIMessageChunk(content=content) + additional_kwargs: Dict = {} + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + id=_dict.get("id"), + tool_call_chunks=tool_call_chunks, + ) elif role == "system": return SystemMessageChunk(content=content) + elif role == "tool": + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id") + ) else: return ChatMessageChunk(content=content, role=role) @@ -262,14 +334,47 @@ class ChatMlflow(BaseChatModel): @staticmethod def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict = {"content": message.content} + if (name := message.name or message.additional_kwargs.get("name")) is not None: + message_dict["name"] = name if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} + message_dict["role"] = message.role elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} + message_dict["role"] = "user" elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} + message_dict["role"] = "assistant" + if message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_openai_tool_call(tc) + for tc in message.invalid_tool_calls + ] # type: ignore[assignment] + elif "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + tool_call_supported_props = {"id", "type", "function"} + message_dict["tool_calls"] = [ + { + k: v + for k, v in tool_call.items() # type: ignore[union-attr] + if k in tool_call_supported_props + } + for tool_call in message_dict["tool_calls"] + ] + else: + pass + # If tool calls present, content null value should be None not empty string. + if "tool_calls" in message_dict: + message_dict["content"] = message_dict["content"] or None # type: ignore[assignment] elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} + message_dict["role"] = "system" + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id + supported_props = {"content", "role", "tool_call_id"} + message_dict = { + k: v for k, v in message_dict.items() if k in supported_props + } elif isinstance(message, FunctionMessage): raise ValueError( "Function messages are not supported by Databricks. Please" @@ -280,12 +385,6 @@ class ChatMlflow(BaseChatModel): if "function_call" in message.additional_kwargs: ChatMlflow._raise_functions_not_supported() - if message.additional_kwargs: - logger.warning( - "Additional message arguments are unsupported by Databricks" - " and will be ignored: %s", - message.additional_kwargs, - ) return message_dict @staticmethod @@ -302,3 +401,89 @@ class ChatMlflow(BaseChatModel): usage = response.get("usage", {}) return ChatResult(generations=generations, llm_output=usage) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Options are: + name of the tool (str): calls corresponding tool; + "auto": automatically selects a tool (including no tool); + "none": model does not generate any tool calls and instead must + generate a standard assistant message; + "required": the model picks the most relevant tool in tools and + must generate a tool call; + + or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice: + if isinstance(tool_choice, str): + # tool_choice is a tool/function name + if tool_choice not in ("auto", "none", "required"): + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] + for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] + for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + else: + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + kwargs["tool_choice"] = tool_choice + return super().bind(tools=formatted_tools, **kwargs) + + +def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _lc_invalid_tool_call_to_openai_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + } diff --git a/libs/community/tests/unit_tests/chat_models/test_mlflow.py b/libs/community/tests/unit_tests/chat_models/test_mlflow.py new file mode 100644 index 00000000000..d526086c490 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_mlflow.py @@ -0,0 +1,423 @@ +import json +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolCallChunk, + ToolMessageChunk, +) +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel +from langchain_core.tools import StructuredTool + +from langchain_community.chat_models.mlflow import ChatMlflow + + +@pytest.fixture +def llm() -> ChatMlflow: + return ChatMlflow( + endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks" + ) + + +@pytest.fixture +def model_input() -> List[BaseMessage]: + data = [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + {"role": "user", "content": "36939 * 8922.4"}, + ] + return [ChatMlflow._convert_dict_to_message(value) for value in data] + + +@pytest.fixture +def mock_prediction() -> dict: + return { + "id": "chatcmpl_id", + "object": "chat.completion", + "created": 1721875529, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "To calculate the result of 36939 multiplied by 8922.4, " + "I get:\n\n36939 x 8922.4 = 329,511,111.6", + }, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66}, + } + + +@pytest.fixture +def mock_predict_stream_result() -> List[dict]: + return [ + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "36939"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "x"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "8922.4"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": " = "}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "329,511,111.6"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66}, + }, + ] + + +@pytest.mark.requires("mlflow") +def test_chat_mlflow_predict( + llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict +) -> None: + mock_client = MagicMock() + llm._client = mock_client + + def mock_predict(*args: Any, **kwargs: Any) -> Any: + return mock_prediction + + mock_client.predict = mock_predict + res = llm.invoke(model_input) + assert res.content == mock_prediction["choices"][0]["message"]["content"] + + +@pytest.mark.requires("mlflow") +def test_chat_mlflow_stream( + llm: ChatMlflow, + model_input: List[BaseMessage], + mock_predict_stream_result: List[dict], +) -> None: + mock_client = MagicMock() + llm._client = mock_client + + def mock_stream(*args: Any, **kwargs: Any) -> Any: + yield from mock_predict_stream_result + + mock_client.predict_stream = mock_stream + for i, res in enumerate(llm.stream(model_input)): + assert ( + res.content + == mock_predict_stream_result[i]["choices"][0]["delta"]["content"] + ) + + +@pytest.mark.requires("mlflow") +@pytest.mark.skipif( + _PYDANTIC_MAJOR_VERSION < 2, + reason="The tool mock is not compatible with pydantic 1.x", +) +def test_chat_mlflow_bind_tools( + llm: ChatMlflow, mock_predict_stream_result: List[dict] +) -> None: + mock_client = MagicMock() + llm._client = mock_client + + def mock_stream(*args: Any, **kwargs: Any) -> Any: + yield from mock_predict_stream_result + + mock_client.predict_stream = mock_stream + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful assistant. Make sure to use tool for information.", + ), + ("placeholder", "{chat_history}"), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), + ] + ) + + def mock_func(*args: Any, **kwargs: Any) -> str: + return "36939 x 8922.4 = 329,511,111.6" + + tools = [ + StructuredTool( + name="name", + description="description", + args_schema=BaseModel, + func=mock_func, + ) + ] + agent = create_tool_calling_agent(llm, tools, prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type] + result = agent_executor.invoke({"input": "36939 * 8922.4"}) + assert result["output"] == "36939x8922.4 = 329,511,111.6" + + +def test_convert_dict_to_message_human() -> None: + message = {"role": "user", "content": "foo"} + result = ChatMlflow._convert_dict_to_message(message) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test_convert_dict_to_message_ai() -> None: + message = {"role": "assistant", "content": "foo"} + result = ChatMlflow._convert_dict_to_message(message) + expected_output = AIMessage(content="foo") + assert result == expected_output + + tool_calls = [ + { + "id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12", + "type": "function", + "function": { + "name": "main__test__python_exec", + "arguments": '{"code": "result = 36939 * 8922.4" }', + }, + } + ] + message_with_tools: Dict[str, Any] = { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + result = ChatMlflow._convert_dict_to_message(message_with_tools) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls}, + id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12", + tool_calls=[ + { + "name": tool_calls[0]["function"]["name"], # type: ignore[index] + "args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index] + "id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12", + "type": "tool_call", + } + ], + ) + + +def test_convert_dict_to_message_system() -> None: + message = {"role": "system", "content": "foo"} + result = ChatMlflow._convert_dict_to_message(message) + expected_output = SystemMessage(content="foo") + assert result == expected_output + + +def test_convert_dict_to_message_chat() -> None: + message = {"role": "any_role", "content": "foo"} + result = ChatMlflow._convert_dict_to_message(message) + expected_output = ChatMessage(content="foo", role="any_role") + assert result == expected_output + + +def test_convert_delta_to_message_chunk_ai() -> None: + delta = {"role": "assistant", "content": "foo"} + result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role") + expected_output = AIMessageChunk(content="foo") + assert result == expected_output + + delta_with_tools: Dict[str, Any] = { + "role": "assistant", + "content": None, + "tool_calls": [{"index": 0, "function": {"arguments": " }"}}], + } + result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role") + expected_output = AIMessageChunk( + content="", + additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]}, + id=None, + tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)], + ) + assert result == expected_output + + +def test_convert_delta_to_message_chunk_tool() -> None: + delta = { + "role": "tool", + "content": "foo", + "tool_call_id": "tool_call_id", + "id": "some_id", + } + result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role") + expected_output = ToolMessageChunk( + content="foo", id="some_id", tool_call_id="tool_call_id" + ) + assert result == expected_output + + +def test_convert_delta_to_message_chunk_human() -> None: + delta = { + "role": "user", + "content": "foo", + } + result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role") + expected_output = HumanMessageChunk(content="foo") + assert result == expected_output + + +def test_convert_delta_to_message_chunk_system() -> None: + delta = { + "role": "system", + "content": "foo", + } + result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role") + expected_output = SystemMessageChunk(content="foo") + assert result == expected_output + + +def test_convert_delta_to_message_chunk_chat() -> None: + delta = { + "role": "any_role", + "content": "foo", + } + result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role") + expected_output = ChatMessageChunk(content="foo", role="any_role") + assert result == expected_output + + +def test_convert_message_to_dict_human() -> None: + human_message = HumanMessage(content="foo") + result = ChatMlflow._convert_message_to_dict(human_message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test_convert_message_to_dict_system() -> None: + system_message = SystemMessage(content="foo") + result = ChatMlflow._convert_message_to_dict(system_message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output + + +def test_convert_message_to_dict_ai() -> None: + ai_message = AIMessage(content="foo") + result = ChatMlflow._convert_message_to_dict(ai_message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + ai_message = AIMessage( + content="", + tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}], + ) + result = ChatMlflow._convert_message_to_dict(ai_message) + expected_output_with_tools: Dict[str, Any] = { + "content": None, + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "id": "id", + "function": {"name": "name", "arguments": "{}"}, + } + ], + } + assert result == expected_output_with_tools + + +def test_convert_message_to_dict_tool() -> None: + tool_message = ToolMessageChunk( + content="foo", id="some_id", tool_call_id="tool_call_id" + ) + result = ChatMlflow._convert_message_to_dict(tool_message) + expected_output = { + "role": "tool", + "content": "foo", + "tool_call_id": "tool_call_id", + } + assert result == expected_output + + +def test_convert_message_to_dict_function() -> None: + with pytest.raises(ValueError): + ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))