From 558191198f3f4f6001c651e3df583e0dfb79a9c5 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 1 Apr 2025 19:40:12 +0200 Subject: [PATCH] core: Add ruff rule FBT003 (boolean-trap) (#29424) See https://docs.astral.sh/ruff/rules/boolean-positional-value-in-call/#boolean-positional-value-in-call-fbt003 This PR also fixes some FBT001/002 in private methods but does not enforce these rules globally atm. Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/_api/deprecation.py | 7 ++++-- libs/core/langchain_core/callbacks/manager.py | 25 ++++++++++--------- libs/core/langchain_core/callbacks/usage.py | 2 +- .../langchain_core/language_models/llms.py | 18 ++++++++++--- libs/core/langchain_core/tools/base.py | 8 +++--- libs/core/langchain_core/tracers/context.py | 4 +-- libs/core/langchain_core/tracers/langchain.py | 2 +- libs/core/langchain_core/utils/mustache.py | 1 + libs/core/pyproject.toml | 3 ++- .../callbacks/test_async_callback_manager.py | 4 +-- .../output_parsers/test_openai_tools.py | 22 ++++++++-------- .../runnables/test_runnable_events_v2.py | 2 +- .../runnables/test_tracing_interops.py | 2 +- libs/core/tests/unit_tests/test_tools.py | 6 ++--- .../unit_tests/utils/test_function_calling.py | 2 +- 15 files changed, 63 insertions(+), 45 deletions(-) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 52ce92075e6..a92dd1457da 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -44,10 +44,11 @@ T = TypeVar("T", bound=Union[type, Callable[..., Any], Any]) def _validate_deprecation_params( - pending: bool, removal: str, alternative: str, alternative_import: str, + *, + pending: bool, ) -> None: """Validate the deprecation parameters.""" if pending and removal: @@ -134,7 +135,9 @@ def deprecated( def the_function_to_deprecate(): pass """ - _validate_deprecation_params(pending, removal, alternative, alternative_import) + _validate_deprecation_params( + removal, alternative, alternative_import, pending=pending + ) def deprecate( obj: T, diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index f7268c17e39..b907d71601b 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -575,7 +575,7 @@ class ParentRunManager(RunManager): manager.add_tags(self.inheritable_tags) manager.add_metadata(self.inheritable_metadata) if tag is not None: - manager.add_tags([tag], False) + manager.add_tags([tag], inherit=False) return manager @@ -656,7 +656,7 @@ class AsyncParentRunManager(AsyncRunManager): manager.add_tags(self.inheritable_tags) manager.add_metadata(self.inheritable_metadata) if tag is not None: - manager.add_tags([tag], False) + manager.add_tags([tag], inherit=False) return manager @@ -1578,11 +1578,11 @@ class CallbackManager(BaseCallbackManager): cls, inheritable_callbacks, local_callbacks, - verbose, inheritable_tags, local_tags, inheritable_metadata, local_metadata, + verbose=verbose, ) @@ -2102,11 +2102,11 @@ class AsyncCallbackManager(BaseCallbackManager): cls, inheritable_callbacks, local_callbacks, - verbose, inheritable_tags, local_tags, inheritable_metadata, local_metadata, + verbose=verbose, ) @@ -2251,11 +2251,12 @@ def _configure( callback_manager_cls: type[T], inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, - verbose: bool = False, inheritable_tags: Optional[list[str]] = None, local_tags: Optional[list[str]] = None, inheritable_metadata: Optional[dict[str, Any]] = None, local_metadata: Optional[dict[str, Any]] = None, + *, + verbose: bool = False, ) -> T: """Configure the callback manager. @@ -2329,13 +2330,13 @@ def _configure( else (local_callbacks.handlers if local_callbacks else []) ) for handler in local_handlers_: - callback_manager.add_handler(handler, False) + callback_manager.add_handler(handler, inherit=False) if inheritable_tags or local_tags: callback_manager.add_tags(inheritable_tags or []) - callback_manager.add_tags(local_tags or [], False) + callback_manager.add_tags(local_tags or [], inherit=False) if inheritable_metadata or local_metadata: callback_manager.add_metadata(inheritable_metadata or {}) - callback_manager.add_metadata(local_metadata or {}, False) + callback_manager.add_metadata(local_metadata or {}, inherit=False) if tracing_metadata: callback_manager.add_metadata(tracing_metadata.copy()) if tracing_tags: @@ -2370,18 +2371,18 @@ def _configure( if debug: pass else: - callback_manager.add_handler(StdOutCallbackHandler(), False) + callback_manager.add_handler(StdOutCallbackHandler(), inherit=False) if debug and not any( isinstance(handler, ConsoleCallbackHandler) for handler in callback_manager.handlers ): - callback_manager.add_handler(ConsoleCallbackHandler(), True) + callback_manager.add_handler(ConsoleCallbackHandler()) if tracing_v2_enabled_ and not any( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): if tracer_v2: - callback_manager.add_handler(tracer_v2, True) + callback_manager.add_handler(tracer_v2) else: try: handler = LangChainTracer( @@ -2393,7 +2394,7 @@ def _configure( ), tags=tracing_tags, ) - callback_manager.add_handler(handler, True) + callback_manager.add_handler(handler) except Exception as e: logger.warning( "Unable to load requested LangChainTracer." diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index 2e7b493f0d0..486940eb98e 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -135,7 +135,7 @@ def get_usage_metadata_callback( usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = ( ContextVar(name, default=None) ) - register_configure_hook(usage_metadata_callback_var, True) + register_configure_hook(usage_metadata_callback_var, inheritable=True) cb = UsageMetadataCallbackHandler() usage_metadata_callback_var.set(cb) yield cb diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index e15d17e831d..5ceb6bc431f 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -786,6 +786,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): prompts: list[str], stop: Optional[list[str]], run_managers: list[CallbackManagerForLLMRun], + *, new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: @@ -973,7 +974,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) ] output = self._generate_helper( - prompts, stop, run_managers, bool(new_arg_supported), **kwargs + prompts, + stop, + run_managers, + new_arg_supported=bool(new_arg_supported), + **kwargs, ) return output if len(missing_prompts) > 0: @@ -989,7 +994,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): for idx in missing_prompt_idxs ] new_results = self._generate_helper( - missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs + missing_prompts, + stop, + run_managers, + new_arg_supported=bool(new_arg_supported), + **kwargs, ) llm_output = update_cache( self.cache, @@ -1031,6 +1040,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): prompts: list[str], stop: Optional[list[str]], run_managers: list[AsyncCallbackManagerForLLMRun], + *, new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: @@ -1226,7 +1236,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): prompts, stop, run_managers, # type: ignore[arg-type] - bool(new_arg_supported), + new_arg_supported=bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) return output @@ -1249,7 +1259,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): missing_prompts, stop, run_managers, # type: ignore[arg-type] - bool(new_arg_supported), + new_arg_supported=bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) llm_output = await aupdate_cache( diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index dc9edc45571..bea485ad247 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -113,7 +113,7 @@ def _get_filtered_args( def _parse_python_function_docstring( - function: Callable, annotations: dict, error_on_invalid_docstring: bool = False + function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False ) -> tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. @@ -1134,7 +1134,7 @@ def get_all_basemodel_annotations( generic_map = dict(zip(generic_type_vars, get_args(parent))) for field in getattr(parent_origin, "__annotations__", {}): annotations[field] = _replace_type_vars( - annotations[field], generic_map, default_to_bound + annotations[field], generic_map, default_to_bound=default_to_bound ) return { @@ -1146,6 +1146,7 @@ def get_all_basemodel_annotations( def _replace_type_vars( type_: type, generic_map: Optional[dict[TypeVar, type]] = None, + *, default_to_bound: bool = True, ) -> type: generic_map = generic_map or {} @@ -1158,7 +1159,8 @@ def _replace_type_vars( return type_ elif (origin := get_origin(type_)) and (args := get_args(type_)): new_args = tuple( - _replace_type_vars(arg, generic_map, default_to_bound) for arg in args + _replace_type_vars(arg, generic_map, default_to_bound=default_to_bound) + for arg in args ) return _py_38_safe_origin(origin)[new_args] # type: ignore[index] else: diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 06520c9fe9d..a89a5417ceb 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -136,7 +136,7 @@ def _get_trace_callbacks( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): - callback_manager.add_handler(tracer, True) + callback_manager.add_handler(tracer) # If it already has a LangChainTracer, we don't need to add another one. # this would likely mess up the trace hierarchy. cb = callback_manager @@ -219,4 +219,4 @@ def register_configure_hook( ) -register_configure_hook(run_collector_var, False) +register_configure_hook(run_collector_var, inheritable=False) diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 5cc3301fe7e..5914895b292 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -66,7 +66,7 @@ def _get_executor() -> ThreadPoolExecutor: return _EXECUTOR -def _run_to_dict(run: Run, exclude_inputs: bool = False) -> dict: +def _run_to_dict(run: Run, *, exclude_inputs: bool = False) -> dict: # TODO: Update once langsmith moves to Pydantic V2 and we can swap run.dict for # run.model_dump with warnings.catch_warnings(): diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 3f631efe0fb..c1bdcc86d50 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -332,6 +332,7 @@ def _html_escape(string: str) -> str: def _get_key( key: str, scopes: Scopes, + *, warn: bool, keep: bool, def_ldel: str, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 5d69524fee5..f0fc237f4d9 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -97,7 +97,8 @@ ignore = [ "BLE", "ERA", "DTZ", - "FBT", + "FBT001", + "FBT002", "FIX", "PGH", "PLC", diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 38350f9d82f..ac565e43e39 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -41,7 +41,7 @@ async def test_inline_handlers_share_parent_context() -> None: called. """ - def __init__(self, run_inline: bool) -> None: + def __init__(self, *, run_inline: bool) -> None: """Initialize the handler.""" self.run_inline = run_inline @@ -91,7 +91,7 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: counter_var.reset(token) class StatefulAsyncCallbackHandler(AsyncCallbackHandler): - def __init__(self, name: str, run_inline: bool = True): + def __init__(self, name: str, *, run_inline: bool = True): self.name = name self.run_inline = run_inline diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index d4940bab2bb..e5fe0f3076c 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -362,7 +362,7 @@ EXPECTED_STREAMED_JSON = [ ] -def _get_iter(use_tool_calls: bool = False) -> Any: +def _get_iter(*, use_tool_calls: bool = False) -> Any: if use_tool_calls: list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS else: @@ -374,7 +374,7 @@ def _get_iter(use_tool_calls: bool = False) -> Any: return input_iter -def _get_aiter(use_tool_calls: bool = False) -> Any: +def _get_aiter(*, use_tool_calls: bool = False) -> Any: if use_tool_calls: list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS else: @@ -389,7 +389,7 @@ def _get_aiter(use_tool_calls: bool = False) -> Any: def test_partial_json_output_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser() actual = list(chain.stream(None)) @@ -402,7 +402,7 @@ def test_partial_json_output_parser() -> None: async def test_partial_json_output_parser_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser() actual = [p async for p in chain.astream(None)] @@ -415,7 +415,7 @@ async def test_partial_json_output_parser_async() -> None: def test_partial_json_output_parser_return_id() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser(return_id=True) actual = list(chain.stream(None)) @@ -434,7 +434,7 @@ def test_partial_json_output_parser_return_id() -> None: def test_partial_json_output_key_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") actual = list(chain.stream(None)) @@ -444,7 +444,7 @@ def test_partial_json_output_key_parser() -> None: async def test_partial_json_output_parser_key_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") @@ -455,7 +455,7 @@ async def test_partial_json_output_parser_key_async() -> None: def test_partial_json_output_key_parser_first_only() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser( key_name="NameCollector", first_tool_only=True @@ -466,7 +466,7 @@ def test_partial_json_output_key_parser_first_only() -> None: async def test_partial_json_output_parser_key_async_first_only() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser( key_name="NameCollector", first_tool_only=True @@ -507,7 +507,7 @@ EXPECTED_STREAMED_PYDANTIC = [ def test_partial_pydantic_output_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | PydanticToolsParser( tools=[NameCollector], first_tool_only=True @@ -519,7 +519,7 @@ def test_partial_pydantic_output_parser() -> None: async def test_partial_pydantic_output_parser_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | PydanticToolsParser( tools=[NameCollector], first_tool_only=True diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index b0fab483dc0..5267159fe3f 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -80,7 +80,7 @@ async def _as_async_iterator(iterable: list) -> AsyncIterator: async def _collect_events( - events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True + events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True ) -> list[StreamEvent]: """Collect the events and remove the run ids.""" materialized_events = [event async for event in events] diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 486970d78e8..c083dd03f12 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -221,7 +221,7 @@ async def test_config_traceable_async_handoff() -> None: @pytest.mark.parametrize("enabled", [None, True, False]) @pytest.mark.parametrize("env", ["", "true"]) def test_tracing_enable_disable( - mock_get_client: MagicMock, enabled: bool, env: str + mock_get_client: MagicMock, *, enabled: bool, env: str ) -> None: mock_session = MagicMock() mock_client_ = Client( diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 2e0a2e050c4..ba187e65b71 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2002,7 +2002,7 @@ invalid_tool_result_blocks = [ *([[block, False] for block in invalid_tool_result_blocks]), ], ) -def test__is_message_content_block(obj: Any, expected: bool) -> None: +def test__is_message_content_block(obj: Any, *, expected: bool) -> None: assert _is_message_content_block(obj) is expected @@ -2014,13 +2014,13 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None: (invalid_tool_result_blocks, False), ], ) -def test__is_message_content_type(obj: Any, expected: bool) -> None: +def test__is_message_content_type(obj: Any, *, expected: bool) -> None: assert _is_message_content_type(obj) is expected @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") @pytest.mark.parametrize("use_v1_namespace", [True, False]) -def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: +def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: A = TypeVar("A") if use_v1_namespace: diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 5c7fb4b4eeb..6e32c2d0b00 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -741,7 +741,7 @@ def test_tool_outputs() -> None: @pytest.mark.parametrize("use_extension_typed_dict", [True, False]) @pytest.mark.parametrize("use_extension_annotated", [True, False]) def test__convert_typed_dict_to_openai_function( - use_extension_typed_dict: bool, use_extension_annotated: bool + *, use_extension_typed_dict: bool, use_extension_annotated: bool ) -> None: typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated