mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
core: Add ruff rules FBT001 and FBT002 (#30695)
Add ruff rules [FBT001](https://docs.astral.sh/ruff/rules/boolean-type-hint-positional-argument/) and [FBT002](https://docs.astral.sh/ruff/rules/boolean-default-value-positional-argument/). Mostly `noqa`s to not introduce breaking changes and possible non-breaking fixes have already been done in a [previous PR](https://github.com/langchain-ai/langchain/pull/29424). These rules will prevent new violations to happen.
This commit is contained in:
parent
2803a48661
commit
913c896598
@ -988,7 +988,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
"""Whether the callback manager is async."""
|
"""Whether the callback manager is async."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
def add_handler(
|
||||||
|
self,
|
||||||
|
handler: BaseCallbackHandler,
|
||||||
|
inherit: bool = True, # noqa: FBT001,FBT002
|
||||||
|
) -> None:
|
||||||
"""Add a handler to the callback manager.
|
"""Add a handler to the callback manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1012,7 +1016,9 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
self.inheritable_handlers.remove(handler)
|
self.inheritable_handlers.remove(handler)
|
||||||
|
|
||||||
def set_handlers(
|
def set_handlers(
|
||||||
self, handlers: list[BaseCallbackHandler], inherit: bool = True
|
self,
|
||||||
|
handlers: list[BaseCallbackHandler],
|
||||||
|
inherit: bool = True, # noqa: FBT001,FBT002
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set handlers as the only handlers on the callback manager.
|
"""Set handlers as the only handlers on the callback manager.
|
||||||
|
|
||||||
@ -1025,7 +1031,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
self.add_handler(handler, inherit=inherit)
|
self.add_handler(handler, inherit=inherit)
|
||||||
|
|
||||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
def set_handler(
|
||||||
|
self,
|
||||||
|
handler: BaseCallbackHandler,
|
||||||
|
inherit: bool = True, # noqa: FBT001,FBT002
|
||||||
|
) -> None:
|
||||||
"""Set handler as the only handler on the callback manager.
|
"""Set handler as the only handler on the callback manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1034,7 +1044,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
"""
|
"""
|
||||||
self.set_handlers([handler], inherit=inherit)
|
self.set_handlers([handler], inherit=inherit)
|
||||||
|
|
||||||
def add_tags(self, tags: list[str], inherit: bool = True) -> None:
|
def add_tags(
|
||||||
|
self,
|
||||||
|
tags: list[str],
|
||||||
|
inherit: bool = True, # noqa: FBT001,FBT002
|
||||||
|
) -> None:
|
||||||
"""Add tags to the callback manager.
|
"""Add tags to the callback manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1058,7 +1072,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
self.tags.remove(tag)
|
self.tags.remove(tag)
|
||||||
self.inheritable_tags.remove(tag)
|
self.inheritable_tags.remove(tag)
|
||||||
|
|
||||||
def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
|
def add_metadata(
|
||||||
|
self,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
inherit: bool = True, # noqa: FBT001,FBT002
|
||||||
|
) -> None:
|
||||||
"""Add metadata to the callback manager.
|
"""Add metadata to the callback manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1547,7 +1547,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
cls,
|
cls,
|
||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
local_callbacks: Callbacks = None,
|
local_callbacks: Callbacks = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False, # noqa: FBT001,FBT002
|
||||||
inheritable_tags: Optional[list[str]] = None,
|
inheritable_tags: Optional[list[str]] = None,
|
||||||
local_tags: Optional[list[str]] = None,
|
local_tags: Optional[list[str]] = None,
|
||||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||||
@ -2073,7 +2073,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
cls,
|
cls,
|
||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
local_callbacks: Callbacks = None,
|
local_callbacks: Callbacks = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False, # noqa: FBT001,FBT002
|
||||||
inheritable_tags: Optional[list[str]] = None,
|
inheritable_tags: Optional[list[str]] = None,
|
||||||
local_tags: Optional[list[str]] = None,
|
local_tags: Optional[list[str]] = None,
|
||||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||||
|
@ -26,7 +26,7 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
|
|||||||
error: Any,
|
error: Any,
|
||||||
observation: Optional[str] = None,
|
observation: Optional[str] = None,
|
||||||
llm_output: Optional[str] = None,
|
llm_output: Optional[str] = None,
|
||||||
send_to_llm: bool = False,
|
send_to_llm: bool = False, # noqa: FBT001,FBT002
|
||||||
):
|
):
|
||||||
"""Create an OutputParserException.
|
"""Create an OutputParserException.
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ _debug: bool = False
|
|||||||
_llm_cache: Optional["BaseCache"] = None
|
_llm_cache: Optional["BaseCache"] = None
|
||||||
|
|
||||||
|
|
||||||
def set_verbose(value: bool) -> None:
|
def set_verbose(value: bool) -> None: # noqa: FBT001
|
||||||
"""Set a new value for the `verbose` global setting.
|
"""Set a new value for the `verbose` global setting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -89,7 +89,7 @@ def get_verbose() -> bool:
|
|||||||
return _verbose or old_verbose
|
return _verbose or old_verbose
|
||||||
|
|
||||||
|
|
||||||
def set_debug(value: bool) -> None:
|
def set_debug(value: bool) -> None: # noqa: FBT001
|
||||||
"""Set a new value for the `debug` global setting.
|
"""Set a new value for the `debug` global setting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -52,7 +52,7 @@ class Reviver:
|
|||||||
self,
|
self,
|
||||||
secrets_map: Optional[dict[str, str]] = None,
|
secrets_map: Optional[dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[list[str]] = None,
|
valid_namespaces: Optional[list[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True, # noqa: FBT001,FBT002
|
||||||
additional_import_mappings: Optional[
|
additional_import_mappings: Optional[
|
||||||
dict[tuple[str, ...], tuple[str, ...]]
|
dict[tuple[str, ...], tuple[str, ...]]
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -235,6 +235,7 @@ class AIMessage(BaseMessage):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Return a pretty representation of the message.
|
"""Return a pretty representation of the message.
|
||||||
|
|
||||||
|
@ -122,7 +122,10 @@ class BaseMessage(Serializable):
|
|||||||
prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
|
prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
|
||||||
return prompt + other
|
return prompt + other
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
"""Get a pretty representation of the message.
|
"""Get a pretty representation of the message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -102,7 +102,10 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
|||||||
List of input variables.
|
List of input variables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -270,6 +273,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
"""
|
"""
|
||||||
return [self.variable_name] if not self.optional else []
|
return [self.variable_name] if not self.optional else []
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
@ -406,6 +410,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
|||||||
"""
|
"""
|
||||||
return self.prompt.input_variables
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
@ -675,6 +680,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
content=content, additional_kwargs=self.additional_kwargs
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
@ -777,7 +783,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
"""Async format kwargs into a list of messages."""
|
"""Async format kwargs into a list of messages."""
|
||||||
return self.format_messages(**kwargs)
|
return self.format_messages(**kwargs)
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1331,6 +1340,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation.
|
"""Human-readable representation.
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.example_selectors import BaseExampleSelector
|
from langchain_core.example_selectors import BaseExampleSelector
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
@ -453,6 +454,7 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
messages = await self.aformat_messages(**kwargs)
|
messages = await self.aformat_messages(**kwargs)
|
||||||
return get_buffer_string(messages)
|
return get_buffer_string(messages)
|
||||||
|
|
||||||
|
@override
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Return a pretty representation of the prompt template.
|
"""Return a pretty representation of the prompt template.
|
||||||
|
|
||||||
|
@ -133,7 +133,10 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
"""
|
"""
|
||||||
return await run_in_executor(None, self.format, **kwargs)
|
return await run_in_executor(None, self.format, **kwargs)
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
"""Return a pretty representation of the prompt.
|
"""Return a pretty representation of the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -293,7 +293,10 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
"""
|
"""
|
||||||
return StringPromptValue(text=await self.aformat(**kwargs))
|
return StringPromptValue(text=await self.aformat(**kwargs))
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
"""Get a pretty representation of the prompt.
|
"""Get a pretty representation of the prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -359,7 +359,7 @@ class Graph:
|
|||||||
source: Node,
|
source: Node,
|
||||||
target: Node,
|
target: Node,
|
||||||
data: Optional[Stringifiable] = None,
|
data: Optional[Stringifiable] = None,
|
||||||
conditional: bool = False,
|
conditional: bool = False, # noqa: FBT001,FBT002
|
||||||
) -> Edge:
|
) -> Edge:
|
||||||
"""Add an edge to the graph and return it.
|
"""Add an edge to the graph and return it.
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ class PngDrawer:
|
|||||||
source: str,
|
source: str,
|
||||||
target: str,
|
target: str,
|
||||||
label: Optional[str] = None,
|
label: Optional[str] = None,
|
||||||
conditional: bool = False,
|
conditional: bool = False, # noqa: FBT001,FBT002
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Adds an edge to the graph.
|
"""Adds an edge to the graph.
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class Tool(BaseTool):
|
|||||||
func: Optional[Callable],
|
func: Optional[Callable],
|
||||||
name: str, # We keep these required to support backwards compatibility
|
name: str, # We keep these required to support backwards compatibility
|
||||||
description: str,
|
description: str,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False, # noqa: FBT001,FBT002
|
||||||
args_schema: Optional[ArgsSchema] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
coroutine: Optional[
|
coroutine: Optional[
|
||||||
Callable[..., Awaitable[Any]]
|
Callable[..., Awaitable[Any]]
|
||||||
|
@ -122,9 +122,9 @@ class StructuredTool(BaseTool):
|
|||||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False, # noqa: FBT001,FBT002
|
||||||
args_schema: Optional[ArgsSchema] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True, # noqa: FBT001,FBT002
|
||||||
*,
|
*,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
|
@ -186,7 +186,7 @@ _configure_hooks: list[
|
|||||||
|
|
||||||
def register_configure_hook(
|
def register_configure_hook(
|
||||||
context_var: ContextVar[Optional[Any]],
|
context_var: ContextVar[Optional[Any]],
|
||||||
inheritable: bool,
|
inheritable: bool, # noqa: FBT001
|
||||||
handle_class: Optional[type[BaseCallbackHandler]] = None,
|
handle_class: Optional[type[BaseCallbackHandler]] = None,
|
||||||
env_var: Optional[str] = None,
|
env_var: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -62,7 +62,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
evaluators: Sequence[langsmith.RunEvaluator],
|
evaluators: Sequence[langsmith.RunEvaluator],
|
||||||
client: Optional[langsmith.Client] = None,
|
client: Optional[langsmith.Client] = None,
|
||||||
example_id: Optional[Union[UUID, str]] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
skip_unfinished: bool = True,
|
skip_unfinished: bool = True, # noqa: FBT001,FBT002
|
||||||
project_name: Optional[str] = "evaluators",
|
project_name: Optional[str] = "evaluators",
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -68,7 +68,7 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
|
|||||||
def l_sa_check(
|
def l_sa_check(
|
||||||
template: str, # noqa: ARG001
|
template: str, # noqa: ARG001
|
||||||
literal: str,
|
literal: str,
|
||||||
is_standalone: bool,
|
is_standalone: bool, # noqa: FBT001
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Do a preliminary check to see if a tag could be a standalone.
|
"""Do a preliminary check to see if a tag could be a standalone.
|
||||||
|
|
||||||
@ -91,7 +91,11 @@ def l_sa_check(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
def r_sa_check(
|
||||||
|
template: str,
|
||||||
|
tag_type: str,
|
||||||
|
is_standalone: bool, # noqa: FBT001
|
||||||
|
) -> bool:
|
||||||
"""Do a final check to see if a tag could be a standalone.
|
"""Do a final check to see if a tag could be a standalone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -422,8 +426,8 @@ def render(
|
|||||||
def_ldel: str = "{{",
|
def_ldel: str = "{{",
|
||||||
def_rdel: str = "}}",
|
def_rdel: str = "}}",
|
||||||
scopes: Optional[Scopes] = None,
|
scopes: Optional[Scopes] = None,
|
||||||
warn: bool = False,
|
warn: bool = False, # noqa: FBT001,FBT002
|
||||||
keep: bool = False,
|
keep: bool = False, # noqa: FBT001,FBT002
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Render a mustache template.
|
"""Render a mustache template.
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ def _create_root_model(
|
|||||||
|
|
||||||
def schema(
|
def schema(
|
||||||
cls: type[BaseModel],
|
cls: type[BaseModel],
|
||||||
by_alias: bool = True,
|
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# Complains about schema not being defined in superclass
|
# Complains about schema not being defined in superclass
|
||||||
@ -432,7 +432,7 @@ def _create_root_model(
|
|||||||
|
|
||||||
def model_json_schema(
|
def model_json_schema(
|
||||||
cls: type[BaseModel],
|
cls: type[BaseModel],
|
||||||
by_alias: bool = True,
|
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||||
mode: JsonSchemaMode = "validation",
|
mode: JsonSchemaMode = "validation",
|
||||||
|
@ -100,8 +100,6 @@ ignore = [
|
|||||||
"ANN401",
|
"ANN401",
|
||||||
"BLE",
|
"BLE",
|
||||||
"ERA",
|
"ERA",
|
||||||
"FBT001",
|
|
||||||
"FBT002",
|
|
||||||
"PLR2004",
|
"PLR2004",
|
||||||
"RUF",
|
"RUF",
|
||||||
"SLF",
|
"SLF",
|
||||||
|
@ -105,7 +105,7 @@ def _remove_additionalproperties_from_untyped_dicts(schema: dict) -> dict[str, A
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _remove_dict_additional_props(
|
def _remove_dict_additional_props(
|
||||||
obj: Union[dict[str, Any], list[Any]], inside_properties: bool = False
|
obj: Union[dict[str, Any], list[Any]], *, inside_properties: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
if (
|
if (
|
||||||
@ -120,11 +120,13 @@ def _remove_additionalproperties_from_untyped_dicts(schema: dict) -> dict[str, A
|
|||||||
# We are "inside_properties" if the *current* key is "properties",
|
# We are "inside_properties" if the *current* key is "properties",
|
||||||
# or if we were already inside properties in the caller.
|
# or if we were already inside properties in the caller.
|
||||||
next_inside_properties = inside_properties or (key == "properties")
|
next_inside_properties = inside_properties or (key == "properties")
|
||||||
_remove_dict_additional_props(value, next_inside_properties)
|
_remove_dict_additional_props(
|
||||||
|
value, inside_properties=next_inside_properties
|
||||||
|
)
|
||||||
|
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
for item in obj:
|
for item in obj:
|
||||||
_remove_dict_additional_props(item, inside_properties)
|
_remove_dict_additional_props(item, inside_properties=inside_properties)
|
||||||
|
|
||||||
_remove_dict_additional_props(schema, inside_properties=False)
|
_remove_dict_additional_props(schema, inside_properties=False)
|
||||||
return schema
|
return schema
|
||||||
|
@ -119,10 +119,10 @@ class _MockStructuredTool(BaseTool):
|
|||||||
args_schema: type[BaseModel] = _MockSchema
|
args_schema: type[BaseModel] = _MockSchema
|
||||||
description: str = "A Structured Tool"
|
description: str = "A Structured Tool"
|
||||||
|
|
||||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
async def _arun(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -146,11 +146,13 @@ def test_misannotated_base_tool_raises_error() -> None:
|
|||||||
args_schema: BaseModel = _MockSchema # type: ignore[assignment]
|
args_schema: BaseModel = _MockSchema # type: ignore[assignment]
|
||||||
description: str = "A Structured Tool"
|
description: str = "A Structured Tool"
|
||||||
|
|
||||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def _run(
|
||||||
|
self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -163,11 +165,11 @@ def test_forward_ref_annotated_base_tool_accepted() -> None:
|
|||||||
args_schema: "type[BaseModel]" = _MockSchema
|
args_schema: "type[BaseModel]" = _MockSchema
|
||||||
description: str = "A Structured Tool"
|
description: str = "A Structured Tool"
|
||||||
|
|
||||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -180,11 +182,11 @@ def test_subclass_annotated_base_tool_accepted() -> None:
|
|||||||
args_schema: type[_MockSchema] = _MockSchema
|
args_schema: type[_MockSchema] = _MockSchema
|
||||||
description: str = "A Structured Tool"
|
description: str = "A Structured Tool"
|
||||||
|
|
||||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -197,14 +199,14 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
"""Test that manually specified schemata are passed through to the tool."""
|
"""Test that manually specified schemata are passed through to the tool."""
|
||||||
|
|
||||||
@tool(args_schema=_MockSchema)
|
@tool(args_schema=_MockSchema)
|
||||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def tool_func(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(tool_func, BaseTool)
|
assert isinstance(tool_func, BaseTool)
|
||||||
assert tool_func.args_schema == _MockSchema
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
@tool(args_schema=cast("ArgsSchema", _MockSchemaV1))
|
@tool(args_schema=cast("ArgsSchema", _MockSchemaV1))
|
||||||
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(tool_func_v1, BaseTool)
|
assert isinstance(tool_func_v1, BaseTool)
|
||||||
@ -216,7 +218,7 @@ def test_decorated_function_schema_equivalent() -> None:
|
|||||||
|
|
||||||
@tool
|
@tool
|
||||||
def structured_tool_input(
|
def structured_tool_input(
|
||||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
*, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the arguments directly."""
|
"""Return the arguments directly."""
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
@ -1397,14 +1399,17 @@ class _MockStructuredToolWithRawOutput(BaseTool):
|
|||||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
self,
|
||||||
|
arg1: int,
|
||||||
|
arg2: bool, # noqa: FBT001
|
||||||
|
arg3: Optional[dict] = None,
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||||
|
|
||||||
|
|
||||||
@tool("structured_api", response_format="content_and_artifact")
|
@tool("structured_api", response_format="content_and_artifact")
|
||||||
def _mock_structured_tool_with_artifact(
|
def _mock_structured_tool_with_artifact(
|
||||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
*, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""A Structured Tool."""
|
"""A Structured Tool."""
|
||||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||||
|
Loading…
Reference in New Issue
Block a user