Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
d8ae2ec1d8 warning 2024-12-06 17:00:09 -08:00
Bagatur
01ebdde0e8 openai[patch]: always filter disabled params 2024-12-06 16:57:49 -08:00

View File

@@ -700,11 +700,12 @@ class BaseChatOpenAI(BaseChatModel):
if stop is not None:
kwargs["stop"] = stop
return {
payload = {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}
return self._filter_disabled_params(**payload)
def _create_chat_result(
self,
@@ -1436,7 +1437,10 @@ class BaseChatOpenAI(BaseChatModel):
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
bind_kwargs = self._filter_disabled_params(
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
tool_choice=tool_name,
parallel_tool_calls=False,
strict=strict,
_raise_warning=False,
)
llm = self.bind_tools([schema], **bind_kwargs)
@@ -1488,7 +1492,9 @@ class BaseChatOpenAI(BaseChatModel):
else:
return llm | output_parser
def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
def _filter_disabled_params(
self, *, _raise_warning: bool = True, **kwargs: Any
) -> Dict[str, Any]:
if not self.disabled_params:
return kwargs
filtered = {}
@@ -1497,6 +1503,9 @@ class BaseChatOpenAI(BaseChatModel):
if k in self.disabled_params and (
self.disabled_params[k] is None or v in self.disabled_params[k]
):
if _raise_warning:
msg = f"Parameter {k}: {v} is disabled and being ignored."
logger.warning(msg)
continue
# Keep param
else: