Compare commits

...

5 Commits

Author SHA1 Message Date
Mason Daugherty
1dcebc7101 Merge branch 'master' into eugene/update_tools_schema 2025-08-11 18:23:47 -04:00
Eugene Yurtsev
fd532c692d x 2025-06-20 13:17:12 -04:00
Eugene Yurtsev
97db95b426 x 2025-06-20 13:16:40 -04:00
Eugene Yurtsev
b9ebae37a4 x 2025-06-20 13:14:10 -04:00
Eugene Yurtsev
d3930f9906 x 2025-06-20 12:41:21 -04:00
2 changed files with 104 additions and 18 deletions

View File

@@ -392,6 +392,10 @@ class ToolException(Exception): # noqa: N818
ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
"""Used to define the schema for tool arguments.
This schema should not include injected arguments like `InjectedToolCallId`.
"""
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
@@ -581,11 +585,17 @@ class ChildTool(BaseTool):
) -> type[BaseModel]:
"""The tool's input schema.
All required arguments will be included as part of this schema, including
any injected arguments like `InjectedToolCallId`.
A schema without injected arguments can be obtained using the
`tool_call_schema` property instead.
Args:
config: The configuration for the tool.
Returns:
The input schema for the tool.
The full input schema for the tool, including injected arguments.
"""
if self.args_schema is not None:
if isinstance(self.args_schema, dict):
@@ -632,29 +642,29 @@ class ChildTool(BaseTool):
InjectedToolCallId is required but not provided.
NotImplementedError: If args_schema is not a supported type.
"""
input_args = self.args_schema
args_schema = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
if isinstance(input_args, dict):
if args_schema is not None:
if isinstance(args_schema, dict):
msg = (
"String tool inputs are not allowed when "
"using tools with JSON schema args_schema."
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if issubclass(input_args, BaseModel):
input_args.model_validate({key_: tool_input})
elif issubclass(input_args, BaseModelV1):
input_args.parse_obj({key_: tool_input})
key_ = next(iter(get_fields(args_schema).keys()))
if issubclass(args_schema, BaseModel):
args_schema.model_validate({key_: tool_input})
elif issubclass(args_schema, BaseModelV1):
args_schema.parse_obj({key_: tool_input})
else:
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
msg = f"args_schema must be a Pydantic BaseModel, got {args_schema}"
raise TypeError(msg)
return tool_input
if input_args is not None:
if isinstance(input_args, dict):
if args_schema is not None:
if isinstance(args_schema, dict):
return tool_input
if issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if issubclass(args_schema, BaseModel):
for k, v in get_all_basemodel_annotations(args_schema).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
@@ -669,10 +679,10 @@ class ChildTool(BaseTool):
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result = args_schema.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
elif issubclass(args_schema, BaseModelV1):
for k, v in get_all_basemodel_annotations(args_schema).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
@@ -687,7 +697,7 @@ class ChildTool(BaseTool):
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result = args_schema.parse_obj(tool_input)
result_dict = result.dict()
else:
msg = (

View File

@@ -2705,3 +2705,79 @@ def test_tool_args_schema_with_annotated_type() -> None:
"type": "array",
}
}
def test_args_schema_without_injected_arguments() -> None:
"""Test that injected args work correctly in tool schemas."""
class CustomSchema(BaseModel):
"""Schema for LLM inputs."""
text: str = Field(description="Text to process")
# InjectedToolCallId should NOT need to be here
@tool(
"broken_tool",
description="Tool that should work with automatic injection",
parse_docstring=True,
args_schema=CustomSchema,
)
def broken_tool(
text: str,
tool_call_id: Annotated[str, InjectedToolCallId],
) -> str:
"""Tool with injected arguments."""
return f"Processed '{text}' (call_id: {tool_call_id})"
# Test tool call structure
tool_call = {
"name": "test_tool",
"args": {"text": "test data"},
"id": "test_123",
"type": "tool_call",
}
# The full input schema should include tool_call_id for validation
schema = _schema(broken_tool.get_input_schema())
assert schema == {
"description": "Tool with injected arguments.",
"properties": {
"text": {
"title": "Text",
"type": "string",
"description": "Text to process",
},
# Tool call ID **is** required in this case.
"tool_call_id": {
"title": "Tool Call Id",
"type": "string",
},
},
"required": [
"text",
"tool_call_id",
],
"title": "CustomSchema",
"type": "object",
}
result = broken_tool.invoke(tool_call)
assert isinstance(result, ToolMessage)
assert result.content == "Processed 'test data' (call_id: test_123)"
# The tool schema does not include injectable arguments that are hidden from the
# LLM.
tool_call_schema = _schema(broken_tool.tool_call_schema)
assert tool_call_schema == {
"description": "Tool that should work with automatic injection",
"properties": {
"text": {
"description": "Text to process",
"title": "Text",
"type": "string",
}
},
"required": ["text"],
"title": "broken_tool",
"type": "object",
}