From ecd19a9e58ca67d02b330ae150576fd96601cc4a Mon Sep 17 00:00:00 2001 From: Pengcheng Liu Date: Thu, 18 Apr 2024 04:42:23 +0800 Subject: [PATCH] community[patch]: Add function call support in Tongyi chat model. (#20119) - [ ] **PR message**: - **Description:** This pr adds function calling support in Tongyi chat model. - **Issue:** None - **Dependencies:** None - **Twitter handle:** None Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_community/chat_models/tongyi.py | 26 +++++- .../chat_models/test_tongyi.py | 41 ++++++++- .../unit_tests/chat_models/test_tongyi.py | 85 +++++++++++++++++++ 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_tongyi.py diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 01ab29abb56..943cace9733 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -33,6 +33,10 @@ from langchain_core.messages import ( SystemMessage, SystemMessageChunk, ) +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, +) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -71,8 +75,28 @@ def convert_dict_to_message( else HumanMessage(content=content) ) elif role == "assistant": + tool_calls = [] + invalid_tool_calls = [] + if "tool_calls" in _dict: + additional_kwargs = {"tool_calls": _dict["tool_calls"]} + for raw_tool_call in _dict["tool_calls"]: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + else: + additional_kwargs = {} return ( - AIMessageChunk(content=content) if is_chunk else AIMessage(content=content) + AIMessageChunk(content=content) + if is_chunk + else AIMessage( + content=content, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) ) elif role == "system": return ( diff --git a/libs/community/tests/integration_tests/chat_models/test_tongyi.py b/libs/community/tests/integration_tests/chat_models/test_tongyi.py index 475db4315c0..73591bb4e3d 100644 --- a/libs/community/tests/integration_tests/chat_models/test_tongyi.py +++ b/libs/community/tests/integration_tests/chat_models/test_tongyi.py @@ -1,15 +1,37 @@ """Test Alibaba Tongyi Chat Model.""" -from typing import cast +from typing import Any, cast from langchain_core.callbacks import CallbackManager from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture from langchain_community.chat_models.tongyi import ChatTongyi from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +_FUNCTIONS: Any = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] + def test_initialization() -> None: """Test chat model initialization.""" @@ -52,6 +74,23 @@ def test_model() -> None: assert isinstance(response.content, str) +def test_functions_call_thoughts() -> None: + chat = ChatTongyi(model="qwen-plus") + + prompt_tmpl = "Use the given functions to answer following question: {input}" + prompt_msgs = [ + HumanMessagePromptTemplate.from_template(prompt_tmpl), + ] + prompt = ChatPromptTemplate(messages=prompt_msgs) + + chain = prompt | chat.bind(functions=_FUNCTIONS) + + message = HumanMessage(content="What's the weather like in Shanghai today?") + response = chain.batch([{"input": message}]) + assert isinstance(response[0], AIMessage) + assert "tool_calls" in response[0].additional_kwargs + + def test_multiple_history() -> None: """Tests multiple history works.""" chat = ChatTongyi() diff --git a/libs/community/tests/unit_tests/chat_models/test_tongyi.py b/libs/community/tests/unit_tests/chat_models/test_tongyi.py new file mode 100644 index 00000000000..62421b3e610 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_tongyi.py @@ -0,0 +1,85 @@ +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.output_parsers.openai_tools import ( + parse_tool_call, +) + +from langchain_community.chat_models.tongyi import ( + convert_dict_to_message, + convert_message_to_dict, +) + + +def test__convert_dict_to_message_human() -> None: + message_dict = {"role": "user", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message_dict = {"role": "assistant", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = AIMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_other_role() -> None: + message_dict = {"role": "system", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = SystemMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_function_call() -> None: + raw_function_calls = [ + { + "function": { + "name": "get_current_weather", + "arguments": '{"location": "Boston", "unit": "fahrenheit"}', + }, + "type": "function", + } + ] + message_dict = { + "role": "assistant", + "content": "foo", + "tool_calls": raw_function_calls, + } + result = convert_dict_to_message(message_dict) + + tool_calls = [ + parse_tool_call(raw_tool_call, return_id=True) + for raw_tool_call in raw_function_calls + ] + expected_output = AIMessage( + content="foo", + additional_kwargs={"tool_calls": raw_function_calls}, + tool_calls=tool_calls, + invalid_tool_calls=[], + ) + assert result == expected_output + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output