mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
community(sparkllm): Add function call support in Sparkllm chat model. (#20607)
- **Description:** Add function call support in Sparkllm chat model. Related documents https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E - @baskaryan --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
37f5ba416e
commit
fcf9230257
@ -8,7 +8,7 @@ import threading
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from time import mktime
|
from time import mktime
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type
|
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type, cast
|
||||||
from urllib.parse import urlencode, urlparse, urlunparse
|
from urllib.parse import urlencode, urlparse, urlunparse
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
@ -26,9 +26,15 @@ from langchain_core.messages import (
|
|||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
ToolMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
make_invalid_tool_call,
|
||||||
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -48,13 +54,24 @@ SPARK_API_URL = "wss://spark-api.xf-yun.com/v3.5/chat"
|
|||||||
SPARK_LLM_DOMAIN = "generalv3.5"
|
SPARK_LLM_DOMAIN = "generalv3.5"
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
|
# If function call only, content is None not empty string
|
||||||
|
if message_dict["content"] == "":
|
||||||
|
message_dict["content"] = None
|
||||||
|
if "tool_calls" in message.additional_kwargs:
|
||||||
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
|
# If tool calls only, content is None not empty string
|
||||||
|
if message_dict["content"] == "":
|
||||||
|
message_dict["content"] = None
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
else:
|
else:
|
||||||
@ -63,14 +80,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
msg_role = _dict["role"]
|
msg_role = _dict["role"]
|
||||||
msg_content = _dict["content"]
|
msg_content = _dict["content"]
|
||||||
if msg_role == "user":
|
if msg_role == "user":
|
||||||
return HumanMessage(content=msg_content)
|
return HumanMessage(content=msg_content)
|
||||||
elif msg_role == "assistant":
|
elif msg_role == "assistant":
|
||||||
|
invalid_tool_calls = []
|
||||||
|
additional_kwargs: Dict = {}
|
||||||
|
if function_call := _dict.get("function_call"):
|
||||||
|
additional_kwargs["function_call"] = dict(function_call)
|
||||||
|
tool_calls = []
|
||||||
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
|
for raw_tool_call in _dict["tool_calls"]:
|
||||||
|
try:
|
||||||
|
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||||
|
except Exception as e:
|
||||||
|
invalid_tool_calls.append(
|
||||||
|
make_invalid_tool_call(raw_tool_call, str(e))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
content = msg_content or ""
|
content = msg_content or ""
|
||||||
return AIMessage(content=content)
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
|
)
|
||||||
elif msg_role == "system":
|
elif msg_role == "system":
|
||||||
return SystemMessage(content=msg_content)
|
return SystemMessage(content=msg_content)
|
||||||
else:
|
else:
|
||||||
@ -80,12 +118,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
def _convert_delta_to_message_chunk(
|
def _convert_delta_to_message_chunk(
|
||||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
msg_role = _dict["role"]
|
msg_role = cast(str, _dict.get("role"))
|
||||||
msg_content = _dict.get("content", "")
|
msg_content = cast(str, _dict.get("content") or "")
|
||||||
|
additional_kwargs: Dict = {}
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
function_call = dict(_dict["function_call"])
|
||||||
|
if "name" in function_call and function_call["name"] is None:
|
||||||
|
function_call["name"] = ""
|
||||||
|
additional_kwargs["function_call"] = function_call
|
||||||
|
if _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||||
if msg_role == "user" or default_class == HumanMessageChunk:
|
if msg_role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=msg_content)
|
return HumanMessageChunk(content=msg_content)
|
||||||
elif msg_role == "assistant" or default_class == AIMessageChunk:
|
elif msg_role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=msg_content)
|
return AIMessageChunk(content=msg_content, additional_kwargs=additional_kwargs)
|
||||||
|
elif msg_role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=msg_content, name=_dict["name"])
|
||||||
|
elif msg_role == "tool" or default_class == ToolMessageChunk:
|
||||||
|
return ToolMessageChunk(content=msg_content, tool_call_id=_dict["tool_call_id"])
|
||||||
elif msg_role or default_class == ChatMessageChunk:
|
elif msg_role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=msg_content, role=msg_role)
|
return ChatMessageChunk(content=msg_content, role=msg_role)
|
||||||
else:
|
else:
|
||||||
@ -335,7 +385,7 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
|
|
||||||
self.client.arun(
|
self.client.arun(
|
||||||
[_convert_message_to_dict(m) for m in messages],
|
[convert_message_to_dict(m) for m in messages],
|
||||||
self.spark_user_id,
|
self.spark_user_id,
|
||||||
self.model_kwargs,
|
self.model_kwargs,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
@ -365,7 +415,7 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
self.client.arun(
|
self.client.arun(
|
||||||
[_convert_message_to_dict(m) for m in messages],
|
[convert_message_to_dict(m) for m in messages],
|
||||||
self.spark_user_id,
|
self.spark_user_id,
|
||||||
self.model_kwargs,
|
self.model_kwargs,
|
||||||
False,
|
False,
|
||||||
@ -378,7 +428,7 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
if "data" not in content:
|
if "data" not in content:
|
||||||
continue
|
continue
|
||||||
completion = content["data"]
|
completion = content["data"]
|
||||||
message = _convert_dict_to_message(completion)
|
message = convert_dict_to_message(completion)
|
||||||
generations = [ChatGeneration(message=message)]
|
generations = [ChatGeneration(message=message)]
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
@ -1,7 +1,48 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
|
|
||||||
from langchain_community.chat_models.sparkllm import ChatSparkLLM
|
from langchain_community.chat_models.sparkllm import ChatSparkLLM
|
||||||
|
|
||||||
|
_FUNCTIONS: Any = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_functions_call_thoughts() -> None:
|
||||||
|
chat = ChatSparkLLM(timeout=30)
|
||||||
|
|
||||||
|
prompt_tmpl = "Use the given functions to answer following question: {input}"
|
||||||
|
prompt_msgs = [
|
||||||
|
HumanMessagePromptTemplate.from_template(prompt_tmpl),
|
||||||
|
]
|
||||||
|
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||||
|
|
||||||
|
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||||
|
|
||||||
|
message = HumanMessage(content="What's the weather like in Shanghai today?")
|
||||||
|
response = chain.batch([{"input": message}])
|
||||||
|
assert isinstance(response[0], AIMessage)
|
||||||
|
assert "tool_calls" in response[0].additional_kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test chat model initialization."""
|
"""Test chat model initialization."""
|
||||||
|
85
libs/community/tests/unit_tests/chat_models/test_sparkllm.py
Normal file
85
libs/community/tests/unit_tests/chat_models/test_sparkllm.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
parse_tool_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.chat_models.sparkllm import (
|
||||||
|
convert_dict_to_message,
|
||||||
|
convert_message_to_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_human() -> None:
|
||||||
|
message_dict = {"role": "user", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = HumanMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_ai() -> None:
|
||||||
|
message_dict = {"role": "assistant", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = AIMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_other_role() -> None:
|
||||||
|
message_dict = {"role": "system", "content": "foo"}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
expected_output = SystemMessage(content="foo")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_function_call() -> None:
|
||||||
|
raw_function_calls = [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{"location": "Boston", "unit": "fahrenheit"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
message_dict = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "foo",
|
||||||
|
"tool_calls": raw_function_calls,
|
||||||
|
}
|
||||||
|
result = convert_dict_to_message(message_dict)
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
parse_tool_call(raw_tool_call, return_id=True)
|
||||||
|
for raw_tool_call in raw_function_calls
|
||||||
|
]
|
||||||
|
expected_output = AIMessage(
|
||||||
|
content="foo",
|
||||||
|
additional_kwargs={"tool_calls": raw_function_calls},
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
invalid_tool_calls=[],
|
||||||
|
)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_human() -> None:
|
||||||
|
message = HumanMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "user", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_ai() -> None:
|
||||||
|
message = AIMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_system() -> None:
|
||||||
|
message = SystemMessage(content="foo")
|
||||||
|
result = convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "system", "content": "foo"}
|
||||||
|
assert result == expected_output
|
Loading…
Reference in New Issue
Block a user