mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
community[patch]: support bind_tools for ChatMlflow (#24547)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - **Description:** Support ChatMlflow.bind_tools method Tested in Databricks: <img width="836" alt="image" src="https://github.com/user-attachments/assets/fa28ef50-0110-4698-8eda-4faf6f0b9ef8"> - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
@@ -36,7 +36,7 @@
|
|||||||
"### Model features\n",
|
"### Model features\n",
|
||||||
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||||
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
|
"| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
|
||||||
"\n",
|
"\n",
|
||||||
"### Supported Methods\n",
|
"### Supported Methods\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -395,6 +395,66 @@
|
|||||||
"chat_model_external.invoke(\"How to use Databricks?\")"
|
"chat_model_external.invoke(\"How to use Databricks?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Function calling on Databricks"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Databricks Function Calling is OpenAI-compatible and is only available during model serving as part of Foundation Model APIs.\n",
|
||||||
|
"\n",
|
||||||
|
"See [Databricks function calling introduction](https://docs.databricks.com/en/machine-learning/model-serving/function-calling.html#supported-models) for supported models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"# supported tool_choice values: \"auto\", \"required\", \"none\", function name in string format,\n",
|
||||||
|
"# or a dictionary as {\"type\": \"function\", \"function\": {\"name\": <<tool_name>>}}\n",
|
||||||
|
"model = llm.bind_tools(tools, tool_choice=\"auto\")\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature of Chicago?\"}]\n",
|
||||||
|
"print(model.invoke(messages))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"See [Databricks Unity Catalog](docs/integrations/tools/databricks.ipynb) about how to use UC functions in chains."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@@ -38,7 +38,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
|
"%pip install --upgrade --quiet databricks-sdk langchain-community mlflow"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -47,9 +47,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_openai import ChatOpenAI\n",
|
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
|
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@@ -91,3 +91,4 @@ vdms>=0.0.20
|
|||||||
xata>=1.0.0a7,<2
|
xata>=1.0.0a7,<2
|
||||||
xmltodict>=0.13.0,<0.14
|
xmltodict>=0.13.0,<0.14
|
||||||
nanopq==0.2.1
|
nanopq==0.2.1
|
||||||
|
mlflow[genai]>=2.14.0
|
||||||
|
@@ -1,5 +1,19 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
@@ -15,15 +29,27 @@ from langchain_core.messages import (
|
|||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
InvalidToolCall,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolMessage,
|
||||||
|
ToolMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
make_invalid_tool_call,
|
||||||
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import (
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import Runnable, RunnableConfig
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -228,11 +254,32 @@ class ChatMlflow(BaseChatModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
role = _dict["role"]
|
role = _dict["role"]
|
||||||
content = _dict["content"]
|
content = cast(str, _dict.get("content"))
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=content)
|
return HumanMessage(content=content)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
return AIMessage(content=content)
|
content = content or ""
|
||||||
|
additional_kwargs: Dict = {}
|
||||||
|
tool_calls = []
|
||||||
|
invalid_tool_calls = []
|
||||||
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
|
for raw_tool_call in raw_tool_calls:
|
||||||
|
try:
|
||||||
|
tool_calls.append(
|
||||||
|
parse_tool_call(raw_tool_call, return_id=True)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
invalid_tool_calls.append(
|
||||||
|
make_invalid_tool_call(raw_tool_call, str(e))
|
||||||
|
)
|
||||||
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
id=_dict.get("id"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return SystemMessage(content=content)
|
return SystemMessage(content=content)
|
||||||
else:
|
else:
|
||||||
@@ -243,13 +290,38 @@ class ChatMlflow(BaseChatModel):
|
|||||||
_dict: Mapping[str, Any], default_role: str
|
_dict: Mapping[str, Any], default_role: str
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
role = _dict.get("role", default_role)
|
role = _dict.get("role", default_role)
|
||||||
content = _dict["content"]
|
content = _dict.get("content") or ""
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
return AIMessageChunk(content=content)
|
additional_kwargs: Dict = {}
|
||||||
|
tool_call_chunks = []
|
||||||
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
|
try:
|
||||||
|
tool_call_chunks = [
|
||||||
|
tool_call_chunk(
|
||||||
|
name=rtc["function"].get("name"),
|
||||||
|
args=rtc["function"].get("arguments"),
|
||||||
|
id=rtc.get("id"),
|
||||||
|
index=rtc["index"],
|
||||||
|
)
|
||||||
|
for rtc in raw_tool_calls
|
||||||
|
]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return AIMessageChunk(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
id=_dict.get("id"),
|
||||||
|
tool_call_chunks=tool_call_chunks,
|
||||||
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "tool":
|
||||||
|
return ToolMessageChunk(
|
||||||
|
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return ChatMessageChunk(content=content, role=role)
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
|
||||||
@@ -262,14 +334,47 @@ class ChatMlflow(BaseChatModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
message_dict = {"content": message.content}
|
||||||
|
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||||
|
message_dict["name"] = name
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
message_dict = {"role": message.role, "content": message.content}
|
message_dict["role"] = message.role
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict["role"] = "user"
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict["role"] = "assistant"
|
||||||
|
if message.tool_calls or message.invalid_tool_calls:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||||
|
] + [
|
||||||
|
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||||
|
for tc in message.invalid_tool_calls
|
||||||
|
] # type: ignore[assignment]
|
||||||
|
elif "tool_calls" in message.additional_kwargs:
|
||||||
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
|
tool_call_supported_props = {"id", "type", "function"}
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
{
|
||||||
|
k: v
|
||||||
|
for k, v in tool_call.items() # type: ignore[union-attr]
|
||||||
|
if k in tool_call_supported_props
|
||||||
|
}
|
||||||
|
for tool_call in message_dict["tool_calls"]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# If tool calls present, content null value should be None not empty string.
|
||||||
|
if "tool_calls" in message_dict:
|
||||||
|
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict["role"] = "system"
|
||||||
|
elif isinstance(message, ToolMessage):
|
||||||
|
message_dict["role"] = "tool"
|
||||||
|
message_dict["tool_call_id"] = message.tool_call_id
|
||||||
|
supported_props = {"content", "role", "tool_call_id"}
|
||||||
|
message_dict = {
|
||||||
|
k: v for k, v in message_dict.items() if k in supported_props
|
||||||
|
}
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Function messages are not supported by Databricks. Please"
|
"Function messages are not supported by Databricks. Please"
|
||||||
@@ -280,12 +385,6 @@ class ChatMlflow(BaseChatModel):
|
|||||||
|
|
||||||
if "function_call" in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
ChatMlflow._raise_functions_not_supported()
|
ChatMlflow._raise_functions_not_supported()
|
||||||
if message.additional_kwargs:
|
|
||||||
logger.warning(
|
|
||||||
"Additional message arguments are unsupported by Databricks"
|
|
||||||
" and will be ignored: %s",
|
|
||||||
message.additional_kwargs,
|
|
||||||
)
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -302,3 +401,89 @@ class ChatMlflow(BaseChatModel):
|
|||||||
|
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
return ChatResult(generations=generations, llm_output=usage)
|
return ChatResult(generations=generations, llm_output=usage)
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[
|
||||||
|
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tool-like objects to this chat model.
|
||||||
|
|
||||||
|
Assumes model is compatible with OpenAI tool-calling API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: A list of tool definitions to bind to this chat model.
|
||||||
|
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||||
|
models, callables, and BaseTools will be automatically converted to
|
||||||
|
their schema dictionary representation.
|
||||||
|
tool_choice: Which tool to require the model to call.
|
||||||
|
Options are:
|
||||||
|
name of the tool (str): calls corresponding tool;
|
||||||
|
"auto": automatically selects a tool (including no tool);
|
||||||
|
"none": model does not generate any tool calls and instead must
|
||||||
|
generate a standard assistant message;
|
||||||
|
"required": the model picks the most relevant tool in tools and
|
||||||
|
must generate a tool call;
|
||||||
|
|
||||||
|
or a dict of the form:
|
||||||
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
|
**kwargs: Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
"""
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
if tool_choice:
|
||||||
|
if isinstance(tool_choice, str):
|
||||||
|
# tool_choice is a tool/function name
|
||||||
|
if tool_choice not in ("auto", "none", "required"):
|
||||||
|
tool_choice = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": tool_choice},
|
||||||
|
}
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
tool_names = [
|
||||||
|
formatted_tool["function"]["name"]
|
||||||
|
for formatted_tool in formatted_tools
|
||||||
|
]
|
||||||
|
if not any(
|
||||||
|
tool_name == tool_choice["function"]["name"]
|
||||||
|
for tool_name in tool_names
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool choice {tool_choice} was specified, but the only "
|
||||||
|
f"provided tools were {tool_names}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||||
|
f"Received: {tool_choice}"
|
||||||
|
)
|
||||||
|
kwargs["tool_choice"] = tool_choice
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_invalid_tool_call_to_openai_tool_call(
|
||||||
|
invalid_tool_call: InvalidToolCall,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": invalid_tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": invalid_tool_call["name"],
|
||||||
|
"arguments": invalid_tool_call["args"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
423
libs/community/tests/unit_tests/chat_models/test_mlflow.py
Normal file
423
libs/community/tests/unit_tests/chat_models/test_mlflow.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
ToolCallChunk,
|
||||||
|
ToolMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
|
||||||
|
from langchain_community.chat_models.mlflow import ChatMlflow
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm() -> ChatMlflow:
|
||||||
|
return ChatMlflow(
|
||||||
|
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_input() -> List[BaseMessage]:
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "36939 * 8922.4"},
|
||||||
|
]
|
||||||
|
return [ChatMlflow._convert_dict_to_message(value) for value in data]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prediction() -> dict:
|
||||||
|
return {
|
||||||
|
"id": "chatcmpl_id",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1721875529,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
||||||
|
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_predict_stream_result() -> List[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "36939"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "x"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "8922.4"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": " = "},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1721877054,
|
||||||
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": ""},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("mlflow")
|
||||||
|
def test_chat_mlflow_predict(
|
||||||
|
llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict
|
||||||
|
) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
llm._client = mock_client
|
||||||
|
|
||||||
|
def mock_predict(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return mock_prediction
|
||||||
|
|
||||||
|
mock_client.predict = mock_predict
|
||||||
|
res = llm.invoke(model_input)
|
||||||
|
assert res.content == mock_prediction["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("mlflow")
|
||||||
|
def test_chat_mlflow_stream(
|
||||||
|
llm: ChatMlflow,
|
||||||
|
model_input: List[BaseMessage],
|
||||||
|
mock_predict_stream_result: List[dict],
|
||||||
|
) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
llm._client = mock_client
|
||||||
|
|
||||||
|
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
yield from mock_predict_stream_result
|
||||||
|
|
||||||
|
mock_client.predict_stream = mock_stream
|
||||||
|
for i, res in enumerate(llm.stream(model_input)):
|
||||||
|
assert (
|
||||||
|
res.content
|
||||||
|
== mock_predict_stream_result[i]["choices"][0]["delta"]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("mlflow")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
_PYDANTIC_MAJOR_VERSION < 2,
|
||||||
|
reason="The tool mock is not compatible with pydantic 1.x",
|
||||||
|
)
|
||||||
|
def test_chat_mlflow_bind_tools(
|
||||||
|
llm: ChatMlflow, mock_predict_stream_result: List[dict]
|
||||||
|
) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
llm._client = mock_client
|
||||||
|
|
||||||
|
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
yield from mock_predict_stream_result
|
||||||
|
|
||||||
|
mock_client.predict_stream = mock_stream
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
"You are a helpful assistant. Make sure to use tool for information.",
|
||||||
|
),
|
||||||
|
("placeholder", "{chat_history}"),
|
||||||
|
("human", "{input}"),
|
||||||
|
("placeholder", "{agent_scratchpad}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_func(*args: Any, **kwargs: Any) -> str:
|
||||||
|
return "36939 x 8922.4 = 329,511,111.6"
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
StructuredTool(
|
||||||
|
name="name",
|
||||||
|
description="description",
|
||||||
|
args_schema=BaseModel,
|
||||||
|
func=mock_func,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
agent = create_tool_calling_agent(llm, tools, prompt)
|
||||||
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type]
|
||||||
|
result = agent_executor.invoke({"input": "36939 * 8922.4"})
|
||||||
|
assert result["output"] == "36939x8922.4 = 329,511,111.6"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_human() -> None:
|
||||||
|
message = {"role": "user", "content": "foo"}
|
||||||
|
result = ChatMlflow._convert_dict_to_message(message)
|
||||||
|
expected_output = HumanMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_ai() -> None:
|
||||||
|
message = {"role": "assistant", "content": "foo"}
|
||||||
|
result = ChatMlflow._convert_dict_to_message(message)
|
||||||
|
expected_output = AIMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
{
|
||||||
|
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "main__test__python_exec",
|
||||||
|
"arguments": '{"code": "result = 36939 * 8922.4" }',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
message_with_tools: Dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_dict_to_message(message_with_tools)
|
||||||
|
expected_output = AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls},
|
||||||
|
id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
||||||
|
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
||||||
|
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_system() -> None:
|
||||||
|
message = {"role": "system", "content": "foo"}
|
||||||
|
result = ChatMlflow._convert_dict_to_message(message)
|
||||||
|
expected_output = SystemMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_dict_to_message_chat() -> None:
|
||||||
|
message = {"role": "any_role", "content": "foo"}
|
||||||
|
result = ChatMlflow._convert_dict_to_message(message)
|
||||||
|
expected_output = ChatMessage(content="foo", role="any_role")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_delta_to_message_chunk_ai() -> None:
|
||||||
|
delta = {"role": "assistant", "content": "foo"}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||||
|
expected_output = AIMessageChunk(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
delta_with_tools: Dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role")
|
||||||
|
expected_output = AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
||||||
|
id=None,
|
||||||
|
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
||||||
|
)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_delta_to_message_chunk_tool() -> None:
|
||||||
|
delta = {
|
||||||
|
"role": "tool",
|
||||||
|
"content": "foo",
|
||||||
|
"tool_call_id": "tool_call_id",
|
||||||
|
"id": "some_id",
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||||
|
expected_output = ToolMessageChunk(
|
||||||
|
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||||
|
)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_delta_to_message_chunk_human() -> None:
|
||||||
|
delta = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "foo",
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||||
|
expected_output = HumanMessageChunk(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_delta_to_message_chunk_system() -> None:
|
||||||
|
delta = {
|
||||||
|
"role": "system",
|
||||||
|
"content": "foo",
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||||
|
expected_output = SystemMessageChunk(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_delta_to_message_chunk_chat() -> None:
|
||||||
|
delta = {
|
||||||
|
"role": "any_role",
|
||||||
|
"content": "foo",
|
||||||
|
}
|
||||||
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||||
|
expected_output = ChatMessageChunk(content="foo", role="any_role")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_to_dict_human() -> None:
|
||||||
|
human_message = HumanMessage(content="foo")
|
||||||
|
result = ChatMlflow._convert_message_to_dict(human_message)
|
||||||
|
expected_output = {"role": "user", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_to_dict_system() -> None:
|
||||||
|
system_message = SystemMessage(content="foo")
|
||||||
|
result = ChatMlflow._convert_message_to_dict(system_message)
|
||||||
|
expected_output = {"role": "system", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_to_dict_ai() -> None:
|
||||||
|
ai_message = AIMessage(content="foo")
|
||||||
|
result = ChatMlflow._convert_message_to_dict(ai_message)
|
||||||
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}],
|
||||||
|
)
|
||||||
|
result = ChatMlflow._convert_message_to_dict(ai_message)
|
||||||
|
expected_output_with_tools: Dict[str, Any] = {
|
||||||
|
"content": None,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"id": "id",
|
||||||
|
"function": {"name": "name", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
assert result == expected_output_with_tools
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_to_dict_tool() -> None:
|
||||||
|
tool_message = ToolMessageChunk(
|
||||||
|
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||||
|
)
|
||||||
|
result = ChatMlflow._convert_message_to_dict(tool_message)
|
||||||
|
expected_output = {
|
||||||
|
"role": "tool",
|
||||||
|
"content": "foo",
|
||||||
|
"tool_call_id": "tool_call_id",
|
||||||
|
}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_message_to_dict_function() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))
|
Reference in New Issue
Block a user