From 98f0016fc29ce9e7abfe5972997eca46caadbafe Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 9 Apr 2025 20:39:36 +0200 Subject: [PATCH] core: Add ruff rules ARG (#30732) See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg --- .../langchain_core/_api/beta_decorator.py | 4 +- libs/core/langchain_core/_api/deprecation.py | 8 +- libs/core/langchain_core/caches.py | 4 + libs/core/langchain_core/callbacks/file.py | 8 ++ libs/core/langchain_core/callbacks/manager.py | 6 +- libs/core/langchain_core/callbacks/stdout.py | 8 ++ .../callbacks/streaming_stdout.py | 3 + libs/core/langchain_core/callbacks/usage.py | 1 + libs/core/langchain_core/indexing/api.py | 2 +- .../core/langchain_core/indexing/in_memory.py | 5 ++ .../language_models/chat_models.py | 3 +- .../language_models/fake_chat_models.py | 4 + .../langchain_core/language_models/llms.py | 1 + .../langchain_core/output_parsers/base.py | 9 +- .../output_parsers/openai_functions.py | 4 + .../output_parsers/transform.py | 4 + libs/core/langchain_core/prompts/base.py | 1 + libs/core/langchain_core/retrievers.py | 2 +- libs/core/langchain_core/runnables/base.py | 20 ++--- libs/core/langchain_core/runnables/history.py | 7 +- libs/core/langchain_core/tracers/base.py | 22 +++-- libs/core/langchain_core/tracers/context.py | 2 +- libs/core/langchain_core/tracers/core.py | 48 +++++------ .../langchain_core/tracers/event_stream.py | 12 ++- libs/core/langchain_core/tracers/langchain.py | 4 +- .../langchain_core/tracers/langchain_v1.py | 4 +- libs/core/langchain_core/utils/loading.py | 4 +- libs/core/langchain_core/utils/mustache.py | 6 +- libs/core/langchain_core/utils/pydantic.py | 2 + libs/core/langchain_core/utils/utils.py | 2 + libs/core/langchain_core/vectorstores/base.py | 4 +- .../langchain_core/vectorstores/in_memory.py | 13 +-- libs/core/pyproject.toml | 2 +- .../callbacks/test_async_callback_manager.py | 3 + .../callbacks/test_dispatch_custom_event.py | 1 + .../unit_tests/document_loaders/test_base.py | 2 + .../example_selectors/test_similarity.py | 5 ++ libs/core/tests/unit_tests/fake/callbacks.py | 33 ++++++++ .../unit_tests/fake/test_fake_chat_model.py | 3 + .../language_models/chat_models/test_base.py | 6 ++ .../language_models/chat_models/test_cache.py | 2 + .../language_models/llms/test_base.py | 5 ++ .../language_models/llms/test_cache.py | 4 + .../tests/unit_tests/messages/test_utils.py | 2 + .../output_parsers/test_base_parsers.py | 4 + .../tests/unit_tests/prompts/test_few_shot.py | 3 + .../unit_tests/prompts/test_structured.py | 2 +- .../unit_tests/runnables/test_configurable.py | 13 ++- .../unit_tests/runnables/test_context.py | 6 +- .../unit_tests/runnables/test_fallbacks.py | 22 +++-- .../unit_tests/runnables/test_history.py | 8 +- .../unit_tests/runnables/test_runnable.py | 82 +++++++++---------- .../runnables/test_runnable_events_v1.py | 16 ++-- .../runnables/test_runnable_events_v2.py | 20 +++-- .../runnables/test_tracing_interops.py | 4 +- libs/core/tests/unit_tests/test_tools.py | 24 +++--- .../unit_tests/utils/test_function_calling.py | 4 +- .../vectorstores/test_vectorstore.py | 5 ++ 58 files changed, 328 insertions(+), 180 deletions(-) diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index f527cc116af..c3113b68777 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -124,7 +124,7 @@ def beta( _name = _name or obj.__qualname__ old_doc = obj.__doc__ - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 """Finalize the annotation of a class.""" # Can't set new_doc on some extension objects. with contextlib.suppress(AttributeError): @@ -190,7 +190,7 @@ def beta( if _name == "": _name = set_name - def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: # noqa: ARG001 """Finalize the property.""" return _BetaProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index f6e09a78726..78897a02d9f 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -204,7 +204,7 @@ def deprecated( _name = _name or obj.__qualname__ old_doc = obj.__doc__ - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 """Finalize the deprecation of a class.""" # Can't set new_doc on some extension objects. with contextlib.suppress(AttributeError): @@ -234,7 +234,7 @@ def deprecated( raise ValueError(msg) old_doc = obj.description - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 return cast( "T", FieldInfoV1( @@ -255,7 +255,7 @@ def deprecated( raise ValueError(msg) old_doc = obj.description - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 return cast( "T", FieldInfoV2( @@ -315,7 +315,7 @@ def deprecated( if _name == "": _name = set_name - def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: + def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001 """Finalize the property.""" return cast( "T", diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index 4805a747595..2ccbf06bfee 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -27,6 +27,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Optional +from typing_extensions import override + from langchain_core.outputs import Generation from langchain_core.runnables import run_in_executor @@ -194,6 +196,7 @@ class InMemoryCache(BaseCache): del self._cache[next(iter(self._cache))] self._cache[(prompt, llm_string)] = return_val + @override def clear(self, **kwargs: Any) -> None: """Clear cache.""" self._cache = {} @@ -227,6 +230,7 @@ class InMemoryCache(BaseCache): """ self.update(prompt, llm_string, return_val) + @override async def aclear(self, **kwargs: Any) -> None: """Async clear cache.""" self.clear() diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index 3ee4e18e5df..4912da1ec03 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -5,6 +5,8 @@ from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TextIO, cast +from typing_extensions import override + from langchain_core.callbacks import BaseCallbackHandler from langchain_core.utils.input import print_text @@ -38,6 +40,7 @@ class FileCallbackHandler(BaseCallbackHandler): """Destructor to cleanup when done.""" self.file.close() + @override def on_chain_start( self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: @@ -61,6 +64,7 @@ class FileCallbackHandler(BaseCallbackHandler): file=self.file, ) + @override def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain. @@ -70,6 +74,7 @@ class FileCallbackHandler(BaseCallbackHandler): """ print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) + @override def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: @@ -83,6 +88,7 @@ class FileCallbackHandler(BaseCallbackHandler): """ print_text(action.log, color=color or self.color, file=self.file) + @override def on_tool_end( self, output: str, @@ -109,6 +115,7 @@ class FileCallbackHandler(BaseCallbackHandler): if llm_prefix is not None: print_text(f"\n{llm_prefix}", file=self.file) + @override def on_text( self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any ) -> None: @@ -123,6 +130,7 @@ class FileCallbackHandler(BaseCallbackHandler): """ print_text(text, color=color or self.color, end=end, file=self.file) + @override def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 12287c330e9..163599cb0ca 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -22,7 +22,7 @@ from typing import ( from uuid import UUID from langsmith.run_helpers import get_tracing_context -from typing_extensions import Self +from typing_extensions import Self, override from langchain_core.callbacks.base import ( BaseCallbackHandler, @@ -1401,6 +1401,7 @@ class CallbackManager(BaseCallbackManager): inheritable_metadata=self.inheritable_metadata, ) + @override def on_tool_start( self, serialized: Optional[dict[str, Any]], @@ -1456,6 +1457,7 @@ class CallbackManager(BaseCallbackManager): inheritable_metadata=self.inheritable_metadata, ) + @override def on_retriever_start( self, serialized: Optional[dict[str, Any]], @@ -1927,6 +1929,7 @@ class AsyncCallbackManager(BaseCallbackManager): inheritable_metadata=self.inheritable_metadata, ) + @override async def on_tool_start( self, serialized: Optional[dict[str, Any]], @@ -2017,6 +2020,7 @@ class AsyncCallbackManager(BaseCallbackManager): metadata=self.metadata, ) + @override async def on_retriever_start( self, serialized: Optional[dict[str, Any]], diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index aadc1cc8ebb..ae99d3ec67b 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional +from typing_extensions import override + from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.utils import print_text @@ -22,6 +24,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """ self.color = color + @override def on_chain_start( self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: @@ -41,6 +44,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): name = "" print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201 + @override def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain. @@ -50,6 +54,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """ print("\n\033[1m> Finished chain.\033[0m") # noqa: T201 + @override def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: @@ -62,6 +67,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """ print_text(action.log, color=color or self.color) + @override def on_tool_end( self, output: Any, @@ -87,6 +93,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): if llm_prefix is not None: print_text(f"\n{llm_prefix}") + @override def on_text( self, text: str, @@ -104,6 +111,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """ print_text(text, color=color or self.color, end=end) + @override def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index 7f6fa579080..abed19d5bdc 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -5,6 +5,8 @@ from __future__ import annotations import sys from typing import TYPE_CHECKING, Any +from typing_extensions import override + from langchain_core.callbacks.base import BaseCallbackHandler if TYPE_CHECKING: @@ -41,6 +43,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): **kwargs (Any): Additional keyword arguments. """ + @override def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled. diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index 486940eb98e..727167fa1de 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -58,6 +58,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler): def __repr__(self) -> str: return str(self.usage_metadata) + @override def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Collect token usage.""" # Check for usage_metadata (langchain-core >= 0.2.2) diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 16ff9ed4037..c2405533582 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -151,7 +151,7 @@ def _get_source_id_assigner( ) -> Callable[[Document], Union[str, None]]: """Get the source id from the document.""" if source_id_key is None: - return lambda doc: None + return lambda _doc: None if isinstance(source_id_key, str): return lambda doc: doc.metadata[source_id_key] if callable(source_id_key): diff --git a/libs/core/langchain_core/indexing/in_memory.py b/libs/core/langchain_core/indexing/in_memory.py index be040b46105..34609a0370d 100644 --- a/libs/core/langchain_core/indexing/in_memory.py +++ b/libs/core/langchain_core/indexing/in_memory.py @@ -6,6 +6,7 @@ from collections.abc import Sequence from typing import Any, Optional, cast from pydantic import Field +from typing_extensions import override from langchain_core._api import beta from langchain_core.callbacks import CallbackManagerForRetrieverRun @@ -29,6 +30,7 @@ class InMemoryDocumentIndex(DocumentIndex): store: dict[str, Document] = Field(default_factory=dict) top_k: int = 4 + @override def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: """Upsert items into the index.""" ok_ids = [] @@ -47,6 +49,7 @@ class InMemoryDocumentIndex(DocumentIndex): return UpsertResponse(succeeded=ok_ids, failed=[]) + @override def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: """Delete by ID.""" if ids is None: @@ -64,10 +67,12 @@ class InMemoryDocumentIndex(DocumentIndex): succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[] ) + @override def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]: """Get by ids.""" return [self.store[id_] for id_ in ids if id_ in self.store] + @override def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index d98c1ce4091..0a8a0022325 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -576,7 +576,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): # --- Custom methods --- - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002 return {} def _get_invocation_params( @@ -1246,6 +1246,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _llm_type(self) -> str: """Return type of chat model.""" + @override def dict(self, **kwargs: Any) -> dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index acce2418350..a1cdd49a7b3 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -171,6 +171,7 @@ class FakeListChatModel(SimpleChatModel): class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" + @override def _call( self, messages: list[BaseMessage], @@ -180,6 +181,7 @@ class FakeChatModel(SimpleChatModel): ) -> str: return "fake response" + @override async def _agenerate( self, messages: list[BaseMessage], @@ -224,6 +226,7 @@ class GenericFakeChatModel(BaseChatModel): into message chunks. """ + @override def _generate( self, messages: list[BaseMessage], @@ -346,6 +349,7 @@ class ParrotFakeChatModel(BaseChatModel): * Chat model should be usable in both sync and async tests """ + @override def _generate( self, messages: list[BaseMessage], diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 326adf95603..032192e0cad 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1399,6 +1399,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _llm_type(self) -> str: """Return type of llm.""" + @override def dict(self, **kwargs: Any) -> dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 94055c9bee8..a14a10829f9 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -59,7 +59,7 @@ class BaseLLMOutputParser(Generic[T], ABC): Returns: Structured output. """ - return await run_in_executor(None, self.parse_result, result) + return await run_in_executor(None, self.parse_result, result, partial=partial) class BaseGenerationOutputParser( @@ -231,6 +231,7 @@ class BaseOutputParser( run_type="parser", ) + @override def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. @@ -290,7 +291,11 @@ class BaseOutputParser( return await run_in_executor(None, self.parse, text) # TODO: rename 'completion' -> 'text'. - def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + def parse_with_prompt( + self, + completion: str, + prompt: PromptValue, # noqa: ARG002 + ) -> Any: """Parse the output of an LLM call with the input prompt for context. The prompt is largely provided in the event the OutputParser wants diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 73215284016..e5fa59bcb38 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union import jsonpatch # type: ignore[import] from pydantic import BaseModel, model_validator +from typing_extensions import override from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import ( @@ -23,6 +24,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): args_only: bool = True """Whether to only return the arguments to the function call.""" + @override def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. @@ -251,6 +253,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): raise ValueError(msg) return values + @override def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. @@ -287,6 +290,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): attr_name: str """The name of the attribute to return.""" + @override def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 253556df54b..783abedf116 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -9,6 +9,8 @@ from typing import ( Union, ) +from typing_extensions import override + from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.output_parsers.base import BaseOutputParser, T from langchain_core.outputs import ( @@ -48,6 +50,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): None, self.parse_result, [Generation(text=chunk)] ) + @override def transform( self, input: Iterator[Union[str, BaseMessage]], @@ -68,6 +71,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): input, self._transform, config, run_type="parser" ) + @override async def atransform( self, input: AsyncIterator[Union[str, BaseMessage]], diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index cf89f83b5af..d880971750f 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -127,6 +127,7 @@ class BasePromptTemplate( """Return the output type of the prompt.""" return Union[StringPromptValue, ChatPromptValueConcrete] + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index c05372027d9..e82ff04032d 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -199,7 +199,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 ) - def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: + def _get_ls_params(self, **_kwargs: Any) -> LangSmithRetrieverParams: """Get standard params for tracing.""" default_retriever_name = self.get_name() if default_retriever_name.startswith("Retriever"): diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index cb08448ea48..1c48e9aac4d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -326,7 +326,8 @@ class Runnable(Generic[Input, Output], ABC): return self.get_input_schema() def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, # noqa: ARG002 ) -> type[BaseModel]: """Get a pydantic model that can be used to validate input to the Runnable. @@ -398,7 +399,8 @@ class Runnable(Generic[Input, Output], ABC): return self.get_output_schema() def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, # noqa: ARG002 ) -> type[BaseModel]: """Get a pydantic model that can be used to validate output to the Runnable. @@ -4751,11 +4753,6 @@ class RunnableLambda(Runnable[Input, Output]): ) return cast("Output", output) - def _config( - self, config: Optional[RunnableConfig], callable: Callable[..., Any] - ) -> RunnableConfig: - return ensure_config(config) - @override def invoke( self, @@ -4780,7 +4777,7 @@ class RunnableLambda(Runnable[Input, Output]): return self._call_with_config( self._invoke, input, - self._config(config, self.func), + ensure_config(config), **kwargs, ) msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead." @@ -4803,11 +4800,10 @@ class RunnableLambda(Runnable[Input, Output]): Returns: The output of this Runnable. """ - the_func = self.afunc if hasattr(self, "afunc") else self.func return await self._acall_with_config( self._ainvoke, input, - self._config(config, the_func), + ensure_config(config), **kwargs, ) @@ -4884,7 +4880,7 @@ class RunnableLambda(Runnable[Input, Output]): yield from self._transform_stream_with_config( input, self._transform, - self._config(config, self.func), + ensure_config(config), **kwargs, ) else: @@ -5012,7 +5008,7 @@ class RunnableLambda(Runnable[Input, Output]): async for output in self._atransform_stream_with_config( input, self._atransform, - self._config(config, self.afunc if hasattr(self, "afunc") else self.func), + ensure_config(config), **kwargs, ): yield output diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 15f2021a463..fdd9a2b0ecb 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -400,6 +400,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): def OutputType(self) -> type[Output]: return self._history_chain.OutputType + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -432,12 +433,6 @@ class RunnableWithMessageHistory(RunnableBindingBase): module_name=self.__class__.__module__, ) - def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool: - return False - - async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool: - return True - def _get_input_messages( self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> list[BaseMessage]: diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 38795d8e0f6..ee588606165 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -133,6 +133,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self._on_llm_start(llm_run) return llm_run + @override def on_llm_new_token( self, token: str, @@ -161,11 +162,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): run_id=run_id, chunk=chunk, parent_run_id=parent_run_id, - **kwargs, ) self._on_llm_new_token(llm_run, token, chunk) return llm_run + @override def on_retry( self, retry_state: RetryCallState, @@ -188,6 +189,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): run_id=run_id, ) + @override def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for an LLM run. @@ -235,6 +237,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self._on_llm_error(llm_run) return llm_run + @override def on_chain_start( self, serialized: dict[str, Any], @@ -279,6 +282,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self._on_chain_start(chain_run) return chain_run + @override def on_chain_end( self, outputs: dict[str, Any], @@ -302,12 +306,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): outputs=outputs, run_id=run_id, inputs=inputs, - **kwargs, ) self._end_trace(chain_run) self._on_chain_end(chain_run) return chain_run + @override def on_chain_error( self, error: BaseException, @@ -331,7 +335,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): error=error, run_id=run_id, inputs=inputs, - **kwargs, ) self._end_trace(chain_run) self._on_chain_error(chain_run) @@ -381,6 +384,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self._on_tool_start(tool_run) return tool_run + @override def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for a tool run. @@ -395,12 +399,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): tool_run = self._complete_tool_run( output=output, run_id=run_id, - **kwargs, ) self._end_trace(tool_run) self._on_tool_end(tool_run) return tool_run + @override def on_tool_error( self, error: BaseException, @@ -467,6 +471,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self._on_retriever_start(retrieval_run) return retrieval_run + @override def on_retriever_error( self, error: BaseException, @@ -487,12 +492,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): retrieval_run = self._errored_retrieval_run( error=error, run_id=run_id, - **kwargs, ) self._end_trace(retrieval_run) self._on_retriever_error(retrieval_run) return retrieval_run + @override def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any ) -> Run: @@ -509,7 +514,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): retrieval_run = self._complete_retrieval_run( documents=documents, run_id=run_id, - **kwargs, ) self._end_trace(retrieval_run) self._on_retriever_end(retrieval_run) @@ -623,7 +627,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): run_id=run_id, chunk=chunk, parent_run_id=parent_run_id, - **kwargs, ) await self._on_llm_new_token(llm_run, token, chunk) @@ -715,7 +718,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): outputs=outputs, run_id=run_id, inputs=inputs, - **kwargs, ) tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)] await asyncio.gather(*tasks) @@ -733,7 +735,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): error=error, inputs=inputs, run_id=run_id, - **kwargs, ) tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)] await asyncio.gather(*tasks) @@ -776,7 +777,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): tool_run = self._complete_tool_run( output=output, run_id=run_id, - **kwargs, ) tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)] await asyncio.gather(*tasks) @@ -839,7 +839,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): retrieval_run = self._errored_retrieval_run( error=error, run_id=run_id, - **kwargs, ) tasks = [ self._end_trace(retrieval_run), @@ -860,7 +859,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): retrieval_run = self._complete_retrieval_run( documents=documents, run_id=run_id, - **kwargs, ) tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)] await asyncio.gather(*tasks) diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index a89a5417ceb..6c28fcc2179 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -41,7 +41,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa @contextmanager def tracing_enabled( - session_name: str = "default", + session_name: str = "default", # noqa: ARG001 ) -> Generator[TracerSessionV1, None, None]: """Throw an error because this has been replaced by tracing_v2_enabled.""" msg = ( diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 9e485709f83..fbb701f8bbc 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -231,8 +231,7 @@ class _TracerCore(ABC): token: str, run_id: UUID, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, + parent_run_id: Optional[UUID] = None, # noqa: ARG002 ) -> Run: """Append token event to LLM run and return the run.""" llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) @@ -252,7 +251,6 @@ class _TracerCore(ABC): self, retry_state: RetryCallState, run_id: UUID, - **kwargs: Any, ) -> Run: llm_run = self._get_run(run_id) retry_d: dict[str, Any] = { @@ -369,7 +367,6 @@ class _TracerCore(ABC): outputs: dict[str, Any], run_id: UUID, inputs: Optional[dict[str, Any]] = None, - **kwargs: Any, ) -> Run: """Update a chain run with outputs and end time.""" chain_run = self._get_run(run_id) @@ -385,7 +382,6 @@ class _TracerCore(ABC): error: BaseException, inputs: Optional[dict[str, Any]], run_id: UUID, - **kwargs: Any, ) -> Run: chain_run = self._get_run(run_id) chain_run.error = self._get_stacktrace(error) @@ -439,7 +435,6 @@ class _TracerCore(ABC): self, output: dict[str, Any], run_id: UUID, - **kwargs: Any, ) -> Run: """Update a tool run with outputs and end time.""" tool_run = self._get_run(run_id, run_type="tool") @@ -452,7 +447,6 @@ class _TracerCore(ABC): self, error: BaseException, run_id: UUID, - **kwargs: Any, ) -> Run: """Update a tool run with error and end time.""" tool_run = self._get_run(run_id, run_type="tool") @@ -494,7 +488,6 @@ class _TracerCore(ABC): self, documents: Sequence[Document], run_id: UUID, - **kwargs: Any, ) -> Run: """Update a retrieval run with outputs and end time.""" retrieval_run = self._get_run(run_id, run_type="retriever") @@ -507,7 +500,6 @@ class _TracerCore(ABC): self, error: BaseException, run_id: UUID, - **kwargs: Any, ) -> Run: retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run.error = self._get_stacktrace(error) @@ -523,75 +515,75 @@ class _TracerCore(ABC): """Copy the tracer.""" return self - def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """End a trace for a run.""" return None - def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process a run upon creation.""" return None - def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process a run upon update.""" return None - def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the LLM Run upon start.""" return None def _on_llm_new_token( self, - run: Run, - token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + run: Run, # noqa: ARG002 + token: str, # noqa: ARG002 + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002 ) -> Union[None, Coroutine[Any, Any, None]]: """Process new LLM token.""" return None - def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the LLM Run.""" return None - def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the LLM Run upon error.""" return None - def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Chain Run upon start.""" return None - def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Chain Run.""" return None - def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Chain Run upon error.""" return None - def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Tool Run upon start.""" return None - def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Tool Run.""" return None - def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Tool Run upon error.""" return None - def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Chat Model Run upon start.""" return None - def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Retriever Run upon start.""" return None - def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Retriever Run.""" return None - def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: + def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002 """Process the Retriever Run upon error.""" return None diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 7106a281ebd..be87d663b51 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -15,7 +15,7 @@ from typing import ( ) from uuid import UUID, uuid4 -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, override from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk @@ -293,6 +293,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self.run_map[run_id] = info self.parent_map[run_id] = parent_run_id + @override async def on_chat_model_start( self, serialized: dict[str, Any], @@ -334,6 +335,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_type, ) + @override async def on_llm_start( self, serialized: dict[str, Any], @@ -377,6 +379,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_type, ) + @override async def on_custom_event( self, name: str, @@ -399,6 +402,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand ) self._send(event, name) + @override async def on_llm_new_token( self, token: str, @@ -450,6 +454,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_info["run_type"], ) + @override async def on_llm_end( self, response: LLMResult, *, run_id: UUID, **kwargs: Any ) -> None: @@ -552,6 +557,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_type_, ) + @override async def on_chain_end( self, outputs: dict[str, Any], @@ -586,6 +592,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_type, ) + @override async def on_tool_start( self, serialized: dict[str, Any], @@ -627,6 +634,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand "tool", ) + @override async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for a tool run.""" run_info = self.run_map.pop(run_id) @@ -654,6 +662,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand "tool", ) + @override async def on_retriever_start( self, serialized: dict[str, Any], @@ -697,6 +706,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_type, ) + @override async def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 5914895b292..9694d04602c 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -19,6 +19,7 @@ from tenacity import ( stop_after_attempt, wait_exponential_jitter, ) +from typing_extensions import override from langchain_core.env import get_runtime_environment from langchain_core.load import dumpd @@ -252,13 +253,13 @@ class LangChainTracer(BaseTracer): run.reference_example_id = self.example_id self._persist_run_single(run) + @override def _llm_run_with_token_event( self, token: str, run_id: UUID, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, parent_run_id: Optional[UUID] = None, - **kwargs: Any, ) -> Run: """Append token event to LLM run and return the run.""" return super()._llm_run_with_token_event( @@ -267,7 +268,6 @@ class LangChainTracer(BaseTracer): run_id, chunk=None, parent_run_id=parent_run_id, - **kwargs, ) def _on_chat_model_start(self, run: Run) -> None: diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py index 504b74dbed4..63aab51eeea 100644 --- a/libs/core/langchain_core/tracers/langchain_v1.py +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -6,7 +6,7 @@ Please use LangChainTracer instead. from typing import Any -def get_headers(*args: Any, **kwargs: Any) -> Any: +def get_headers(*args: Any, **kwargs: Any) -> Any: # noqa: ARG001 """Throw an error because this has been replaced by get_headers.""" msg = ( "get_headers for LangChainTracerV1 is no longer supported. " @@ -15,7 +15,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any: raise RuntimeError(msg) -def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802 +def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802,ARG001 """Throw an error because this has been replaced by LangChainTracer.""" msg = ( "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." diff --git a/libs/core/langchain_core/utils/loading.py b/libs/core/langchain_core/utils/loading.py index 27db390a16d..7921f6a06c1 100644 --- a/libs/core/langchain_core/utils/loading.py +++ b/libs/core/langchain_core/utils/loading.py @@ -16,8 +16,8 @@ from langchain_core._api.deprecation import deprecated ), ) def try_load_from_hub( - *args: Any, - **kwargs: Any, + *args: Any, # noqa: ARG001 + **kwargs: Any, # noqa: ARG001 ) -> Any: """[DEPRECATED] Try to load from the old Hub.""" warnings.warn( diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 92592dfe80f..ff0b4923d58 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -65,7 +65,11 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]: return (literal, template) -def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: +def l_sa_check( + template: str, # noqa: ARG001 + literal: str, + is_standalone: bool, +) -> bool: """Do a preliminary check to see if a tag could be a standalone. Args: diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index fcea09490bf..2c296429569 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -38,6 +38,7 @@ from pydantic.json_schema import ( JsonSchemaMode, JsonSchemaValue, ) +from typing_extensions import override if TYPE_CHECKING: from pydantic_core import core_schema @@ -233,6 +234,7 @@ class _IgnoreUnserializable(GenerateJsonSchema): https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process """ + @override def handle_invalid_for_json_schema( self, schema: core_schema.CoreSchema, error_info: str ) -> JsonSchemaValue: diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index fc43c867331..76a6d4962a1 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -13,6 +13,7 @@ from typing import Any, Callable, Optional, Union, overload from packaging.version import parse from pydantic import SecretStr from requests import HTTPError, Response +from typing_extensions import override from langchain_core.utils.pydantic import ( is_pydantic_v1_subclass, @@ -91,6 +92,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]: """Mock datetime.datetime.now() with a fixed datetime.""" @classmethod + @override def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime": # Create a copy of dt_value. return MockDateTime( diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 2c9615b5945..173bb6caec2 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -36,7 +36,7 @@ from typing import ( ) from pydantic import ConfigDict, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, override from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams @@ -1070,6 +1070,7 @@ class VectorStoreRetriever(BaseRetriever): return ls_params + @override def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> list[Document]: @@ -1090,6 +1091,7 @@ class VectorStoreRetriever(BaseRetriever): raise ValueError(msg) return docs + @override async def _aget_relevant_documents( self, query: str, diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index 47034aa4ec0..32e2e05cac1 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -285,7 +285,7 @@ class InMemoryVectorStore(VectorStore): since="0.2.29", removal="1.0", ) - def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: + def upsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse: """[DEPRECATED] Upsert documents into the store. Args: @@ -319,7 +319,7 @@ class InMemoryVectorStore(VectorStore): removal="1.0", ) async def aupsert( - self, items: Sequence[Document], /, **kwargs: Any + self, items: Sequence[Document], /, **_kwargs: Any ) -> UpsertResponse: """[DEPRECATED] Upsert documents into the store. @@ -364,7 +364,6 @@ class InMemoryVectorStore(VectorStore): embedding: list[float], k: int = 4, filter: Optional[Callable[[Document], bool]] = None, - **kwargs: Any, ) -> list[tuple[Document, float, list[float]]]: # get all docs with fixed order in list docs = list(self.store.values()) @@ -404,7 +403,7 @@ class InMemoryVectorStore(VectorStore): embedding: list[float], k: int = 4, filter: Optional[Callable[[Document], bool]] = None, - **kwargs: Any, + **_kwargs: Any, ) -> list[tuple[Document, float]]: """Search for the most similar documents to the given embedding. @@ -419,7 +418,7 @@ class InMemoryVectorStore(VectorStore): return [ (doc, similarity) for doc, similarity, _ in self._similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs + embedding=embedding, k=k, filter=filter ) ] @@ -490,12 +489,14 @@ class InMemoryVectorStore(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + *, + filter: Optional[Callable[[Document], bool]] = None, **kwargs: Any, ) -> list[Document]: prefetch_hits = self._similarity_search_with_score_by_vector( embedding=embedding, k=fetch_k, - **kwargs, + filter=filter, ) try: diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index fdbdf16c0f1..3a5d1f6dd9a 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -98,7 +98,6 @@ ignore = [ # TODO rules "A", "ANN401", - "ARG", "BLE", "ERA", "FBT001", @@ -132,5 +131,6 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini "tests/unit_tests/prompts/test_chat.py" = [ "E501",] "tests/unit_tests/runnables/test_runnable.py" = [ "E501",] "tests/unit_tests/runnables/test_graph.py" = [ "E501",] +"tests/unit_tests/test_tools.py" = [ "ARG",] "tests/**" = [ "D", "S",] "scripts/**" = [ "INP", "S",] diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index ac565e43e39..5ae0d316e6c 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -10,6 +10,8 @@ from contextlib import asynccontextmanager from typing import Any, Optional from uuid import UUID +from typing_extensions import override + from langchain_core.callbacks import ( AsyncCallbackHandler, AsyncCallbackManager, @@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None: """Initialize the handler.""" self.run_inline = run_inline + @override async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: """Update the callstack with the name of the callback.""" some_var.set("on_llm_start") diff --git a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py index d1d4f27f4de..86e0aec31ff 100644 --- a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py +++ b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py @@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None: # a decorator for async functions @RunnableLambda # type: ignore[arg-type] async def foo(x: int, config: RunnableConfig) -> int: + assert "callbacks" in config await adispatch_custom_event("event1", {"x": x}) await adispatch_custom_event("event2", {"x": x}) return x diff --git a/libs/core/tests/unit_tests/document_loaders/test_base.py b/libs/core/tests/unit_tests/document_loaders/test_base.py index 4297b165482..87ab70e9e09 100644 --- a/libs/core/tests/unit_tests/document_loaders/test_base.py +++ b/libs/core/tests/unit_tests/document_loaders/test_base.py @@ -3,6 +3,7 @@ from collections.abc import Iterator import pytest +from typing_extensions import override from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader from langchain_core.documents import Document @@ -15,6 +16,7 @@ def test_base_blob_parser() -> None: class MyParser(BaseBlobParser): """A simple parser that returns a single document.""" + @override def lazy_parse(self, blob: Blob) -> Iterator[Document]: """Lazy parsing interface.""" yield Document( diff --git a/libs/core/tests/unit_tests/example_selectors/test_similarity.py b/libs/core/tests/unit_tests/example_selectors/test_similarity.py index 0936bb2bd72..8887d8538e2 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_similarity.py +++ b/libs/core/tests/unit_tests/example_selectors/test_similarity.py @@ -1,6 +1,8 @@ from collections.abc import Iterable from typing import Any, Optional, cast +from typing_extensions import override + from langchain_core.documents import Document from langchain_core.embeddings import Embeddings, FakeEmbeddings from langchain_core.example_selectors import ( @@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore): def embeddings(self) -> Optional[Embeddings]: return self._embeddings + @override def add_texts( self, texts: Iterable[str], @@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore): self.metadatas.extend(metadatas) return ["dummy_id"] + @override def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> list[Document]: @@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore): ) ] * k + @override def max_marginal_relevance_search( self, query: str, diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index 51ef0fe6c9d..ca1eed68f37 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union from uuid import UUID from pydantic import BaseModel +from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage @@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): """Whether to ignore retriever callbacks.""" return self.ignore_retriever_ + @override def on_llm_start( self, *args: Any, @@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_start_common() + @override def on_llm_new_token( self, *args: Any, @@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_new_token_common() + @override def on_llm_end( self, *args: Any, @@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_end_common() + @override def on_llm_error( self, *args: Any, @@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_error_common(*args, **kwargs) + @override def on_retry( self, *args: Any, @@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retry_common() + @override def on_chain_start( self, *args: Any, @@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_start_common() + @override def on_chain_end( self, *args: Any, @@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_end_common() + @override def on_chain_error( self, *args: Any, @@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_error_common() + @override def on_tool_start( self, *args: Any, @@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_start_common() + @override def on_tool_end( self, *args: Any, @@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_end_common() + @override def on_tool_error( self, *args: Any, @@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_error_common() + @override def on_agent_action( self, *args: Any, @@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_agent_action_common() + @override def on_agent_finish( self, *args: Any, @@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_agent_finish_common() + @override def on_text( self, *args: Any, @@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_text_common() + @override def on_retriever_start( self, *args: Any, @@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_start_common() + @override def on_retriever_end( self, *args: Any, @@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_end_common() + @override def on_retriever_error( self, *args: Any, @@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): + @override def on_chat_model_start( self, serialized: dict[str, Any], @@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi """Whether to ignore agent callbacks.""" return self.ignore_agent_ + @override async def on_retry( self, *args: Any, @@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> Any: self.on_retry_common() + @override async def on_llm_start( self, *args: Any, @@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_start_common() + @override async def on_llm_new_token( self, *args: Any, @@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_new_token_common() + @override async def on_llm_end( self, *args: Any, @@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_end_common() + @override async def on_llm_error( self, *args: Any, @@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_error_common(*args, **kwargs) + @override async def on_chain_start( self, *args: Any, @@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_start_common() + @override async def on_chain_end( self, *args: Any, @@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_end_common() + @override async def on_chain_error( self, *args: Any, @@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_error_common() + @override async def on_tool_start( self, *args: Any, @@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_start_common() + @override async def on_tool_end( self, *args: Any, @@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_end_common() + @override async def on_tool_error( self, *args: Any, @@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_error_common() + @override async def on_agent_action( self, *args: Any, @@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_agent_action_common() + @override async def on_agent_finish( self, *args: Any, @@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_agent_finish_common() + @override async def on_text( self, *args: Any, diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 5d7d4525e3f..7500e1640ac 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -4,6 +4,8 @@ from itertools import cycle from typing import Any, Optional, Union from uuid import UUID +from typing_extensions import override + from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.language_models import ( FakeListChatModel, @@ -171,6 +173,7 @@ async def test_callback_handlers() -> None: # Required to implement since this is an abstract method pass + @override async def on_llm_new_token( self, token: str, diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 79e059e0efe..328e745f7a6 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator from typing import TYPE_CHECKING, Any, Literal, Optional, Union import pytest +from typing_extensions import override from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel, FakeListChatModel @@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None: """Test astream uses appropriate implementation.""" class ModelWithGenerate(BaseChatModel): + @override def _generate( self, messages: list[BaseMessage], @@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None: """Top Level call.""" raise NotImplementedError + @override def _stream( self, messages: list[BaseMessage], @@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None: """Top Level call.""" raise NotImplementedError + @override async def _astream( # type: ignore self, messages: list[BaseMessage], @@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None: class NoStreamingModel(BaseChatModel): + @override def _generate( self, messages: list[BaseMessage], @@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel): class StreamingModel(NoStreamingModel): + @override def _stream( self, messages: list[BaseMessage], diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index b35db669695..9705a015366 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -3,6 +3,7 @@ from typing import Any, Optional import pytest +from typing_extensions import override from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.globals import set_llm_cache @@ -30,6 +31,7 @@ class InMemoryCache(BaseCache): """Update cache based on prompt and llm_string.""" self._cache[(prompt, llm_string)] = return_val + @override def clear(self, **kwargs: Any) -> None: """Clear cache.""" self._cache = {} diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index a5741b2225f..b81f98d07be 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator from typing import Any, Optional import pytest +from typing_extensions import override from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -106,6 +107,7 @@ async def test_error_callback() -> None: """Return type of llm.""" return "failing-llm" + @override def _call( self, prompt: str, @@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None: """Test astream uses appropriate implementation.""" class ModelWithGenerate(BaseLLM): + @override def _generate( self, prompts: list[str], @@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None: """Top Level call.""" raise NotImplementedError + @override def _stream( self, prompt: str, @@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None: """Top Level call.""" raise NotImplementedError + @override async def _astream( self, prompt: str, diff --git a/libs/core/tests/unit_tests/language_models/llms/test_cache.py b/libs/core/tests/unit_tests/language_models/llms/test_cache.py index 6894328724b..65356e66c5f 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_cache.py @@ -1,5 +1,7 @@ from typing import Any, Optional +from typing_extensions import override + from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.globals import set_llm_cache from langchain_core.language_models import FakeListLLM @@ -20,6 +22,7 @@ class InMemoryCache(BaseCache): """Update cache based on prompt and llm_string.""" self._cache[(prompt, llm_string)] = return_val + @override def clear(self, **kwargs: Any) -> None: """Clear cache.""" self._cache = {} @@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache): msg = "This code should not be triggered" raise NotImplementedError(msg) + @override def clear(self, **kwargs: Any) -> None: """Clear cache.""" self._cache = {} diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index bef9f30878c..b2bf829acf4 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -6,6 +6,7 @@ from collections.abc import Sequence from typing import Any, Callable, Optional, Union import pytest +from typing_extensions import override from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( @@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None: class FakeTokenCountingModel(FakeChatModel): + @override def get_num_tokens_from_messages( self, messages: list[BaseMessage], diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index aec31c98963..d5c7efe3983 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -1,5 +1,7 @@ """Module to test base parser implementations.""" +from typing_extensions import override + from langchain_core.exceptions import OutputParserException from langchain_core.language_models import GenericFakeChatModel from langchain_core.messages import AIMessage @@ -16,6 +18,7 @@ def test_base_generation_parser() -> None: class StrInvertCase(BaseGenerationOutputParser[str]): """An example parser that inverts the case of the characters in the message.""" + @override def parse_result( self, result: list[Generation], *, partial: bool = False ) -> str: @@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None: """Parse a single string into a specific format.""" raise NotImplementedError + @override def parse_result( self, result: list[Generation], *, partial: bool = False ) -> str: diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 4bce10ca4f7..ded570b77da 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from typing import Any import pytest +from typing_extensions import override from langchain_core.example_selectors import BaseExampleSelector from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector): def add_example(self, example: dict[str, str]) -> Any: raise NotImplementedError + @override def select_examples(self, input_variables: dict[str, str]) -> list[dict]: return list(self.examples) @@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector): def select_examples(self, input_variables: dict[str, str]) -> list[dict]: raise NotImplementedError + @override async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: return list(self.examples) diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index e194ac5831f..6e69936223a 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass def _fake_runnable( - input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any + _: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any ) -> Union[BaseModel, dict]: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 8b2557d1a97..0c1b3b6b206 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -2,7 +2,7 @@ from typing import Any, Optional import pytest from pydantic import ConfigDict, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, override from langchain_core.runnables import ( ConfigurableField, @@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]): self._my_hidden_property = self.my_property return self + @override def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]): return self.my_property def my_custom_function_w_config( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, # noqa: ARG002 ) -> str: return self.my_property def my_custom_function_w_kw_config( - self, *, config: Optional[RunnableConfig] = None + self, + *, + config: Optional[RunnableConfig] = None, # noqa: ARG002 ) -> str: return self.my_property @@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]): class MyOtherRunnable(RunnableSerializable[str, str]): my_other_property: str + @override def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]): def my_other_custom_function(self) -> str: return self.my_other_property - def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: + def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002 return self.my_other_property diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index bcffca1667e..b638ac78f57 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable: "What's your name?", ] - retriever = RunnableLambda(lambda x: context) + retriever = RunnableLambda(lambda _: context) prompt = PromptTemplate.from_template("{context} {question}") llm = FakeListLLM(responses=["hello"]) @@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable: "What's your name?", ] - retriever = RunnableLambda(lambda x: context) + retriever = RunnableLambda(lambda _: context) prompt = PromptTemplate.from_template("{context} {question}") llm = FakeListLLM(responses=["hello"]) @@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable: "What's your name?", ] - retriever = RunnableLambda(lambda x: context) + retriever = RunnableLambda(lambda _: context) prompt = PromptTemplate.from_template("{context} {question}") llm = FakeListLLM(responses=["hello"]) diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index f28965fcf58..a3dbc74eee9 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -9,6 +9,7 @@ from typing import ( import pytest from pydantic import BaseModel from syrupy import SnapshotAssertion +from typing_extensions import override from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import ( @@ -60,7 +61,7 @@ def chain() -> Runnable: ) -def _raise_error(inputs: dict) -> str: +def _raise_error(_: dict) -> str: raise ValueError @@ -259,17 +260,17 @@ async def test_abatch() -> None: _assert_potential_error(actual, expected) -def _generate(input: Iterator) -> Iterator[str]: +def _generate(_: Iterator) -> Iterator[str]: yield from "foo bar" -def _generate_immediate_error(input: Iterator) -> Iterator[str]: +def _generate_immediate_error(_: Iterator) -> Iterator[str]: msg = "immmediate error" raise ValueError(msg) yield "" -def _generate_delayed_error(input: Iterator) -> Iterator[str]: +def _generate_delayed_error(_: Iterator) -> Iterator[str]: yield "" msg = "delayed error" raise ValueError(msg) @@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None: list(runnable.stream({})) -async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]: for c in "foo bar": yield c -async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]: msg = "immmediate error" raise ValueError(msg) yield "" -async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]: yield "" msg = "delayed error" raise ValueError(msg) @@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None: class FakeStructuredOutputModel(BaseChatModel): foo: int + @override def _generate( self, messages: list[BaseMessage], @@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel): """Top Level call.""" return ChatResult(generations=[]) + @override def bind_tools( self, tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], @@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel): ) -> Runnable[LanguageModelInput, BaseMessage]: return self.bind(tools=tools) + @override def with_structured_output( self, schema: Union[dict, type[BaseModel]], **kwargs: Any ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: - return RunnableLambda(lambda x: {"foo": self.foo}) + return RunnableLambda(lambda _: {"foo": self.foo}) @property def _llm_type(self) -> str: @@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel): class FakeModel(BaseChatModel): bar: int + @override def _generate( self, messages: list[BaseMessage], @@ -363,6 +368,7 @@ class FakeModel(BaseChatModel): """Top Level call.""" return ChatResult(generations=[]) + @override def bind_tools( self, tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 35f6c2b3801..2a5f2a0a84f 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union import pytest from packaging import version from pydantic import BaseModel +from typing_extensions import override from langchain_core.callbacks import ( CallbackManagerForLLMRun, @@ -39,7 +40,7 @@ def _get_get_session_history( chat_history_store = store if store is not None else {} def get_session_history( - session_id: str, **kwargs: Any + session_id: str, **_kwargs: Any ) -> InMemoryChatMessageHistory: if session_id not in chat_history_store: chat_history_store[session_id] = InMemoryChatMessageHistory() @@ -253,6 +254,7 @@ async def test_output_message_async() -> None: class LengthChatModel(BaseChatModel): """A fake chat model that returns the length of the messages passed in.""" + @override def _generate( self, messages: list[BaseMessage], @@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None: def test_get_output_messages_with_value_error() -> None: illegal_bool_message = False - runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message) + runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message) store: dict = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) @@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None: with_history.bound.invoke([HumanMessage(content="hello")], config) illegal_int_message = 123 - runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message) + runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message) with_history = RunnableWithMessageHistory(runnable, get_session_history) with pytest.raises( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 4facc2979ab..1d4c3d95902 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -21,10 +21,11 @@ from packaging import version from pydantic import BaseModel, Field from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from langchain_core.callbacks.manager import ( - Callbacks, + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, atrace_as_chain_group, trace_as_chain_group, ) @@ -184,6 +185,7 @@ class FakeTracer(BaseTracer): class FakeRunnable(Runnable[str, int]): + @override def invoke( self, input: str, @@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]): class FakeRunnableSerializable(RunnableSerializable[str, int]): hello: str = "" + @override def invoke( self, input: str, @@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]): class FakeRetriever(BaseRetriever): + @override def _get_relevant_documents( - self, - query: str, - *, - callbacks: Callbacks = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: return [Document(page_content="foo"), Document(page_content="bar")] + @override async def _aget_relevant_documents( - self, - query: str, - *, - callbacks: Callbacks = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> list[Document]: return [Document(page_content="foo"), Document(page_content="bar")] @@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: foo_ = RunnableLambda(foo) - assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == { + assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == { "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}}, "required": ["root", "bar"], "title": "RunnableAssignOutput", @@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None: @freeze_time("2023-01-01") +@pytest.mark.usefixtures("deterministic_uuids") def test_prompt_with_chat_model( mocker: MockerFixture, snapshot: SnapshotAssertion, - deterministic_uuids: MockerFixture, ) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model( @freeze_time("2023-01-01") +@pytest.mark.usefixtures("deterministic_uuids") async def test_prompt_with_chat_model_async( mocker: MockerFixture, snapshot: SnapshotAssertion, - deterministic_uuids: MockerFixture, ) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None: @freeze_time("2023-01-01") async def test_stream_log_lists() -> None: - async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]: + async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]: for i in range(4): yield AddableDict(alist=[str(i)]) @@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda( @freeze_time("2023-01-01") +@pytest.mark.usefixtures("deterministic_uuids") def test_prompt_with_chat_model_and_parser( mocker: MockerFixture, snapshot: SnapshotAssertion, - deterministic_uuids: MockerFixture, ) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser( @freeze_time("2023-01-01") +@pytest.mark.usefixtures("deterministic_uuids") def test_combining_sequences( - mocker: MockerFixture, snapshot: SnapshotAssertion, - deterministic_uuids: MockerFixture, ) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3513,7 +3505,7 @@ def test_bind_bind() -> None: def test_bind_with_lambda() -> None: - def my_function(*args: Any, **kwargs: Any) -> int: + def my_function(_: Any, **kwargs: Any) -> int: return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) @@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None: async def test_bind_with_lambda_async() -> None: - def my_function(*args: Any, **kwargs: Any) -> int: + def my_function(_: Any, **kwargs: Any) -> int: return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) @@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None: def test_recursive_lambda() -> None: def _simple_recursion(x: int) -> Union[int, Runnable]: if x < 10: - return RunnableLambda(lambda *args: _simple_recursion(x + 1)) + return RunnableLambda(lambda *_: _simple_recursion(x + 1)) return x runnable = RunnableLambda(_simple_recursion) @@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None: # sleep to better simulate a real stream llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) - output = list(RunnableLambda(lambda x: llm).stream("")) + output = list(RunnableLambda(lambda _: llm).stream("")) assert output == list(llm_res) @@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None: llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) config: RunnableConfig = {"callbacks": [tracer]} - assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list( + assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list( llm_res ) @@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None: assert tracer.runs[0].error is None assert tracer.runs[0].outputs == {"output": llm_res} - def raise_value_error(x: int) -> int: + def raise_value_error(_: int) -> int: """Raise a value error.""" msg = "x is too large" raise ValueError(msg) @@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None: _ async for _ in RunnableLambda( func=id, - afunc=awrapper(lambda x: llm), + afunc=awrapper(lambda _: llm), ).astream("") ] assert output == list(llm_res) @@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None: output = [ chunk async for chunk in cast( - "AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("") + "AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("") ) ] assert output == list(llm_res) @@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None: config: RunnableConfig = {"callbacks": [tracer]} assert [ - _ async for _ in RunnableLambda(lambda x: llm).astream("", config=config) + _ async for _ in RunnableLambda(lambda _: llm).astream("", config=config) ] == list(llm_res) assert len(tracer.runs) == 1 assert tracer.runs[0].error is None assert tracer.runs[0].outputs == {"output": llm_res} - def raise_value_error(x: int) -> int: + def raise_value_error(_: int) -> int: """Raise a value error.""" msg = "x is too large" raise ValueError(msg) @@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None: def test_runnable_branch_invoke() -> None: # Test with single branch - def raise_value_error(x: int) -> int: + def raise_value_error(_: int) -> int: """Raise a value error.""" msg = "x is too large" raise ValueError(msg) @@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None: """Verify that callbacks are correctly used in invoke.""" tracer = FakeTracer() - def raise_value_error(x: int) -> int: + def raise_value_error(_: int) -> int: """Raise a value error.""" msg = "x is too large" raise ValueError(msg) @@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None: """Verify that callbacks are invoked correctly in ainvoke.""" tracer = FakeTracer() - async def raise_value_error(x: int) -> int: + async def raise_value_error(_: int) -> int: """Raise a value error.""" msg = "x is too large" raise ValueError(msg) @@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None: runnable = RunnableLambda(lambda x: x * 2) assert repr(runnable) == "RunnableLambda(lambda x: x * 2)" - def f(x: int) -> int: + def f(_: int) -> int: """Return 2.""" return 2 assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)" - async def af(x: int) -> int: + async def af(_: int) -> int: """Return 2.""" return 2 @@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None: def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" - def gen(input: Iterator[Any]) -> Iterator[int]: + def gen(_: Iterator[Any]) -> Iterator[int]: yield 1 yield 2 yield 3 @@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None: async def test_runnable_gen_async() -> None: """Test that a generator can be used as a runnable.""" - async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 yield 3 @@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None: assert await arunnable.abatch([None, None]) == [6, 6] class AsyncGen: - async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: + async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 yield 3 @@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None: """ fake = RunnableLambda(len) - def gen(input: Iterator[Any]) -> Iterator[int]: + def gen(_: Iterator[Any]) -> Iterator[int]: yield fake.invoke("a") yield fake.invoke("aa") yield fake.invoke("aaa") @@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None: fake = RunnableLambda(len) - async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]: yield await fake.ainvoke("a") yield await fake.ainvoke("aa") yield await fake.ainvoke("aaa") @@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None: """Test that default transform works with dicts.""" class CustomRunnable(RunnableSerializable[Input, Output]): + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None: """Test that default transform works with dicts.""" class CustomRunnable(RunnableSerializable[Input, Output]): + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None: on_chain_end_triggered = False class MyHandler(BaseCallbackHandler): + @override def on_chain_error( self, error: BaseException, @@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None: nonlocal on_chain_error_triggered on_chain_error_triggered = True + @override def on_chain_end( self, outputs: dict[str, Any], diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 9c457b24915..d099d35c978 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -8,6 +8,7 @@ from typing import Any, cast import pytest from pydantic import BaseModel +from typing_extensions import override from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.chat_history import BaseChatMessageHistory @@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) - async def test_event_stream_with_simple_function_tool() -> None: """Test the event stream with a function and tool.""" - def foo(x: int) -> dict: + def foo(_: int) -> dict: """Foo.""" return {"x": 5} @tool - def get_docs(x: int) -> list[Document]: + def get_docs(x: int) -> list[Document]: # noqa: ARG001 """Hello Doc.""" return [Document(page_content="hello")] @@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None: async def test_event_stream_with_lambdas_from_lambda() -> None: - as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( + as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config( {"run_name": "my_lambda"} ) events = await _collect_events( @@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None: return "hello" @tool - def with_callbacks(callbacks: Callbacks) -> str: + def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001 """A tool that does nothing.""" return "world" @@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None: return {"x": x, "y": y} @tool - def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: + def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001 """A tool that does nothing.""" return {"x": x, "y": y} @@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None: class HardCodedRetriever(BaseRetriever): documents: list[Document] + @override def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: @@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None: async def test_event_stream_with_retry() -> None: """Test the event stream with a tool.""" - def success(inputs: str) -> str: + def success(_: str) -> str: return "success" - def fail(inputs: str) -> None: + def fail(_: str) -> None: """Simple func.""" msg = "fail" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 6f82a7a44e9..f20a06b932c 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -17,6 +17,7 @@ from typing import ( import pytest from blockbuster import BlockBuster from pydantic import BaseModel +from typing_extensions import override from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.chat_history import BaseChatMessageHistory @@ -97,12 +98,12 @@ async def _collect_events( async def test_event_stream_with_simple_function_tool() -> None: """Test the event stream with a function and tool.""" - def foo(x: int) -> dict: + def foo(x: int) -> dict: # noqa: ARG001 """Foo.""" return {"x": 5} @tool - def get_docs(x: int) -> list[Document]: + def get_docs(x: int) -> list[Document]: # noqa: ARG001 """Hello Doc.""" return [Document(page_content="hello")] @@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None: async def test_event_stream_with_lambdas_from_lambda() -> None: - as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( + as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config( {"run_name": "my_lambda"} ) events = await _collect_events( @@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None: return "hello" @tool - def with_callbacks(callbacks: Callbacks) -> str: + def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001 """A tool that does nothing.""" return "world" @@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None: return {"x": x, "y": y} @tool - def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: + def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001 """A tool that does nothing.""" return {"x": x, "y": y} @@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None: class HardCodedRetriever(BaseRetriever): documents: list[Document] + @override def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: @@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None: async def test_event_stream_with_retry() -> None: """Test the event stream with a tool.""" - def success(inputs: str) -> str: + def success(_: str) -> str: return "success" - def fail(inputs: str) -> None: + def fail(_: str) -> None: """Simple func.""" msg = "fail" raise ValueError(msg) @@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]): """Initialize the runnable.""" self.iterable = iterable + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]): ) -> Iterator[Output]: raise NotImplementedError + @override async def astream( self, input: Input, @@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None: async def test_runnable_generator() -> None: """Test async events from sync lambda.""" - async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]: + async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]: yield "1" yield "2" diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 31ad313827f..a681fd8a12d 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting: def test_sync( self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int] ) -> None: - def other_thing(a: int) -> Generator[int, None, None]: # type: ignore + def other_thing(_: int) -> Generator[int, None, None]: # type: ignore yield 1 parent = self._create_parent(other_thing) @@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting: [RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int] ], ) -> None: - async def other_thing(a: int) -> AsyncGenerator[int, None]: + async def other_thing(_: int) -> AsyncGenerator[int, None]: yield 1 parent = self._create_parent(other_thing) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 64368de0dda..f21bf705872 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None: self.value = "bar" @tool - def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: + def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001 """Tool that has an injected tool arg.""" return foo.value @@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None: def test_empty_string_tool_call_id() -> None: @tool - def foo(x: int) -> str: + def foo(x: int) -> str: # noqa: ARG001 """Foo.""" return "hi" @@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None: def test_tool_decorator_description() -> None: # test basic tool @tool - def foo(x: int) -> str: + def foo(x: int) -> str: # noqa: ARG001 """Foo.""" return "hi" @@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None: # test basic tool with description @tool(description="description") - def foo_description(x: int) -> str: + def foo_description(x: int) -> str: # noqa: ARG001 """Foo.""" return "hi" @@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None: x: int @tool(args_schema=ArgsSchema) - def foo_args_schema(x: int) -> str: + def foo_args_schema(x: int) -> str: # noqa: ARG001 return "hi" assert foo_args_schema.description == "Bar." @@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None: ) @tool(description="description", args_schema=ArgsSchema) - def foo_args_schema_description(x: int) -> str: + def foo_args_schema_description(x: int) -> str: # noqa: ARG001 return "hi" assert foo_args_schema_description.description == "description" @@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None: } @tool(args_schema=args_json_schema) - def foo_args_jsons_schema(x: int) -> str: + def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001 return "hi" @tool(description="description", args_schema=args_json_schema) - def foo_args_jsons_schema_with_description(x: int) -> str: + def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001 return "hi" assert foo_args_jsons_schema.description == "JSON Schema." @@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None: async def test_tool_ainvoke_does_not_mutate_inputs() -> None: """Verify that the inputs are not mutated when invoking a tool asynchronously.""" - def sync_no_op(foo: int) -> str: + def sync_no_op(foo: int) -> str: # noqa: ARG001 return "good" - async def async_no_op(foo: int) -> str: + async def async_no_op(foo: int) -> str: # noqa: ARG001 return "good" tool = StructuredTool( @@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None: def test_tool_invoke_does_not_mutate_inputs() -> None: """Verify that the inputs are not mutated when invoking a tool synchronously.""" - def sync_no_op(foo: int) -> str: + def sync_no_op(foo: int) -> str: # noqa: ARG001 return "good" - async def async_no_op(foo: int) -> str: + async def async_no_op(foo: int) -> str: # noqa: ARG001 return "good" tool = StructuredTool( diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 5562ee5da19..7e6f41992cb 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool: arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") return StructuredTool.from_function( - lambda x: None, + lambda _: None, name="dummy_function", description="Dummy function.", args_schema=Schema, @@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool: "required": ["arg1", "arg2"], } return StructuredTool.from_function( - lambda x: None, + lambda _: None, name="dummy_function", description="Dummy function.", args_schema=args_schema, diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 2ebee79cf9d..c8e6df87ac9 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -10,6 +10,7 @@ import uuid from typing import TYPE_CHECKING, Any, Optional import pytest +from typing_extensions import override from langchain_core.documents import Document from langchain_core.embeddings import Embeddings, FakeEmbeddings @@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore): def __init__(self) -> None: self.store: dict[str, Document] = {} + @override def add_texts( self, texts: Iterable[str], @@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore): return [self.store[id] for id in ids if id in self.store] @classmethod + @override def from_texts( # type: ignore cls, texts: list[str], @@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore): def __init__(self) -> None: self.store: dict[str, Document] = {} + @override def add_documents( self, documents: list[Document], @@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore): return [self.store[id] for id in ids if id in self.store] @classmethod + @override def from_texts( # type: ignore cls, texts: list[str],