mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
community[patch]: Extend Baichuan model with tool support (#24529)
**Description:** Expanded the chat model functionality to support tools in the 'baichuan.py' file. Updated module imports and added tool object handling in message conversions. Additional changes include the implementation of tool binding and related unit tests. The alterations offer enhanced model capabilities by enabling interaction with tool-like objects. --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
ee399e3ec5
commit
4f3b4fc7fe
@ -1,13 +1,26 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@ -24,14 +37,27 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_community.chat_models.llamacpp import (
|
||||
_lc_invalid_tool_call_to_openai_tool_call,
|
||||
_lc_tool_call_to_openai_tool_call,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,14 +66,33 @@ DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
content = message.content
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
message_dict = {"role": message.role, "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
message_dict = {"role": "assistant", "content": content}
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
|
||||
elif message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
] + [
|
||||
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"content": content,
|
||||
"name": message.name or message.additional_kwargs.get("name"),
|
||||
}
|
||||
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
message_dict = {"role": "system", "content": content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
@ -56,14 +101,43 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict.get("content", "")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("content", "") or "")
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
additional_kwargs = {}
|
||||
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
@ -226,6 +300,24 @@ class ChatBaichuan(BaseChatModel):
|
||||
},
|
||||
id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0'
|
||||
)
|
||||
Tool calling:
|
||||
|
||||
.. code-block:: python
|
||||
class get_current_weather(BaseModel):
|
||||
'''Get current weather.'''
|
||||
|
||||
location: str = Field('City or province, such as Shanghai')
|
||||
|
||||
|
||||
llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather])
|
||||
llm_with_tools.invoke('How is the weather today?')
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[{'name': 'get_current_weather',
|
||||
'args': {'location': 'New York'},
|
||||
'id': '3951017OF8doB0A',
|
||||
'type': 'tool_call'}]
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
@ -486,6 +578,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
model = parameters.pop("model")
|
||||
with_search_enhance = parameters.pop("with_search_enhance", False)
|
||||
stream = parameters.pop("stream", False)
|
||||
tools = parameters.pop("tools", [])
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
@ -495,7 +588,9 @@ class ChatBaichuan(BaseChatModel):
|
||||
"temperature": temperature,
|
||||
"with_search_enhance": with_search_enhance,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
|
||||
@ -526,3 +621,23 @@ class ChatBaichuan(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "baichuan-chat"
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool.
|
||||
Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
@ -58,6 +59,18 @@ def test__convert_message_to_dict_system() -> None:
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_message_to_dict_tool() -> None:
|
||||
message = ToolMessage(name="foo", content="bar", tool_call_id="abc123")
|
||||
result = _convert_message_to_dict(message)
|
||||
expected_output = {
|
||||
"name": "foo",
|
||||
"content": "bar",
|
||||
"tool_call_id": "abc123",
|
||||
"role": "tool",
|
||||
}
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_message_to_dict_function() -> None:
|
||||
message = FunctionMessage(name="foo", content="bar")
|
||||
with pytest.raises(TypeError) as e:
|
||||
|
Loading…
Reference in New Issue
Block a user