format + typing

This commit is contained in:
Mason Daugherty 2025-07-16 10:54:27 -04:00
parent 587f3b0090
commit 300299d0d0
No known key found for this signature in database
2 changed files with 10 additions and 8 deletions

View File

@ -689,9 +689,7 @@ class ChildTool(BaseTool):
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}" f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
return { return {k: getattr(result, k) for k, v in result_dict.items()}
k: getattr(result, k) for k, v in result_dict.items()
}
return tool_input return tool_input
@model_validator(mode="before") @model_validator(mode="before")

View File

@ -634,7 +634,7 @@ def test_named_tool_decorator_return_direct() -> None:
"""Test functionality when arguments and return direct are provided as input.""" """Test functionality when arguments and return direct are provided as input."""
@tool("search", return_direct=True) @tool("search", return_direct=True)
def search_api(query: str, *args: Any) -> str: def search_api(query: str) -> str:
"""Search the API for the query.""" """Search the API for the query."""
return "API result" return "API result"
@ -2549,24 +2549,28 @@ def test_tool_args_schema_with_pydantic_validator() -> None:
x: NestedArgsSchema x: NestedArgsSchema
@model_validator(mode="before") @model_validator(mode="before")
def wrap_in_x(cls, data: dict) -> dict: def wrap_in_x(cls, data: Any) -> Any: # noqa: N805
if not isinstance(data, dict):
return {"x": data}
if "x" not in data: if "x" not in data:
return {"x": data} return {"x": data}
return data return data
@tool(args_schema=ArgsSchema) @tool(args_schema=ArgsSchema)
def foo(**args) -> ArgsSchema: def foo(**args: Any) -> ArgsSchema:
"""Bar.""" """Bar."""
return ArgsSchema.model_validate(args) return ArgsSchema.model_validate(args)
# Test case where validator is identity function # Test case where validator is identity function
valid_inputs = {"x": {"y": 5}} valid_inputs = {"x": {"y": 5}}
assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs) assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs)
# Test case where validator wraps input in "x" # Test case where validator wraps input in "x"
invalid_inputs = {"y": 5} invalid_inputs = {"y": 5}
assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate({"x": invalid_inputs}) assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate(
{"x": invalid_inputs}
)
def test_title_property_preserved() -> None: def test_title_property_preserved() -> None: