_prep_run_args,tool_input copy

This commit is contained in:
xzq.xu 2025-03-26 22:56:32 +08:00
parent e90abce577
commit 3382b0d8ea

View File

@ -91,11 +91,11 @@ def _get_annotation_description(arg_type: type) -> str | None:
def _get_filtered_args( def _get_filtered_args(
inferred_model: type[BaseModel], inferred_model: type[BaseModel],
func: Callable, func: Callable,
*, *,
filter_args: Sequence[str], filter_args: Sequence[str],
include_injected: bool = True, include_injected: bool = True,
) -> dict: ) -> dict:
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.model_json_schema()["properties"] schema = inferred_model.model_json_schema()["properties"]
@ -104,13 +104,13 @@ def _get_filtered_args(
k: schema[k] k: schema[k]
for i, (k, param) in enumerate(valid_keys.items()) for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args if k not in filter_args
and (i > 0 or param.name not in ("self", "cls")) and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation)) and (include_injected or not _is_injected_arg_type(param.annotation))
} }
def _parse_python_function_docstring( def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]: ) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function. """Parse the function and argument descriptions from the docstring of a function.
@ -125,7 +125,7 @@ def _parse_python_function_docstring(
def _validate_docstring_args_against_annotations( def _validate_docstring_args_against_annotations(
arg_descriptions: dict, annotations: dict arg_descriptions: dict, annotations: dict
) -> None: ) -> None:
"""Raise error if docstring arg is not in type annotations.""" """Raise error if docstring arg is not in type annotations."""
for docstring_arg in arg_descriptions: for docstring_arg in arg_descriptions:
@ -135,10 +135,10 @@ def _validate_docstring_args_against_annotations(
def _infer_arg_descriptions( def _infer_arg_descriptions(
fn: Callable, fn: Callable,
*, *,
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]: ) -> tuple[str, dict]:
"""Infer argument descriptions from a function's docstring.""" """Infer argument descriptions from a function's docstring."""
if hasattr(inspect, "get_annotations"): if hasattr(inspect, "get_annotations"):
@ -173,7 +173,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
def _function_annotations_are_pydantic_v1( def _function_annotations_are_pydantic_v1(
signature: inspect.Signature, func: Callable signature: inspect.Signature, func: Callable
) -> bool: ) -> bool:
"""Determine if all Pydantic annotations in a function signature are from V1.""" """Determine if all Pydantic annotations in a function signature are from V1."""
any_v1_annotations = any( any_v1_annotations = any(
@ -210,13 +210,13 @@ class _SchemaConfig:
def create_schema_from_function( def create_schema_from_function(
model_name: str, model_name: str,
func: Callable, func: Callable,
*, *,
filter_args: Optional[Sequence[str]] = None, filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
include_injected: bool = True, include_injected: bool = True,
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic schema from a function's signature. """Create a pydantic schema from a function's signature.
@ -277,7 +277,7 @@ def create_schema_from_function(
for existing_param in existing_params: for existing_param in existing_params:
if not include_injected and _is_injected_arg_type( if not include_injected and _is_injected_arg_type(
sig.parameters[existing_param].annotation sig.parameters[existing_param].annotation
): ):
filter_args_.append(existing_param) filter_args_.append(existing_param)
@ -427,10 +427,10 @@ class ChildTool(BaseTool):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool.""" """Initialize the tool."""
if ( if (
"args_schema" in kwargs "args_schema" in kwargs
and kwargs["args_schema"] is not None and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"]) and not is_basemodel_subclass(kwargs["args_schema"])
and not isinstance(kwargs["args_schema"], dict) and not isinstance(kwargs["args_schema"], dict)
): ):
msg = ( msg = (
"args_schema must be a subclass of pydantic BaseModel or " "args_schema must be a subclass of pydantic BaseModel or "
@ -481,7 +481,7 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]: ) -> type[BaseModel]:
"""The tool's input schema. """The tool's input schema.
@ -499,19 +499,19 @@ class ChildTool(BaseTool):
return create_schema_from_function(self.name, self._run) return create_schema_from_function(self.name, self._run)
def invoke( def invoke(
self, self,
input: Union[str, dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs) tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return self.run(tool_input, **kwargs) return self.run(tool_input, **kwargs)
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs) tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return await self.arun(tool_input, **kwargs) return await self.arun(tool_input, **kwargs)
@ -519,7 +519,7 @@ class ChildTool(BaseTool):
# --- Tool --- # --- Tool ---
def _parse_input( def _parse_input(
self, tool_input: Union[str, dict], tool_call_id: Optional[str] self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> Union[str, dict[str, Any]]: ) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model. """Convert tool input to a pydantic model.
@ -548,8 +548,8 @@ class ChildTool(BaseTool):
elif issubclass(input_args, BaseModel): elif issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items(): for k, v in get_all_basemodel_annotations(input_args).items():
if ( if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId) _is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input and k not in tool_input
): ):
if tool_call_id is None: if tool_call_id is None:
msg = ( msg = (
@ -566,8 +566,8 @@ class ChildTool(BaseTool):
elif issubclass(input_args, BaseModelV1): elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items(): for k, v in get_all_basemodel_annotations(input_args).items():
if ( if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId) _is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input and k not in tool_input
): ):
if tool_call_id is None: if tool_call_id is None:
msg = ( msg = (
@ -629,19 +629,19 @@ class ChildTool(BaseTool):
to child implementations to enable tracing. to child implementations to enable tracing.
""" """
if kwargs.get("run_manager") and signature(self._run).parameters.get( if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager" "run_manager"
): ):
kwargs["run_manager"] = kwargs["run_manager"].get_sync() kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs) return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs( def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str] self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]: ) -> tuple[tuple, dict]:
if ( if (
self.args_schema is not None self.args_schema is not None
and isinstance(self.args_schema, type) and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema) and is_basemodel_subclass(self.args_schema)
and not get_fields(self.args_schema) and not get_fields(self.args_schema)
): ):
# StructuredTool with no args # StructuredTool with no args
return (), {} return (), {}
@ -654,20 +654,20 @@ class ChildTool(BaseTool):
return (), tool_input return (), tool_input
def run( def run(
self, self,
tool_input: Union[str, dict[str, Any]], tool_input: Union[str, dict[str, Any]],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool. """Run the tool.
@ -766,20 +766,20 @@ class ChildTool(BaseTool):
return output return output
async def arun( async def arun(
self, self,
tool_input: Union[str, dict], tool_input: Union[str, dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the tool asynchronously. """Run the tool asynchronously.
@ -837,8 +837,6 @@ class ChildTool(BaseTool):
self._run if self.__class__._arun is BaseTool._arun else self._arun self._run if self.__class__._arun is BaseTool._arun else self._arun
) )
if signature(func_to_check).parameters.get("run_manager"): if signature(func_to_check).parameters.get("run_manager"):
import copy
tool_kwargs = copy.deepcopy(tool_kwargs)
tool_kwargs["run_manager"] = run_manager tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check): if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config tool_kwargs[config_param] = config
@ -895,11 +893,11 @@ def _is_tool_call(x: Any) -> bool:
def _handle_validation_error( def _handle_validation_error(
e: Union[ValidationError, ValidationErrorV1], e: Union[ValidationError, ValidationErrorV1],
*, *,
flag: Union[ flag: Union[
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
], ],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, bool):
content = "Tool input validation error" content = "Tool input validation error"
@ -917,9 +915,9 @@ def _handle_validation_error(
def _handle_tool_error( def _handle_tool_error(
e: ToolException, e: ToolException,
*, *,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, bool):
content = e.args[0] if e.args else "Tool execution error" content = e.args[0] if e.args else "Tool execution error"
@ -937,9 +935,9 @@ def _handle_tool_error(
def _prep_run_args( def _prep_run_args(
input: Union[str, dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
**kwargs: Any, **kwargs: Any,
) -> tuple[Union[str, dict], dict]: ) -> tuple[Union[str, dict], dict]:
config = ensure_config(config) config = ensure_config(config)
if _is_tool_call(input): if _is_tool_call(input):
@ -947,7 +945,7 @@ def _prep_run_args(
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy() tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
else: else:
tool_call_id = None tool_call_id = None
tool_input = cast(Union[str, dict], input) tool_input = cast(Union[str, dict], input).copy()
return ( return (
tool_input, tool_input,
dict( dict(
@ -964,11 +962,11 @@ def _prep_run_args(
def _format_output( def _format_output(
content: Any, content: Any,
artifact: Any, artifact: Any,
tool_call_id: Optional[str], tool_call_id: Optional[str],
name: str, name: str,
status: str, status: str,
) -> Union[ToolOutputMixin, Any]: ) -> Union[ToolOutputMixin, Any]:
if isinstance(content, ToolOutputMixin) or tool_call_id is None: if isinstance(content, ToolOutputMixin) or tool_call_id is None:
return content return content
@ -986,9 +984,9 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool: def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content.""" """Check for OpenAI or Anthropic format tool message content."""
return ( return (
isinstance(obj, str) isinstance(obj, str)
or isinstance(obj, list) or isinstance(obj, list)
and all(_is_message_content_block(e) for e in obj) and all(_is_message_content_block(e) for e in obj)
) )
@ -1051,7 +1049,7 @@ class InjectedToolCallId(InjectedToolArg):
def _is_injected_arg_type( def _is_injected_arg_type(
type_: type, injected_type: Optional[type[InjectedToolArg]] = None type_: type, injected_type: Optional[type[InjectedToolArg]] = None
) -> bool: ) -> bool:
injected_type = injected_type or InjectedToolArg injected_type = injected_type or InjectedToolArg
return any( return any(
@ -1062,7 +1060,7 @@ def _is_injected_arg_type(
def get_all_basemodel_annotations( def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> dict[str, type]: ) -> dict[str, type]:
# cls has no subscript: cls = FooBar # cls has no subscript: cls = FooBar
if isinstance(cls, type): if isinstance(cls, type):
@ -1071,8 +1069,8 @@ def get_all_basemodel_annotations(
# Exclude hidden init args added by pydantic Config. For example if # Exclude hidden init args added by pydantic Config. For example if
# BaseModel(extra="allow") then "extra_data" will part of init sig. # BaseModel(extra="allow") then "extra_data" will part of init sig.
if ( if (
fields := getattr(cls, "model_fields", {}) # pydantic v2+ fields := getattr(cls, "model_fields", {}) # pydantic v2+
or getattr(cls, "__fields__", {}) # pydantic v1 or getattr(cls, "__fields__", {}) # pydantic v1
) and name not in fields: ) and name not in fields:
continue continue
annotations[name] = param.annotation annotations[name] = param.annotation
@ -1122,9 +1120,9 @@ def get_all_basemodel_annotations(
def _replace_type_vars( def _replace_type_vars(
type_: type, type_: type,
generic_map: Optional[dict[TypeVar, type]] = None, generic_map: Optional[dict[TypeVar, type]] = None,
default_to_bound: bool = True, default_to_bound: bool = True,
) -> type: ) -> type:
generic_map = generic_map or {} generic_map = generic_map or {}
if isinstance(type_, TypeVar): if isinstance(type_, TypeVar):