diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index dcc26a5357f..adfeed5517d 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -8,7 +8,7 @@ import threading from datetime import datetime from queue import Queue from time import mktime -from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type +from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Type, cast from urllib.parse import urlencode, urlparse, urlunparse from wsgiref.handlers import format_date_time @@ -26,9 +26,15 @@ from langchain_core.messages import ( BaseMessageChunk, ChatMessage, ChatMessageChunk, + FunctionMessageChunk, HumanMessage, HumanMessageChunk, SystemMessage, + ToolMessageChunk, +) +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ( ChatGeneration, @@ -48,13 +54,24 @@ SPARK_API_URL = "wss://spark-api.xf-yun.com/v3.5/chat" SPARK_LLM_DOMAIN = "generalv3.5" -def _convert_message_to_dict(message: BaseMessage) -> dict: +def convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + # If function call only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + # If tool calls only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} else: @@ -63,14 +80,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: +def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: msg_role = _dict["role"] msg_content = _dict["content"] if msg_role == "user": return HumanMessage(content=msg_content) elif msg_role == "assistant": + invalid_tool_calls = [] + additional_kwargs: Dict = {} + if function_call := _dict.get("function_call"): + additional_kwargs["function_call"] = dict(function_call) + tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in _dict["tool_calls"]: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + else: + additional_kwargs = {} content = msg_content or "" - return AIMessage(content=content) + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) elif msg_role == "system": return SystemMessage(content=msg_content) else: @@ -80,12 +118,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_delta_to_message_chunk( _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: - msg_role = _dict["role"] - msg_content = _dict.get("content", "") + msg_role = cast(str, _dict.get("role")) + msg_content = cast(str, _dict.get("content") or "") + additional_kwargs: Dict = {} + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] if msg_role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=msg_content) elif msg_role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=msg_content) + return AIMessageChunk(content=msg_content, additional_kwargs=additional_kwargs) + elif msg_role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=msg_content, name=_dict["name"]) + elif msg_role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk(content=msg_content, tool_call_id=_dict["tool_call_id"]) elif msg_role or default_class == ChatMessageChunk: return ChatMessageChunk(content=msg_content, role=msg_role) else: @@ -335,7 +385,7 @@ class ChatSparkLLM(BaseChatModel): default_chunk_class = AIMessageChunk self.client.arun( - [_convert_message_to_dict(m) for m in messages], + [convert_message_to_dict(m) for m in messages], self.spark_user_id, self.model_kwargs, streaming=True, @@ -365,7 +415,7 @@ class ChatSparkLLM(BaseChatModel): return generate_from_stream(stream_iter) self.client.arun( - [_convert_message_to_dict(m) for m in messages], + [convert_message_to_dict(m) for m in messages], self.spark_user_id, self.model_kwargs, False, @@ -378,7 +428,7 @@ class ChatSparkLLM(BaseChatModel): if "data" not in content: continue completion = content["data"] - message = _convert_dict_to_message(completion) + message = convert_dict_to_message(completion) generations = [ChatGeneration(message=message)] return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py index ae94a8a3e60..5eb241b66a5 100644 --- a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py @@ -1,7 +1,48 @@ +from typing import Any + from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_community.chat_models.sparkllm import ChatSparkLLM +_FUNCTIONS: Any = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] + + +def test_functions_call_thoughts() -> None: + chat = ChatSparkLLM(timeout=30) + + prompt_tmpl = "Use the given functions to answer following question: {input}" + prompt_msgs = [ + HumanMessagePromptTemplate.from_template(prompt_tmpl), + ] + prompt = ChatPromptTemplate(messages=prompt_msgs) + + chain = prompt | chat.bind(functions=_FUNCTIONS) + + message = HumanMessage(content="What's the weather like in Shanghai today?") + response = chain.batch([{"input": message}]) + assert isinstance(response[0], AIMessage) + assert "tool_calls" in response[0].additional_kwargs + def test_initialization() -> None: """Test chat model initialization.""" diff --git a/libs/community/tests/unit_tests/chat_models/test_sparkllm.py b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py new file mode 100644 index 00000000000..6d7e4cf6aa8 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_sparkllm.py @@ -0,0 +1,85 @@ +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.output_parsers.openai_tools import ( + parse_tool_call, +) + +from langchain_community.chat_models.sparkllm import ( + convert_dict_to_message, + convert_message_to_dict, +) + + +def test__convert_dict_to_message_human() -> None: + message_dict = {"role": "user", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message_dict = {"role": "assistant", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = AIMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_other_role() -> None: + message_dict = {"role": "system", "content": "foo"} + result = convert_dict_to_message(message_dict) + expected_output = SystemMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_function_call() -> None: + raw_function_calls = [ + { + "function": { + "name": "get_current_weather", + "arguments": '{"location": "Boston", "unit": "fahrenheit"}', + }, + "type": "function", + } + ] + message_dict = { + "role": "assistant", + "content": "foo", + "tool_calls": raw_function_calls, + } + result = convert_dict_to_message(message_dict) + + tool_calls = [ + parse_tool_call(raw_tool_call, return_id=True) + for raw_tool_call in raw_function_calls + ] + expected_output = AIMessage( + content="foo", + additional_kwargs={"tool_calls": raw_function_calls}, + tool_calls=tool_calls, + invalid_tool_calls=[], + ) + assert result == expected_output + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + result = convert_message_to_dict(message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output