Compare commits

...

5 Commits

Author SHA1 Message Date
Sydney Runkle
4657c06606 cleanup 2026-01-16 11:33:48 -05:00
Sydney Runkle
f2ef21c1a4 boom 2026-01-16 10:33:20 -05:00
Sydney Runkle
bab649f124 fixing 2026-01-16 10:27:13 -05:00
Sydney Runkle
2b14c85d2b lint 2026-01-16 10:18:05 -05:00
Sydney Runkle
cb7f9c9ac2 initial fix 2026-01-16 10:15:40 -05:00
3 changed files with 181 additions and 189 deletions

View File

@@ -601,10 +601,34 @@ class ChildTool(BaseTool):
self.name, full_schema, fields, fn_description=self.description
)
@functools.cached_property
def _injected_arg_info(self) -> tuple[frozenset[str], str | None]:
"""Get injected argument info from the args_schema annotations.
Returns:
Tuple of (all_injected_keys, tool_call_id_key).
"""
if self.args_schema is None or isinstance(self.args_schema, dict):
return _EMPTY_SET, None
annotations = get_all_basemodel_annotations(self.args_schema)
injected: set[str] = set()
tool_call_id_key: str | None = None
for k, v in annotations.items():
if _is_injected_arg_type(v):
injected.add(k)
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
tool_call_id_key = k
return frozenset(injected), tool_call_id_key
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
# base implementation doesn't manage injected args
return _EMPTY_SET
"""Get injected argument keys from the args_schema annotations."""
return self._injected_arg_info[0]
@functools.cached_property
def _injected_tool_call_id_key(self) -> str | None:
"""Get the key for InjectedToolCallId argument, if any."""
return self._injected_arg_info[1]
# --- Runnable ---
@@ -666,106 +690,85 @@ class ChildTool(BaseTool):
"""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
if isinstance(input_args, dict):
msg = (
"String tool inputs are not allowed when "
"using tools with JSON schema args_schema."
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if issubclass(input_args, BaseModel):
input_args.model_validate({key_: tool_input})
elif issubclass(input_args, BaseModelV1):
input_args.parse_obj({key_: tool_input})
else:
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
raise TypeError(msg)
# No schema - return input as-is
if input_args is None:
return tool_input
if input_args is not None:
if isinstance(input_args, dict):
return tool_input
if issubclass(input_args, BaseModel):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
# JSON schema dict - string input not allowed, dict passes through
if isinstance(input_args, dict):
if isinstance(tool_input, str):
msg = (
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
"String tool inputs are not allowed when "
"using tools with JSON schema args_schema."
)
raise NotImplementedError(msg)
raise ValueError(msg) # noqa: TRY004
return tool_input
# Include fields from tool_input, plus fields with explicit defaults.
# This applies Pydantic defaults (like Field(default=1)) while excluding
# synthetic "args"/"kwargs" fields that Pydantic creates for *args/**kwargs.
field_info = get_fields(input_args)
validated_input = {}
for k in result_dict:
if k in tool_input:
# Field was provided in input - include it (validated)
# Must be a Pydantic BaseModel at this point
if not issubclass(input_args, (BaseModel, BaseModelV1)):
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
raise TypeError(msg)
is_v2 = issubclass(input_args, BaseModel)
# String input - validate and return as-is
if isinstance(tool_input, str):
key_ = next(iter(get_fields(input_args).keys()))
if is_v2:
input_args.model_validate({key_: tool_input})
else:
cast("type[BaseModelV1]", input_args).parse_obj({key_: tool_input})
return tool_input
# Dict input - full validation flow
injected_keys = self._injected_args_keys
# Inject tool_call_id for InjectedToolCallId field
if (tc_key := self._injected_tool_call_id_key) is not None:
if tool_call_id is not None:
# Real tool_call_id from ToolCall overrides any LLM-generated value
tool_input[tc_key] = tool_call_id
elif tc_key not in tool_input:
msg = (
"When tool includes an InjectedToolCallId argument, tool "
"must always be invoked with a full model ToolCall of the "
"form: {'args': {...}, 'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
# Validate and dump to dict (excluding injected keys to avoid warnings)
result: BaseModel | BaseModelV1
if is_v2:
result = input_args.model_validate(tool_input)
result_dict = result.model_dump(exclude=set(injected_keys))
else:
result = cast("type[BaseModelV1]", input_args).parse_obj(tool_input)
result_dict = result.dict(exclude=injected_keys)
# Build validated_input: include provided fields + fields with defaults
field_info = get_fields(input_args)
validated_input: dict[str, Any] = {}
for k in result_dict:
if k in tool_input:
validated_input[k] = getattr(result, k)
elif k in field_info and k not in ("args", "kwargs"):
# Include fields with explicit defaults (not synthetic *args/**kwargs)
fi = field_info[k]
has_default = (
not fi.is_required()
if hasattr(fi, "is_required")
else not getattr(fi, "required", True)
)
if has_default:
validated_input[k] = getattr(result, k)
elif k in field_info and k not in ("args", "kwargs"):
# Check if field has an explicit default defined in the schema.
# Exclude "args"/"kwargs" as these are synthetic fields for variadic
# parameters that should not be passed as keyword arguments.
fi = field_info[k]
# Pydantic v2 uses is_required() method, v1 uses required attribute
has_default = (
not fi.is_required()
if hasattr(fi, "is_required")
else not getattr(fi, "required", True)
)
if has_default:
validated_input[k] = getattr(result, k)
for k in self._injected_args_keys:
if k in tool_input:
validated_input[k] = tool_input[k]
elif k == "tool_call_id":
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
validated_input[k] = tool_call_id
# Add back injected args (were excluded from model_dump)
for k in injected_keys:
if k in tool_input:
validated_input[k] = tool_input[k]
return validated_input
return tool_input
return validated_input
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> Any:

View File

@@ -23,10 +23,10 @@ from langchain_core.callbacks import (
)
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
_EMPTY_SET,
FILTERED_ARGS,
ArgsSchema,
BaseTool,
InjectedToolCallId,
_get_runnable_config_param,
_is_injected_arg_type,
create_schema_from_function,
@@ -246,15 +246,24 @@ class StructuredTool(BaseTool):
)
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
def _injected_arg_info(self) -> tuple[frozenset[str], str | None]:
# Combine injected args from schema (via super) and function signature
schema_keys, schema_tc_key = super()._injected_arg_info
fn = self.func or self.coroutine
if fn is None:
return _EMPTY_SET
return frozenset(
k
for k, v in signature(fn).parameters.items()
if _is_injected_arg_type(v.annotation)
)
return schema_keys, schema_tc_key
func_keys: set[str] = set()
func_tc_key: str | None = None
for k, v in signature(fn).parameters.items():
if _is_injected_arg_type(v.annotation):
func_keys.add(k)
if _is_injected_arg_type(
v.annotation, injected_type=InjectedToolCallId
):
func_tc_key = k
return schema_keys | func_keys, schema_tc_key or func_tc_key
def _filter_schema_args(func: Callable) -> list[str]:

View File

@@ -5,6 +5,7 @@ import json
import sys
import textwrap
import threading
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
@@ -3184,98 +3185,28 @@ def test_filter_tool_runtime_directly_injected_arg() -> None:
assert "runtime" not in captured
# Custom directly injected arg type (similar to ToolRuntime)
class _CustomRuntime(_DirectlyInjectedToolArg):
"""Custom runtime info injected at tool call time."""
def test_filter_injected_args_from_callback_inputs() -> None:
"""Test that injected args from function signature are filtered from callbacks."""
def __init__(self, data: dict[str, Any]) -> None:
self.data = data
@tool
def my_tool(
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
) -> str:
"""Tool with injected runtime arg.
Args:
query: The search query.
limit: Max results.
runtime: Injected runtime info.
"""
return f"Query: {query}, Limit: {limit}"
# Schema that does NOT include the injected arg
class _ToolArgsSchemaNoRuntime(BaseModel):
"""Schema with only the non-injected args."""
query: str
limit: int
def _tool_func_directly_injected(
query: str, limit: int, runtime: _CustomRuntime
) -> str:
"""Tool with directly injected runtime not in schema.
Args:
query: The search query.
limit: Max results.
runtime: Custom runtime (directly injected, not in schema).
"""
return f"Query: {query}, Limit: {limit}"
def _tool_func_annotated_injected(
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
) -> str:
"""Tool with Annotated injected runtime not in schema.
Args:
query: The search query.
limit: Max results.
runtime: Custom runtime (annotated as injected, not in schema).
"""
return f"Query: {query}, Limit: {limit}"
@pytest.mark.parametrize(
("tool_func", "runtime_value", "description"),
[
pytest.param(
_tool_func_directly_injected,
_CustomRuntime(data={"foo": "bar"}),
"directly injected (_DirectlyInjectedToolArg subclass)",
id="directly_injected",
),
pytest.param(
_tool_func_annotated_injected,
{"foo": "bar"},
"annotated injected (Annotated[Any, InjectedToolArg()])",
id="annotated_injected",
),
],
)
def test_filter_injected_args_not_in_schema(
tool_func: Callable[..., str], runtime_value: Any, description: str
) -> None:
"""Test filtering injected args that are in function signature but not in schema.
This tests the case where an injected argument (like ToolRuntime) is in the
function signature but is not present in the args_schema. The fix ensures
we check _injected_args_keys from the function signature, not just the schema.
Args:
tool_func: The tool function with an injected arg.
runtime_value: The value to pass for the runtime arg.
description: Description of the injection style being tested.
"""
# Create StructuredTool with explicit args_schema that excludes runtime
custom_tool = StructuredTool.from_function(
func=tool_func,
name="custom_tool",
description=f"Tool with {description} arg not in schema",
args_schema=_ToolArgsSchemaNoRuntime,
)
# Verify _injected_args_keys contains 'runtime'
assert "runtime" in custom_tool._injected_args_keys
assert "runtime" in my_tool._injected_args_keys
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = custom_tool.invoke(
{
"query": "test",
"limit": 5,
"runtime": runtime_value,
},
result = my_tool.invoke(
{"query": "test", "limit": 5, "runtime": {"foo": "bar"}},
config={"callbacks": [handler]},
)
@@ -3283,7 +3214,7 @@ def test_filter_injected_args_not_in_schema(
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that runtime is filtered out even though it's not in args_schema
# Verify that runtime is filtered from callback inputs
captured = handler.captured_inputs[0]
assert captured is not None
assert captured == {"query": "test", "limit": 5}
@@ -3610,3 +3541,52 @@ def test_tool_args_schema_falsy_defaults() -> None:
# Invoke with only required argument - falsy defaults should be applied
result = config_tool.invoke({"name": "test"})
assert result == "name=test, enabled=False, count=0, prefix=''"
def test_injected_args_no_pydantic_warnings_during_parse_input() -> None:
"""Test that no Pydantic warnings are triggered for injected args.
When injected args are in the schema but have type mismatches with the actual
values, model_dump() should not serialize them to avoid Pydantic warnings.
Regression test for https://github.com/langchain-ai/langchain/issues/34770
"""
class InjectedContext:
"""A context object that will be injected."""
def __init__(self, data: dict[str, Any]) -> None:
self.data = data
class ToolArgsWithInjected(BaseModel):
"""Tool args schema with an injected arg."""
model_config = ConfigDict(arbitrary_types_allowed=True)
query: str = Field(..., description="The search query")
# Annotated with a specific type, but actual value may differ
context: Annotated[InjectedContext, InjectedToolArg()] = Field(
..., description="Injected context"
)
@tool("search_with_context", args_schema=ToolArgsWithInjected)
def search_with_context(query: str, context: InjectedContext) -> str:
"""Search with injected context.
Args:
query: The search query.
context: The injected context.
"""
return f"Query: {query}, Context: {context.data}"
# Create an injected context
injected_context = InjectedContext(data={"key": "value"})
# Invoke the tool - should not trigger any Pydantic warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
result = search_with_context.invoke(
{"query": "test", "context": injected_context}
)
assert result == "Query: test, Context: {'key': 'value'}"