community(sparkllm): Add function call support in Sparkllm chat model. (#20607)

- **Description:** Add function call support in Sparkllm chat model.
Related documents
https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E
- @baskaryan

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Guangdong Liu 2024-08-29 22:38:39 +08:00 committed by GitHub
parent 37f5ba416e
commit fcf9230257
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 186 additions and 10 deletions

View File

@ -8,7 +8,7 @@ import threading
from datetime import datetime
from queue import Queue
from time import mktime
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type, cast
from urllib.parse import urlencode, urlparse, urlunparse
from wsgiref.handlers import format_date_time
@ -26,9 +26,15 @@ from langchain_core.messages import (
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
@ -48,13 +54,24 @@ SPARK_API_URL = "wss://spark-api.xf-yun.com/v3.5/chat"
SPARK_LLM_DOMAIN = "generalv3.5"
def _convert_message_to_dict(message: BaseMessage) -> dict:
def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
@ -63,14 +80,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
msg_role = _dict["role"]
msg_content = _dict["content"]
if msg_role == "user":
return HumanMessage(content=msg_content)
elif msg_role == "assistant":
invalid_tool_calls = []
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in _dict["tool_calls"]:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
additional_kwargs = {}
content = msg_content or ""
return AIMessage(content=content)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif msg_role == "system":
return SystemMessage(content=msg_content)
else:
@ -80,12 +118,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
msg_role = _dict["role"]
msg_content = _dict.get("content", "")
msg_role = cast(str, _dict.get("role"))
msg_content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if msg_role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=msg_content)
elif msg_role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=msg_content)
return AIMessageChunk(content=msg_content, additional_kwargs=additional_kwargs)
elif msg_role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=msg_content, name=_dict["name"])
elif msg_role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=msg_content, tool_call_id=_dict["tool_call_id"])
elif msg_role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=msg_content, role=msg_role)
else:
@ -335,7 +385,7 @@ class ChatSparkLLM(BaseChatModel):
default_chunk_class = AIMessageChunk
self.client.arun(
[_convert_message_to_dict(m) for m in messages],
[convert_message_to_dict(m) for m in messages],
self.spark_user_id,
self.model_kwargs,
streaming=True,
@ -365,7 +415,7 @@ class ChatSparkLLM(BaseChatModel):
return generate_from_stream(stream_iter)
self.client.arun(
[_convert_message_to_dict(m) for m in messages],
[convert_message_to_dict(m) for m in messages],
self.spark_user_id,
self.model_kwargs,
False,
@ -378,7 +428,7 @@ class ChatSparkLLM(BaseChatModel):
if "data" not in content:
continue
completion = content["data"]
message = _convert_dict_to_message(completion)
message = convert_dict_to_message(completion)
generations = [ChatGeneration(message=message)]
return ChatResult(generations=generations, llm_output=llm_output)

View File

@ -1,7 +1,48 @@
from typing import Any
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_community.chat_models.sparkllm import ChatSparkLLM
_FUNCTIONS: Any = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
def test_functions_call_thoughts() -> None:
chat = ChatSparkLLM(timeout=30)
prompt_tmpl = "Use the given functions to answer following question: {input}"
prompt_msgs = [
HumanMessagePromptTemplate.from_template(prompt_tmpl),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
chain = prompt | chat.bind(functions=_FUNCTIONS)
message = HumanMessage(content="What's the weather like in Shanghai today?")
response = chain.batch([{"input": message}])
assert isinstance(response[0], AIMessage)
assert "tool_calls" in response[0].additional_kwargs
def test_initialization() -> None:
"""Test chat model initialization."""

View File

@ -0,0 +1,85 @@
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_community.chat_models.sparkllm import (
convert_dict_to_message,
convert_message_to_dict,
)
def test__convert_dict_to_message_human() -> None:
message_dict = {"role": "user", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message_dict = {"role": "assistant", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = convert_dict_to_message(message_dict)
expected_output = SystemMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_function_call() -> None:
raw_function_calls = [
{
"function": {
"name": "get_current_weather",
"arguments": '{"location": "Boston", "unit": "fahrenheit"}',
},
"type": "function",
}
]
message_dict = {
"role": "assistant",
"content": "foo",
"tool_calls": raw_function_calls,
}
result = convert_dict_to_message(message_dict)
tool_calls = [
parse_tool_call(raw_tool_call, return_id=True)
for raw_tool_call in raw_function_calls
]
expected_output = AIMessage(
content="foo",
additional_kwargs={"tool_calls": raw_function_calls},
tool_calls=tool_calls,
invalid_tool_calls=[],
)
assert result == expected_output
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
result = convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output