core: Add ruff rules ARG (#30732)

See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg
This commit is contained in:
Christophe Bornet 2025-04-09 20:39:36 +02:00 committed by GitHub
parent 66758599a9
commit 98f0016fc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 328 additions and 180 deletions

View File

@ -124,7 +124,7 @@ def beta(
_name = _name or obj.__qualname__ _name = _name or obj.__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the annotation of a class.""" """Finalize the annotation of a class."""
# Can't set new_doc on some extension objects. # Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
@ -190,7 +190,7 @@ def beta(
if _name == "<lambda>": if _name == "<lambda>":
_name = set_name _name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: # noqa: ARG001
"""Finalize the property.""" """Finalize the property."""
return _BetaProperty( return _BetaProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc

View File

@ -204,7 +204,7 @@ def deprecated(
_name = _name or obj.__qualname__ _name = _name or obj.__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the deprecation of a class.""" """Finalize the deprecation of a class."""
# Can't set new_doc on some extension objects. # Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
@ -234,7 +234,7 @@ def deprecated(
raise ValueError(msg) raise ValueError(msg)
old_doc = obj.description old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast( return cast(
"T", "T",
FieldInfoV1( FieldInfoV1(
@ -255,7 +255,7 @@ def deprecated(
raise ValueError(msg) raise ValueError(msg)
old_doc = obj.description old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast( return cast(
"T", "T",
FieldInfoV2( FieldInfoV2(
@ -315,7 +315,7 @@ def deprecated(
if _name == "<lambda>": if _name == "<lambda>":
_name = set_name _name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the property.""" """Finalize the property."""
return cast( return cast(
"T", "T",

View File

@ -27,6 +27,8 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
from typing_extensions import override
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
@ -194,6 +196,7 @@ class InMemoryCache(BaseCache):
del self._cache[next(iter(self._cache))] del self._cache[next(iter(self._cache))]
self._cache[(prompt, llm_string)] = return_val self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache.""" """Clear cache."""
self._cache = {} self._cache = {}
@ -227,6 +230,7 @@ class InMemoryCache(BaseCache):
""" """
self.update(prompt, llm_string, return_val) self.update(prompt, llm_string, return_val)
@override
async def aclear(self, **kwargs: Any) -> None: async def aclear(self, **kwargs: Any) -> None:
"""Async clear cache.""" """Async clear cache."""
self.clear() self.clear()

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
from typing_extensions import override
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.utils.input import print_text from langchain_core.utils.input import print_text
@ -38,6 +40,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""Destructor to cleanup when done.""" """Destructor to cleanup when done."""
self.file.close() self.file.close()
@override
def on_chain_start( def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None: ) -> None:
@ -61,6 +64,7 @@ class FileCallbackHandler(BaseCallbackHandler):
file=self.file, file=self.file,
) )
@override
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain. """Print out that we finished a chain.
@ -70,6 +74,7 @@ class FileCallbackHandler(BaseCallbackHandler):
""" """
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file) print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
@override
def on_agent_action( def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -83,6 +88,7 @@ class FileCallbackHandler(BaseCallbackHandler):
""" """
print_text(action.log, color=color or self.color, file=self.file) print_text(action.log, color=color or self.color, file=self.file)
@override
def on_tool_end( def on_tool_end(
self, self,
output: str, output: str,
@ -109,6 +115,7 @@ class FileCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None: if llm_prefix is not None:
print_text(f"\n{llm_prefix}", file=self.file) print_text(f"\n{llm_prefix}", file=self.file)
@override
def on_text( def on_text(
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
) -> None: ) -> None:
@ -123,6 +130,7 @@ class FileCallbackHandler(BaseCallbackHandler):
""" """
print_text(text, color=color or self.color, end=end, file=self.file) print_text(text, color=color or self.color, end=end, file=self.file)
@override
def on_agent_finish( def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None: ) -> None:

View File

@ -22,7 +22,7 @@ from typing import (
from uuid import UUID from uuid import UUID
from langsmith.run_helpers import get_tracing_context from langsmith.run_helpers import get_tracing_context
from typing_extensions import Self from typing_extensions import Self, override
from langchain_core.callbacks.base import ( from langchain_core.callbacks.base import (
BaseCallbackHandler, BaseCallbackHandler,
@ -1401,6 +1401,7 @@ class CallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata, inheritable_metadata=self.inheritable_metadata,
) )
@override
def on_tool_start( def on_tool_start(
self, self,
serialized: Optional[dict[str, Any]], serialized: Optional[dict[str, Any]],
@ -1456,6 +1457,7 @@ class CallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata, inheritable_metadata=self.inheritable_metadata,
) )
@override
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Optional[dict[str, Any]], serialized: Optional[dict[str, Any]],
@ -1927,6 +1929,7 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata, inheritable_metadata=self.inheritable_metadata,
) )
@override
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Optional[dict[str, Any]], serialized: Optional[dict[str, Any]],
@ -2017,6 +2020,7 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata=self.metadata, metadata=self.metadata,
) )
@override
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Optional[dict[str, Any]], serialized: Optional[dict[str, Any]],

View File

@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from typing_extensions import override
from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.utils import print_text from langchain_core.utils import print_text
@ -22,6 +24,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
""" """
self.color = color self.color = color
@override
def on_chain_start( def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None: ) -> None:
@ -41,6 +44,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
name = "<unknown>" name = "<unknown>"
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201 print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201
@override
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain. """Print out that we finished a chain.
@ -50,6 +54,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
""" """
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201 print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
@override
def on_agent_action( def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -62,6 +67,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
""" """
print_text(action.log, color=color or self.color) print_text(action.log, color=color or self.color)
@override
def on_tool_end( def on_tool_end(
self, self,
output: Any, output: Any,
@ -87,6 +93,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None: if llm_prefix is not None:
print_text(f"\n{llm_prefix}") print_text(f"\n{llm_prefix}")
@override
def on_text( def on_text(
self, self,
text: str, text: str,
@ -104,6 +111,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
""" """
print_text(text, color=color or self.color, end=end) print_text(text, color=color or self.color, end=end)
@override
def on_agent_finish( def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None: ) -> None:

View File

@ -5,6 +5,8 @@ from __future__ import annotations
import sys import sys
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from typing_extensions import override
from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,6 +43,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
**kwargs (Any): Additional keyword arguments. **kwargs (Any): Additional keyword arguments.
""" """
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled. """Run on new LLM token. Only available when streaming is enabled.

View File

@ -58,6 +58,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self.usage_metadata) return str(self.usage_metadata)
@override
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage.""" """Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2) # Check for usage_metadata (langchain-core >= 0.2.2)

View File

@ -151,7 +151,7 @@ def _get_source_id_assigner(
) -> Callable[[Document], Union[str, None]]: ) -> Callable[[Document], Union[str, None]]:
"""Get the source id from the document.""" """Get the source id from the document."""
if source_id_key is None: if source_id_key is None:
return lambda doc: None return lambda _doc: None
if isinstance(source_id_key, str): if isinstance(source_id_key, str):
return lambda doc: doc.metadata[source_id_key] return lambda doc: doc.metadata[source_id_key]
if callable(source_id_key): if callable(source_id_key):

View File

@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.callbacks import CallbackManagerForRetrieverRun
@ -29,6 +30,7 @@ class InMemoryDocumentIndex(DocumentIndex):
store: dict[str, Document] = Field(default_factory=dict) store: dict[str, Document] = Field(default_factory=dict)
top_k: int = 4 top_k: int = 4
@override
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
"""Upsert items into the index.""" """Upsert items into the index."""
ok_ids = [] ok_ids = []
@ -47,6 +49,7 @@ class InMemoryDocumentIndex(DocumentIndex):
return UpsertResponse(succeeded=ok_ids, failed=[]) return UpsertResponse(succeeded=ok_ids, failed=[])
@override
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by ID.""" """Delete by ID."""
if ids is None: if ids is None:
@ -64,10 +67,12 @@ class InMemoryDocumentIndex(DocumentIndex):
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[] succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
) )
@override
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]: def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids.""" """Get by ids."""
return [self.store[id_] for id_ in ids if id_ in self.store] return [self.store[id_] for id_ in ids if id_ in self.store]
@override
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]: ) -> list[Document]:

View File

@ -576,7 +576,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Custom methods --- # --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
return {} return {}
def _get_invocation_params( def _get_invocation_params(
@ -1246,6 +1246,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of chat model.""" """Return type of chat model."""
@override
def dict(self, **kwargs: Any) -> dict: def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)

View File

@ -171,6 +171,7 @@ class FakeListChatModel(SimpleChatModel):
class FakeChatModel(SimpleChatModel): class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes.""" """Fake Chat Model wrapper for testing purposes."""
@override
def _call( def _call(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -180,6 +181,7 @@ class FakeChatModel(SimpleChatModel):
) -> str: ) -> str:
return "fake response" return "fake response"
@override
async def _agenerate( async def _agenerate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -224,6 +226,7 @@ class GenericFakeChatModel(BaseChatModel):
into message chunks. into message chunks.
""" """
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -346,6 +349,7 @@ class ParrotFakeChatModel(BaseChatModel):
* Chat model should be usable in both sync and async tests * Chat model should be usable in both sync and async tests
""" """
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -1399,6 +1399,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
@override
def dict(self, **kwargs: Any) -> dict: def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)

View File

@ -59,7 +59,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
Returns: Returns:
Structured output. Structured output.
""" """
return await run_in_executor(None, self.parse_result, result) return await run_in_executor(None, self.parse_result, result, partial=partial)
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
@ -231,6 +231,7 @@ class BaseOutputParser(
run_type="parser", run_type="parser",
) )
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
@ -290,7 +291,11 @@ class BaseOutputParser(
return await run_in_executor(None, self.parse, text) return await run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'. # TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: def parse_with_prompt(
self,
completion: str,
prompt: PromptValue, # noqa: ARG002
) -> Any:
"""Parse the output of an LLM call with the input prompt for context. """Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants The prompt is largely provided in the event the OutputParser wants

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
@ -23,6 +24,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True args_only: bool = True
"""Whether to only return the arguments to the function call.""" """Whether to only return the arguments to the function call."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
@ -251,6 +253,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
raise ValueError(msg) raise ValueError(msg)
return values return values
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
@ -287,6 +290,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str attr_name: str
"""The name of the attribute to return.""" """The name of the attribute to return."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.

View File

@ -9,6 +9,8 @@ from typing import (
Union, Union,
) )
from typing_extensions import override
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import ( from langchain_core.outputs import (
@ -48,6 +50,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
None, self.parse_result, [Generation(text=chunk)] None, self.parse_result, [Generation(text=chunk)]
) )
@override
def transform( def transform(
self, self,
input: Iterator[Union[str, BaseMessage]], input: Iterator[Union[str, BaseMessage]],
@ -68,6 +71,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
input, self._transform, config, run_type="parser" input, self._transform, config, run_type="parser"
) )
@override
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Union[str, BaseMessage]], input: AsyncIterator[Union[str, BaseMessage]],

View File

@ -127,6 +127,7 @@ class BasePromptTemplate(
"""Return the output type of the prompt.""" """Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete] return Union[StringPromptValue, ChatPromptValueConcrete]
@override
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]: ) -> type[BaseModel]:

View File

@ -199,7 +199,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
) )
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: def _get_ls_params(self, **_kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
default_retriever_name = self.get_name() default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"): if default_retriever_name.startswith("Retriever"):

View File

@ -326,7 +326,8 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_input_schema() return self.get_input_schema()
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate input to the Runnable. """Get a pydantic model that can be used to validate input to the Runnable.
@ -398,7 +399,8 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_output_schema() return self.get_output_schema()
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable. """Get a pydantic model that can be used to validate output to the Runnable.
@ -4751,11 +4753,6 @@ class RunnableLambda(Runnable[Input, Output]):
) )
return cast("Output", output) return cast("Output", output)
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
return ensure_config(config)
@override @override
def invoke( def invoke(
self, self,
@ -4780,7 +4777,7 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config( return self._call_with_config(
self._invoke, self._invoke,
input, input,
self._config(config, self.func), ensure_config(config),
**kwargs, **kwargs,
) )
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead." msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
@ -4803,11 +4800,10 @@ class RunnableLambda(Runnable[Input, Output]):
Returns: Returns:
The output of this Runnable. The output of this Runnable.
""" """
the_func = self.afunc if hasattr(self, "afunc") else self.func
return await self._acall_with_config( return await self._acall_with_config(
self._ainvoke, self._ainvoke,
input, input,
self._config(config, the_func), ensure_config(config),
**kwargs, **kwargs,
) )
@ -4884,7 +4880,7 @@ class RunnableLambda(Runnable[Input, Output]):
yield from self._transform_stream_with_config( yield from self._transform_stream_with_config(
input, input,
self._transform, self._transform,
self._config(config, self.func), ensure_config(config),
**kwargs, **kwargs,
) )
else: else:
@ -5012,7 +5008,7 @@ class RunnableLambda(Runnable[Input, Output]):
async for output in self._atransform_stream_with_config( async for output in self._atransform_stream_with_config(
input, input,
self._atransform, self._atransform,
self._config(config, self.afunc if hasattr(self, "afunc") else self.func), ensure_config(config),
**kwargs, **kwargs,
): ):
yield output yield output

View File

@ -400,6 +400,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
return self._history_chain.OutputType return self._history_chain.OutputType
@override
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]: ) -> type[BaseModel]:
@ -432,12 +433,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
module_name=self.__class__.__module__, module_name=self.__class__.__module__,
) )
def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return False
async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return True
def _get_input_messages( def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> list[BaseMessage]: ) -> list[BaseMessage]:

View File

@ -133,6 +133,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_llm_start(llm_run) self._on_llm_start(llm_run)
return llm_run return llm_run
@override
def on_llm_new_token( def on_llm_new_token(
self, self,
token: str, token: str,
@ -161,11 +162,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id=run_id, run_id=run_id,
chunk=chunk, chunk=chunk,
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
**kwargs,
) )
self._on_llm_new_token(llm_run, token, chunk) self._on_llm_new_token(llm_run, token, chunk)
return llm_run return llm_run
@override
def on_retry( def on_retry(
self, self,
retry_state: RetryCallState, retry_state: RetryCallState,
@ -188,6 +189,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id=run_id, run_id=run_id,
) )
@override
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run. """End a trace for an LLM run.
@ -235,6 +237,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_llm_error(llm_run) self._on_llm_error(llm_run)
return llm_run return llm_run
@override
def on_chain_start( def on_chain_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -279,6 +282,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_chain_start(chain_run) self._on_chain_start(chain_run)
return chain_run return chain_run
@override
def on_chain_end( def on_chain_end(
self, self,
outputs: dict[str, Any], outputs: dict[str, Any],
@ -302,12 +306,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
outputs=outputs, outputs=outputs,
run_id=run_id, run_id=run_id,
inputs=inputs, inputs=inputs,
**kwargs,
) )
self._end_trace(chain_run) self._end_trace(chain_run)
self._on_chain_end(chain_run) self._on_chain_end(chain_run)
return chain_run return chain_run
@override
def on_chain_error( def on_chain_error(
self, self,
error: BaseException, error: BaseException,
@ -331,7 +335,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
error=error, error=error,
run_id=run_id, run_id=run_id,
inputs=inputs, inputs=inputs,
**kwargs,
) )
self._end_trace(chain_run) self._end_trace(chain_run)
self._on_chain_error(chain_run) self._on_chain_error(chain_run)
@ -381,6 +384,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_tool_start(tool_run) self._on_tool_start(tool_run)
return tool_run return tool_run
@override
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run: def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run. """End a trace for a tool run.
@ -395,12 +399,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
tool_run = self._complete_tool_run( tool_run = self._complete_tool_run(
output=output, output=output,
run_id=run_id, run_id=run_id,
**kwargs,
) )
self._end_trace(tool_run) self._end_trace(tool_run)
self._on_tool_end(tool_run) self._on_tool_end(tool_run)
return tool_run return tool_run
@override
def on_tool_error( def on_tool_error(
self, self,
error: BaseException, error: BaseException,
@ -467,6 +471,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_retriever_start(retrieval_run) self._on_retriever_start(retrieval_run)
return retrieval_run return retrieval_run
@override
def on_retriever_error( def on_retriever_error(
self, self,
error: BaseException, error: BaseException,
@ -487,12 +492,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
retrieval_run = self._errored_retrieval_run( retrieval_run = self._errored_retrieval_run(
error=error, error=error,
run_id=run_id, run_id=run_id,
**kwargs,
) )
self._end_trace(retrieval_run) self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run) self._on_retriever_error(retrieval_run)
return retrieval_run return retrieval_run
@override
def on_retriever_end( def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run: ) -> Run:
@ -509,7 +514,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
retrieval_run = self._complete_retrieval_run( retrieval_run = self._complete_retrieval_run(
documents=documents, documents=documents,
run_id=run_id, run_id=run_id,
**kwargs,
) )
self._end_trace(retrieval_run) self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run) self._on_retriever_end(retrieval_run)
@ -623,7 +627,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
run_id=run_id, run_id=run_id,
chunk=chunk, chunk=chunk,
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
**kwargs,
) )
await self._on_llm_new_token(llm_run, token, chunk) await self._on_llm_new_token(llm_run, token, chunk)
@ -715,7 +718,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
outputs=outputs, outputs=outputs,
run_id=run_id, run_id=run_id,
inputs=inputs, inputs=inputs,
**kwargs,
) )
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)] tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -733,7 +735,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
error=error, error=error,
inputs=inputs, inputs=inputs,
run_id=run_id, run_id=run_id,
**kwargs,
) )
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)] tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -776,7 +777,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
tool_run = self._complete_tool_run( tool_run = self._complete_tool_run(
output=output, output=output,
run_id=run_id, run_id=run_id,
**kwargs,
) )
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)] tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -839,7 +839,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
retrieval_run = self._errored_retrieval_run( retrieval_run = self._errored_retrieval_run(
error=error, error=error,
run_id=run_id, run_id=run_id,
**kwargs,
) )
tasks = [ tasks = [
self._end_trace(retrieval_run), self._end_trace(retrieval_run),
@ -860,7 +859,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
retrieval_run = self._complete_retrieval_run( retrieval_run = self._complete_retrieval_run(
documents=documents, documents=documents,
run_id=run_id, run_id=run_id,
**kwargs,
) )
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)] tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@ -41,7 +41,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
@contextmanager @contextmanager
def tracing_enabled( def tracing_enabled(
session_name: str = "default", session_name: str = "default", # noqa: ARG001
) -> Generator[TracerSessionV1, None, None]: ) -> Generator[TracerSessionV1, None, None]:
"""Throw an error because this has been replaced by tracing_v2_enabled.""" """Throw an error because this has been replaced by tracing_v2_enabled."""
msg = ( msg = (

View File

@ -231,8 +231,7 @@ class _TracerCore(ABC):
token: str, token: str,
run_id: UUID, run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, # noqa: ARG002
**kwargs: Any,
) -> Run: ) -> Run:
"""Append token event to LLM run and return the run.""" """Append token event to LLM run and return the run."""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
@ -252,7 +251,6 @@ class _TracerCore(ABC):
self, self,
retry_state: RetryCallState, retry_state: RetryCallState,
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
llm_run = self._get_run(run_id) llm_run = self._get_run(run_id)
retry_d: dict[str, Any] = { retry_d: dict[str, Any] = {
@ -369,7 +367,6 @@ class _TracerCore(ABC):
outputs: dict[str, Any], outputs: dict[str, Any],
run_id: UUID, run_id: UUID,
inputs: Optional[dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run: ) -> Run:
"""Update a chain run with outputs and end time.""" """Update a chain run with outputs and end time."""
chain_run = self._get_run(run_id) chain_run = self._get_run(run_id)
@ -385,7 +382,6 @@ class _TracerCore(ABC):
error: BaseException, error: BaseException,
inputs: Optional[dict[str, Any]], inputs: Optional[dict[str, Any]],
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
chain_run = self._get_run(run_id) chain_run = self._get_run(run_id)
chain_run.error = self._get_stacktrace(error) chain_run.error = self._get_stacktrace(error)
@ -439,7 +435,6 @@ class _TracerCore(ABC):
self, self,
output: dict[str, Any], output: dict[str, Any],
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
"""Update a tool run with outputs and end time.""" """Update a tool run with outputs and end time."""
tool_run = self._get_run(run_id, run_type="tool") tool_run = self._get_run(run_id, run_type="tool")
@ -452,7 +447,6 @@ class _TracerCore(ABC):
self, self,
error: BaseException, error: BaseException,
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
"""Update a tool run with error and end time.""" """Update a tool run with error and end time."""
tool_run = self._get_run(run_id, run_type="tool") tool_run = self._get_run(run_id, run_type="tool")
@ -494,7 +488,6 @@ class _TracerCore(ABC):
self, self,
documents: Sequence[Document], documents: Sequence[Document],
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
"""Update a retrieval run with outputs and end time.""" """Update a retrieval run with outputs and end time."""
retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run = self._get_run(run_id, run_type="retriever")
@ -507,7 +500,6 @@ class _TracerCore(ABC):
self, self,
error: BaseException, error: BaseException,
run_id: UUID, run_id: UUID,
**kwargs: Any,
) -> Run: ) -> Run:
retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = self._get_stacktrace(error) retrieval_run.error = self._get_stacktrace(error)
@ -523,75 +515,75 @@ class _TracerCore(ABC):
"""Copy the tracer.""" """Copy the tracer."""
return self return self
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""End a trace for a run.""" """End a trace for a run."""
return None return None
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process a run upon creation.""" """Process a run upon creation."""
return None return None
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process a run upon update.""" """Process a run upon update."""
return None return None
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run upon start.""" """Process the LLM Run upon start."""
return None return None
def _on_llm_new_token( def _on_llm_new_token(
self, self,
run: Run, run: Run, # noqa: ARG002
token: str, token: str, # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
) -> Union[None, Coroutine[Any, Any, None]]: ) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token.""" """Process new LLM token."""
return None return None
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run.""" """Process the LLM Run."""
return None return None
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run upon error.""" """Process the LLM Run upon error."""
return None return None
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run upon start.""" """Process the Chain Run upon start."""
return None return None
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run.""" """Process the Chain Run."""
return None return None
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run upon error.""" """Process the Chain Run upon error."""
return None return None
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run upon start.""" """Process the Tool Run upon start."""
return None return None
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run.""" """Process the Tool Run."""
return None return None
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run upon error.""" """Process the Tool Run upon error."""
return None return None
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chat Model Run upon start.""" """Process the Chat Model Run upon start."""
return None return None
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run upon start.""" """Process the Retriever Run upon start."""
return None return None
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run.""" """Process the Retriever Run."""
return None return None
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run upon error.""" """Process the Retriever Run upon error."""
return None return None

View File

@ -15,7 +15,7 @@ from typing import (
) )
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict, override
from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
@ -293,6 +293,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.run_map[run_id] = info self.run_map[run_id] = info
self.parent_map[run_id] = parent_run_id self.parent_map[run_id] = parent_run_id
@override
async def on_chat_model_start( async def on_chat_model_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -334,6 +335,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type, run_type,
) )
@override
async def on_llm_start( async def on_llm_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -377,6 +379,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type, run_type,
) )
@override
async def on_custom_event( async def on_custom_event(
self, self,
name: str, name: str,
@ -399,6 +402,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
) )
self._send(event, name) self._send(event, name)
@override
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,
@ -450,6 +454,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info["run_type"], run_info["run_type"],
) )
@override
async def on_llm_end( async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any self, response: LLMResult, *, run_id: UUID, **kwargs: Any
) -> None: ) -> None:
@ -552,6 +557,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type_, run_type_,
) )
@override
async def on_chain_end( async def on_chain_end(
self, self,
outputs: dict[str, Any], outputs: dict[str, Any],
@ -586,6 +592,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type, run_type,
) )
@override
async def on_tool_start( async def on_tool_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -627,6 +634,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool", "tool",
) )
@override
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run.""" """End a trace for a tool run."""
run_info = self.run_map.pop(run_id) run_info = self.run_map.pop(run_id)
@ -654,6 +662,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool", "tool",
) )
@override
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -697,6 +706,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type, run_type,
) )
@override
async def on_retriever_end( async def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None: ) -> None:

View File

@ -19,6 +19,7 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential_jitter, wait_exponential_jitter,
) )
from typing_extensions import override
from langchain_core.env import get_runtime_environment from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd from langchain_core.load import dumpd
@ -252,13 +253,13 @@ class LangChainTracer(BaseTracer):
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._persist_run_single(run) self._persist_run_single(run)
@override
def _llm_run_with_token_event( def _llm_run_with_token_event(
self, self,
token: str, token: str,
run_id: UUID, run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run: ) -> Run:
"""Append token event to LLM run and return the run.""" """Append token event to LLM run and return the run."""
return super()._llm_run_with_token_event( return super()._llm_run_with_token_event(
@ -267,7 +268,6 @@ class LangChainTracer(BaseTracer):
run_id, run_id,
chunk=None, chunk=None,
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
**kwargs,
) )
def _on_chat_model_start(self, run: Run) -> None: def _on_chat_model_start(self, run: Run) -> None:

View File

@ -6,7 +6,7 @@ Please use LangChainTracer instead.
from typing import Any from typing import Any
def get_headers(*args: Any, **kwargs: Any) -> Any: def get_headers(*args: Any, **kwargs: Any) -> Any: # noqa: ARG001
"""Throw an error because this has been replaced by get_headers.""" """Throw an error because this has been replaced by get_headers."""
msg = ( msg = (
"get_headers for LangChainTracerV1 is no longer supported. " "get_headers for LangChainTracerV1 is no longer supported. "
@ -15,7 +15,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(msg) raise RuntimeError(msg)
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802 def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802,ARG001
"""Throw an error because this has been replaced by LangChainTracer.""" """Throw an error because this has been replaced by LangChainTracer."""
msg = ( msg = (
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."

View File

@ -16,8 +16,8 @@ from langchain_core._api.deprecation import deprecated
), ),
) )
def try_load_from_hub( def try_load_from_hub(
*args: Any, *args: Any, # noqa: ARG001
**kwargs: Any, **kwargs: Any, # noqa: ARG001
) -> Any: ) -> Any:
"""[DEPRECATED] Try to load from the old Hub.""" """[DEPRECATED] Try to load from the old Hub."""
warnings.warn( warnings.warn(

View File

@ -65,7 +65,11 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
return (literal, template) return (literal, template)
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: def l_sa_check(
template: str, # noqa: ARG001
literal: str,
is_standalone: bool,
) -> bool:
"""Do a preliminary check to see if a tag could be a standalone. """Do a preliminary check to see if a tag could be a standalone.
Args: Args:

View File

@ -38,6 +38,7 @@ from pydantic.json_schema import (
JsonSchemaMode, JsonSchemaMode,
JsonSchemaValue, JsonSchemaValue,
) )
from typing_extensions import override
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic_core import core_schema from pydantic_core import core_schema
@ -233,6 +234,7 @@ class _IgnoreUnserializable(GenerateJsonSchema):
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
""" """
@override
def handle_invalid_for_json_schema( def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue: ) -> JsonSchemaValue:

View File

@ -13,6 +13,7 @@ from typing import Any, Callable, Optional, Union, overload
from packaging.version import parse from packaging.version import parse
from pydantic import SecretStr from pydantic import SecretStr
from requests import HTTPError, Response from requests import HTTPError, Response
from typing_extensions import override
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass, is_pydantic_v1_subclass,
@ -91,6 +92,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]:
"""Mock datetime.datetime.now() with a fixed datetime.""" """Mock datetime.datetime.now() with a fixed datetime."""
@classmethod @classmethod
@override
def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime": def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime":
# Create a copy of dt_value. # Create a copy of dt_value.
return MockDateTime( return MockDateTime(

View File

@ -36,7 +36,7 @@ from typing import (
) )
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self from typing_extensions import Self, override
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
@ -1070,6 +1070,7 @@ class VectorStoreRetriever(BaseRetriever):
return ls_params return ls_params
@override
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]: ) -> list[Document]:
@ -1090,6 +1091,7 @@ class VectorStoreRetriever(BaseRetriever):
raise ValueError(msg) raise ValueError(msg)
return docs return docs
@override
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,

View File

@ -285,7 +285,7 @@ class InMemoryVectorStore(VectorStore):
since="0.2.29", since="0.2.29",
removal="1.0", removal="1.0",
) )
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: def upsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse:
"""[DEPRECATED] Upsert documents into the store. """[DEPRECATED] Upsert documents into the store.
Args: Args:
@ -319,7 +319,7 @@ class InMemoryVectorStore(VectorStore):
removal="1.0", removal="1.0",
) )
async def aupsert( async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any self, items: Sequence[Document], /, **_kwargs: Any
) -> UpsertResponse: ) -> UpsertResponse:
"""[DEPRECATED] Upsert documents into the store. """[DEPRECATED] Upsert documents into the store.
@ -364,7 +364,6 @@ class InMemoryVectorStore(VectorStore):
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> list[tuple[Document, float, list[float]]]: ) -> list[tuple[Document, float, list[float]]]:
# get all docs with fixed order in list # get all docs with fixed order in list
docs = list(self.store.values()) docs = list(self.store.values())
@ -404,7 +403,7 @@ class InMemoryVectorStore(VectorStore):
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any, **_kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Search for the most similar documents to the given embedding. """Search for the most similar documents to the given embedding.
@ -419,7 +418,7 @@ class InMemoryVectorStore(VectorStore):
return [ return [
(doc, similarity) (doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector( for doc, similarity, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs embedding=embedding, k=k, filter=filter
) )
] ]
@ -490,12 +489,14 @@ class InMemoryVectorStore(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
*,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any, **kwargs: Any,
) -> list[Document]: ) -> list[Document]:
prefetch_hits = self._similarity_search_with_score_by_vector( prefetch_hits = self._similarity_search_with_score_by_vector(
embedding=embedding, embedding=embedding,
k=fetch_k, k=fetch_k,
**kwargs, filter=filter,
) )
try: try:

