diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py index 97c0fbd2781..05e8496c210 100644 --- a/libs/community/langchain_community/chat_models/baichuan.py +++ b/libs/community/langchain_community/chat_models/baichuan.py @@ -1,13 +1,26 @@ import json import logging from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + Union, +) import requests from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -24,14 +37,27 @@ from langchain_core.messages import ( HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, +) +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, get_pydantic_field_names, ) +from langchain_core.utils.function_calling import convert_to_openai_tool + +from langchain_community.chat_models.llamacpp import ( + _lc_invalid_tool_call_to_openai_tool_call, + _lc_tool_call_to_openai_tool_call, +) logger = logging.getLogger(__name__) @@ -40,14 +66,33 @@ DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions" def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] + content = message.content if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} + message_dict = {"role": message.role, "content": content} elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} + message_dict = {"role": "user", "content": content} elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} + message_dict = {"role": "assistant", "content": content} + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + + elif message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_openai_tool_call(tc) + for tc in message.invalid_tool_calls + ] + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "tool_call_id": message.tool_call_id, + "content": content, + "name": message.name or message.additional_kwargs.get("name"), + } + elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} + message_dict = {"role": "system", "content": content} else: raise TypeError(f"Got unknown type {message}") @@ -56,14 +101,43 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] + content = _dict.get("content", "") if role == "user": - return HumanMessage(content=_dict["content"]) + return HumanMessage(content=content) elif role == "assistant": - return AIMessage(content=_dict.get("content", "") or "") + tool_calls = [] + invalid_tool_calls = [] + additional_kwargs = {} + + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, # type: ignore[arg-type] + invalid_tool_calls=invalid_tool_calls, + ) + elif role == "tool": + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ToolMessage( + content=content, + tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type] + additional_kwargs=additional_kwargs, + ) elif role == "system": - return SystemMessage(content=_dict.get("content", "")) + return SystemMessage(content=content) else: - return ChatMessage(content=_dict["content"], role=role) + return ChatMessage(content=content, role=role) def _convert_delta_to_message_chunk( @@ -226,6 +300,24 @@ class ChatBaichuan(BaseChatModel): }, id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0' ) + Tool calling: + + .. code-block:: python + class get_current_weather(BaseModel): + '''Get current weather.''' + + location: str = Field('City or province, such as Shanghai') + + + llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather]) + llm_with_tools.invoke('How is the weather today?') + + .. code-block:: python + + [{'name': 'get_current_weather', + 'args': {'location': 'New York'}, + 'id': '3951017OF8doB0A', + 'type': 'tool_call'}] Response metadata .. code-block:: python @@ -486,6 +578,7 @@ class ChatBaichuan(BaseChatModel): model = parameters.pop("model") with_search_enhance = parameters.pop("with_search_enhance", False) stream = parameters.pop("stream", False) + tools = parameters.pop("tools", []) payload = { "model": model, @@ -495,7 +588,9 @@ class ChatBaichuan(BaseChatModel): "temperature": temperature, "with_search_enhance": with_search_enhance, "stream": stream, + "tools": tools, } + return payload def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def] @@ -526,3 +621,23 @@ class ChatBaichuan(BaseChatModel): @property def _llm_type(self) -> str: return "baichuan-chat" + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. + Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/tests/unit_tests/chat_models/test_baichuan.py b/libs/community/tests/unit_tests/chat_models/test_baichuan.py index 4dd44dd1811..31d588274a7 100644 --- a/libs/community/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/unit_tests/chat_models/test_baichuan.py @@ -8,6 +8,7 @@ from langchain_core.messages import ( HumanMessage, HumanMessageChunk, SystemMessage, + ToolMessage, ) from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -58,6 +59,18 @@ def test__convert_message_to_dict_system() -> None: assert result == expected_output +def test__convert_message_to_dict_tool() -> None: + message = ToolMessage(name="foo", content="bar", tool_call_id="abc123") + result = _convert_message_to_dict(message) + expected_output = { + "name": "foo", + "content": "bar", + "tool_call_id": "abc123", + "role": "tool", + } + assert result == expected_output + + def test__convert_message_to_dict_function() -> None: message = FunctionMessage(name="foo", content="bar") with pytest.raises(TypeError) as e: