Compare commits

...

4 Commits

Author SHA1 Message Date
Sydney Runkle
f2ef21c1a4 boom 2026-01-16 10:33:20 -05:00
Sydney Runkle
bab649f124 fixing 2026-01-16 10:27:13 -05:00
Sydney Runkle
2b14c85d2b lint 2026-01-16 10:18:05 -05:00
Sydney Runkle
cb7f9c9ac2 initial fix 2026-01-16 10:15:40 -05:00
2 changed files with 91 additions and 94 deletions

View File

@@ -688,8 +688,15 @@ class ChildTool(BaseTool):
if isinstance(input_args, dict):
return tool_input
if issubclass(input_args, BaseModel):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
# Identify injected arg keys to exclude from model_dump.
# Injected args don't need to be validated and we don't assume
# they're serializable, so we exclude them.
annotations = get_all_basemodel_annotations(input_args)
injected_keys: set[str] = set()
for k, v in annotations.items():
if _is_injected_arg_type(v):
injected_keys.add(k)
# Check for InjectedToolCallId specifically
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
msg = (
@@ -702,10 +709,17 @@ class ChildTool(BaseTool):
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
result_dict = result.model_dump(exclude=injected_keys)
elif issubclass(input_args, BaseModelV1):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
# Identify injected arg keys to exclude from dict().
# Injected args don't need to be validated and we don't assume
# they're serializable, so we exclude them.
annotations = get_all_basemodel_annotations(input_args)
injected_keys = set()
for k, v in annotations.items():
if _is_injected_arg_type(v):
injected_keys.add(k)
# Check for InjectedToolCallId specifically
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
msg = (
@@ -718,7 +732,7 @@ class ChildTool(BaseTool):
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
result_dict = result.dict(exclude=injected_keys)
else:
msg = (
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
@@ -748,7 +762,10 @@ class ChildTool(BaseTool):
if has_default:
validated_input[k] = getattr(result, k)
for k in self._injected_args_keys:
# Add injected args from both function signature and schema.
# These were excluded from model_dump() to avoid serialization warnings.
all_injected_keys = injected_keys | self._injected_args_keys
for k in all_injected_keys:
if k in tool_input:
validated_input[k] = tool_input[k]
elif k == "tool_call_id":

View File

@@ -5,6 +5,7 @@ import json
import sys
import textwrap
import threading
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
@@ -3184,98 +3185,28 @@ def test_filter_tool_runtime_directly_injected_arg() -> None:
assert "runtime" not in captured
# Custom directly injected arg type (similar to ToolRuntime)
class _CustomRuntime(_DirectlyInjectedToolArg):
"""Custom runtime info injected at tool call time."""
def test_filter_injected_args_from_callback_inputs() -> None:
"""Test that injected args from function signature are filtered from callbacks."""
def __init__(self, data: dict[str, Any]) -> None:
self.data = data
@tool
def my_tool(
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
) -> str:
"""Tool with injected runtime arg.
Args:
query: The search query.
limit: Max results.
runtime: Injected runtime info.
"""
return f"Query: {query}, Limit: {limit}"
# Schema that does NOT include the injected arg
class _ToolArgsSchemaNoRuntime(BaseModel):
"""Schema with only the non-injected args."""
query: str
limit: int
def _tool_func_directly_injected(
query: str, limit: int, runtime: _CustomRuntime
) -> str:
"""Tool with directly injected runtime not in schema.
Args:
query: The search query.
limit: Max results.
runtime: Custom runtime (directly injected, not in schema).
"""
return f"Query: {query}, Limit: {limit}"
def _tool_func_annotated_injected(
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
) -> str:
"""Tool with Annotated injected runtime not in schema.
Args:
query: The search query.
limit: Max results.
runtime: Custom runtime (annotated as injected, not in schema).
"""
return f"Query: {query}, Limit: {limit}"
@pytest.mark.parametrize(
("tool_func", "runtime_value", "description"),
[
pytest.param(
_tool_func_directly_injected,
_CustomRuntime(data={"foo": "bar"}),
"directly injected (_DirectlyInjectedToolArg subclass)",
id="directly_injected",
),
pytest.param(
_tool_func_annotated_injected,
{"foo": "bar"},
"annotated injected (Annotated[Any, InjectedToolArg()])",
id="annotated_injected",
),
],
)
def test_filter_injected_args_not_in_schema(
tool_func: Callable[..., str], runtime_value: Any, description: str
) -> None:
"""Test filtering injected args that are in function signature but not in schema.
This tests the case where an injected argument (like ToolRuntime) is in the
function signature but is not present in the args_schema. The fix ensures
we check _injected_args_keys from the function signature, not just the schema.
Args:
tool_func: The tool function with an injected arg.
runtime_value: The value to pass for the runtime arg.
description: Description of the injection style being tested.
"""
# Create StructuredTool with explicit args_schema that excludes runtime
custom_tool = StructuredTool.from_function(
func=tool_func,
name="custom_tool",
description=f"Tool with {description} arg not in schema",
args_schema=_ToolArgsSchemaNoRuntime,
)
# Verify _injected_args_keys contains 'runtime'
assert "runtime" in custom_tool._injected_args_keys
assert "runtime" in my_tool._injected_args_keys
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = custom_tool.invoke(
{
"query": "test",
"limit": 5,
"runtime": runtime_value,
},
result = my_tool.invoke(
{"query": "test", "limit": 5, "runtime": {"foo": "bar"}},
config={"callbacks": [handler]},
)
@@ -3283,7 +3214,7 @@ def test_filter_injected_args_not_in_schema(
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that runtime is filtered out even though it's not in args_schema
# Verify that runtime is filtered from callback inputs
captured = handler.captured_inputs[0]
assert captured is not None
assert captured == {"query": "test", "limit": 5}
@@ -3610,3 +3541,52 @@ def test_tool_args_schema_falsy_defaults() -> None:
# Invoke with only required argument - falsy defaults should be applied
result = config_tool.invoke({"name": "test"})
assert result == "name=test, enabled=False, count=0, prefix=''"
def test_injected_args_no_pydantic_warnings_during_parse_input() -> None:
"""Test that no Pydantic warnings are triggered for injected args.
When injected args are in the schema but have type mismatches with the actual
values, model_dump() should not serialize them to avoid Pydantic warnings.
Regression test for https://github.com/langchain-ai/langchain/issues/34770
"""
class InjectedContext:
"""A context object that will be injected."""
def __init__(self, data: dict[str, Any]) -> None:
self.data = data
class ToolArgsWithInjected(BaseModel):
"""Tool args schema with an injected arg."""
model_config = ConfigDict(arbitrary_types_allowed=True)
query: str = Field(..., description="The search query")
# Annotated with a specific type, but actual value may differ
context: Annotated[InjectedContext, InjectedToolArg()] = Field(
..., description="Injected context"
)
@tool("search_with_context", args_schema=ToolArgsWithInjected)
def search_with_context(query: str, context: InjectedContext) -> str:
"""Search with injected context.
Args:
query: The search query.
context: The injected context.
"""
return f"Query: {query}, Context: {context.data}"
# Create an injected context
injected_context = InjectedContext(data={"key": "value"})
# Invoke the tool - should not trigger any Pydantic warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
result = search_with_context.invoke(
{"query": "test", "context": injected_context}
)
assert result == "Query: test, Context: {'key': 'value'}"