core: Add ruff rule FBT003 (boolean-trap) (#29424)

See
https://docs.astral.sh/ruff/rules/boolean-positional-value-in-call/#boolean-positional-value-in-call-fbt003
This PR also fixes some FBT001/002 in private methods but does not
enforce these rules globally atm.

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-04-01 19:40:12 +02:00 committed by GitHub
parent 4f8ea13cea
commit 558191198f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 63 additions and 45 deletions

View File

@ -44,10 +44,11 @@ T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
def _validate_deprecation_params(
pending: bool,
removal: str,
alternative: str,
alternative_import: str,
*,
pending: bool,
) -> None:
"""Validate the deprecation parameters."""
if pending and removal:
@ -134,7 +135,9 @@ def deprecated(
def the_function_to_deprecate():
pass
"""
_validate_deprecation_params(pending, removal, alternative, alternative_import)
_validate_deprecation_params(
removal, alternative, alternative_import, pending=pending
)
def deprecate(
obj: T,

View File

@ -575,7 +575,7 @@ class ParentRunManager(RunManager):
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager
@ -656,7 +656,7 @@ class AsyncParentRunManager(AsyncRunManager):
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager
@ -1578,11 +1578,11 @@ class CallbackManager(BaseCallbackManager):
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)
@ -2102,11 +2102,11 @@ class AsyncCallbackManager(BaseCallbackManager):
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)
@ -2251,11 +2251,12 @@ def _configure(
callback_manager_cls: type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
*,
verbose: bool = False,
) -> T:
"""Configure the callback manager.
@ -2329,13 +2330,13 @@ def _configure(
else (local_callbacks.handlers if local_callbacks else [])
)
for handler in local_handlers_:
callback_manager.add_handler(handler, False)
callback_manager.add_handler(handler, inherit=False)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
callback_manager.add_tags(local_tags or [], inherit=False)
if inheritable_metadata or local_metadata:
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
callback_manager.add_metadata(local_metadata or {}, inherit=False)
if tracing_metadata:
callback_manager.add_metadata(tracing_metadata.copy())
if tracing_tags:
@ -2370,18 +2371,18 @@ def _configure(
if debug:
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
callback_manager.add_handler(StdOutCallbackHandler(), inherit=False)
if debug and not any(
isinstance(handler, ConsoleCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
callback_manager.add_handler(ConsoleCallbackHandler())
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
callback_manager.add_handler(tracer_v2)
else:
try:
handler = LangChainTracer(
@ -2393,7 +2394,7 @@ def _configure(
),
tags=tracing_tags,
)
callback_manager.add_handler(handler, True)
callback_manager.add_handler(handler)
except Exception as e:
logger.warning(
"Unable to load requested LangChainTracer."

View File

@ -135,7 +135,7 @@ def get_usage_metadata_callback(
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
ContextVar(name, default=None)
)
register_configure_hook(usage_metadata_callback_var, True)
register_configure_hook(usage_metadata_callback_var, inheritable=True)
cb = UsageMetadataCallbackHandler()
usage_metadata_callback_var.set(cb)
yield cb

View File

@ -786,6 +786,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[CallbackManagerForLLMRun],
*,
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
@ -973,7 +974,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
]
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
prompts,
stop,
run_managers,
new_arg_supported=bool(new_arg_supported),
**kwargs,
)
return output
if len(missing_prompts) > 0:
@ -989,7 +994,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
for idx in missing_prompt_idxs
]
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
missing_prompts,
stop,
run_managers,
new_arg_supported=bool(new_arg_supported),
**kwargs,
)
llm_output = update_cache(
self.cache,
@ -1031,6 +1040,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[AsyncCallbackManagerForLLMRun],
*,
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
@ -1226,7 +1236,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
@ -1249,7 +1259,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
missing_prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
llm_output = await aupdate_cache(

View File

@ -113,7 +113,7 @@ def _get_filtered_args(
def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
@ -1134,7 +1134,7 @@ def get_all_basemodel_annotations(
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
annotations[field], generic_map, default_to_bound=default_to_bound
)
return {
@ -1146,6 +1146,7 @@ def get_all_basemodel_annotations(
def _replace_type_vars(
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
*,
default_to_bound: bool = True,
) -> type:
generic_map = generic_map or {}
@ -1158,7 +1159,8 @@ def _replace_type_vars(
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
for arg in args
)
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else:

View File

@ -136,7 +136,7 @@ def _get_trace_callbacks(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
callback_manager.add_handler(tracer)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
@ -219,4 +219,4 @@ def register_configure_hook(
)
register_configure_hook(run_collector_var, False)
register_configure_hook(run_collector_var, inheritable=False)

View File

@ -66,7 +66,7 @@ def _get_executor() -> ThreadPoolExecutor:
return _EXECUTOR
def _run_to_dict(run: Run, exclude_inputs: bool = False) -> dict:
def _run_to_dict(run: Run, *, exclude_inputs: bool = False) -> dict:
# TODO: Update once langsmith moves to Pydantic V2 and we can swap run.dict for
# run.model_dump
with warnings.catch_warnings():

View File

@ -332,6 +332,7 @@ def _html_escape(string: str) -> str:
def _get_key(
key: str,
scopes: Scopes,
*,
warn: bool,
keep: bool,
def_ldel: str,

View File

@ -97,7 +97,8 @@ ignore = [
"BLE",
"ERA",
"DTZ",
"FBT",
"FBT001",
"FBT002",
"FIX",
"PGH",
"PLC",

View File

@ -41,7 +41,7 @@ async def test_inline_handlers_share_parent_context() -> None:
called.
"""
def __init__(self, run_inline: bool) -> None:
def __init__(self, *, run_inline: bool) -> None:
"""Initialize the handler."""
self.run_inline = run_inline
@ -91,7 +91,7 @@ async def test_inline_handlers_share_parent_context_multiple() -> None:
counter_var.reset(token)
class StatefulAsyncCallbackHandler(AsyncCallbackHandler):
def __init__(self, name: str, run_inline: bool = True):
def __init__(self, name: str, *, run_inline: bool = True):
self.name = name
self.run_inline = run_inline

View File

@ -362,7 +362,7 @@ EXPECTED_STREAMED_JSON = [
]
def _get_iter(use_tool_calls: bool = False) -> Any:
def _get_iter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
@ -374,7 +374,7 @@ def _get_iter(use_tool_calls: bool = False) -> Any:
return input_iter
def _get_aiter(use_tool_calls: bool = False) -> Any:
def _get_aiter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
@ -389,7 +389,7 @@ def _get_aiter(use_tool_calls: bool = False) -> Any:
def test_partial_json_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
@ -402,7 +402,7 @@ def test_partial_json_output_parser() -> None:
async def test_partial_json_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
@ -415,7 +415,7 @@ async def test_partial_json_output_parser_async() -> None:
def test_partial_json_output_parser_return_id() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
@ -434,7 +434,7 @@ def test_partial_json_output_parser_return_id() -> None:
def test_partial_json_output_key_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
@ -444,7 +444,7 @@ def test_partial_json_output_key_parser() -> None:
async def test_partial_json_output_parser_key_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
@ -455,7 +455,7 @@ async def test_partial_json_output_parser_key_async() -> None:
def test_partial_json_output_key_parser_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
@ -466,7 +466,7 @@ def test_partial_json_output_key_parser_first_only() -> None:
async def test_partial_json_output_parser_key_async_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
@ -507,7 +507,7 @@ EXPECTED_STREAMED_PYDANTIC = [
def test_partial_pydantic_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
@ -519,7 +519,7 @@ def test_partial_pydantic_output_parser() -> None:
async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True

View File

@ -80,7 +80,7 @@ async def _as_async_iterator(iterable: list) -> AsyncIterator:
async def _collect_events(
events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
) -> list[StreamEvent]:
"""Collect the events and remove the run ids."""
materialized_events = [event async for event in events]

View File

@ -221,7 +221,7 @@ async def test_config_traceable_async_handoff() -> None:
@pytest.mark.parametrize("enabled", [None, True, False])
@pytest.mark.parametrize("env", ["", "true"])
def test_tracing_enable_disable(
mock_get_client: MagicMock, enabled: bool, env: str
mock_get_client: MagicMock, *, enabled: bool, env: str
) -> None:
mock_session = MagicMock()
mock_client_ = Client(

View File

@ -2002,7 +2002,7 @@ invalid_tool_result_blocks = [
*([[block, False] for block in invalid_tool_result_blocks]),
],
)
def test__is_message_content_block(obj: Any, expected: bool) -> None:
def test__is_message_content_block(obj: Any, *, expected: bool) -> None:
assert _is_message_content_block(obj) is expected
@ -2014,13 +2014,13 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
(invalid_tool_result_blocks, False),
],
)
def test__is_message_content_type(obj: Any, expected: bool) -> None:
def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
assert _is_message_content_type(obj) is expected
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
@pytest.mark.parametrize("use_v1_namespace", [True, False])
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None:
A = TypeVar("A")
if use_v1_namespace:

View File

@ -741,7 +741,7 @@ def test_tool_outputs() -> None:
@pytest.mark.parametrize("use_extension_typed_dict", [True, False])
@pytest.mark.parametrize("use_extension_annotated", [True, False])
def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool
*, use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None:
typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated