mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 20:49:17 +00:00
_prep_run_args,tool_input copy
This commit is contained in:
parent
e90abce577
commit
3382b0d8ea
@ -91,11 +91,11 @@ def _get_annotation_description(arg_type: type) -> str | None:
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: type[BaseModel],
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Sequence[str],
|
||||
include_injected: bool = True,
|
||||
inferred_model: type[BaseModel],
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Sequence[str],
|
||||
include_injected: bool = True,
|
||||
) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.model_json_schema()["properties"]
|
||||
@ -104,13 +104,13 @@ def _get_filtered_args(
|
||||
k: schema[k]
|
||||
for i, (k, param) in enumerate(valid_keys.items())
|
||||
if k not in filter_args
|
||||
and (i > 0 or param.name not in ("self", "cls"))
|
||||
and (include_injected or not _is_injected_arg_type(param.annotation))
|
||||
and (i > 0 or param.name not in ("self", "cls"))
|
||||
and (include_injected or not _is_injected_arg_type(param.annotation))
|
||||
}
|
||||
|
||||
|
||||
def _parse_python_function_docstring(
|
||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
@ -125,7 +125,7 @@ def _parse_python_function_docstring(
|
||||
|
||||
|
||||
def _validate_docstring_args_against_annotations(
|
||||
arg_descriptions: dict, annotations: dict
|
||||
arg_descriptions: dict, annotations: dict
|
||||
) -> None:
|
||||
"""Raise error if docstring arg is not in type annotations."""
|
||||
for docstring_arg in arg_descriptions:
|
||||
@ -135,10 +135,10 @@ def _validate_docstring_args_against_annotations(
|
||||
|
||||
|
||||
def _infer_arg_descriptions(
|
||||
fn: Callable,
|
||||
*,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
fn: Callable,
|
||||
*,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Infer argument descriptions from a function's docstring."""
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
@ -173,7 +173,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
||||
|
||||
|
||||
def _function_annotations_are_pydantic_v1(
|
||||
signature: inspect.Signature, func: Callable
|
||||
signature: inspect.Signature, func: Callable
|
||||
) -> bool:
|
||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||
any_v1_annotations = any(
|
||||
@ -210,13 +210,13 @@ class _SchemaConfig:
|
||||
|
||||
|
||||
def create_schema_from_function(
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Optional[Sequence[str]] = None,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
include_injected: bool = True,
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Optional[Sequence[str]] = None,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
include_injected: bool = True,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
|
||||
@ -277,7 +277,7 @@ def create_schema_from_function(
|
||||
|
||||
for existing_param in existing_params:
|
||||
if not include_injected and _is_injected_arg_type(
|
||||
sig.parameters[existing_param].annotation
|
||||
sig.parameters[existing_param].annotation
|
||||
):
|
||||
filter_args_.append(existing_param)
|
||||
|
||||
@ -427,10 +427,10 @@ class ChildTool(BaseTool):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the tool."""
|
||||
if (
|
||||
"args_schema" in kwargs
|
||||
and kwargs["args_schema"] is not None
|
||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||
and not isinstance(kwargs["args_schema"], dict)
|
||||
"args_schema" in kwargs
|
||||
and kwargs["args_schema"] is not None
|
||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||
and not isinstance(kwargs["args_schema"], dict)
|
||||
):
|
||||
msg = (
|
||||
"args_schema must be a subclass of pydantic BaseModel or "
|
||||
@ -481,7 +481,7 @@ class ChildTool(BaseTool):
|
||||
# --- Runnable ---
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> type[BaseModel]:
|
||||
"""The tool's input schema.
|
||||
|
||||
@ -499,19 +499,19 @@ class ChildTool(BaseTool):
|
||||
return create_schema_from_function(self.name, self._run)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return self.run(tool_input, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return await self.arun(tool_input, **kwargs)
|
||||
@ -519,7 +519,7 @@ class ChildTool(BaseTool):
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
"""Convert tool input to a pydantic model.
|
||||
|
||||
@ -548,8 +548,8 @@ class ChildTool(BaseTool):
|
||||
elif issubclass(input_args, BaseModel):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
@ -566,8 +566,8 @@ class ChildTool(BaseTool):
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if (
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||
and k not in tool_input
|
||||
):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
@ -629,19 +629,19 @@ class ChildTool(BaseTool):
|
||||
to child implementations to enable tracing.
|
||||
"""
|
||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||
"run_manager"
|
||||
"run_manager"
|
||||
):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||
) -> tuple[tuple, dict]:
|
||||
if (
|
||||
self.args_schema is not None
|
||||
and isinstance(self.args_schema, type)
|
||||
and is_basemodel_subclass(self.args_schema)
|
||||
and not get_fields(self.args_schema)
|
||||
self.args_schema is not None
|
||||
and isinstance(self.args_schema, type)
|
||||
and is_basemodel_subclass(self.args_schema)
|
||||
and not get_fields(self.args_schema)
|
||||
):
|
||||
# StructuredTool with no args
|
||||
return (), {}
|
||||
@ -654,20 +654,20 @@ class ChildTool(BaseTool):
|
||||
return (), tool_input
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, dict[str, Any]],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
tool_input: Union[str, dict[str, Any]],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool.
|
||||
|
||||
@ -766,20 +766,20 @@ class ChildTool(BaseTool):
|
||||
return output
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously.
|
||||
|
||||
@ -837,8 +837,6 @@ class ChildTool(BaseTool):
|
||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||
)
|
||||
if signature(func_to_check).parameters.get("run_manager"):
|
||||
import copy
|
||||
tool_kwargs = copy.deepcopy(tool_kwargs)
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
if config_param := _get_runnable_config_param(func_to_check):
|
||||
tool_kwargs[config_param] = config
|
||||
@ -895,11 +893,11 @@ def _is_tool_call(x: Any) -> bool:
|
||||
|
||||
|
||||
def _handle_validation_error(
|
||||
e: Union[ValidationError, ValidationErrorV1],
|
||||
*,
|
||||
flag: Union[
|
||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
e: Union[ValidationError, ValidationErrorV1],
|
||||
*,
|
||||
flag: Union[
|
||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||
],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
content = "Tool input validation error"
|
||||
@ -917,9 +915,9 @@ def _handle_validation_error(
|
||||
|
||||
|
||||
def _handle_tool_error(
|
||||
e: ToolException,
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
e: ToolException,
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
content = e.args[0] if e.args else "Tool execution error"
|
||||
@ -937,9 +935,9 @@ def _handle_tool_error(
|
||||
|
||||
|
||||
def _prep_run_args(
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> tuple[Union[str, dict], dict]:
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(input):
|
||||
@ -947,7 +945,7 @@ def _prep_run_args(
|
||||
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
|
||||
else:
|
||||
tool_call_id = None
|
||||
tool_input = cast(Union[str, dict], input)
|
||||
tool_input = cast(Union[str, dict], input).copy()
|
||||
return (
|
||||
tool_input,
|
||||
dict(
|
||||
@ -964,11 +962,11 @@ def _prep_run_args(
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any,
|
||||
artifact: Any,
|
||||
tool_call_id: Optional[str],
|
||||
name: str,
|
||||
status: str,
|
||||
content: Any,
|
||||
artifact: Any,
|
||||
tool_call_id: Optional[str],
|
||||
name: str,
|
||||
status: str,
|
||||
) -> Union[ToolOutputMixin, Any]:
|
||||
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
||||
return content
|
||||
@ -986,9 +984,9 @@ def _format_output(
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
"""Check for OpenAI or Anthropic format tool message content."""
|
||||
return (
|
||||
isinstance(obj, str)
|
||||
or isinstance(obj, list)
|
||||
and all(_is_message_content_block(e) for e in obj)
|
||||
isinstance(obj, str)
|
||||
or isinstance(obj, list)
|
||||
and all(_is_message_content_block(e) for e in obj)
|
||||
)
|
||||
|
||||
|
||||
@ -1051,7 +1049,7 @@ class InjectedToolCallId(InjectedToolArg):
|
||||
|
||||
|
||||
def _is_injected_arg_type(
|
||||
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||
) -> bool:
|
||||
injected_type = injected_type or InjectedToolArg
|
||||
return any(
|
||||
@ -1062,7 +1060,7 @@ def _is_injected_arg_type(
|
||||
|
||||
|
||||
def get_all_basemodel_annotations(
|
||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||
) -> dict[str, type]:
|
||||
# cls has no subscript: cls = FooBar
|
||||
if isinstance(cls, type):
|
||||
@ -1071,8 +1069,8 @@ def get_all_basemodel_annotations(
|
||||
# Exclude hidden init args added by pydantic Config. For example if
|
||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||
if (
|
||||
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
||||
or getattr(cls, "__fields__", {}) # pydantic v1
|
||||
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
||||
or getattr(cls, "__fields__", {}) # pydantic v1
|
||||
) and name not in fields:
|
||||
continue
|
||||
annotations[name] = param.annotation
|
||||
@ -1122,9 +1120,9 @@ def get_all_basemodel_annotations(
|
||||
|
||||
|
||||
def _replace_type_vars(
|
||||
type_: type,
|
||||
generic_map: Optional[dict[TypeVar, type]] = None,
|
||||
default_to_bound: bool = True,
|
||||
type_: type,
|
||||
generic_map: Optional[dict[TypeVar, type]] = None,
|
||||
default_to_bound: bool = True,
|
||||
) -> type:
|
||||
generic_map = generic_map or {}
|
||||
if isinstance(type_, TypeVar):
|
||||
@ -1148,4 +1146,4 @@ class BaseToolkit(BaseModel, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
"""Get the tools in the toolkit."""
|
Loading…
Reference in New Issue
Block a user