View File

@ -98,7 +98,6 @@ ignore = [
# TODO rules # TODO rules
"A", "A",
"ANN401", "ANN401",
"ARG",
"BLE", "BLE",
"ERA", "ERA",
"FBT001", "FBT001",
@ -132,5 +131,6 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini
"tests/unit_tests/prompts/test_chat.py" = [ "E501",] "tests/unit_tests/prompts/test_chat.py" = [ "E501",]
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",] "tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
"tests/unit_tests/runnables/test_graph.py" = [ "E501",] "tests/unit_tests/runnables/test_graph.py" = [ "E501",]
"tests/unit_tests/test_tools.py" = [ "ARG",]
"tests/**" = [ "D", "S",] "tests/**" = [ "D", "S",]
"scripts/**" = [ "INP", "S",] "scripts/**" = [ "INP", "S",]

View File

@ -10,6 +10,8 @@ from contextlib import asynccontextmanager
from typing import Any, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackHandler, AsyncCallbackHandler,
AsyncCallbackManager, AsyncCallbackManager,
@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None:
"""Initialize the handler.""" """Initialize the handler."""
self.run_inline = run_inline self.run_inline = run_inline
@override
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
"""Update the callstack with the name of the callback.""" """Update the callstack with the name of the callback."""
some_var.set("on_llm_start") some_var.set("on_llm_start")

View File

@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None:
# a decorator for async functions # a decorator for async functions
@RunnableLambda # type: ignore[arg-type] @RunnableLambda # type: ignore[arg-type]
async def foo(x: int, config: RunnableConfig) -> int: async def foo(x: int, config: RunnableConfig) -> int:
assert "callbacks" in config
await adispatch_custom_event("event1", {"x": x}) await adispatch_custom_event("event1", {"x": x})
await adispatch_custom_event("event2", {"x": x}) await adispatch_custom_event("event2", {"x": x})
return x return x

View File

@ -3,6 +3,7 @@
from collections.abc import Iterator from collections.abc import Iterator
import pytest import pytest
from typing_extensions import override
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
@ -15,6 +16,7 @@ def test_base_blob_parser() -> None:
class MyParser(BaseBlobParser): class MyParser(BaseBlobParser):
"""A simple parser that returns a single document.""" """A simple parser that returns a single document."""
@override
def lazy_parse(self, blob: Blob) -> Iterator[Document]: def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazy parsing interface.""" """Lazy parsing interface."""
yield Document( yield Document(

View File

@ -1,6 +1,8 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, cast from typing import Any, Optional, cast
from typing_extensions import override
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.example_selectors import ( from langchain_core.example_selectors import (
@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore):
def embeddings(self) -> Optional[Embeddings]: def embeddings(self) -> Optional[Embeddings]:
return self._embeddings return self._embeddings
@override
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore):
self.metadatas.extend(metadatas) self.metadatas.extend(metadatas)
return ["dummy_id"] return ["dummy_id"]
@override
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]: ) -> list[Document]:
@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore):
) )
] * k ] * k
@override
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, self,
query: str, query: str,

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks.""" """Whether to ignore retriever callbacks."""
return self.ignore_retriever_ return self.ignore_retriever_
@override
def on_llm_start( def on_llm_start(
self, self,
*args: Any, *args: Any,
@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_start_common() self.on_llm_start_common()
@override
def on_llm_new_token( def on_llm_new_token(
self, self,
*args: Any, *args: Any,
@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_new_token_common() self.on_llm_new_token_common()
@override
def on_llm_end( def on_llm_end(
self, self,
*args: Any, *args: Any,
@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_end_common() self.on_llm_end_common()
@override
def on_llm_error( def on_llm_error(
self, self,
*args: Any, *args: Any,
@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_error_common(*args, **kwargs) self.on_llm_error_common(*args, **kwargs)
@override
def on_retry( def on_retry(
self, self,
*args: Any, *args: Any,
@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retry_common() self.on_retry_common()
@override
def on_chain_start( def on_chain_start(
self, self,
*args: Any, *args: Any,
@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_start_common() self.on_chain_start_common()
@override
def on_chain_end( def on_chain_end(
self, self,
*args: Any, *args: Any,
@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_end_common() self.on_chain_end_common()
@override
def on_chain_error( def on_chain_error(
self, self,
*args: Any, *args: Any,
@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_error_common() self.on_chain_error_common()
@override
def on_tool_start( def on_tool_start(
self, self,
*args: Any, *args: Any,
@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_start_common() self.on_tool_start_common()
@override
def on_tool_end( def on_tool_end(
self, self,
*args: Any, *args: Any,
@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_end_common() self.on_tool_end_common()
@override
def on_tool_error( def on_tool_error(
self, self,
*args: Any, *args: Any,
@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_error_common() self.on_tool_error_common()
@override
def on_agent_action( def on_agent_action(
self, self,
*args: Any, *args: Any,
@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_agent_action_common() self.on_agent_action_common()
@override
def on_agent_finish( def on_agent_finish(
self, self,
*args: Any, *args: Any,
@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_agent_finish_common() self.on_agent_finish_common()
@override
def on_text( def on_text(
self, self,
*args: Any, *args: Any,
@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_text_common() self.on_text_common()
@override
def on_retriever_start( def on_retriever_start(
self, self,
*args: Any, *args: Any,
@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retriever_start_common() self.on_retriever_start_common()
@override
def on_retriever_end( def on_retriever_end(
self, self,
*args: Any, *args: Any,
@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retriever_end_common() self.on_retriever_end_common()
@override
def on_retriever_error( def on_retriever_error(
self, self,
*args: Any, *args: Any,
@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
@override
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return self.ignore_agent_ return self.ignore_agent_
@override
async def on_retry( async def on_retry(
self, self,
*args: Any, *args: Any,
@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> Any: ) -> Any:
self.on_retry_common() self.on_retry_common()
@override
async def on_llm_start( async def on_llm_start(
self, self,
*args: Any, *args: Any,
@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_start_common() self.on_llm_start_common()
@override
async def on_llm_new_token( async def on_llm_new_token(
self, self,
*args: Any, *args: Any,
@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_new_token_common() self.on_llm_new_token_common()
@override
async def on_llm_end( async def on_llm_end(
self, self,
*args: Any, *args: Any,
@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_end_common() self.on_llm_end_common()
@override
async def on_llm_error( async def on_llm_error(
self, self,
*args: Any, *args: Any,
@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_error_common(*args, **kwargs) self.on_llm_error_common(*args, **kwargs)
@override
async def on_chain_start( async def on_chain_start(
self, self,
*args: Any, *args: Any,
@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_start_common() self.on_chain_start_common()
@override
async def on_chain_end( async def on_chain_end(
self, self,
*args: Any, *args: Any,
@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_end_common() self.on_chain_end_common()
@override
async def on_chain_error( async def on_chain_error(
self, self,
*args: Any, *args: Any,
@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_error_common() self.on_chain_error_common()
@override
async def on_tool_start( async def on_tool_start(
self, self,
*args: Any, *args: Any,
@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_start_common() self.on_tool_start_common()
@override
async def on_tool_end( async def on_tool_end(
self, self,
*args: Any, *args: Any,
@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_end_common() self.on_tool_end_common()
@override
async def on_tool_error( async def on_tool_error(
self, self,
*args: Any, *args: Any,
@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_error_common() self.on_tool_error_common()
@override
async def on_agent_action( async def on_agent_action(
self, self,
*args: Any, *args: Any,
@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_agent_action_common() self.on_agent_action_common()
@override
async def on_agent_finish( async def on_agent_finish(
self, self,
*args: Any, *args: Any,
@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_agent_finish_common() self.on_agent_finish_common()
@override
async def on_text( async def on_text(
self, self,
*args: Any, *args: Any,

View File

@ -4,6 +4,8 @@ from itertools import cycle
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import ( from langchain_core.language_models import (
FakeListChatModel, FakeListChatModel,
@ -171,6 +173,7 @@ async def test_callback_handlers() -> None:
# Required to implement since this is an abstract method # Required to implement since this is an abstract method
pass pass
@override
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,

View File

@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import pytest import pytest
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel from langchain_core.language_models import BaseChatModel, FakeListChatModel
@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation.""" """Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseChatModel): class ModelWithGenerate(BaseChatModel):
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call.""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
@override
def _stream( def _stream(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call.""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
@override
async def _astream( # type: ignore async def _astream( # type: ignore
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None:
class NoStreamingModel(BaseChatModel): class NoStreamingModel(BaseChatModel):
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel):
class StreamingModel(NoStreamingModel): class StreamingModel(NoStreamingModel):
@override
def _stream( def _stream(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -3,6 +3,7 @@
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache from langchain_core.globals import set_llm_cache
@ -30,6 +31,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache.""" """Clear cache."""
self._cache = {} self._cache = {}

View File

@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from typing_extensions import override
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -106,6 +107,7 @@ async def test_error_callback() -> None:
"""Return type of llm.""" """Return type of llm."""
return "failing-llm" return "failing-llm"
@override
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation.""" """Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseLLM): class ModelWithGenerate(BaseLLM):
@override
def _generate( def _generate(
self, self,
prompts: list[str], prompts: list[str],
@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call.""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
@override
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call.""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
@override
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,

View File

@ -1,5 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache from langchain_core.globals import set_llm_cache
from langchain_core.language_models import FakeListLLM from langchain_core.language_models import FakeListLLM
@ -20,6 +22,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache.""" """Clear cache."""
self._cache = {} self._cache = {}
@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache):
msg = "This code should not be triggered" msg = "This code should not be triggered"
raise NotImplementedError(msg) raise NotImplementedError(msg)
@override
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache.""" """Clear cache."""
self._cache = {} self._cache = {}

View File

@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import pytest import pytest
from typing_extensions import override
from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import ( from langchain_core.messages import (
@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
class FakeTokenCountingModel(FakeChatModel): class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages( def get_num_tokens_from_messages(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -1,5 +1,7 @@
"""Module to test base parser implementations.""" """Module to test base parser implementations."""
from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
@ -16,6 +18,7 @@ def test_base_generation_parser() -> None:
class StrInvertCase(BaseGenerationOutputParser[str]): class StrInvertCase(BaseGenerationOutputParser[str]):
"""An example parser that inverts the case of the characters in the message.""" """An example parser that inverts the case of the characters in the message."""
@override
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> str: ) -> str:
@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None:
"""Parse a single string into a specific format.""" """Parse a single string into a specific format."""
raise NotImplementedError raise NotImplementedError
@override
def parse_result( def parse_result(
self, result: list[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> str: ) -> str:

View File

@ -5,6 +5,7 @@ from collections.abc import Sequence
from typing import Any from typing import Any
import pytest import pytest
from typing_extensions import override
from langchain_core.example_selectors import BaseExampleSelector from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> Any: def add_example(self, example: dict[str, str]) -> Any:
raise NotImplementedError raise NotImplementedError
@override
def select_examples(self, input_variables: dict[str, str]) -> list[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples) return list(self.examples)
@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector):
def select_examples(self, input_variables: dict[str, str]) -> list[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
raise NotImplementedError raise NotImplementedError
@override
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples) return list(self.examples)

View File

@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable( def _fake_runnable(
input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any _: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
) -> Union[BaseModel, dict]: ) -> Union[BaseModel, dict]:
if isclass(schema) and is_basemodel_subclass(schema): if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value) return schema(name="yo", value=value)

View File

@ -2,7 +2,7 @@ from typing import Any, Optional
import pytest import pytest
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self from typing_extensions import Self, override
from langchain_core.runnables import ( from langchain_core.runnables import (
ConfigurableField, ConfigurableField,
@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]):
self._my_hidden_property = self.my_property self._my_hidden_property = self.my_property
return self return self
@override
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]):
return self.my_property return self.my_property
def my_custom_function_w_config( def my_custom_function_w_config(
self, config: Optional[RunnableConfig] = None self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str: ) -> str:
return self.my_property return self.my_property
def my_custom_function_w_kw_config( def my_custom_function_w_kw_config(
self, *, config: Optional[RunnableConfig] = None self,
*,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str: ) -> str:
return self.my_property return self.my_property
@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]):
class MyOtherRunnable(RunnableSerializable[str, str]): class MyOtherRunnable(RunnableSerializable[str, str]):
my_other_property: str my_other_property: str
@override
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def my_other_custom_function(self) -> str: def my_other_custom_function(self) -> str:
return self.my_other_property return self.my_other_property
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002
return self.my_other_property return self.my_other_property

View File

@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable:
"What's your name?", "What's your name?",
] ]
retriever = RunnableLambda(lambda x: context) retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}") prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"]) llm = FakeListLLM(responses=["hello"])
@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable:
"What's your name?", "What's your name?",
] ]
retriever = RunnableLambda(lambda x: context) retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}") prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"]) llm = FakeListLLM(responses=["hello"])
@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable:
"What's your name?", "What's your name?",
] ]
retriever = RunnableLambda(lambda x: context) retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}") prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"]) llm = FakeListLLM(responses=["hello"])

View File

@ -9,6 +9,7 @@ from typing import (
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import ( from langchain_core.language_models import (
@ -60,7 +61,7 @@ def chain() -> Runnable:
) )
def _raise_error(inputs: dict) -> str: def _raise_error(_: dict) -> str:
raise ValueError raise ValueError
@ -259,17 +260,17 @@ async def test_abatch() -> None:
_assert_potential_error(actual, expected) _assert_potential_error(actual, expected)
def _generate(input: Iterator) -> Iterator[str]: def _generate(_: Iterator) -> Iterator[str]:
yield from "foo bar" yield from "foo bar"
def _generate_immediate_error(input: Iterator) -> Iterator[str]: def _generate_immediate_error(_: Iterator) -> Iterator[str]:
msg = "immmediate error" msg = "immmediate error"
raise ValueError(msg) raise ValueError(msg)
yield "" yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]: def _generate_delayed_error(_: Iterator) -> Iterator[str]:
yield "" yield ""
msg = "delayed error" msg = "delayed error"
raise ValueError(msg) raise ValueError(msg)
@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None:
list(runnable.stream({})) list(runnable.stream({}))
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]: async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
for c in "foo bar": for c in "foo bar":
yield c yield c
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
msg = "immmediate error" msg = "immmediate error"
raise ValueError(msg) raise ValueError(msg)
yield "" yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
yield "" yield ""
msg = "delayed error" msg = "delayed error"
raise ValueError(msg) raise ValueError(msg)
@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None:
class FakeStructuredOutputModel(BaseChatModel): class FakeStructuredOutputModel(BaseChatModel):
foo: int foo: int
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel):
"""Top Level call.""" """Top Level call."""
return ChatResult(generations=[]) return ChatResult(generations=[])
@override
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel):
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools) return self.bind(tools=tools)
@override
def with_structured_output( def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return RunnableLambda(lambda x: {"foo": self.foo}) return RunnableLambda(lambda _: {"foo": self.foo})
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel):
class FakeModel(BaseChatModel): class FakeModel(BaseChatModel):
bar: int bar: int
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -363,6 +368,7 @@ class FakeModel(BaseChatModel):
"""Top Level call.""" """Top Level call."""
return ChatResult(generations=[]) return ChatResult(generations=[])
@override
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
import pytest import pytest
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -39,7 +40,7 @@ def _get_get_session_history(
chat_history_store = store if store is not None else {} chat_history_store = store if store is not None else {}
def get_session_history( def get_session_history(
session_id: str, **kwargs: Any session_id: str, **_kwargs: Any
) -> InMemoryChatMessageHistory: ) -> InMemoryChatMessageHistory:
if session_id not in chat_history_store: if session_id not in chat_history_store:
chat_history_store[session_id] = InMemoryChatMessageHistory() chat_history_store[session_id] = InMemoryChatMessageHistory()
@ -253,6 +254,7 @@ async def test_output_message_async() -> None:
class LengthChatModel(BaseChatModel): class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in.""" """A fake chat model that returns the length of the messages passed in."""
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
def test_get_output_messages_with_value_error() -> None: def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message) runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {} store: dict = {}
get_session_history = _get_get_session_history(store=store) get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None:
with_history.bound.invoke([HumanMessage(content="hello")], config) with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123 illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message) runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises( with pytest.raises(

View File

@ -21,10 +21,11 @@ from packaging import version
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from typing_extensions import TypedDict from typing_extensions import TypedDict, override
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
Callbacks, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
atrace_as_chain_group, atrace_as_chain_group,
trace_as_chain_group, trace_as_chain_group,
) )
@ -184,6 +185,7 @@ class FakeTracer(BaseTracer):
class FakeRunnable(Runnable[str, int]): class FakeRunnable(Runnable[str, int]):
@override
def invoke( def invoke(
self, self,
input: str, input: str,
@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]):
class FakeRunnableSerializable(RunnableSerializable[str, int]): class FakeRunnableSerializable(RunnableSerializable[str, int]):
hello: str = "" hello: str = ""
@override
def invoke( def invoke(
self, self,
input: str, input: str,
@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
class FakeRetriever(BaseRetriever): class FakeRetriever(BaseRetriever):
@override
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")] return [Document(page_content="foo"), Document(page_content="bar")]
@override
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")] return [Document(page_content="foo"), Document(page_content="bar")]
@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
foo_ = RunnableLambda(foo) foo_ = RunnableLambda(foo)
assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == { assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}}, "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
"required": ["root", "bar"], "required": ["root", "bar"],
"title": "RunnableAssignOutput", "title": "RunnableAssignOutput",
@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None:
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model( def test_prompt_with_chat_model(
mocker: MockerFixture, mocker: MockerFixture,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None: ) -> None:
prompt = ( prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.") SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model(
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
async def test_prompt_with_chat_model_async( async def test_prompt_with_chat_model_async(
mocker: MockerFixture, mocker: MockerFixture,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None: ) -> None:
prompt = ( prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.") SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None:
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_stream_log_lists() -> None: async def test_stream_log_lists() -> None:
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]: async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
for i in range(4): for i in range(4):
yield AddableDict(alist=[str(i)]) yield AddableDict(alist=[str(i)])
@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda(
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model_and_parser( def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, mocker: MockerFixture,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None: ) -> None:
prompt = ( prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.") SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser(
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_combining_sequences( def test_combining_sequences(
mocker: MockerFixture,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None: ) -> None:
prompt = ( prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.") SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -3513,7 +3505,7 @@ def test_bind_bind() -> None:
def test_bind_with_lambda() -> None: def test_bind_with_lambda() -> None:
def my_function(*args: Any, **kwargs: Any) -> int: def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0) return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1) runnable = RunnableLambda(my_function).bind(n=1)
@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None:
async def test_bind_with_lambda_async() -> None: async def test_bind_with_lambda_async() -> None:
def my_function(*args: Any, **kwargs: Any) -> int: def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0) return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1) runnable = RunnableLambda(my_function).bind(n=1)
@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
def test_recursive_lambda() -> None: def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]: def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10: if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1)) return RunnableLambda(lambda *_: _simple_recursion(x + 1))
return x return x
runnable = RunnableLambda(_simple_recursion) runnable = RunnableLambda(_simple_recursion)
@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None:
# sleep to better simulate a real stream # sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda x: llm).stream("")) output = list(RunnableLambda(lambda _: llm).stream(""))
assert output == list(llm_res) assert output == list(llm_res)
@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]} config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list( assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
llm_res llm_res
) )
@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
assert tracer.runs[0].error is None assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res} assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int: def raise_value_error(_: int) -> int:
"""Raise a value error.""" """Raise a value error."""
msg = "x is too large" msg = "x is too large"
raise ValueError(msg) raise ValueError(msg)
@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None:
_ _
async for _ in RunnableLambda( async for _ in RunnableLambda(
func=id, func=id,
afunc=awrapper(lambda x: llm), afunc=awrapper(lambda _: llm),
).astream("") ).astream("")
] ]
assert output == list(llm_res) assert output == list(llm_res)
@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None:
output = [ output = [
chunk chunk
async for chunk in cast( async for chunk in cast(
"AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("") "AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
) )
] ]
assert output == list(llm_res) assert output == list(llm_res)
@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
config: RunnableConfig = {"callbacks": [tracer]} config: RunnableConfig = {"callbacks": [tracer]}
assert [ assert [
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config) _ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
] == list(llm_res) ] == list(llm_res)
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
assert tracer.runs[0].error is None assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res} assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int: def raise_value_error(_: int) -> int:
"""Raise a value error.""" """Raise a value error."""
msg = "x is too large" msg = "x is too large"
raise ValueError(msg) raise ValueError(msg)
@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
def test_runnable_branch_invoke() -> None: def test_runnable_branch_invoke() -> None:
# Test with single branch # Test with single branch
def raise_value_error(x: int) -> int: def raise_value_error(_: int) -> int:
"""Raise a value error.""" """Raise a value error."""
msg = "x is too large" msg = "x is too large"
raise ValueError(msg) raise ValueError(msg)
@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
"""Verify that callbacks are correctly used in invoke.""" """Verify that callbacks are correctly used in invoke."""
tracer = FakeTracer() tracer = FakeTracer()
def raise_value_error(x: int) -> int: def raise_value_error(_: int) -> int:
"""Raise a value error.""" """Raise a value error."""
msg = "x is too large" msg = "x is too large"
raise ValueError(msg) raise ValueError(msg)
@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
"""Verify that callbacks are invoked correctly in ainvoke.""" """Verify that callbacks are invoked correctly in ainvoke."""
tracer = FakeTracer() tracer = FakeTracer()
async def raise_value_error(x: int) -> int: async def raise_value_error(_: int) -> int:
"""Raise a value error.""" """Raise a value error."""
msg = "x is too large" msg = "x is too large"
raise ValueError(msg) raise ValueError(msg)
@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None:
runnable = RunnableLambda(lambda x: x * 2) runnable = RunnableLambda(lambda x: x * 2)
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)" assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
def f(x: int) -> int: def f(_: int) -> int:
"""Return 2.""" """Return 2."""
return 2 return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)" assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
async def af(x: int) -> int: async def af(_: int) -> int:
"""Return 2.""" """Return 2."""
return 2 return 2
@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None:
def test_runnable_gen() -> None: def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable.""" """Test that a generator can be used as a runnable."""
def gen(input: Iterator[Any]) -> Iterator[int]: def gen(_: Iterator[Any]) -> Iterator[int]:
yield 1 yield 1
yield 2 yield 2
yield 3 yield 3
@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None:
async def test_runnable_gen_async() -> None: async def test_runnable_gen_async() -> None:
"""Test that a generator can be used as a runnable.""" """Test that a generator can be used as a runnable."""
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1 yield 1
yield 2 yield 2
yield 3 yield 3
@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None:
assert await arunnable.abatch([None, None]) == [6, 6] assert await arunnable.abatch([None, None]) == [6, 6]
class AsyncGen: class AsyncGen:
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1 yield 1
yield 2 yield 2
yield 3 yield 3
@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None:
""" """
fake = RunnableLambda(len) fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]: def gen(_: Iterator[Any]) -> Iterator[int]:
yield fake.invoke("a") yield fake.invoke("a")
yield fake.invoke("aa") yield fake.invoke("aa")
yield fake.invoke("aaa") yield fake.invoke("aaa")
@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None:
fake = RunnableLambda(len) fake = RunnableLambda(len)
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield await fake.ainvoke("a") yield await fake.ainvoke("a")
yield await fake.ainvoke("aa") yield await fake.ainvoke("aa")
yield await fake.ainvoke("aaa") yield await fake.ainvoke("aaa")
@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None:
"""Test that default transform works with dicts.""" """Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]): class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts.""" """Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]): class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
on_chain_end_triggered = False on_chain_end_triggered = False
class MyHandler(BaseCallbackHandler): class MyHandler(BaseCallbackHandler):
@override
def on_chain_error( def on_chain_error(
self, self,
error: BaseException, error: BaseException,
@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
nonlocal on_chain_error_triggered nonlocal on_chain_error_triggered
on_chain_error_triggered = True on_chain_error_triggered = True
@override
def on_chain_end( def on_chain_end(
self, self,
outputs: dict[str, Any], outputs: dict[str, Any],

View File

@ -8,6 +8,7 @@ from typing import Any, cast
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
async def test_event_stream_with_simple_function_tool() -> None: async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool.""" """Test the event stream with a function and tool."""
def foo(x: int) -> dict: def foo(_: int) -> dict:
"""Foo.""" """Foo."""
return {"x": 5} return {"x": 5}
@tool @tool
def get_docs(x: int) -> list[Document]: def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc.""" """Hello Doc."""
return [Document(page_content="hello")] return [Document(page_content="hello")]
@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None: async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"} {"run_name": "my_lambda"}
) )
events = await _collect_events( events = await _collect_events(
@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello" return "hello"
@tool @tool
def with_callbacks(callbacks: Callbacks) -> str: def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing.""" """A tool that does nothing."""
return "world" return "world"
@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y} return {"x": x, "y": y}
@tool @tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing.""" """A tool that does nothing."""
return {"x": x, "y": y} return {"x": x, "y": y}
@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever): class HardCodedRetriever(BaseRetriever):
documents: list[Document] documents: list[Document]
@override
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]: ) -> list[Document]:
@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None: async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool.""" """Test the event stream with a tool."""
def success(inputs: str) -> str: def success(_: str) -> str:
return "success" return "success"
def fail(inputs: str) -> None: def fail(_: str) -> None:
"""Simple func.""" """Simple func."""
msg = "fail" msg = "fail"
raise ValueError(msg) raise ValueError(msg)

View File

@ -17,6 +17,7 @@ from typing import (
import pytest import pytest
from blockbuster import BlockBuster from blockbuster import BlockBuster
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
@ -97,12 +98,12 @@ async def _collect_events(
async def test_event_stream_with_simple_function_tool() -> None: async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool.""" """Test the event stream with a function and tool."""
def foo(x: int) -> dict: def foo(x: int) -> dict: # noqa: ARG001
"""Foo.""" """Foo."""
return {"x": 5} return {"x": 5}
@tool @tool
def get_docs(x: int) -> list[Document]: def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc.""" """Hello Doc."""
return [Document(page_content="hello")] return [Document(page_content="hello")]
@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None: async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"} {"run_name": "my_lambda"}
) )
events = await _collect_events( events = await _collect_events(
@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello" return "hello"
@tool @tool
def with_callbacks(callbacks: Callbacks) -> str: def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing.""" """A tool that does nothing."""
return "world" return "world"
@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y} return {"x": x, "y": y}
@tool @tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing.""" """A tool that does nothing."""
return {"x": x, "y": y} return {"x": x, "y": y}
@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever): class HardCodedRetriever(BaseRetriever):
documents: list[Document] documents: list[Document]
@override
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]: ) -> list[Document]:
@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None: async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool.""" """Test the event stream with a tool."""
def success(inputs: str) -> str: def success(_: str) -> str:
return "success" return "success"
def fail(inputs: str) -> None: def fail(_: str) -> None:
"""Simple func.""" """Simple func."""
msg = "fail" msg = "fail"
raise ValueError(msg) raise ValueError(msg)
@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]):
"""Initialize the runnable.""" """Initialize the runnable."""
self.iterable = iterable self.iterable = iterable
@override
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]):
) -> Iterator[Output]: ) -> Iterator[Output]:
raise NotImplementedError raise NotImplementedError
@override
async def astream( async def astream(
self, self,
input: Input, input: Input,
@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None:
async def test_runnable_generator() -> None: async def test_runnable_generator() -> None:
"""Test async events from sync lambda.""" """Test async events from sync lambda."""
async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]: async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
yield "1" yield "1"
yield "2" yield "2"

View File

@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting:
def test_sync( def test_sync(
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int] self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
) -> None: ) -> None:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore def other_thing(_: int) -> Generator[int, None, None]: # type: ignore
yield 1 yield 1
parent = self._create_parent(other_thing) parent = self._create_parent(other_thing)
@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting:
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int] [RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
], ],
) -> None: ) -> None:
async def other_thing(a: int) -> AsyncGenerator[int, None]: async def other_thing(_: int) -> AsyncGenerator[int, None]:
yield 1 yield 1
parent = self._create_parent(other_thing) parent = self._create_parent(other_thing)

View File

@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None:
self.value = "bar" self.value = "bar"
@tool @tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
"""Tool that has an injected tool arg.""" """Tool that has an injected tool arg."""
return foo.value return foo.value
@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None:
def test_empty_string_tool_call_id() -> None: def test_empty_string_tool_call_id() -> None:
@tool @tool
def foo(x: int) -> str: def foo(x: int) -> str: # noqa: ARG001
"""Foo.""" """Foo."""
return "hi" return "hi"
@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None:
def test_tool_decorator_description() -> None: def test_tool_decorator_description() -> None:
# test basic tool # test basic tool
@tool @tool
def foo(x: int) -> str: def foo(x: int) -> str: # noqa: ARG001
"""Foo.""" """Foo."""
return "hi" return "hi"
@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None:
# test basic tool with description # test basic tool with description
@tool(description="description") @tool(description="description")
def foo_description(x: int) -> str: def foo_description(x: int) -> str: # noqa: ARG001
"""Foo.""" """Foo."""
return "hi" return "hi"
@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None:
x: int x: int
@tool(args_schema=ArgsSchema) @tool(args_schema=ArgsSchema)
def foo_args_schema(x: int) -> str: def foo_args_schema(x: int) -> str: # noqa: ARG001
return "hi" return "hi"
assert foo_args_schema.description == "Bar." assert foo_args_schema.description == "Bar."
@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None:
) )
@tool(description="description", args_schema=ArgsSchema) @tool(description="description", args_schema=ArgsSchema)
def foo_args_schema_description(x: int) -> str: def foo_args_schema_description(x: int) -> str: # noqa: ARG001
return "hi" return "hi"
assert foo_args_schema_description.description == "description" assert foo_args_schema_description.description == "description"
@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None:
} }
@tool(args_schema=args_json_schema) @tool(args_schema=args_json_schema)
def foo_args_jsons_schema(x: int) -> str: def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
return "hi" return "hi"
@tool(description="description", args_schema=args_json_schema) @tool(description="description", args_schema=args_json_schema)
def foo_args_jsons_schema_with_description(x: int) -> str: def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
return "hi" return "hi"
assert foo_args_jsons_schema.description == "JSON Schema." assert foo_args_jsons_schema.description == "JSON Schema."
@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None:
async def test_tool_ainvoke_does_not_mutate_inputs() -> None: async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool asynchronously.""" """Verify that the inputs are not mutated when invoking a tool asynchronously."""
def sync_no_op(foo: int) -> str: def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good" return "good"
async def async_no_op(foo: int) -> str: async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good" return "good"
tool = StructuredTool( tool = StructuredTool(
@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
def test_tool_invoke_does_not_mutate_inputs() -> None: def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously.""" """Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str: def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good" return "good"
async def async_no_op(foo: int) -> str: async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good" return "good"
tool = StructuredTool( tool = StructuredTool(

View File

@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool:
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return StructuredTool.from_function( return StructuredTool.from_function(
lambda x: None, lambda _: None,
name="dummy_function", name="dummy_function",
description="Dummy function.", description="Dummy function.",
args_schema=Schema, args_schema=Schema,
@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
"required": ["arg1", "arg2"], "required": ["arg1", "arg2"],
} }
return StructuredTool.from_function( return StructuredTool.from_function(
lambda x: None, lambda _: None,
name="dummy_function", name="dummy_function",
description="Dummy function.", description="Dummy function.",
args_schema=args_schema, args_schema=args_schema,

View File

@ -10,6 +10,7 @@ import uuid
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import pytest import pytest
from typing_extensions import override
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings from langchain_core.embeddings import Embeddings, FakeEmbeddings
@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore):
def __init__(self) -> None: def __init__(self) -> None:
self.store: dict[str, Document] = {} self.store: dict[str, Document] = {}
@override
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store] return [self.store[id] for id in ids if id in self.store]
@classmethod @classmethod
@override
def from_texts( # type: ignore def from_texts( # type: ignore
cls, cls,
texts: list[str], texts: list[str],
@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
def __init__(self) -> None: def __init__(self) -> None:
self.store: dict[str, Document] = {} self.store: dict[str, Document] = {}
@override
def add_documents( def add_documents(
self, self,
documents: list[Document], documents: list[Document],
@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store] return [self.store[id] for id in ids if id in self.store]
@classmethod @classmethod
@override
def from_texts( # type: ignore def from_texts( # type: ignore
cls, cls,
texts: list[str], texts: list[str],