community[minor]: add tool calling for DeepInfraChat (#22745)

DeepInfra now supports tool calling for supported models.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Oguz Vuruskaner
2024-06-17 12:21:49 -07:00
committed by GitHub
parent 158701ab3c
commit dd25d08c06
3 changed files with 211 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
@@ -24,6 +25,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@@ -44,15 +46,18 @@ from langchain_core.messages import (
SystemMessage,
SystemMessageChunk,
)
from langchain_core.messages.tool import ToolCall
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
# from langchain.llms.base import create_base_retry_decorator
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
@@ -78,19 +83,51 @@ def _create_retry_decorator(
)
def _parse_tool_calling(tool_call: dict) -> ToolCall:
"""
Convert a tool calling response from server to a ToolCall object.
Args:
tool_call:
Returns:
"""
name = tool_call.get("name", "")
args = json.loads(tool_call["function"]["arguments"])
id = tool_call.get("id")
return ToolCall(name=name, args=args, id=id)
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
"""
Convert a ToolCall object to a tool calling request for server.
Args:
tool_call:
Returns:
"""
return {
"type": "function",
"function": {
"arguments": json.dumps(tool_call["args"]),
"name": tool_call["name"],
},
"id": tool_call.get("id"),
}
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
tool_calls_content = _dict.get("tool_calls", []) or []
tool_calls = [
_parse_tool_calling(tool_call) for tool_call in tool_calls_content
]
return AIMessage(content=content, tool_calls=tool_calls)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
@@ -104,15 +141,14 @@ def _convert_delta_to_message_chunk(
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
tool_calls = [
_parse_tool_calling(tool_call) for tool_call in _dict.get("tool_calls", [])
]
return AIMessageChunk(content=content, tool_calls=tool_calls)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
@@ -129,9 +165,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
tool_calls = [
_convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
]
message_dict = {
"role": "assistant",
"content": message.content,
"tool_calls": tool_calls, # type: ignore[dict-item]
}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
@@ -417,6 +458,27 @@ class ChatDeepInfra(BaseChatModel):
def _body(self, kwargs: Any) -> Dict:
return kwargs
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody: