This commit is contained in:
Chester Curme 2025-02-26 14:14:31 -05:00
parent 9cd20080fc
commit 6ae43bec21
2 changed files with 90 additions and 40 deletions

View File

@ -605,6 +605,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
params = {**params, **kwargs}
return str(sorted(params.items()))
def _format_params_for_provider(self, params: dict) -> dict:
return params
def generate(
self,
messages: list[list[BaseMessage]],
@ -641,6 +644,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
kwargs = self._format_params_for_provider(**kwargs)
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:

View File

@ -1024,6 +1024,90 @@ class BaseChatOpenAI(BaseChatModel):
encoding = tiktoken.get_encoding(model)
return model, encoding
def _format_tool_params(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None,
**kwargs: Any,
) -> dict:
if parallel_tool_calls is not None: # Keep parameter explicit in signature
kwargs["parallel_tool_calls"] = parallel_tool_calls
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
kwargs["tools"] = formatted_tools
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool):
tool_choice = "required"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return kwargs
def _format_params_for_provider(
self,
*,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]] = [],
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None,
response_format: Optional[_DictOrPydanticClass] = None,
**kwargs: Any,
) -> dict:
if tools:
kwargs = self._format_tool_params(
tools=tools,
tool_choice=tool_choice,
strict=strict,
parallel_tool_calls=parallel_tool_calls,
**kwargs,
)
if response_format:
kwargs["response_format"] = _convert_to_openai_response_format(
response_format, strict=strict,
)
kwargs["structured_output_format"] = {
"kwargs": {"method": "json_schema"},
"schema": convert_to_openai_tool(response_format),
}
return kwargs
def get_token_ids(self, text: str) -> List[int]:
"""Get the tokens present in the text with tiktoken package."""
if self.custom_get_token_ids is not None:
@ -1229,46 +1313,7 @@ class BaseChatOpenAI(BaseChatModel):
Support for ``strict`` argument added.
""" # noqa: E501
if parallel_tool_calls is not None:
kwargs["parallel_tool_calls"] = parallel_tool_calls
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "any", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, bool):
tool_choice = "required"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
return super().bind(**kwargs)
def with_structured_output(
self,