propagate to tool calls

This commit is contained in:
Chester Curme 2025-03-17 14:12:31 -04:00
parent 192035f8c0
commit f333720aad

View File

@ -490,6 +490,15 @@ class BaseChatOpenAI(BaseChatModel):
.. versionadded:: 0.2.14
"""
reasoning_summary: Optional[str] = None
"""Generate a summary of the reasoning performed by the model (Responses API).
Reasoning models only, like OpenAI o1 and o3-mini.
Currently supported values are ``"concise"`` or ``"detailed"``.
.. versionadded:: 0.3.9
"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
@ -541,6 +550,13 @@ class BaseChatOpenAI(BaseChatModel):
If not specified then will be inferred based on invocation params.
.. versionadded:: 0.3.9
"""
truncation: Optional[str] = None
"""Truncation strategy (Responses API). Can be ``"auto"`` or ``"disabled"``
(default). If ``"auto"``, model may drop input items from the middle of the
message sequence to fit the context window.
.. versionadded:: 0.3.9
"""
@ -652,6 +668,8 @@ class BaseChatOpenAI(BaseChatModel):
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
"reasoning_summary": self.reasoning_summary,
"truncation": self.truncation,
}
params = {
@ -2877,8 +2895,13 @@ def _construct_responses_api_payload(
for legacy_token_param in ["max_tokens", "max_completion_tokens"]:
if legacy_token_param in payload:
payload["max_output_tokens"] = payload.pop(legacy_token_param)
if "reasoning_effort" in payload:
payload["reasoning"] = {"effort": payload.pop("reasoning_effort")}
if "reasoning_effort" in payload or "reasoning_summary" in payload:
reasoning = {}
if "reasoning_effort" in payload:
reasoning["effort"] = payload.pop("reasoning_effort")
if "reasoning_summary" in payload:
reasoning["generate_summary"] = payload.pop("reasoning_summary")
payload["reasoning"] = reasoning
payload["input"] = _construct_responses_api_input(messages)
if tools := payload.pop("tools", None):
@ -3127,6 +3150,17 @@ def _construct_lc_result_from_responses_api(
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
elif output.type == "computer_call":
tool_call = {
"type": "tool_call",
"name": output.type,
"args": output.action.model_dump(),
"id": output.call_id,
}
tool_calls.append(tool_call)
if _FUNCTION_CALL_IDS_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][output.call_id] = output.id
elif output.type == "reasoning":
additional_kwargs["reasoning"] = output.model_dump(
exclude_none=True, mode="json"