community[patch]: Extend Baichuan model with tool support (#24529)

**Description:** Expanded the chat model functionality to support tools
in the 'baichuan.py' file. Updated module imports and added tool object
handling in message conversions. Additional changes include the
implementation of tool binding and related unit tests. The alterations
offer enhanced model capabilities by enabling interaction with tool-like
objects.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
nobbbbby 2024-07-26 14:20:44 +08:00 committed by GitHub
parent ee399e3ec5
commit 4f3b4fc7fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 10 deletions

View File

@ -1,13 +1,26 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@ -24,14 +37,27 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_community.chat_models.llamacpp import (
_lc_invalid_tool_call_to_openai_tool_call,
_lc_tool_call_to_openai_tool_call,
)
logger = logging.getLogger(__name__)
@ -40,14 +66,33 @@ DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
content = message.content
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict = {"role": message.role, "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict = {"role": "user", "content": content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict = {"role": "assistant", "content": content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"tool_call_id": message.tool_call_id,
"content": content,
"name": message.name or message.additional_kwargs.get("name"),
}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict = {"role": "system", "content": content}
else:
raise TypeError(f"Got unknown type {message}")
@ -56,14 +101,43 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict.get("content", "")
if role == "user":
return HumanMessage(content=_dict["content"])
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
tool_calls = []
invalid_tool_calls = []
additional_kwargs = {}
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls, # type: ignore[arg-type]
invalid_tool_calls=invalid_tool_calls,
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=content,
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
additional_kwargs=additional_kwargs,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
return SystemMessage(content=content)
else:
return ChatMessage(content=_dict["content"], role=role)
return ChatMessage(content=content, role=role)
def _convert_delta_to_message_chunk(
@ -226,6 +300,24 @@ class ChatBaichuan(BaseChatModel):
},
id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0'
)
Tool calling:
.. code-block:: python
class get_current_weather(BaseModel):
'''Get current weather.'''
location: str = Field('City or province, such as Shanghai')
llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather])
llm_with_tools.invoke('How is the weather today?')
.. code-block:: python
[{'name': 'get_current_weather',
'args': {'location': 'New York'},
'id': '3951017OF8doB0A',
'type': 'tool_call'}]
Response metadata
.. code-block:: python
@ -486,6 +578,7 @@ class ChatBaichuan(BaseChatModel):
model = parameters.pop("model")
with_search_enhance = parameters.pop("with_search_enhance", False)
stream = parameters.pop("stream", False)
tools = parameters.pop("tools", [])
payload = {
"model": model,
@ -495,7 +588,9 @@ class ChatBaichuan(BaseChatModel):
"temperature": temperature,
"with_search_enhance": with_search_enhance,
"stream": stream,
"tools": tools,
}
return payload
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
@ -526,3 +621,23 @@ class ChatBaichuan(BaseChatModel):
@property
def _llm_type(self) -> str:
return "baichuan-chat"
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool.
Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@ -8,6 +8,7 @@ from langchain_core.messages import (
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
@ -58,6 +59,18 @@ def test__convert_message_to_dict_system() -> None:
assert result == expected_output
def test__convert_message_to_dict_tool() -> None:
message = ToolMessage(name="foo", content="bar", tool_call_id="abc123")
result = _convert_message_to_dict(message)
expected_output = {
"name": "foo",
"content": "bar",
"tool_call_id": "abc123",
"role": "tool",
}
assert result == expected_output
def test__convert_message_to_dict_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(TypeError) as e: