mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
_prep_run_args,tool_input copy
This commit is contained in:
parent
e90abce577
commit
3382b0d8ea
@ -91,11 +91,11 @@ def _get_annotation_description(arg_type: type) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def _get_filtered_args(
|
def _get_filtered_args(
|
||||||
inferred_model: type[BaseModel],
|
inferred_model: type[BaseModel],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
*,
|
*,
|
||||||
filter_args: Sequence[str],
|
filter_args: Sequence[str],
|
||||||
include_injected: bool = True,
|
include_injected: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.model_json_schema()["properties"]
|
schema = inferred_model.model_json_schema()["properties"]
|
||||||
@ -104,13 +104,13 @@ def _get_filtered_args(
|
|||||||
k: schema[k]
|
k: schema[k]
|
||||||
for i, (k, param) in enumerate(valid_keys.items())
|
for i, (k, param) in enumerate(valid_keys.items())
|
||||||
if k not in filter_args
|
if k not in filter_args
|
||||||
and (i > 0 or param.name not in ("self", "cls"))
|
and (i > 0 or param.name not in ("self", "cls"))
|
||||||
and (include_injected or not _is_injected_arg_type(param.annotation))
|
and (include_injected or not _is_injected_arg_type(param.annotation))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _parse_python_function_docstring(
|
def _parse_python_function_docstring(
|
||||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Parse the function and argument descriptions from the docstring of a function.
|
"""Parse the function and argument descriptions from the docstring of a function.
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ def _parse_python_function_docstring(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_docstring_args_against_annotations(
|
def _validate_docstring_args_against_annotations(
|
||||||
arg_descriptions: dict, annotations: dict
|
arg_descriptions: dict, annotations: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise error if docstring arg is not in type annotations."""
|
"""Raise error if docstring arg is not in type annotations."""
|
||||||
for docstring_arg in arg_descriptions:
|
for docstring_arg in arg_descriptions:
|
||||||
@ -135,10 +135,10 @@ def _validate_docstring_args_against_annotations(
|
|||||||
|
|
||||||
|
|
||||||
def _infer_arg_descriptions(
|
def _infer_arg_descriptions(
|
||||||
fn: Callable,
|
fn: Callable,
|
||||||
*,
|
*,
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = False,
|
error_on_invalid_docstring: bool = False,
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Infer argument descriptions from a function's docstring."""
|
"""Infer argument descriptions from a function's docstring."""
|
||||||
if hasattr(inspect, "get_annotations"):
|
if hasattr(inspect, "get_annotations"):
|
||||||
@ -173,7 +173,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
|||||||
|
|
||||||
|
|
||||||
def _function_annotations_are_pydantic_v1(
|
def _function_annotations_are_pydantic_v1(
|
||||||
signature: inspect.Signature, func: Callable
|
signature: inspect.Signature, func: Callable
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||||
any_v1_annotations = any(
|
any_v1_annotations = any(
|
||||||
@ -210,13 +210,13 @@ class _SchemaConfig:
|
|||||||
|
|
||||||
|
|
||||||
def create_schema_from_function(
|
def create_schema_from_function(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
func: Callable,
|
func: Callable,
|
||||||
*,
|
*,
|
||||||
filter_args: Optional[Sequence[str]] = None,
|
filter_args: Optional[Sequence[str]] = None,
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = False,
|
error_on_invalid_docstring: bool = False,
|
||||||
include_injected: bool = True,
|
include_injected: bool = True,
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create a pydantic schema from a function's signature.
|
"""Create a pydantic schema from a function's signature.
|
||||||
|
|
||||||
@ -277,7 +277,7 @@ def create_schema_from_function(
|
|||||||
|
|
||||||
for existing_param in existing_params:
|
for existing_param in existing_params:
|
||||||
if not include_injected and _is_injected_arg_type(
|
if not include_injected and _is_injected_arg_type(
|
||||||
sig.parameters[existing_param].annotation
|
sig.parameters[existing_param].annotation
|
||||||
):
|
):
|
||||||
filter_args_.append(existing_param)
|
filter_args_.append(existing_param)
|
||||||
|
|
||||||
@ -427,10 +427,10 @@ class ChildTool(BaseTool):
|
|||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
"""Initialize the tool."""
|
"""Initialize the tool."""
|
||||||
if (
|
if (
|
||||||
"args_schema" in kwargs
|
"args_schema" in kwargs
|
||||||
and kwargs["args_schema"] is not None
|
and kwargs["args_schema"] is not None
|
||||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||||
and not isinstance(kwargs["args_schema"], dict)
|
and not isinstance(kwargs["args_schema"], dict)
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
"args_schema must be a subclass of pydantic BaseModel or "
|
"args_schema must be a subclass of pydantic BaseModel or "
|
||||||
@ -481,7 +481,7 @@ class ChildTool(BaseTool):
|
|||||||
# --- Runnable ---
|
# --- Runnable ---
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""The tool's input schema.
|
"""The tool's input schema.
|
||||||
|
|
||||||
@ -499,19 +499,19 @@ class ChildTool(BaseTool):
|
|||||||
return create_schema_from_function(self.name, self._run)
|
return create_schema_from_function(self.name, self._run)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, dict, ToolCall],
|
input: Union[str, dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||||
return self.run(tool_input, **kwargs)
|
return self.run(tool_input, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, dict, ToolCall],
|
input: Union[str, dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||||
return await self.arun(tool_input, **kwargs)
|
return await self.arun(tool_input, **kwargs)
|
||||||
@ -519,7 +519,7 @@ class ChildTool(BaseTool):
|
|||||||
# --- Tool ---
|
# --- Tool ---
|
||||||
|
|
||||||
def _parse_input(
|
def _parse_input(
|
||||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> Union[str, dict[str, Any]]:
|
||||||
"""Convert tool input to a pydantic model.
|
"""Convert tool input to a pydantic model.
|
||||||
|
|
||||||
@ -548,8 +548,8 @@ class ChildTool(BaseTool):
|
|||||||
elif issubclass(input_args, BaseModel):
|
elif issubclass(input_args, BaseModel):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
and k not in tool_input
|
and k not in tool_input
|
||||||
):
|
):
|
||||||
if tool_call_id is None:
|
if tool_call_id is None:
|
||||||
msg = (
|
msg = (
|
||||||
@ -566,8 +566,8 @@ class ChildTool(BaseTool):
|
|||||||
elif issubclass(input_args, BaseModelV1):
|
elif issubclass(input_args, BaseModelV1):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
and k not in tool_input
|
and k not in tool_input
|
||||||
):
|
):
|
||||||
if tool_call_id is None:
|
if tool_call_id is None:
|
||||||
msg = (
|
msg = (
|
||||||
@ -629,19 +629,19 @@ class ChildTool(BaseTool):
|
|||||||
to child implementations to enable tracing.
|
to child implementations to enable tracing.
|
||||||
"""
|
"""
|
||||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||||
"run_manager"
|
"run_manager"
|
||||||
):
|
):
|
||||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||||
|
|
||||||
def _to_args_and_kwargs(
|
def _to_args_and_kwargs(
|
||||||
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
||||||
) -> tuple[tuple, dict]:
|
) -> tuple[tuple, dict]:
|
||||||
if (
|
if (
|
||||||
self.args_schema is not None
|
self.args_schema is not None
|
||||||
and isinstance(self.args_schema, type)
|
and isinstance(self.args_schema, type)
|
||||||
and is_basemodel_subclass(self.args_schema)
|
and is_basemodel_subclass(self.args_schema)
|
||||||
and not get_fields(self.args_schema)
|
and not get_fields(self.args_schema)
|
||||||
):
|
):
|
||||||
# StructuredTool with no args
|
# StructuredTool with no args
|
||||||
return (), {}
|
return (), {}
|
||||||
@ -654,20 +654,20 @@ class ChildTool(BaseTool):
|
|||||||
return (), tool_input
|
return (), tool_input
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict[str, Any]],
|
tool_input: Union[str, dict[str, Any]],
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
tool_call_id: Optional[str] = None,
|
tool_call_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool.
|
"""Run the tool.
|
||||||
|
|
||||||
@ -766,20 +766,20 @@ class ChildTool(BaseTool):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, dict],
|
tool_input: Union[str, dict],
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
tool_call_id: Optional[str] = None,
|
tool_call_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool asynchronously.
|
"""Run the tool asynchronously.
|
||||||
|
|
||||||
@ -837,8 +837,6 @@ class ChildTool(BaseTool):
|
|||||||
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||||
)
|
)
|
||||||
if signature(func_to_check).parameters.get("run_manager"):
|
if signature(func_to_check).parameters.get("run_manager"):
|
||||||
import copy
|
|
||||||
tool_kwargs = copy.deepcopy(tool_kwargs)
|
|
||||||
tool_kwargs["run_manager"] = run_manager
|
tool_kwargs["run_manager"] = run_manager
|
||||||
if config_param := _get_runnable_config_param(func_to_check):
|
if config_param := _get_runnable_config_param(func_to_check):
|
||||||
tool_kwargs[config_param] = config
|
tool_kwargs[config_param] = config
|
||||||
@ -895,11 +893,11 @@ def _is_tool_call(x: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _handle_validation_error(
|
def _handle_validation_error(
|
||||||
e: Union[ValidationError, ValidationErrorV1],
|
e: Union[ValidationError, ValidationErrorV1],
|
||||||
*,
|
*,
|
||||||
flag: Union[
|
flag: Union[
|
||||||
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
||||||
],
|
],
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(flag, bool):
|
if isinstance(flag, bool):
|
||||||
content = "Tool input validation error"
|
content = "Tool input validation error"
|
||||||
@ -917,9 +915,9 @@ def _handle_validation_error(
|
|||||||
|
|
||||||
|
|
||||||
def _handle_tool_error(
|
def _handle_tool_error(
|
||||||
e: ToolException,
|
e: ToolException,
|
||||||
*,
|
*,
|
||||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(flag, bool):
|
if isinstance(flag, bool):
|
||||||
content = e.args[0] if e.args else "Tool execution error"
|
content = e.args[0] if e.args else "Tool execution error"
|
||||||
@ -937,9 +935,9 @@ def _handle_tool_error(
|
|||||||
|
|
||||||
|
|
||||||
def _prep_run_args(
|
def _prep_run_args(
|
||||||
input: Union[str, dict, ToolCall],
|
input: Union[str, dict, ToolCall],
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[Union[str, dict], dict]:
|
) -> tuple[Union[str, dict], dict]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if _is_tool_call(input):
|
if _is_tool_call(input):
|
||||||
@ -947,7 +945,7 @@ def _prep_run_args(
|
|||||||
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
|
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
|
||||||
else:
|
else:
|
||||||
tool_call_id = None
|
tool_call_id = None
|
||||||
tool_input = cast(Union[str, dict], input)
|
tool_input = cast(Union[str, dict], input).copy()
|
||||||
return (
|
return (
|
||||||
tool_input,
|
tool_input,
|
||||||
dict(
|
dict(
|
||||||
@ -964,11 +962,11 @@ def _prep_run_args(
|
|||||||
|
|
||||||
|
|
||||||
def _format_output(
|
def _format_output(
|
||||||
content: Any,
|
content: Any,
|
||||||
artifact: Any,
|
artifact: Any,
|
||||||
tool_call_id: Optional[str],
|
tool_call_id: Optional[str],
|
||||||
name: str,
|
name: str,
|
||||||
status: str,
|
status: str,
|
||||||
) -> Union[ToolOutputMixin, Any]:
|
) -> Union[ToolOutputMixin, Any]:
|
||||||
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
||||||
return content
|
return content
|
||||||
@ -986,9 +984,9 @@ def _format_output(
|
|||||||
def _is_message_content_type(obj: Any) -> bool:
|
def _is_message_content_type(obj: Any) -> bool:
|
||||||
"""Check for OpenAI or Anthropic format tool message content."""
|
"""Check for OpenAI or Anthropic format tool message content."""
|
||||||
return (
|
return (
|
||||||
isinstance(obj, str)
|
isinstance(obj, str)
|
||||||
or isinstance(obj, list)
|
or isinstance(obj, list)
|
||||||
and all(_is_message_content_block(e) for e in obj)
|
and all(_is_message_content_block(e) for e in obj)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1051,7 +1049,7 @@ class InjectedToolCallId(InjectedToolArg):
|
|||||||
|
|
||||||
|
|
||||||
def _is_injected_arg_type(
|
def _is_injected_arg_type(
|
||||||
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
injected_type = injected_type or InjectedToolArg
|
injected_type = injected_type or InjectedToolArg
|
||||||
return any(
|
return any(
|
||||||
@ -1062,7 +1060,7 @@ def _is_injected_arg_type(
|
|||||||
|
|
||||||
|
|
||||||
def get_all_basemodel_annotations(
|
def get_all_basemodel_annotations(
|
||||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||||
) -> dict[str, type]:
|
) -> dict[str, type]:
|
||||||
# cls has no subscript: cls = FooBar
|
# cls has no subscript: cls = FooBar
|
||||||
if isinstance(cls, type):
|
if isinstance(cls, type):
|
||||||
@ -1071,8 +1069,8 @@ def get_all_basemodel_annotations(
|
|||||||
# Exclude hidden init args added by pydantic Config. For example if
|
# Exclude hidden init args added by pydantic Config. For example if
|
||||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||||
if (
|
if (
|
||||||
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
||||||
or getattr(cls, "__fields__", {}) # pydantic v1
|
or getattr(cls, "__fields__", {}) # pydantic v1
|
||||||
) and name not in fields:
|
) and name not in fields:
|
||||||
continue
|
continue
|
||||||
annotations[name] = param.annotation
|
annotations[name] = param.annotation
|
||||||
@ -1122,9 +1120,9 @@ def get_all_basemodel_annotations(
|
|||||||
|
|
||||||
|
|
||||||
def _replace_type_vars(
|
def _replace_type_vars(
|
||||||
type_: type,
|
type_: type,
|
||||||
generic_map: Optional[dict[TypeVar, type]] = None,
|
generic_map: Optional[dict[TypeVar, type]] = None,
|
||||||
default_to_bound: bool = True,
|
default_to_bound: bool = True,
|
||||||
) -> type:
|
) -> type:
|
||||||
generic_map = generic_map or {}
|
generic_map = generic_map or {}
|
||||||
if isinstance(type_, TypeVar):
|
if isinstance(type_, TypeVar):
|
||||||
|
Loading…
Reference in New Issue
Block a user