core: Add ruff rules ARG (#30732)

See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg
This commit is contained in:
Christophe Bornet
2025-04-09 20:39:36 +02:00
committed by GitHub
parent 66758599a9
commit 98f0016fc2
58 changed files with 328 additions and 180 deletions

View File

@@ -10,6 +10,8 @@ from contextlib import asynccontextmanager
from typing import Any, Optional
from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks import (
AsyncCallbackHandler,
AsyncCallbackManager,
@@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None:
"""Initialize the handler."""
self.run_inline = run_inline
@override
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
"""Update the callstack with the name of the callback."""
some_var.set("on_llm_start")

View File

@@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None:
# a decorator for async functions
@RunnableLambda # type: ignore[arg-type]
async def foo(x: int, config: RunnableConfig) -> int:
assert "callbacks" in config
await adispatch_custom_event("event1", {"x": x})
await adispatch_custom_event("event2", {"x": x})
return x

View File

@@ -3,6 +3,7 @@
from collections.abc import Iterator
import pytest
from typing_extensions import override
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_core.documents import Document
@@ -15,6 +16,7 @@ def test_base_blob_parser() -> None:
class MyParser(BaseBlobParser):
"""A simple parser that returns a single document."""
@override
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazy parsing interface."""
yield Document(

View File

@@ -1,6 +1,8 @@
from collections.abc import Iterable
from typing import Any, Optional, cast
from typing_extensions import override
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.example_selectors import (
@@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore):
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
@override
def add_texts(
self,
texts: Iterable[str],
@@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore):
self.metadatas.extend(metadatas)
return ["dummy_id"]
@override
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
@@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore):
)
] * k
@override
def max_marginal_relevance_search(
self,
query: str,

View File

@@ -5,6 +5,7 @@ from typing import Any, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
@@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
@override
def on_llm_start(
self,
*args: Any,
@@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_start_common()
@override
def on_llm_new_token(
self,
*args: Any,
@@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_new_token_common()
@override
def on_llm_end(
self,
*args: Any,
@@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_end_common()
@override
def on_llm_error(
self,
*args: Any,
@@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_error_common(*args, **kwargs)
@override
def on_retry(
self,
*args: Any,
@@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retry_common()
@override
def on_chain_start(
self,
*args: Any,
@@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_start_common()
@override
def on_chain_end(
self,
*args: Any,
@@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_end_common()
@override
def on_chain_error(
self,
*args: Any,
@@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_error_common()
@override
def on_tool_start(
self,
*args: Any,
@@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_start_common()
@override
def on_tool_end(
self,
*args: Any,
@@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_end_common()
@override
def on_tool_error(
self,
*args: Any,
@@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_error_common()
@override
def on_agent_action(
self,
*args: Any,
@@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_action_common()
@override
def on_agent_finish(
self,
*args: Any,
@@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_finish_common()
@override
def on_text(
self,
*args: Any,
@@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_text_common()
@override
def on_retriever_start(
self,
*args: Any,
@@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_start_common()
@override
def on_retriever_end(
self,
*args: Any,
@@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_end_common()
@override
def on_retriever_error(
self,
*args: Any,
@@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
@override
def on_chat_model_start(
self,
serialized: dict[str, Any],
@@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@override
async def on_retry(
self,
*args: Any,
@@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> Any:
self.on_retry_common()
@override
async def on_llm_start(
self,
*args: Any,
@@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_start_common()
@override
async def on_llm_new_token(
self,
*args: Any,
@@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_new_token_common()
@override
async def on_llm_end(
self,
*args: Any,
@@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_end_common()
@override
async def on_llm_error(
self,
*args: Any,
@@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_error_common(*args, **kwargs)
@override
async def on_chain_start(
self,
*args: Any,
@@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_start_common()
@override
async def on_chain_end(
self,
*args: Any,
@@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_end_common()
@override
async def on_chain_error(
self,
*args: Any,
@@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_error_common()
@override
async def on_tool_start(
self,
*args: Any,
@@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_start_common()
@override
async def on_tool_end(
self,
*args: Any,
@@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_end_common()
@override
async def on_tool_error(
self,
*args: Any,
@@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_error_common()
@override
async def on_agent_action(
self,
*args: Any,
@@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_action_common()
@override
async def on_agent_finish(
self,
*args: Any,
@@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_finish_common()
@override
async def on_text(
self,
*args: Any,

View File

@@ -4,6 +4,8 @@ from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import (
FakeListChatModel,
@@ -171,6 +173,7 @@ async def test_callback_handlers() -> None:
# Required to implement since this is an abstract method
pass
@override
async def on_llm_new_token(
self,
token: str,

View File

@@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import pytest
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel
@@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseChatModel):
@override
def _generate(
self,
messages: list[BaseMessage],
@@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
def _stream(
self,
messages: list[BaseMessage],
@@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
async def _astream( # type: ignore
self,
messages: list[BaseMessage],
@@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None:
class NoStreamingModel(BaseChatModel):
@override
def _generate(
self,
messages: list[BaseMessage],
@@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel):
class StreamingModel(NoStreamingModel):
@override
def _stream(
self,
messages: list[BaseMessage],

View File

@@ -3,6 +3,7 @@
from typing import Any, Optional
import pytest
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
@@ -30,6 +31,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}

View File

@@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional
import pytest
from typing_extensions import override
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@@ -106,6 +107,7 @@ async def test_error_callback() -> None:
"""Return type of llm."""
return "failing-llm"
@override
def _call(
self,
prompt: str,
@@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseLLM):
@override
def _generate(
self,
prompts: list[str],
@@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
def _stream(
self,
prompt: str,
@@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
async def _astream(
self,
prompt: str,

View File

@@ -1,5 +1,7 @@
from typing import Any, Optional
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models import FakeListLLM
@@ -20,6 +22,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
@@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache):
msg = "This code should not be triggered"
raise NotImplementedError(msg)
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}

View File

@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pytest
from typing_extensions import override
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
@@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],

View File

@@ -1,5 +1,7 @@
"""Module to test base parser implementations."""
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
@@ -16,6 +18,7 @@ def test_base_generation_parser() -> None:
class StrInvertCase(BaseGenerationOutputParser[str]):
"""An example parser that inverts the case of the characters in the message."""
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> str:
@@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None:
"""Parse a single string into a specific format."""
raise NotImplementedError
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> str:

View File

@@ -5,6 +5,7 @@ from collections.abc import Sequence
from typing import Any
import pytest
from typing_extensions import override
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> Any:
raise NotImplementedError
@override
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples)
@@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector):
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
raise NotImplementedError
@override
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples)

View File

@@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
) -> Union[BaseModel, dict]:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)

View File

@@ -2,7 +2,7 @@ from typing import Any, Optional
import pytest
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
from langchain_core.runnables import (
ConfigurableField,
@@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]):
self._my_hidden_property = self.my_property
return self
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]):
return self.my_property
def my_custom_function_w_config(
self, config: Optional[RunnableConfig] = None
self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str:
return self.my_property
def my_custom_function_w_kw_config(
self, *, config: Optional[RunnableConfig] = None
self,
*,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str:
return self.my_property
@@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]):
class MyOtherRunnable(RunnableSerializable[str, str]):
my_other_property: str
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def my_other_custom_function(self) -> str:
return self.my_other_property
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str:
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002
return self.my_other_property

View File

@@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
@@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
@@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])

View File

@@ -9,6 +9,7 @@ from typing import (
import pytest
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import (
@@ -60,7 +61,7 @@ def chain() -> Runnable:
)
def _raise_error(inputs: dict) -> str:
def _raise_error(_: dict) -> str:
raise ValueError
@@ -259,17 +260,17 @@ async def test_abatch() -> None:
_assert_potential_error(actual, expected)
def _generate(input: Iterator) -> Iterator[str]:
def _generate(_: Iterator) -> Iterator[str]:
yield from "foo bar"
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None:
list(runnable.stream({}))
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
for c in "foo bar":
yield c
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None:
class FakeStructuredOutputModel(BaseChatModel):
foo: int
@override
def _generate(
self,
messages: list[BaseMessage],
@@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel):
"""Top Level call."""
return ChatResult(generations=[])
@override
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
@@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel):
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
@override
def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return RunnableLambda(lambda x: {"foo": self.foo})
return RunnableLambda(lambda _: {"foo": self.foo})
@property
def _llm_type(self) -> str:
@@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel):
class FakeModel(BaseChatModel):
bar: int
@override
def _generate(
self,
messages: list[BaseMessage],
@@ -363,6 +368,7 @@ class FakeModel(BaseChatModel):
"""Top Level call."""
return ChatResult(generations=[])
@override
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],

View File

@@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
import pytest
from packaging import version
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
@@ -39,7 +40,7 @@ def _get_get_session_history(
chat_history_store = store if store is not None else {}
def get_session_history(
session_id: str, **kwargs: Any
session_id: str, **_kwargs: Any
) -> InMemoryChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = InMemoryChatMessageHistory()
@@ -253,6 +254,7 @@ async def test_output_message_async() -> None:
class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in."""
@override
def _generate(
self,
messages: list[BaseMessage],
@@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
@@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None:
with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises(

View File

@@ -21,10 +21,11 @@ from packaging import version
from pydantic import BaseModel, Field
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict
from typing_extensions import TypedDict, override
from langchain_core.callbacks.manager import (
Callbacks,
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
atrace_as_chain_group,
trace_as_chain_group,
)
@@ -184,6 +185,7 @@ class FakeTracer(BaseTracer):
class FakeRunnable(Runnable[str, int]):
@override
def invoke(
self,
input: str,
@@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]):
class FakeRunnableSerializable(RunnableSerializable[str, int]):
hello: str = ""
@override
def invoke(
self,
input: str,
@@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
class FakeRetriever(BaseRetriever):
@override
def _get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@override
async def _aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
foo_ = RunnableLambda(foo)
assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == {
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
"required": ["root", "bar"],
"title": "RunnableAssignOutput",
@@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None:
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
async def test_prompt_with_chat_model_async(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None:
@freeze_time("2023-01-01")
async def test_stream_log_lists() -> None:
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
for i in range(4):
yield AddableDict(alist=[str(i)])
@@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_combining_sequences(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@@ -3513,7 +3505,7 @@ def test_bind_bind() -> None:
def test_bind_with_lambda() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
@@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None:
async def test_bind_with_lambda_async() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
@@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
return x
runnable = RunnableLambda(_simple_recursion)
@@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None:
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda x: llm).stream(""))
output = list(RunnableLambda(lambda _: llm).stream(""))
assert output == list(llm_res)
@@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
llm_res
)
@@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None:
_
async for _ in RunnableLambda(
func=id,
afunc=awrapper(lambda x: llm),
afunc=awrapper(lambda _: llm),
).astream("")
]
assert output == list(llm_res)
@@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None:
output = [
chunk
async for chunk in cast(
"AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("")
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
)
]
assert output == list(llm_res)
@@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
config: RunnableConfig = {"callbacks": [tracer]}
assert [
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
] == list(llm_res)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
def test_runnable_branch_invoke() -> None:
# Test with single branch
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
"""Verify that callbacks are correctly used in invoke."""
tracer = FakeTracer()
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
"""Verify that callbacks are invoked correctly in ainvoke."""
tracer = FakeTracer()
async def raise_value_error(x: int) -> int:
async def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None:
runnable = RunnableLambda(lambda x: x * 2)
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
def f(x: int) -> int:
def f(_: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
async def af(x: int) -> int:
async def af(_: int) -> int:
"""Return 2."""
return 2
@@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None:
def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable."""
def gen(input: Iterator[Any]) -> Iterator[int]:
def gen(_: Iterator[Any]) -> Iterator[int]:
yield 1
yield 2
yield 3
@@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None:
async def test_runnable_gen_async() -> None:
"""Test that a generator can be used as a runnable."""
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
@@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None:
assert await arunnable.abatch([None, None]) == [6, 6]
class AsyncGen:
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
@@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None:
"""
fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]:
def gen(_: Iterator[Any]) -> Iterator[int]:
yield fake.invoke("a")
yield fake.invoke("aa")
yield fake.invoke("aaa")
@@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None:
fake = RunnableLambda(len)
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield await fake.ainvoke("a")
yield await fake.ainvoke("aa")
yield await fake.ainvoke("aaa")
@@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
on_chain_end_triggered = False
class MyHandler(BaseCallbackHandler):
@override
def on_chain_error(
self,
error: BaseException,
@@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
nonlocal on_chain_error_triggered
on_chain_error_triggered = True
@override
def on_chain_end(
self,
outputs: dict[str, Any],

View File

@@ -8,6 +8,7 @@ from typing import Any, cast
import pytest
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
@@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
def foo(_: int) -> dict:
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc."""
return [Document(page_content="hello")]
@@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(
@@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello"
@tool
def with_callbacks(callbacks: Callbacks) -> str:
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing."""
return "world"
@@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y}
@tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing."""
return {"x": x, "y": y}
@@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever):
documents: list[Document]
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
@@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool."""
def success(inputs: str) -> str:
def success(_: str) -> str:
return "success"
def fail(inputs: str) -> None:
def fail(_: str) -> None:
"""Simple func."""
msg = "fail"
raise ValueError(msg)

View File

@@ -17,6 +17,7 @@ from typing import (
import pytest
from blockbuster import BlockBuster
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
@@ -97,12 +98,12 @@ async def _collect_events(
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
def foo(x: int) -> dict: # noqa: ARG001
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc."""
return [Document(page_content="hello")]
@@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(
@@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello"
@tool
def with_callbacks(callbacks: Callbacks) -> str:
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing."""
return "world"
@@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y}
@tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing."""
return {"x": x, "y": y}
@@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever):
documents: list[Document]
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
@@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool."""
def success(inputs: str) -> str:
def success(_: str) -> str:
return "success"
def fail(inputs: str) -> None:
def fail(_: str) -> None:
"""Simple func."""
msg = "fail"
raise ValueError(msg)
@@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]):
"""Initialize the runnable."""
self.iterable = iterable
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]):
) -> Iterator[Output]:
raise NotImplementedError
@override
async def astream(
self,
input: Input,
@@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None:
async def test_runnable_generator() -> None:
"""Test async events from sync lambda."""
async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]:
async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
yield "1"
yield "2"

