mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 12:25:37 +00:00
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
424 lines
13 KiB
Python
424 lines
13 KiB
Python
import json
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
ChatMessage,
|
|
ChatMessageChunk,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
HumanMessageChunk,
|
|
SystemMessage,
|
|
SystemMessageChunk,
|
|
ToolCallChunk,
|
|
ToolMessageChunk,
|
|
)
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.tools import StructuredTool
|
|
from pydantic import BaseModel
|
|
|
|
from langchain_community.chat_models.mlflow import ChatMlflow
|
|
|
|
|
|
@pytest.fixture
|
|
def llm() -> ChatMlflow:
|
|
return ChatMlflow(
|
|
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def model_input() -> List[BaseMessage]:
|
|
data = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant.",
|
|
},
|
|
{"role": "user", "content": "36939 * 8922.4"},
|
|
]
|
|
return [ChatMlflow._convert_dict_to_message(value) for value in data]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_prediction() -> dict:
|
|
return {
|
|
"id": "chatcmpl_id",
|
|
"object": "chat.completion",
|
|
"created": 1721875529,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
|
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
|
},
|
|
"finish_reason": "stop",
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_predict_stream_result() -> List[dict]:
|
|
return [
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "36939"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
|
},
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "x"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
|
},
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "8922.4"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
|
},
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": " = "},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
|
},
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
|
"finish_reason": None,
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
|
},
|
|
{
|
|
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
"object": "chat.completion.chunk",
|
|
"created": 1721877054,
|
|
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": ""},
|
|
"finish_reason": "stop",
|
|
"logprobs": None,
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.mark.requires("mlflow")
|
|
def test_chat_mlflow_predict(
|
|
llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict
|
|
) -> None:
|
|
mock_client = MagicMock()
|
|
llm._client = mock_client
|
|
|
|
def mock_predict(*args: Any, **kwargs: Any) -> Any:
|
|
return mock_prediction
|
|
|
|
mock_client.predict = mock_predict
|
|
res = llm.invoke(model_input)
|
|
assert res.content == mock_prediction["choices"][0]["message"]["content"]
|
|
|
|
|
|
@pytest.mark.requires("mlflow")
|
|
def test_chat_mlflow_stream(
|
|
llm: ChatMlflow,
|
|
model_input: List[BaseMessage],
|
|
mock_predict_stream_result: List[dict],
|
|
) -> None:
|
|
mock_client = MagicMock()
|
|
llm._client = mock_client
|
|
|
|
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield from mock_predict_stream_result
|
|
|
|
mock_client.predict_stream = mock_stream
|
|
for i, res in enumerate(llm.stream(model_input)):
|
|
assert (
|
|
res.content
|
|
== mock_predict_stream_result[i]["choices"][0]["delta"]["content"]
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("mlflow")
|
|
def test_chat_mlflow_bind_tools(
|
|
llm: ChatMlflow, mock_predict_stream_result: List[dict]
|
|
) -> None:
|
|
mock_client = MagicMock()
|
|
llm._client = mock_client
|
|
|
|
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
|
yield from mock_predict_stream_result
|
|
|
|
mock_client.predict_stream = mock_stream
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"You are a helpful assistant. Make sure to use tool for information.",
|
|
),
|
|
("placeholder", "{chat_history}"),
|
|
("human", "{input}"),
|
|
("placeholder", "{agent_scratchpad}"),
|
|
]
|
|
)
|
|
|
|
def mock_func(x: int, y: int) -> str:
|
|
return "36939 x 8922.4 = 329,511,111.6"
|
|
|
|
class ArgsSchema(BaseModel):
|
|
x: int
|
|
y: int
|
|
|
|
tools = [
|
|
StructuredTool(
|
|
name="name",
|
|
description="description",
|
|
args_schema=ArgsSchema,
|
|
func=mock_func,
|
|
)
|
|
]
|
|
agent = create_tool_calling_agent(llm, tools, prompt)
|
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type]
|
|
result = agent_executor.invoke({"input": "36939 * 8922.4"})
|
|
assert result["output"] == "36939x8922.4 = 329,511,111.6"
|
|
|
|
|
|
def test_convert_dict_to_message_human() -> None:
|
|
message = {"role": "user", "content": "foo"}
|
|
result = ChatMlflow._convert_dict_to_message(message)
|
|
expected_output = HumanMessage(content="foo")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_dict_to_message_ai() -> None:
|
|
message = {"role": "assistant", "content": "foo"}
|
|
result = ChatMlflow._convert_dict_to_message(message)
|
|
expected_output = AIMessage(content="foo")
|
|
assert result == expected_output
|
|
|
|
tool_calls = [
|
|
{
|
|
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "main__test__python_exec",
|
|
"arguments": '{"code": "result = 36939 * 8922.4" }',
|
|
},
|
|
}
|
|
]
|
|
message_with_tools: Dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
result = ChatMlflow._convert_dict_to_message(message_with_tools)
|
|
expected_output = AIMessage(
|
|
content="",
|
|
additional_kwargs={"tool_calls": tool_calls},
|
|
id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
|
tool_calls=[
|
|
{
|
|
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
|
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
|
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
|
"type": "tool_call",
|
|
}
|
|
],
|
|
)
|
|
|
|
|
|
def test_convert_dict_to_message_system() -> None:
|
|
message = {"role": "system", "content": "foo"}
|
|
result = ChatMlflow._convert_dict_to_message(message)
|
|
expected_output = SystemMessage(content="foo")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_dict_to_message_chat() -> None:
|
|
message = {"role": "any_role", "content": "foo"}
|
|
result = ChatMlflow._convert_dict_to_message(message)
|
|
expected_output = ChatMessage(content="foo", role="any_role")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_delta_to_message_chunk_ai() -> None:
|
|
delta = {"role": "assistant", "content": "foo"}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
|
expected_output = AIMessageChunk(content="foo")
|
|
assert result == expected_output
|
|
|
|
delta_with_tools: Dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
|
}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role")
|
|
expected_output = AIMessageChunk(
|
|
content="",
|
|
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
|
id=None,
|
|
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
|
)
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_delta_to_message_chunk_tool() -> None:
|
|
delta = {
|
|
"role": "tool",
|
|
"content": "foo",
|
|
"tool_call_id": "tool_call_id",
|
|
"id": "some_id",
|
|
}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
|
expected_output = ToolMessageChunk(
|
|
content="foo", id="some_id", tool_call_id="tool_call_id"
|
|
)
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_delta_to_message_chunk_human() -> None:
|
|
delta = {
|
|
"role": "user",
|
|
"content": "foo",
|
|
}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
|
expected_output = HumanMessageChunk(content="foo")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_delta_to_message_chunk_system() -> None:
|
|
delta = {
|
|
"role": "system",
|
|
"content": "foo",
|
|
}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
|
expected_output = SystemMessageChunk(content="foo")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_delta_to_message_chunk_chat() -> None:
|
|
delta = {
|
|
"role": "any_role",
|
|
"content": "foo",
|
|
}
|
|
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
|
expected_output = ChatMessageChunk(content="foo", role="any_role")
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_message_to_dict_human() -> None:
|
|
human_message = HumanMessage(content="foo")
|
|
result = ChatMlflow._convert_message_to_dict(human_message)
|
|
expected_output = {"role": "user", "content": "foo"}
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_message_to_dict_system() -> None:
|
|
system_message = SystemMessage(content="foo")
|
|
result = ChatMlflow._convert_message_to_dict(system_message)
|
|
expected_output = {"role": "system", "content": "foo"}
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_message_to_dict_ai() -> None:
|
|
ai_message = AIMessage(content="foo")
|
|
result = ChatMlflow._convert_message_to_dict(ai_message)
|
|
expected_output = {"role": "assistant", "content": "foo"}
|
|
assert result == expected_output
|
|
|
|
ai_message = AIMessage(
|
|
content="",
|
|
tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}],
|
|
)
|
|
result = ChatMlflow._convert_message_to_dict(ai_message)
|
|
expected_output_with_tools: Dict[str, Any] = {
|
|
"content": None,
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"type": "function",
|
|
"id": "id",
|
|
"function": {"name": "name", "arguments": "{}"},
|
|
}
|
|
],
|
|
}
|
|
assert result == expected_output_with_tools
|
|
|
|
|
|
def test_convert_message_to_dict_tool() -> None:
|
|
tool_message = ToolMessageChunk(
|
|
content="foo", id="some_id", tool_call_id="tool_call_id"
|
|
)
|
|
result = ChatMlflow._convert_message_to_dict(tool_message)
|
|
expected_output = {
|
|
"role": "tool",
|
|
"content": "foo",
|
|
"tool_call_id": "tool_call_id",
|
|
}
|
|
assert result == expected_output
|
|
|
|
|
|
def test_convert_message_to_dict_function() -> None:
|
|
with pytest.raises(ValueError):
|
|
ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))
|