community[patch]: Add function call support in Tongyi chat model. (#20119)

- [ ] **PR message**: 
- **Description:** This pr adds function calling support in Tongyi chat
model.
    - **Issue:** None
    - **Dependencies:** None
    - **Twitter handle:** None

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Pengcheng Liu 2024-04-18 04:42:23 +08:00 committed by GitHub
parent 80679ab906
commit ecd19a9e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 2 deletions

View File

@ -33,6 +33,10 @@ from langchain_core.messages import (
SystemMessage,
SystemMessageChunk,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@ -71,8 +75,28 @@ def convert_dict_to_message(
else HumanMessage(content=content)
)
elif role == "assistant":
tool_calls = []
invalid_tool_calls = []
if "tool_calls" in _dict:
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
for raw_tool_call in _dict["tool_calls"]:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
additional_kwargs = {}
return (
AIMessageChunk(content=content) if is_chunk else AIMessage(content=content)
AIMessageChunk(content=content)
if is_chunk
else AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
)
elif role == "system":
return (

View File

@ -1,15 +1,37 @@
"""Test Alibaba Tongyi Chat Model."""
from typing import cast
from typing import Any, cast
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_community.chat_models.tongyi import ChatTongyi
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
_FUNCTIONS: Any = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
def test_initialization() -> None:
"""Test chat model initialization."""
@ -52,6 +74,23 @@ def test_model() -> None:
assert isinstance(response.content, str)
def test_functions_call_thoughts() -> None:
chat = ChatTongyi(model="qwen-plus")
prompt_tmpl = "Use the given functions to answer following question: {input}"
prompt_msgs = [
HumanMessagePromptTemplate.from_template(prompt_tmpl),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
chain = prompt | chat.bind(functions=_FUNCTIONS)
message = HumanMessage(content="What's the weather like in Shanghai today?")
response = chain.batch([{"input": message}])
assert isinstance(response[0], AIMessage)
assert "tool_calls" in response[0].additional_kwargs
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = ChatTongyi()

View File

@ -0,0 +1,85 @@
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_community.chat_models.tongyi import (
convert_dict_to_message,
convert_message_to_dict,
)
def test__convert_dict_to_message_human() -> None:
message_dict = {"role": "user", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message_dict = {"role": "assistant", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = SystemMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_function_call() -> None:
raw_function_calls = [
{
"function": {
"name": "get_current_weather",
"arguments": '{"location": "Boston", "unit": "fahrenheit"}',
},
"type": "function",
}
]
message_dict = {
"role": "assistant",
"content": "foo",
"tool_calls": raw_function_calls,
}
result = convert_dict_to_message(message_dict)
tool_calls = [
parse_tool_call(raw_tool_call, return_id=True)
for raw_tool_call in raw_function_calls
]
expected_output = AIMessage(
content="foo",
additional_kwargs={"tool_calls": raw_function_calls},
tool_calls=tool_calls,
invalid_tool_calls=[],
)
assert result == expected_output
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output