View File

@@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting:
def test_sync(
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
) -> None:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
def other_thing(_: int) -> Generator[int, None, None]: # type: ignore
yield 1
parent = self._create_parent(other_thing)
@@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting:
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
],
) -> None:
async def other_thing(a: int) -> AsyncGenerator[int, None]:
async def other_thing(_: int) -> AsyncGenerator[int, None]:
yield 1
parent = self._create_parent(other_thing)

View File

@@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None:
self.value = "bar"
@tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
"""Tool that has an injected tool arg."""
return foo.value
@@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None:
def test_empty_string_tool_call_id() -> None:
@tool
def foo(x: int) -> str:
def foo(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None:
def test_tool_decorator_description() -> None:
# test basic tool
@tool
def foo(x: int) -> str:
def foo(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None:
# test basic tool with description
@tool(description="description")
def foo_description(x: int) -> str:
def foo_description(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None:
x: int
@tool(args_schema=ArgsSchema)
def foo_args_schema(x: int) -> str:
def foo_args_schema(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_schema.description == "Bar."
@@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None:
)
@tool(description="description", args_schema=ArgsSchema)
def foo_args_schema_description(x: int) -> str:
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_schema_description.description == "description"
@@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None:
}
@tool(args_schema=args_json_schema)
def foo_args_jsons_schema(x: int) -> str:
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
return "hi"
@tool(description="description", args_schema=args_json_schema)
def foo_args_jsons_schema_with_description(x: int) -> str:
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_jsons_schema.description == "JSON Schema."
@@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None:
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
def sync_no_op(foo: int) -> str:
def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good"
async def async_no_op(foo: int) -> str:
async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good"
tool = StructuredTool(
@@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str:
def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good"
async def async_no_op(foo: int) -> str:
async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good"
tool = StructuredTool(

View File

@@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool:
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return StructuredTool.from_function(
lambda x: None,
lambda _: None,
name="dummy_function",
description="Dummy function.",
args_schema=Schema,
@@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
"required": ["arg1", "arg2"],
}
return StructuredTool.from_function(
lambda x: None,
lambda _: None,
name="dummy_function",
description="Dummy function.",
args_schema=args_schema,

View File

@@ -10,6 +10,7 @@ import uuid
from typing import TYPE_CHECKING, Any, Optional
import pytest
from typing_extensions import override
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings
@@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore):
def __init__(self) -> None:
self.store: dict[str, Document] = {}
@override
def add_texts(
self,
texts: Iterable[str],
@@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store]
@classmethod
@override
def from_texts( # type: ignore
cls,
texts: list[str],
@@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
def __init__(self) -> None:
self.store: dict[str, Document] = {}
@override
def add_documents(
self,
documents: list[Document],
@@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store]
@classmethod
@override
def from_texts( # type: ignore
cls,
texts: list[str],