_prep_run_args,tool_input copy

This commit is contained in:
xzq.xu 2025-03-26 22:56:32 +08:00
parent e90abce577
commit 3382b0d8ea

View File

@ -91,11 +91,11 @@ def _get_annotation_description(arg_type: type) -> str | None:
def _get_filtered_args(
inferred_model: type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
include_injected: bool = True,
inferred_model: type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
include_injected: bool = True,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.model_json_schema()["properties"]
@ -104,13 +104,13 @@ def _get_filtered_args(
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args
and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation))
and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation))
}
def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
@ -125,7 +125,7 @@ def _parse_python_function_docstring(
def _validate_docstring_args_against_annotations(
arg_descriptions: dict, annotations: dict
arg_descriptions: dict, annotations: dict
) -> None:
"""Raise error if docstring arg is not in type annotations."""
for docstring_arg in arg_descriptions:
@ -135,10 +135,10 @@ def _validate_docstring_args_against_annotations(
def _infer_arg_descriptions(
fn: Callable,
*,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
fn: Callable,
*,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
if hasattr(inspect, "get_annotations"):
@ -173,7 +173,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
def _function_annotations_are_pydantic_v1(
signature: inspect.Signature, func: Callable
signature: inspect.Signature, func: Callable
) -> bool:
"""Determine if all Pydantic annotations in a function signature are from V1."""
any_v1_annotations = any(
@ -210,13 +210,13 @@ class _SchemaConfig:
def create_schema_from_function(
model_name: str,
func: Callable,
*,
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
model_name: str,
func: Callable,
*,
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
) -> type[BaseModel]:
"""Create a pydantic schema from a function's signature.
@ -277,7 +277,7 @@ def create_schema_from_function(
for existing_param in existing_params:
if not include_injected and _is_injected_arg_type(
sig.parameters[existing_param].annotation
sig.parameters[existing_param].annotation
):
filter_args_.append(existing_param)
@ -427,10 +427,10 @@ class ChildTool(BaseTool):
def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool."""
if (
"args_schema" in kwargs
and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"])
and not isinstance(kwargs["args_schema"], dict)
"args_schema" in kwargs
and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"])
and not isinstance(kwargs["args_schema"], dict)
):
msg = (
"args_schema must be a subclass of pydantic BaseModel or "
@ -481,7 +481,7 @@ class ChildTool(BaseTool):
# --- Runnable ---
def get_input_schema(
self, config: Optional[RunnableConfig] = None
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
"""The tool's input schema.
@ -499,19 +499,19 @@ class ChildTool(BaseTool):
return create_schema_from_function(self.name, self._run)
def invoke(
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return self.run(tool_input, **kwargs)
async def ainvoke(
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
self,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return await self.arun(tool_input, **kwargs)
@ -519,7 +519,7 @@ class ChildTool(BaseTool):
# --- Tool ---
def _parse_input(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model.
@ -548,8 +548,8 @@ class ChildTool(BaseTool):
elif issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
@ -566,8 +566,8 @@ class ChildTool(BaseTool):
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
@ -629,19 +629,19 @@ class ChildTool(BaseTool):
to child implementations to enable tracing.
"""
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
"run_manager"
):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
if (
self.args_schema is not None
and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema)
and not get_fields(self.args_schema)
self.args_schema is not None
and isinstance(self.args_schema, type)
and is_basemodel_subclass(self.args_schema)
and not get_fields(self.args_schema)
):
# StructuredTool with no args
return (), {}
@ -654,20 +654,20 @@ class ChildTool(BaseTool):
return (), tool_input
def run(
self,
tool_input: Union[str, dict[str, Any]],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
self,
tool_input: Union[str, dict[str, Any]],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool.
@ -766,20 +766,20 @@ class ChildTool(BaseTool):
return output
async def arun(
self,
tool_input: Union[str, dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
self,
tool_input: Union[str, dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously.
@ -837,8 +837,6 @@ class ChildTool(BaseTool):
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
import copy
tool_kwargs = copy.deepcopy(tool_kwargs)
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
@ -895,11 +893,11 @@ def _is_tool_call(x: Any) -> bool:
def _handle_validation_error(
e: Union[ValidationError, ValidationErrorV1],
*,
flag: Union[
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
e: Union[ValidationError, ValidationErrorV1],
*,
flag: Union[
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> str:
if isinstance(flag, bool):
content = "Tool input validation error"
@ -917,9 +915,9 @@ def _handle_validation_error(
def _handle_tool_error(
e: ToolException,
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
e: ToolException,
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
if isinstance(flag, bool):
content = e.args[0] if e.args else "Tool execution error"
@ -937,9 +935,9 @@ def _handle_tool_error(
def _prep_run_args(
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
) -> tuple[Union[str, dict], dict]:
config = ensure_config(config)
if _is_tool_call(input):
@ -947,7 +945,7 @@ def _prep_run_args(
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
else:
tool_call_id = None
tool_input = cast(Union[str, dict], input)
tool_input = cast(Union[str, dict], input).copy()
return (
tool_input,
dict(
@ -964,11 +962,11 @@ def _prep_run_args(
def _format_output(
content: Any,
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
content: Any,
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
) -> Union[ToolOutputMixin, Any]:
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
return content
@ -986,9 +984,9 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content."""
return (
isinstance(obj, str)
or isinstance(obj, list)
and all(_is_message_content_block(e) for e in obj)
isinstance(obj, str)
or isinstance(obj, list)
and all(_is_message_content_block(e) for e in obj)
)
@ -1051,7 +1049,7 @@ class InjectedToolCallId(InjectedToolArg):
def _is_injected_arg_type(
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
) -> bool:
injected_type = injected_type or InjectedToolArg
return any(
@ -1062,7 +1060,7 @@ def _is_injected_arg_type(
def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> dict[str, type]:
# cls has no subscript: cls = FooBar
if isinstance(cls, type):
@ -1071,8 +1069,8 @@ def get_all_basemodel_annotations(
# Exclude hidden init args added by pydantic Config. For example if
# BaseModel(extra="allow") then "extra_data" will part of init sig.
if (
fields := getattr(cls, "model_fields", {}) # pydantic v2+
or getattr(cls, "__fields__", {}) # pydantic v1
fields := getattr(cls, "model_fields", {}) # pydantic v2+
or getattr(cls, "__fields__", {}) # pydantic v1
) and name not in fields:
continue
annotations[name] = param.annotation
@ -1122,9 +1120,9 @@ def get_all_basemodel_annotations(
def _replace_type_vars(
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
default_to_bound: bool = True,
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
default_to_bound: bool = True,
) -> type:
generic_map = generic_map or {}
if isinstance(type_, TypeVar):
@ -1148,4 +1146,4 @@ class BaseToolkit(BaseModel, ABC):
@abstractmethod
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""
"""Get the tools in the toolkit."""