mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 22:17:15 +00:00
core: Add ruff rules ARG (#30732)
See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg
This commit is contained in:
committed by
GitHub
parent
66758599a9
commit
98f0016fc2
@@ -10,6 +10,8 @@ from contextlib import asynccontextmanager
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackHandler,
|
||||
AsyncCallbackManager,
|
||||
@@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None:
|
||||
"""Initialize the handler."""
|
||||
self.run_inline = run_inline
|
||||
|
||||
@override
|
||||
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Update the callstack with the name of the callback."""
|
||||
some_var.set("on_llm_start")
|
||||
|
@@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None:
|
||||
# a decorator for async functions
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def foo(x: int, config: RunnableConfig) -> int:
|
||||
assert "callbacks" in config
|
||||
await adispatch_custom_event("event1", {"x": x})
|
||||
await adispatch_custom_event("event2", {"x": x})
|
||||
return x
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
@@ -15,6 +16,7 @@ def test_base_blob_parser() -> None:
|
||||
class MyParser(BaseBlobParser):
|
||||
"""A simple parser that returns a single document."""
|
||||
|
||||
@override
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazy parsing interface."""
|
||||
yield Document(
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||
from langchain_core.example_selectors import (
|
||||
@@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore):
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embeddings
|
||||
|
||||
@override
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore):
|
||||
self.metadatas.extend(metadatas)
|
||||
return ["dummy_id"]
|
||||
|
||||
@override
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
@@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore):
|
||||
)
|
||||
] * k
|
||||
|
||||
@override
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
@@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
@override
|
||||
def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_start_common()
|
||||
|
||||
@override
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
@override
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_end_common()
|
||||
|
||||
@override
|
||||
def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
@override
|
||||
def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_start_common()
|
||||
|
||||
@override
|
||||
def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_end_common()
|
||||
|
||||
@override
|
||||
def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_error_common()
|
||||
|
||||
@override
|
||||
def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_start_common()
|
||||
|
||||
@override
|
||||
def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_end_common()
|
||||
|
||||
@override
|
||||
def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_error_common()
|
||||
|
||||
@override
|
||||
def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_agent_action_common()
|
||||
|
||||
@override
|
||||
def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
@override
|
||||
def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
@override
|
||||
def on_retriever_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_start_common()
|
||||
|
||||
@override
|
||||
def on_retriever_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_end_common()
|
||||
|
||||
@override
|
||||
def on_retriever_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
|
||||
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
@override
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@override
|
||||
async def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
@override
|
||||
async def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_start_common()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_end_common()
|
||||
|
||||
@override
|
||||
async def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
@override
|
||||
async def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_start_common()
|
||||
|
||||
@override
|
||||
async def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_end_common()
|
||||
|
||||
@override
|
||||
async def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_error_common()
|
||||
|
||||
@override
|
||||
async def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_start_common()
|
||||
|
||||
@override
|
||||
async def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_end_common()
|
||||
|
||||
@override
|
||||
async def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_error_common()
|
||||
|
||||
@override
|
||||
async def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_agent_action_common()
|
||||
|
||||
@override
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
@override
|
||||
async def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
|
@@ -4,6 +4,8 @@ from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.language_models import (
|
||||
FakeListChatModel,
|
||||
@@ -171,6 +173,7 @@ async def test_callback_handlers() -> None:
|
||||
# Required to implement since this is an abstract method
|
||||
pass
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
|
@@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
||||
@@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithGenerate(BaseChatModel):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
"""Top Level call."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
"""Top Level call."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def _astream( # type: ignore
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None:
|
||||
|
||||
|
||||
class NoStreamingModel(BaseChatModel):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel):
|
||||
|
||||
|
||||
class StreamingModel(NoStreamingModel):
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.globals import set_llm_cache
|
||||
@@ -30,6 +31,7 @@ class InMemoryCache(BaseCache):
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self._cache[(prompt, llm_string)] = return_val
|
||||
|
||||
@override
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
|
@@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -106,6 +107,7 @@ async def test_error_callback() -> None:
|
||||
"""Return type of llm."""
|
||||
return "failing-llm"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithGenerate(BaseLLM):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
@@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
"""Top Level call."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
"""Top Level call."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.globals import set_llm_cache
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
@@ -20,6 +22,7 @@ class InMemoryCache(BaseCache):
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self._cache[(prompt, llm_string)] = return_val
|
||||
|
||||
@override
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
@@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache):
|
||||
msg = "This code should not be triggered"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@override
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
|
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
|
||||
|
||||
|
||||
class FakeTokenCountingModel(FakeChatModel):
|
||||
@override
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@@ -1,5 +1,7 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -16,6 +18,7 @@ def test_base_generation_parser() -> None:
|
||||
class StrInvertCase(BaseGenerationOutputParser[str]):
|
||||
"""An example parser that inverts the case of the characters in the message."""
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> str:
|
||||
@@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None:
|
||||
"""Parse a single string into a specific format."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> str:
|
||||
|
@@ -5,6 +5,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
@@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector):
|
||||
def add_example(self, example: dict[str, str]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
return list(self.examples)
|
||||
|
||||
@@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector):
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
return list(self.examples)
|
||||
|
||||
|
@@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any
|
||||
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
||||
) -> Union[BaseModel, dict]:
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
return schema(name="yo", value=value)
|
||||
|
@@ -2,7 +2,7 @@ from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
@@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
self._my_hidden_property = self.my_property
|
||||
return self
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
return self.my_property
|
||||
|
||||
def my_custom_function_w_config(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
return self.my_property
|
||||
|
||||
def my_custom_function_w_kw_config(
|
||||
self, *, config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
return self.my_property
|
||||
|
||||
@@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
class MyOtherRunnable(RunnableSerializable[str, str]):
|
||||
my_other_property: str
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
|
||||
def my_other_custom_function(self) -> str:
|
||||
return self.my_other_property
|
||||
|
||||
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str:
|
||||
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002
|
||||
return self.my_other_property
|
||||
|
||||
|
||||
|
@@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable:
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
retriever = RunnableLambda(lambda _: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
@@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable:
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
retriever = RunnableLambda(lambda _: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
@@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable:
|
||||
"What's your name?",
|
||||
]
|
||||
|
||||
retriever = RunnableLambda(lambda x: context)
|
||||
retriever = RunnableLambda(lambda _: context)
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
|
||||
|
@@ -9,6 +9,7 @@ from typing import (
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import (
|
||||
@@ -60,7 +61,7 @@ def chain() -> Runnable:
|
||||
)
|
||||
|
||||
|
||||
def _raise_error(inputs: dict) -> str:
|
||||
def _raise_error(_: dict) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
@@ -259,17 +260,17 @@ async def test_abatch() -> None:
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
def _generate(input: Iterator) -> Iterator[str]:
|
||||
def _generate(_: Iterator) -> Iterator[str]:
|
||||
yield from "foo bar"
|
||||
|
||||
|
||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
|
||||
yield ""
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
@@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None:
|
||||
list(runnable.stream({}))
|
||||
|
||||
|
||||
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
for c in "foo bar":
|
||||
yield c
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
@@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None:
|
||||
class FakeStructuredOutputModel(BaseChatModel):
|
||||
foo: int
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel):
|
||||
"""Top Level call."""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
@override
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
@@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel):
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
@override
|
||||
def with_structured_output(
|
||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
return RunnableLambda(lambda x: {"foo": self.foo})
|
||||
return RunnableLambda(lambda _: {"foo": self.foo})
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel):
|
||||
class FakeModel(BaseChatModel):
|
||||
bar: int
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -363,6 +368,7 @@ class FakeModel(BaseChatModel):
|
||||
"""Top Level call."""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
@override
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
|
||||
import pytest
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -39,7 +40,7 @@ def _get_get_session_history(
|
||||
chat_history_store = store if store is not None else {}
|
||||
|
||||
def get_session_history(
|
||||
session_id: str, **kwargs: Any
|
||||
session_id: str, **_kwargs: Any
|
||||
) -> InMemoryChatMessageHistory:
|
||||
if session_id not in chat_history_store:
|
||||
chat_history_store[session_id] = InMemoryChatMessageHistory()
|
||||
@@ -253,6 +254,7 @@ async def test_output_message_async() -> None:
|
||||
class LengthChatModel(BaseChatModel):
|
||||
"""A fake chat model that returns the length of the messages passed in."""
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
|
||||
|
||||
def test_get_output_messages_with_value_error() -> None:
|
||||
illegal_bool_message = False
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
||||
store: dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
@@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None:
|
||||
with_history.bound.invoke([HumanMessage(content="hello")], config)
|
||||
|
||||
illegal_int_message = 123
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
|
||||
with pytest.raises(
|
||||
|
@@ -21,10 +21,11 @@ from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, override
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
atrace_as_chain_group,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
@@ -184,6 +185,7 @@ class FakeTracer(BaseTracer):
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: str,
|
||||
@@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]):
|
||||
class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
||||
hello: str = ""
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: str,
|
||||
@@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
||||
|
||||
|
||||
class FakeRetriever(BaseRetriever):
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
@@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
|
||||
assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == {
|
||||
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
|
||||
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
|
||||
"required": ["root", "bar"],
|
||||
"title": "RunnableAssignOutput",
|
||||
@@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None:
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.mark.usefixtures("deterministic_uuids")
|
||||
def test_prompt_with_chat_model(
|
||||
mocker: MockerFixture,
|
||||
snapshot: SnapshotAssertion,
|
||||
deterministic_uuids: MockerFixture,
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model(
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.mark.usefixtures("deterministic_uuids")
|
||||
async def test_prompt_with_chat_model_async(
|
||||
mocker: MockerFixture,
|
||||
snapshot: SnapshotAssertion,
|
||||
deterministic_uuids: MockerFixture,
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None:
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_stream_log_lists() -> None:
|
||||
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
|
||||
async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
|
||||
for i in range(4):
|
||||
yield AddableDict(alist=[str(i)])
|
||||
|
||||
@@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda(
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.mark.usefixtures("deterministic_uuids")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture,
|
||||
snapshot: SnapshotAssertion,
|
||||
deterministic_uuids: MockerFixture,
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser(
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.mark.usefixtures("deterministic_uuids")
|
||||
def test_combining_sequences(
|
||||
mocker: MockerFixture,
|
||||
snapshot: SnapshotAssertion,
|
||||
deterministic_uuids: MockerFixture,
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -3513,7 +3505,7 @@ def test_bind_bind() -> None:
|
||||
|
||||
|
||||
def test_bind_with_lambda() -> None:
|
||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
||||
def my_function(_: Any, **kwargs: Any) -> int:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
@@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None:
|
||||
|
||||
|
||||
async def test_bind_with_lambda_async() -> None:
|
||||
def my_function(*args: Any, **kwargs: Any) -> int:
|
||||
def my_function(_: Any, **kwargs: Any) -> int:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
@@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
def test_recursive_lambda() -> None:
|
||||
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
||||
if x < 10:
|
||||
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
||||
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
|
||||
return x
|
||||
|
||||
runnable = RunnableLambda(_simple_recursion)
|
||||
@@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None:
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = list(RunnableLambda(lambda x: llm).stream(""))
|
||||
output = list(RunnableLambda(lambda _: llm).stream(""))
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
@@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
|
||||
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
|
||||
llm_res
|
||||
)
|
||||
|
||||
@@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
def raise_value_error(_: int) -> int:
|
||||
"""Raise a value error."""
|
||||
msg = "x is too large"
|
||||
raise ValueError(msg)
|
||||
@@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None:
|
||||
_
|
||||
async for _ in RunnableLambda(
|
||||
func=id,
|
||||
afunc=awrapper(lambda x: llm),
|
||||
afunc=awrapper(lambda _: llm),
|
||||
).astream("")
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
@@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None:
|
||||
output = [
|
||||
chunk
|
||||
async for chunk in cast(
|
||||
"AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("")
|
||||
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
|
||||
)
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
@@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert [
|
||||
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
|
||||
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
|
||||
] == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
def raise_value_error(_: int) -> int:
|
||||
"""Raise a value error."""
|
||||
msg = "x is too large"
|
||||
raise ValueError(msg)
|
||||
@@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||
|
||||
def test_runnable_branch_invoke() -> None:
|
||||
# Test with single branch
|
||||
def raise_value_error(x: int) -> int:
|
||||
def raise_value_error(_: int) -> int:
|
||||
"""Raise a value error."""
|
||||
msg = "x is too large"
|
||||
raise ValueError(msg)
|
||||
@@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
||||
"""Verify that callbacks are correctly used in invoke."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
def raise_value_error(_: int) -> int:
|
||||
"""Raise a value error."""
|
||||
msg = "x is too large"
|
||||
raise ValueError(msg)
|
||||
@@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||
"""Verify that callbacks are invoked correctly in ainvoke."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
async def raise_value_error(x: int) -> int:
|
||||
async def raise_value_error(_: int) -> int:
|
||||
"""Raise a value error."""
|
||||
msg = "x is too large"
|
||||
raise ValueError(msg)
|
||||
@@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None:
|
||||
runnable = RunnableLambda(lambda x: x * 2)
|
||||
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
|
||||
|
||||
def f(x: int) -> int:
|
||||
def f(_: int) -> int:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
|
||||
|
||||
async def af(x: int) -> int:
|
||||
async def af(_: int) -> int:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
@@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None:
|
||||
def test_runnable_gen() -> None:
|
||||
"""Test that a generator can be used as a runnable."""
|
||||
|
||||
def gen(input: Iterator[Any]) -> Iterator[int]:
|
||||
def gen(_: Iterator[Any]) -> Iterator[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
@@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None:
|
||||
async def test_runnable_gen_async() -> None:
|
||||
"""Test that a generator can be used as a runnable."""
|
||||
|
||||
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
@@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None:
|
||||
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||
|
||||
class AsyncGen:
|
||||
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
@@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None:
|
||||
"""
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
def gen(input: Iterator[Any]) -> Iterator[int]:
|
||||
def gen(_: Iterator[Any]) -> Iterator[int]:
|
||||
yield fake.invoke("a")
|
||||
yield fake.invoke("aa")
|
||||
yield fake.invoke("aaa")
|
||||
@@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None:
|
||||
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||
yield await fake.ainvoke("a")
|
||||
yield await fake.ainvoke("aa")
|
||||
yield await fake.ainvoke("aaa")
|
||||
@@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
@@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
@@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
on_chain_end_triggered = False
|
||||
|
||||
class MyHandler(BaseCallbackHandler):
|
||||
@override
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
@@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
nonlocal on_chain_error_triggered
|
||||
on_chain_error_triggered = True
|
||||
|
||||
@override
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: dict[str, Any],
|
||||
|
@@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
@@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
|
||||
async def test_event_stream_with_simple_function_tool() -> None:
|
||||
"""Test the event stream with a function and tool."""
|
||||
|
||||
def foo(x: int) -> dict:
|
||||
def foo(_: int) -> dict:
|
||||
"""Foo."""
|
||||
return {"x": 5}
|
||||
|
||||
@tool
|
||||
def get_docs(x: int) -> list[Document]:
|
||||
def get_docs(x: int) -> list[Document]: # noqa: ARG001
|
||||
"""Hello Doc."""
|
||||
return [Document(page_content="hello")]
|
||||
|
||||
@@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
|
||||
|
||||
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
||||
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
|
||||
{"run_name": "my_lambda"}
|
||||
)
|
||||
events = await _collect_events(
|
||||
@@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
return "hello"
|
||||
|
||||
@tool
|
||||
def with_callbacks(callbacks: Callbacks) -> str:
|
||||
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
|
||||
"""A tool that does nothing."""
|
||||
return "world"
|
||||
|
||||
@@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
return {"x": x, "y": y}
|
||||
|
||||
@tool
|
||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
|
||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
|
||||
"""A tool that does nothing."""
|
||||
return {"x": x, "y": y}
|
||||
|
||||
@@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
class HardCodedRetriever(BaseRetriever):
|
||||
documents: list[Document]
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
@@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None:
|
||||
async def test_event_stream_with_retry() -> None:
|
||||
"""Test the event stream with a tool."""
|
||||
|
||||
def success(inputs: str) -> str:
|
||||
def success(_: str) -> str:
|
||||
return "success"
|
||||
|
||||
def fail(inputs: str) -> None:
|
||||
def fail(_: str) -> None:
|
||||
"""Simple func."""
|
||||
msg = "fail"
|
||||
raise ValueError(msg)
|
||||
|
@@ -17,6 +17,7 @@ from typing import (
|
||||
import pytest
|
||||
from blockbuster import BlockBuster
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
@@ -97,12 +98,12 @@ async def _collect_events(
|
||||
async def test_event_stream_with_simple_function_tool() -> None:
|
||||
"""Test the event stream with a function and tool."""
|
||||
|
||||
def foo(x: int) -> dict:
|
||||
def foo(x: int) -> dict: # noqa: ARG001
|
||||
"""Foo."""
|
||||
return {"x": 5}
|
||||
|
||||
@tool
|
||||
def get_docs(x: int) -> list[Document]:
|
||||
def get_docs(x: int) -> list[Document]: # noqa: ARG001
|
||||
"""Hello Doc."""
|
||||
return [Document(page_content="hello")]
|
||||
|
||||
@@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
|
||||
|
||||
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
||||
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
|
||||
{"run_name": "my_lambda"}
|
||||
)
|
||||
events = await _collect_events(
|
||||
@@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
return "hello"
|
||||
|
||||
@tool
|
||||
def with_callbacks(callbacks: Callbacks) -> str:
|
||||
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
|
||||
"""A tool that does nothing."""
|
||||
return "world"
|
||||
|
||||
@@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
return {"x": x, "y": y}
|
||||
|
||||
@tool
|
||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
|
||||
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
|
||||
"""A tool that does nothing."""
|
||||
return {"x": x, "y": y}
|
||||
|
||||
@@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
class HardCodedRetriever(BaseRetriever):
|
||||
documents: list[Document]
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
@@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None:
|
||||
async def test_event_stream_with_retry() -> None:
|
||||
"""Test the event stream with a tool."""
|
||||
|
||||
def success(inputs: str) -> str:
|
||||
def success(_: str) -> str:
|
||||
return "success"
|
||||
|
||||
def fail(inputs: str) -> None:
|
||||
def fail(_: str) -> None:
|
||||
"""Simple func."""
|
||||
msg = "fail"
|
||||
raise ValueError(msg)
|
||||
@@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
"""Initialize the runnable."""
|
||||
self.iterable = iterable
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
@@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
@@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None:
|
||||
async def test_runnable_generator() -> None:
|
||||
"""Test async events from sync lambda."""
|
||||
|
||||
async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]:
|
||||
async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
|
||||
yield "1"
|
||||
yield "2"
|
||||
|
||||
|
@@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
||||
def test_sync(
|
||||
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
|
||||
) -> None:
|
||||
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
|
||||
def other_thing(_: int) -> Generator[int, None, None]: # type: ignore
|
||||
yield 1
|
||||
|
||||
parent = self._create_parent(other_thing)
|
||||
@@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
||||
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
|
||||
],
|
||||
) -> None:
|
||||
async def other_thing(a: int) -> AsyncGenerator[int, None]:
|
||||
async def other_thing(_: int) -> AsyncGenerator[int, None]:
|
||||
yield 1
|
||||
|
||||
parent = self._create_parent(other_thing)
|
||||
|
@@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
self.value = "bar"
|
||||
|
||||
@tool
|
||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
|
||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
|
||||
"""Tool that has an injected tool arg."""
|
||||
return foo.value
|
||||
|
||||
@@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
||||
|
||||
def test_empty_string_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int) -> str:
|
||||
def foo(x: int) -> str: # noqa: ARG001
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None:
|
||||
def test_tool_decorator_description() -> None:
|
||||
# test basic tool
|
||||
@tool
|
||||
def foo(x: int) -> str:
|
||||
def foo(x: int) -> str: # noqa: ARG001
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None:
|
||||
|
||||
# test basic tool with description
|
||||
@tool(description="description")
|
||||
def foo_description(x: int) -> str:
|
||||
def foo_description(x: int) -> str: # noqa: ARG001
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None:
|
||||
x: int
|
||||
|
||||
@tool(args_schema=ArgsSchema)
|
||||
def foo_args_schema(x: int) -> str:
|
||||
def foo_args_schema(x: int) -> str: # noqa: ARG001
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema.description == "Bar."
|
||||
@@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None:
|
||||
)
|
||||
|
||||
@tool(description="description", args_schema=ArgsSchema)
|
||||
def foo_args_schema_description(x: int) -> str:
|
||||
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema_description.description == "description"
|
||||
@@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None:
|
||||
}
|
||||
|
||||
@tool(args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema(x: int) -> str:
|
||||
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
|
||||
return "hi"
|
||||
|
||||
@tool(description="description", args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema_with_description(x: int) -> str:
|
||||
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
|
||||
return "hi"
|
||||
|
||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||
@@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None:
|
||||
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
||||
|
||||
def sync_no_op(foo: int) -> str:
|
||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||
return "good"
|
||||
|
||||
async def async_no_op(foo: int) -> str:
|
||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||
return "good"
|
||||
|
||||
tool = StructuredTool(
|
||||
@@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
||||
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
||||
|
||||
def sync_no_op(foo: int) -> str:
|
||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||
return "good"
|
||||
|
||||
async def async_no_op(foo: int) -> str:
|
||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||
return "good"
|
||||
|
||||
tool = StructuredTool(
|
||||
|
@@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool:
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
return StructuredTool.from_function(
|
||||
lambda x: None,
|
||||
lambda _: None,
|
||||
name="dummy_function",
|
||||
description="Dummy function.",
|
||||
args_schema=Schema,
|
||||
@@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
return StructuredTool.from_function(
|
||||
lambda x: None,
|
||||
lambda _: None,
|
||||
name="dummy_function",
|
||||
description="Dummy function.",
|
||||
args_schema=args_schema,
|
||||
|
@@ -10,6 +10,7 @@ import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||
@@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, Document] = {}
|
||||
|
||||
@override
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_texts( # type: ignore
|
||||
cls,
|
||||
texts: list[str],
|
||||
@@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, Document] = {}
|
||||
|
||||
@override
|
||||
def add_documents(
|
||||
self,
|
||||
documents: list[Document],
|
||||
@@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_texts( # type: ignore
|
||||
cls,
|
||||
texts: list[str],
|
||||
|
Reference in New Issue
Block a user