diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index a8669415e87..b73511ec7d2 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -91,11 +91,11 @@ def _get_annotation_description(arg_type: type) -> str | None: def _get_filtered_args( - inferred_model: type[BaseModel], - func: Callable, - *, - filter_args: Sequence[str], - include_injected: bool = True, + inferred_model: type[BaseModel], + func: Callable, + *, + filter_args: Sequence[str], + include_injected: bool = True, ) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.model_json_schema()["properties"] @@ -104,13 +104,13 @@ def _get_filtered_args( k: schema[k] for i, (k, param) in enumerate(valid_keys.items()) if k not in filter_args - and (i > 0 or param.name not in ("self", "cls")) - and (include_injected or not _is_injected_arg_type(param.annotation)) + and (i > 0 or param.name not in ("self", "cls")) + and (include_injected or not _is_injected_arg_type(param.annotation)) } def _parse_python_function_docstring( - function: Callable, annotations: dict, error_on_invalid_docstring: bool = False + function: Callable, annotations: dict, error_on_invalid_docstring: bool = False ) -> tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. @@ -125,7 +125,7 @@ def _parse_python_function_docstring( def _validate_docstring_args_against_annotations( - arg_descriptions: dict, annotations: dict + arg_descriptions: dict, annotations: dict ) -> None: """Raise error if docstring arg is not in type annotations.""" for docstring_arg in arg_descriptions: @@ -135,10 +135,10 @@ def _validate_docstring_args_against_annotations( def _infer_arg_descriptions( - fn: Callable, - *, - parse_docstring: bool = False, - error_on_invalid_docstring: bool = False, + fn: Callable, + *, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, ) -> tuple[str, dict]: """Infer argument descriptions from a function's docstring.""" if hasattr(inspect, "get_annotations"): @@ -173,7 +173,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo def _function_annotations_are_pydantic_v1( - signature: inspect.Signature, func: Callable + signature: inspect.Signature, func: Callable ) -> bool: """Determine if all Pydantic annotations in a function signature are from V1.""" any_v1_annotations = any( @@ -210,13 +210,13 @@ class _SchemaConfig: def create_schema_from_function( - model_name: str, - func: Callable, - *, - filter_args: Optional[Sequence[str]] = None, - parse_docstring: bool = False, - error_on_invalid_docstring: bool = False, - include_injected: bool = True, + model_name: str, + func: Callable, + *, + filter_args: Optional[Sequence[str]] = None, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, + include_injected: bool = True, ) -> type[BaseModel]: """Create a pydantic schema from a function's signature. @@ -277,7 +277,7 @@ def create_schema_from_function( for existing_param in existing_params: if not include_injected and _is_injected_arg_type( - sig.parameters[existing_param].annotation + sig.parameters[existing_param].annotation ): filter_args_.append(existing_param) @@ -427,10 +427,10 @@ class ChildTool(BaseTool): def __init__(self, **kwargs: Any) -> None: """Initialize the tool.""" if ( - "args_schema" in kwargs - and kwargs["args_schema"] is not None - and not is_basemodel_subclass(kwargs["args_schema"]) - and not isinstance(kwargs["args_schema"], dict) + "args_schema" in kwargs + and kwargs["args_schema"] is not None + and not is_basemodel_subclass(kwargs["args_schema"]) + and not isinstance(kwargs["args_schema"], dict) ): msg = ( "args_schema must be a subclass of pydantic BaseModel or " @@ -481,7 +481,7 @@ class ChildTool(BaseTool): # --- Runnable --- def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: """The tool's input schema. @@ -499,19 +499,19 @@ class ChildTool(BaseTool): return create_schema_from_function(self.name, self._run) def invoke( - self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, - **kwargs: Any, + self, + input: Union[str, dict, ToolCall], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Any: tool_input, kwargs = _prep_run_args(input, config, **kwargs) return self.run(tool_input, **kwargs) async def ainvoke( - self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, - **kwargs: Any, + self, + input: Union[str, dict, ToolCall], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Any: tool_input, kwargs = _prep_run_args(input, config, **kwargs) return await self.arun(tool_input, **kwargs) @@ -519,7 +519,7 @@ class ChildTool(BaseTool): # --- Tool --- def _parse_input( - self, tool_input: Union[str, dict], tool_call_id: Optional[str] + self, tool_input: Union[str, dict], tool_call_id: Optional[str] ) -> Union[str, dict[str, Any]]: """Convert tool input to a pydantic model. @@ -548,8 +548,8 @@ class ChildTool(BaseTool): elif issubclass(input_args, BaseModel): for k, v in get_all_basemodel_annotations(input_args).items(): if ( - _is_injected_arg_type(v, injected_type=InjectedToolCallId) - and k not in tool_input + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input ): if tool_call_id is None: msg = ( @@ -566,8 +566,8 @@ class ChildTool(BaseTool): elif issubclass(input_args, BaseModelV1): for k, v in get_all_basemodel_annotations(input_args).items(): if ( - _is_injected_arg_type(v, injected_type=InjectedToolCallId) - and k not in tool_input + _is_injected_arg_type(v, injected_type=InjectedToolCallId) + and k not in tool_input ): if tool_call_id is None: msg = ( @@ -629,19 +629,19 @@ class ChildTool(BaseTool): to child implementations to enable tracing. """ if kwargs.get("run_manager") and signature(self._run).parameters.get( - "run_manager" + "run_manager" ): kwargs["run_manager"] = kwargs["run_manager"].get_sync() return await run_in_executor(None, self._run, *args, **kwargs) def _to_args_and_kwargs( - self, tool_input: Union[str, dict], tool_call_id: Optional[str] + self, tool_input: Union[str, dict], tool_call_id: Optional[str] ) -> tuple[tuple, dict]: if ( - self.args_schema is not None - and isinstance(self.args_schema, type) - and is_basemodel_subclass(self.args_schema) - and not get_fields(self.args_schema) + self.args_schema is not None + and isinstance(self.args_schema, type) + and is_basemodel_subclass(self.args_schema) + and not get_fields(self.args_schema) ): # StructuredTool with no args return (), {} @@ -654,20 +654,20 @@ class ChildTool(BaseTool): return (), tool_input def run( - self, - tool_input: Union[str, dict[str, Any]], - verbose: Optional[bool] = None, - start_color: Optional[str] = "green", - color: Optional[str] = "green", - callbacks: Callbacks = None, - *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, - config: Optional[RunnableConfig] = None, - tool_call_id: Optional[str] = None, - **kwargs: Any, + self, + tool_input: Union[str, dict[str, Any]], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, + config: Optional[RunnableConfig] = None, + tool_call_id: Optional[str] = None, + **kwargs: Any, ) -> Any: """Run the tool. @@ -766,20 +766,20 @@ class ChildTool(BaseTool): return output async def arun( - self, - tool_input: Union[str, dict], - verbose: Optional[bool] = None, - start_color: Optional[str] = "green", - color: Optional[str] = "green", - callbacks: Callbacks = None, - *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, - config: Optional[RunnableConfig] = None, - tool_call_id: Optional[str] = None, - **kwargs: Any, + self, + tool_input: Union[str, dict], + verbose: Optional[bool] = None, + start_color: Optional[str] = "green", + color: Optional[str] = "green", + callbacks: Callbacks = None, + *, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, + config: Optional[RunnableConfig] = None, + tool_call_id: Optional[str] = None, + **kwargs: Any, ) -> Any: """Run the tool asynchronously. @@ -837,8 +837,6 @@ class ChildTool(BaseTool): self._run if self.__class__._arun is BaseTool._arun else self._arun ) if signature(func_to_check).parameters.get("run_manager"): - import copy - tool_kwargs = copy.deepcopy(tool_kwargs) tool_kwargs["run_manager"] = run_manager if config_param := _get_runnable_config_param(func_to_check): tool_kwargs[config_param] = config @@ -895,11 +893,11 @@ def _is_tool_call(x: Any) -> bool: def _handle_validation_error( - e: Union[ValidationError, ValidationErrorV1], - *, - flag: Union[ - Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] - ], + e: Union[ValidationError, ValidationErrorV1], + *, + flag: Union[ + Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] + ], ) -> str: if isinstance(flag, bool): content = "Tool input validation error" @@ -917,9 +915,9 @@ def _handle_validation_error( def _handle_tool_error( - e: ToolException, - *, - flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], + e: ToolException, + *, + flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], ) -> str: if isinstance(flag, bool): content = e.args[0] if e.args else "Tool execution error" @@ -937,9 +935,9 @@ def _handle_tool_error( def _prep_run_args( - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig], - **kwargs: Any, + input: Union[str, dict, ToolCall], + config: Optional[RunnableConfig], + **kwargs: Any, ) -> tuple[Union[str, dict], dict]: config = ensure_config(config) if _is_tool_call(input): @@ -947,7 +945,7 @@ def _prep_run_args( tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy() else: tool_call_id = None - tool_input = cast(Union[str, dict], input) + tool_input = cast(Union[str, dict], input).copy() return ( tool_input, dict( @@ -964,11 +962,11 @@ def _prep_run_args( def _format_output( - content: Any, - artifact: Any, - tool_call_id: Optional[str], - name: str, - status: str, + content: Any, + artifact: Any, + tool_call_id: Optional[str], + name: str, + status: str, ) -> Union[ToolOutputMixin, Any]: if isinstance(content, ToolOutputMixin) or tool_call_id is None: return content @@ -986,9 +984,9 @@ def _format_output( def _is_message_content_type(obj: Any) -> bool: """Check for OpenAI or Anthropic format tool message content.""" return ( - isinstance(obj, str) - or isinstance(obj, list) - and all(_is_message_content_block(e) for e in obj) + isinstance(obj, str) + or isinstance(obj, list) + and all(_is_message_content_block(e) for e in obj) ) @@ -1051,7 +1049,7 @@ class InjectedToolCallId(InjectedToolArg): def _is_injected_arg_type( - type_: type, injected_type: Optional[type[InjectedToolArg]] = None + type_: type, injected_type: Optional[type[InjectedToolArg]] = None ) -> bool: injected_type = injected_type or InjectedToolArg return any( @@ -1062,7 +1060,7 @@ def _is_injected_arg_type( def get_all_basemodel_annotations( - cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True + cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True ) -> dict[str, type]: # cls has no subscript: cls = FooBar if isinstance(cls, type): @@ -1071,8 +1069,8 @@ def get_all_basemodel_annotations( # Exclude hidden init args added by pydantic Config. For example if # BaseModel(extra="allow") then "extra_data" will part of init sig. if ( - fields := getattr(cls, "model_fields", {}) # pydantic v2+ - or getattr(cls, "__fields__", {}) # pydantic v1 + fields := getattr(cls, "model_fields", {}) # pydantic v2+ + or getattr(cls, "__fields__", {}) # pydantic v1 ) and name not in fields: continue annotations[name] = param.annotation @@ -1122,9 +1120,9 @@ def get_all_basemodel_annotations( def _replace_type_vars( - type_: type, - generic_map: Optional[dict[TypeVar, type]] = None, - default_to_bound: bool = True, + type_: type, + generic_map: Optional[dict[TypeVar, type]] = None, + default_to_bound: bool = True, ) -> type: generic_map = generic_map or {} if isinstance(type_, TypeVar): @@ -1148,4 +1146,4 @@ class BaseToolkit(BaseModel, ABC): @abstractmethod def get_tools(self) -> list[BaseTool]: - """Get the tools in the toolkit.""" + """Get the tools in the toolkit.""" \ No newline at end of file