mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
community[patch]: Add function call support in Tongyi chat model. (#20119)
- [ ] **PR message**: - **Description:** This pr adds function calling support in Tongyi chat model. - **Issue:** None - **Dependencies:** None - **Twitter handle:** None Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
80679ab906
commit
ecd19a9e58
@ -33,6 +33,10 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
make_invalid_tool_call,
|
||||||
|
parse_tool_call,
|
||||||
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
@ -71,8 +75,28 @@ def convert_dict_to_message(
|
|||||||
else HumanMessage(content=content)
|
else HumanMessage(content=content)
|
||||||
)
|
)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
|
tool_calls = []
|
||||||
|
invalid_tool_calls = []
|
||||||
|
if "tool_calls" in _dict:
|
||||||
|
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
|
||||||
|
for raw_tool_call in _dict["tool_calls"]:
|
||||||
|
try:
|
||||||
|
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||||
|
except Exception as e:
|
||||||
|
invalid_tool_calls.append(
|
||||||
|
make_invalid_tool_call(raw_tool_call, str(e))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
return (
|
return (
|
||||||
AIMessageChunk(content=content) if is_chunk else AIMessage(content=content)
|
AIMessageChunk(content=content)
|
||||||
|
if is_chunk
|
||||||
|
else AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return (
|
return (
|
||||||
|
@ -1,15 +1,37 @@
|
|||||||
"""Test Alibaba Tongyi Chat Model."""
|
"""Test Alibaba Tongyi Chat Model."""
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
from pytest import CaptureFixture
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
_FUNCTIONS: Any = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test chat model initialization."""
|
"""Test chat model initialization."""
|
||||||
@ -52,6 +74,23 @@ def test_model() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_functions_call_thoughts() -> None:
|
||||||
|
chat = ChatTongyi(model="qwen-plus")
|
||||||
|
|
||||||
|
prompt_tmpl = "Use the given functions to answer following question: {input}"
|
||||||
|
prompt_msgs = [
|
||||||
|
HumanMessagePromptTemplate.from_template(prompt_tmpl),
|
||||||
|
]
|
||||||
|
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||||
|
|
||||||
|
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||||
|
|
||||||
|
message = HumanMessage(content="What's the weather like in Shanghai today?")
|
||||||
|
response = chain.batch([{"input": message}])
|
||||||
|
assert isinstance(response[0], AIMessage)
|
||||||
|
assert "tool_calls" in response[0].additional_kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_history() -> None:
|
def test_multiple_history() -> None:
|
||||||
"""Tests multiple history works."""
|
"""Tests multiple history works."""
|
||||||
chat = ChatTongyi()
|
chat = ChatTongyi()
|
||||||
|
85
libs/community/tests/unit_tests/chat_models/test_tongyi.py
Normal file
85
libs/community/tests/unit_tests/chat_models/test_tongyi.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
parse_tool_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.chat_models.tongyi import (
|
||||||
|
convert_dict_to_message,
|
||||||
|
convert_message_to_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_human() -> None:
|
||||||
|
message_dict = {"role": "user", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = HumanMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_ai() -> None:
|
||||||
|
message_dict = {"role": "assistant", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = AIMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_other_role() -> None:
|
||||||
|
message_dict = {"role": "system", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = SystemMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_function_call() -> None:
|
||||||
|
raw_function_calls = [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{"location": "Boston", "unit": "fahrenheit"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
message_dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "foo",
|
||||||
|
"tool_calls": raw_function_calls,
|
||||||
|
}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
parse_tool_call(raw_tool_call, return_id=True)
|
||||||
|
for raw_tool_call in raw_function_calls
|
||||||
|
]
|
||||||
|
expected_output = AIMessage(
|
||||||
|
content="foo",
|
||||||
|
additional_kwargs={"tool_calls": raw_function_calls},
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=[],
|
||||||
|
)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_human() -> None:
|
||||||
|
message = HumanMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "user", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_ai() -> None:
|
||||||
|
message = AIMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_system() -> None:
|
||||||
|
message = SystemMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "system", "content": "foo"}
|
||||||
|
assert result == expected_output
|
Loading…
Reference in New Issue
Block a user