Adding bind_tools in ChatOctoAI (#26168)

The object extends from
langchain_community.chat_models.openai.ChatOpenAI which doesn't have
`bind_tools` defined. I tried extending from
`langchain_openai.ChatOpenAI` in
https://github.com/langchain-ai/langchain/pull/25975 but that PR got
closed because this is not correct.
So adding our own `bind_tools` (which for now copying from ChatOpenAI is
good enough) will solve the tool calling issue we are having now.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Sebastian Cherny 2024-09-08 20:38:43 +02:00 committed by GitHub
parent 042e84170b
commit b3c7ed4913
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,23 @@
"""OctoAI Endpoints chat wrapper. Relies heavily on ChatOpenAI.""" """OctoAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from typing import Dict from typing import (
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import Field, SecretStr from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.utils.openai import is_openai_v1 from langchain_community.utils.openai import is_openai_v1
@ -92,3 +106,53 @@ class ChatOctoAI(ChatOpenAI):
) )
return values return values
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Imitating bind_tool method from langchain_openai.ChatOpenAI"""
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool):
tool_choice = "required"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)