mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 10:31:23 +00:00
fix(core): ensure InjectedToolCallId
always overrides LLM-generated values (#32766)
This commit is contained in:
@@ -659,10 +659,7 @@ class ChildTool(BaseTool):
|
|||||||
return tool_input
|
return tool_input
|
||||||
if issubclass(input_args, BaseModel):
|
if issubclass(input_args, BaseModel):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
|
||||||
and k not in tool_input
|
|
||||||
):
|
|
||||||
if tool_call_id is None:
|
if tool_call_id is None:
|
||||||
msg = (
|
msg = (
|
||||||
"When tool includes an InjectedToolCallId "
|
"When tool includes an InjectedToolCallId "
|
||||||
@@ -677,10 +674,7 @@ class ChildTool(BaseTool):
|
|||||||
result_dict = result.model_dump()
|
result_dict = result.model_dump()
|
||||||
elif issubclass(input_args, BaseModelV1):
|
elif issubclass(input_args, BaseModelV1):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
|
||||||
and k not in tool_input
|
|
||||||
):
|
|
||||||
if tool_call_id is None:
|
if tool_call_id is None:
|
||||||
msg = (
|
msg = (
|
||||||
"When tool includes an InjectedToolCallId "
|
"When tool includes an InjectedToolCallId "
|
||||||
|
@@ -2349,6 +2349,28 @@ def test_tool_injected_tool_call_id() -> None:
|
|||||||
) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type]
|
) == ToolMessage(0, tool_call_id="bar") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_injected_tool_call_id_override_llm_generated() -> None:
|
||||||
|
"""Test that InjectedToolCallId overrides LLM-generated values."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||||
|
"""Foo."""
|
||||||
|
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Test that when LLM generates the tool_call_id, it gets overridden
|
||||||
|
result = foo.invoke(
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"args": {"x": 0, "tool_call_id": "fake_llm_id"}, # LLM generated this
|
||||||
|
"name": "foo",
|
||||||
|
"id": "real_tool_call_id", # This should be used instead
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# The tool should receive the real tool call ID, not the LLM-generated one
|
||||||
|
assert result == ToolMessage(0, tool_call_id="real_tool_call_id") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def test_tool_uninjected_tool_call_id() -> None:
|
def test_tool_uninjected_tool_call_id() -> None:
|
||||||
@tool
|
@tool
|
||||||
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
def foo(x: int, tool_call_id: str) -> ToolMessage:
|
||||||
|
Reference in New Issue
Block a user