mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
community[minor]: add tool calling for DeepInfraChat (#22745)
DeepInfra now supports tool calling for supported models. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@@ -24,6 +25,7 @@ from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@@ -44,15 +46,18 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
# from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,19 +83,51 @@ def _create_retry_decorator(
|
||||
)
|
||||
|
||||
|
||||
def _parse_tool_calling(tool_call: dict) -> ToolCall:
|
||||
"""
|
||||
Convert a tool calling response from server to a ToolCall object.
|
||||
Args:
|
||||
tool_call:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
name = tool_call.get("name", "")
|
||||
args = json.loads(tool_call["function"]["arguments"])
|
||||
id = tool_call.get("id")
|
||||
return ToolCall(name=name, args=args, id=id)
|
||||
|
||||
|
||||
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a ToolCall object to a tool calling request for server.
|
||||
Args:
|
||||
tool_call:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
"name": tool_call["name"],
|
||||
},
|
||||
"id": tool_call.get("id"),
|
||||
}
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
tool_calls_content = _dict.get("tool_calls", []) or []
|
||||
tool_calls = [
|
||||
_parse_tool_calling(tool_call) for tool_call in tool_calls_content
|
||||
]
|
||||
return AIMessage(content=content, tool_calls=tool_calls)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
@@ -104,15 +141,14 @@ def _convert_delta_to_message_chunk(
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
tool_calls = [
|
||||
_parse_tool_calling(tool_call) for tool_call in _dict.get("tool_calls", [])
|
||||
]
|
||||
return AIMessageChunk(content=content, tool_calls=tool_calls)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
@@ -129,9 +165,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
tool_calls = [
|
||||
_convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
|
||||
]
|
||||
message_dict = {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": tool_calls, # type: ignore[dict-item]
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
@@ -417,6 +458,27 @@ class ChatDeepInfra(BaseChatModel):
|
||||
def _body(self, kwargs: Any) -> Dict:
|
||||
return kwargs
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
|
||||
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||
for line in rbody:
|
||||
|
Reference in New Issue
Block a user