mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-16 10:16:10 +00:00
Compare commits
5 Commits
langchain-
...
sr/clean-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4657c06606 | ||
|
|
f2ef21c1a4 | ||
|
|
bab649f124 | ||
|
|
2b14c85d2b | ||
|
|
cb7f9c9ac2 |
@@ -601,10 +601,34 @@ class ChildTool(BaseTool):
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_arg_info(self) -> tuple[frozenset[str], str | None]:
|
||||
"""Get injected argument info from the args_schema annotations.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_injected_keys, tool_call_id_key).
|
||||
"""
|
||||
if self.args_schema is None or isinstance(self.args_schema, dict):
|
||||
return _EMPTY_SET, None
|
||||
annotations = get_all_basemodel_annotations(self.args_schema)
|
||||
injected: set[str] = set()
|
||||
tool_call_id_key: str | None = None
|
||||
for k, v in annotations.items():
|
||||
if _is_injected_arg_type(v):
|
||||
injected.add(k)
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
tool_call_id_key = k
|
||||
return frozenset(injected), tool_call_id_key
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
# base implementation doesn't manage injected args
|
||||
return _EMPTY_SET
|
||||
"""Get injected argument keys from the args_schema annotations."""
|
||||
return self._injected_arg_info[0]
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_tool_call_id_key(self) -> str | None:
|
||||
"""Get the key for InjectedToolCallId argument, if any."""
|
||||
return self._injected_arg_info[1]
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@@ -666,106 +690,85 @@ class ChildTool(BaseTool):
|
||||
"""
|
||||
input_args = self.args_schema
|
||||
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
msg = (
|
||||
"String tool inputs are not allowed when "
|
||||
"using tools with JSON schema args_schema."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
if issubclass(input_args, BaseModel):
|
||||
input_args.model_validate({key_: tool_input})
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
input_args.parse_obj({key_: tool_input})
|
||||
else:
|
||||
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
|
||||
raise TypeError(msg)
|
||||
# No schema - return input as-is
|
||||
if input_args is None:
|
||||
return tool_input
|
||||
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
if issubclass(input_args, BaseModel):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
# JSON schema dict - string input not allowed, dict passes through
|
||||
if isinstance(input_args, dict):
|
||||
if isinstance(tool_input, str):
|
||||
msg = (
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
"String tool inputs are not allowed when "
|
||||
"using tools with JSON schema args_schema."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
return tool_input
|
||||
|
||||
# Include fields from tool_input, plus fields with explicit defaults.
|
||||
# This applies Pydantic defaults (like Field(default=1)) while excluding
|
||||
# synthetic "args"/"kwargs" fields that Pydantic creates for *args/**kwargs.
|
||||
field_info = get_fields(input_args)
|
||||
validated_input = {}
|
||||
for k in result_dict:
|
||||
if k in tool_input:
|
||||
# Field was provided in input - include it (validated)
|
||||
# Must be a Pydantic BaseModel at this point
|
||||
if not issubclass(input_args, (BaseModel, BaseModelV1)):
|
||||
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
|
||||
raise TypeError(msg)
|
||||
|
||||
is_v2 = issubclass(input_args, BaseModel)
|
||||
|
||||
# String input - validate and return as-is
|
||||
if isinstance(tool_input, str):
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
if is_v2:
|
||||
input_args.model_validate({key_: tool_input})
|
||||
else:
|
||||
cast("type[BaseModelV1]", input_args).parse_obj({key_: tool_input})
|
||||
return tool_input
|
||||
|
||||
# Dict input - full validation flow
|
||||
injected_keys = self._injected_args_keys
|
||||
|
||||
# Inject tool_call_id for InjectedToolCallId field
|
||||
if (tc_key := self._injected_tool_call_id_key) is not None:
|
||||
if tool_call_id is not None:
|
||||
# Real tool_call_id from ToolCall overrides any LLM-generated value
|
||||
tool_input[tc_key] = tool_call_id
|
||||
elif tc_key not in tool_input:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId argument, tool "
|
||||
"must always be invoked with a full model ToolCall of the "
|
||||
"form: {'args': {...}, 'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Validate and dump to dict (excluding injected keys to avoid warnings)
|
||||
result: BaseModel | BaseModelV1
|
||||
if is_v2:
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump(exclude=set(injected_keys))
|
||||
else:
|
||||
result = cast("type[BaseModelV1]", input_args).parse_obj(tool_input)
|
||||
result_dict = result.dict(exclude=injected_keys)
|
||||
|
||||
# Build validated_input: include provided fields + fields with defaults
|
||||
field_info = get_fields(input_args)
|
||||
validated_input: dict[str, Any] = {}
|
||||
for k in result_dict:
|
||||
if k in tool_input:
|
||||
validated_input[k] = getattr(result, k)
|
||||
elif k in field_info and k not in ("args", "kwargs"):
|
||||
# Include fields with explicit defaults (not synthetic *args/**kwargs)
|
||||
fi = field_info[k]
|
||||
has_default = (
|
||||
not fi.is_required()
|
||||
if hasattr(fi, "is_required")
|
||||
else not getattr(fi, "required", True)
|
||||
)
|
||||
if has_default:
|
||||
validated_input[k] = getattr(result, k)
|
||||
elif k in field_info and k not in ("args", "kwargs"):
|
||||
# Check if field has an explicit default defined in the schema.
|
||||
# Exclude "args"/"kwargs" as these are synthetic fields for variadic
|
||||
# parameters that should not be passed as keyword arguments.
|
||||
fi = field_info[k]
|
||||
# Pydantic v2 uses is_required() method, v1 uses required attribute
|
||||
has_default = (
|
||||
not fi.is_required()
|
||||
if hasattr(fi, "is_required")
|
||||
else not getattr(fi, "required", True)
|
||||
)
|
||||
if has_default:
|
||||
validated_input[k] = getattr(result, k)
|
||||
|
||||
for k in self._injected_args_keys:
|
||||
if k in tool_input:
|
||||
validated_input[k] = tool_input[k]
|
||||
elif k == "tool_call_id":
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
validated_input[k] = tool_call_id
|
||||
# Add back injected args (were excluded from model_dump)
|
||||
for k in injected_keys:
|
||||
if k in tool_input:
|
||||
validated_input[k] = tool_input[k]
|
||||
|
||||
return validated_input
|
||||
|
||||
return tool_input
|
||||
return validated_input
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
@@ -23,10 +23,10 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
_EMPTY_SET,
|
||||
FILTERED_ARGS,
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
InjectedToolCallId,
|
||||
_get_runnable_config_param,
|
||||
_is_injected_arg_type,
|
||||
create_schema_from_function,
|
||||
@@ -246,15 +246,24 @@ class StructuredTool(BaseTool):
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
def _injected_arg_info(self) -> tuple[frozenset[str], str | None]:
|
||||
# Combine injected args from schema (via super) and function signature
|
||||
schema_keys, schema_tc_key = super()._injected_arg_info
|
||||
fn = self.func or self.coroutine
|
||||
if fn is None:
|
||||
return _EMPTY_SET
|
||||
return frozenset(
|
||||
k
|
||||
for k, v in signature(fn).parameters.items()
|
||||
if _is_injected_arg_type(v.annotation)
|
||||
)
|
||||
return schema_keys, schema_tc_key
|
||||
|
||||
func_keys: set[str] = set()
|
||||
func_tc_key: str | None = None
|
||||
for k, v in signature(fn).parameters.items():
|
||||
if _is_injected_arg_type(v.annotation):
|
||||
func_keys.add(k)
|
||||
if _is_injected_arg_type(
|
||||
v.annotation, injected_type=InjectedToolCallId
|
||||
):
|
||||
func_tc_key = k
|
||||
|
||||
return schema_keys | func_keys, schema_tc_key or func_tc_key
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> list[str]:
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
@@ -3184,98 +3185,28 @@ def test_filter_tool_runtime_directly_injected_arg() -> None:
|
||||
assert "runtime" not in captured
|
||||
|
||||
|
||||
# Custom directly injected arg type (similar to ToolRuntime)
|
||||
class _CustomRuntime(_DirectlyInjectedToolArg):
|
||||
"""Custom runtime info injected at tool call time."""
|
||||
def test_filter_injected_args_from_callback_inputs() -> None:
|
||||
"""Test that injected args from function signature are filtered from callbacks."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
@tool
|
||||
def my_tool(
|
||||
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
|
||||
) -> str:
|
||||
"""Tool with injected runtime arg.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Max results.
|
||||
runtime: Injected runtime info.
|
||||
"""
|
||||
return f"Query: {query}, Limit: {limit}"
|
||||
|
||||
# Schema that does NOT include the injected arg
|
||||
class _ToolArgsSchemaNoRuntime(BaseModel):
|
||||
"""Schema with only the non-injected args."""
|
||||
|
||||
query: str
|
||||
limit: int
|
||||
|
||||
|
||||
def _tool_func_directly_injected(
|
||||
query: str, limit: int, runtime: _CustomRuntime
|
||||
) -> str:
|
||||
"""Tool with directly injected runtime not in schema.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Max results.
|
||||
runtime: Custom runtime (directly injected, not in schema).
|
||||
"""
|
||||
return f"Query: {query}, Limit: {limit}"
|
||||
|
||||
|
||||
def _tool_func_annotated_injected(
|
||||
query: str, limit: int, runtime: Annotated[Any, InjectedToolArg()]
|
||||
) -> str:
|
||||
"""Tool with Annotated injected runtime not in schema.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Max results.
|
||||
runtime: Custom runtime (annotated as injected, not in schema).
|
||||
"""
|
||||
return f"Query: {query}, Limit: {limit}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_func", "runtime_value", "description"),
|
||||
[
|
||||
pytest.param(
|
||||
_tool_func_directly_injected,
|
||||
_CustomRuntime(data={"foo": "bar"}),
|
||||
"directly injected (_DirectlyInjectedToolArg subclass)",
|
||||
id="directly_injected",
|
||||
),
|
||||
pytest.param(
|
||||
_tool_func_annotated_injected,
|
||||
{"foo": "bar"},
|
||||
"annotated injected (Annotated[Any, InjectedToolArg()])",
|
||||
id="annotated_injected",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_filter_injected_args_not_in_schema(
|
||||
tool_func: Callable[..., str], runtime_value: Any, description: str
|
||||
) -> None:
|
||||
"""Test filtering injected args that are in function signature but not in schema.
|
||||
|
||||
This tests the case where an injected argument (like ToolRuntime) is in the
|
||||
function signature but is not present in the args_schema. The fix ensures
|
||||
we check _injected_args_keys from the function signature, not just the schema.
|
||||
|
||||
Args:
|
||||
tool_func: The tool function with an injected arg.
|
||||
runtime_value: The value to pass for the runtime arg.
|
||||
description: Description of the injection style being tested.
|
||||
"""
|
||||
# Create StructuredTool with explicit args_schema that excludes runtime
|
||||
custom_tool = StructuredTool.from_function(
|
||||
func=tool_func,
|
||||
name="custom_tool",
|
||||
description=f"Tool with {description} arg not in schema",
|
||||
args_schema=_ToolArgsSchemaNoRuntime,
|
||||
)
|
||||
|
||||
# Verify _injected_args_keys contains 'runtime'
|
||||
assert "runtime" in custom_tool._injected_args_keys
|
||||
assert "runtime" in my_tool._injected_args_keys
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
|
||||
result = custom_tool.invoke(
|
||||
{
|
||||
"query": "test",
|
||||
"limit": 5,
|
||||
"runtime": runtime_value,
|
||||
},
|
||||
result = my_tool.invoke(
|
||||
{"query": "test", "limit": 5, "runtime": {"foo": "bar"}},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
@@ -3283,7 +3214,7 @@ def test_filter_injected_args_not_in_schema(
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify that runtime is filtered out even though it's not in args_schema
|
||||
# Verify that runtime is filtered from callback inputs
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert captured == {"query": "test", "limit": 5}
|
||||
@@ -3610,3 +3541,52 @@ def test_tool_args_schema_falsy_defaults() -> None:
|
||||
# Invoke with only required argument - falsy defaults should be applied
|
||||
result = config_tool.invoke({"name": "test"})
|
||||
assert result == "name=test, enabled=False, count=0, prefix=''"
|
||||
|
||||
|
||||
def test_injected_args_no_pydantic_warnings_during_parse_input() -> None:
|
||||
"""Test that no Pydantic warnings are triggered for injected args.
|
||||
|
||||
When injected args are in the schema but have type mismatches with the actual
|
||||
values, model_dump() should not serialize them to avoid Pydantic warnings.
|
||||
|
||||
Regression test for https://github.com/langchain-ai/langchain/issues/34770
|
||||
"""
|
||||
|
||||
class InjectedContext:
|
||||
"""A context object that will be injected."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
|
||||
class ToolArgsWithInjected(BaseModel):
|
||||
"""Tool args schema with an injected arg."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
query: str = Field(..., description="The search query")
|
||||
# Annotated with a specific type, but actual value may differ
|
||||
context: Annotated[InjectedContext, InjectedToolArg()] = Field(
|
||||
..., description="Injected context"
|
||||
)
|
||||
|
||||
@tool("search_with_context", args_schema=ToolArgsWithInjected)
|
||||
def search_with_context(query: str, context: InjectedContext) -> str:
|
||||
"""Search with injected context.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
context: The injected context.
|
||||
"""
|
||||
return f"Query: {query}, Context: {context.data}"
|
||||
|
||||
# Create an injected context
|
||||
injected_context = InjectedContext(data={"key": "value"})
|
||||
|
||||
# Invoke the tool - should not trigger any Pydantic warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
result = search_with_context.invoke(
|
||||
{"query": "test", "context": injected_context}
|
||||
)
|
||||
|
||||
assert result == "Query: test, Context: {'key': 'value'}"
|
||||
|
||||
Reference in New Issue
Block a user