core: Add ruff rules ARG (#30732)

See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg
This commit is contained in:
Christophe Bornet 2025-04-09 20:39:36 +02:00 committed by GitHub
parent 66758599a9
commit 98f0016fc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 328 additions and 180 deletions

View File

@ -124,7 +124,7 @@ def beta(
_name = _name or obj.__qualname__
old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the annotation of a class."""
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
@ -190,7 +190,7 @@ def beta(
if _name == "<lambda>":
_name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: # noqa: ARG001
"""Finalize the property."""
return _BetaProperty(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc

View File

@ -204,7 +204,7 @@ def deprecated(
_name = _name or obj.__qualname__
old_doc = obj.__doc__
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the deprecation of a class."""
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
@ -234,7 +234,7 @@ def deprecated(
raise ValueError(msg)
old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast(
"T",
FieldInfoV1(
@ -255,7 +255,7 @@ def deprecated(
raise ValueError(msg)
old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast(
"T",
FieldInfoV2(
@ -315,7 +315,7 @@ def deprecated(
if _name == "<lambda>":
_name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
"""Finalize the property."""
return cast(
"T",

View File

@ -27,6 +27,8 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
from typing_extensions import override
from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor
@ -194,6 +196,7 @@ class InMemoryCache(BaseCache):
del self._cache[next(iter(self._cache))]
self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
@ -227,6 +230,7 @@ class InMemoryCache(BaseCache):
"""
self.update(prompt, llm_string, return_val)
@override
async def aclear(self, **kwargs: Any) -> None:
"""Async clear cache."""
self.clear()

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
from typing_extensions import override
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.utils.input import print_text
@ -38,6 +40,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""Destructor to cleanup when done."""
self.file.close()
@override
def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
@ -61,6 +64,7 @@ class FileCallbackHandler(BaseCallbackHandler):
file=self.file,
)
@override
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.
@ -70,6 +74,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""
print_text("\n\033[1m> Finished chain.\033[0m", end="\n", file=self.file)
@override
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
@ -83,6 +88,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""
print_text(action.log, color=color or self.color, file=self.file)
@override
def on_tool_end(
self,
output: str,
@ -109,6 +115,7 @@ class FileCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None:
print_text(f"\n{llm_prefix}", file=self.file)
@override
def on_text(
self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any
) -> None:
@ -123,6 +130,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""
print_text(text, color=color or self.color, end=end, file=self.file)
@override
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:

View File

@ -22,7 +22,7 @@ from typing import (
from uuid import UUID
from langsmith.run_helpers import get_tracing_context
from typing_extensions import Self
from typing_extensions import Self, override
from langchain_core.callbacks.base import (
BaseCallbackHandler,
@ -1401,6 +1401,7 @@ class CallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata,
)
@override
def on_tool_start(
self,
serialized: Optional[dict[str, Any]],
@ -1456,6 +1457,7 @@ class CallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata,
)
@override
def on_retriever_start(
self,
serialized: Optional[dict[str, Any]],
@ -1927,6 +1929,7 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata,
)
@override
async def on_tool_start(
self,
serialized: Optional[dict[str, Any]],
@ -2017,6 +2020,7 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata=self.metadata,
)
@override
async def on_retriever_start(
self,
serialized: Optional[dict[str, Any]],

View File

@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from typing_extensions import override
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.utils import print_text
@ -22,6 +24,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""
self.color = color
@override
def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
@ -41,6 +44,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
name = "<unknown>"
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201
@override
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.
@ -50,6 +54,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
@override
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
@ -62,6 +67,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""
print_text(action.log, color=color or self.color)
@override
def on_tool_end(
self,
output: Any,
@ -87,6 +93,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")
@override
def on_text(
self,
text: str,
@ -104,6 +111,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""
print_text(text, color=color or self.color, end=end)
@override
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:

View File

@ -5,6 +5,8 @@ from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from langchain_core.callbacks.base import BaseCallbackHandler
if TYPE_CHECKING:
@ -41,6 +43,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
**kwargs (Any): Additional keyword arguments.
"""
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled.

View File

@ -58,6 +58,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
def __repr__(self) -> str:
return str(self.usage_metadata)
@override
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)

View File

@ -151,7 +151,7 @@ def _get_source_id_assigner(
) -> Callable[[Document], Union[str, None]]:
"""Get the source id from the document."""
if source_id_key is None:
return lambda doc: None
return lambda _doc: None
if isinstance(source_id_key, str):
return lambda doc: doc.metadata[source_id_key]
if callable(source_id_key):

View File

@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Optional, cast
from pydantic import Field
from typing_extensions import override
from langchain_core._api import beta
from langchain_core.callbacks import CallbackManagerForRetrieverRun
@ -29,6 +30,7 @@ class InMemoryDocumentIndex(DocumentIndex):
store: dict[str, Document] = Field(default_factory=dict)
top_k: int = 4
@override
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
"""Upsert items into the index."""
ok_ids = []
@ -47,6 +49,7 @@ class InMemoryDocumentIndex(DocumentIndex):
return UpsertResponse(succeeded=ok_ids, failed=[])
@override
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by ID."""
if ids is None:
@ -64,10 +67,12 @@ class InMemoryDocumentIndex(DocumentIndex):
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
)
@override
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids."""
return [self.store[id_] for id_ in ids if id_ in self.store]
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:

View File

@ -576,7 +576,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
return {}
def _get_invocation_params(
@ -1246,6 +1246,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _llm_type(self) -> str:
"""Return type of chat model."""
@override
def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)

View File

@ -171,6 +171,7 @@ class FakeListChatModel(SimpleChatModel):
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
@override
def _call(
self,
messages: list[BaseMessage],
@ -180,6 +181,7 @@ class FakeChatModel(SimpleChatModel):
) -> str:
return "fake response"
@override
async def _agenerate(
self,
messages: list[BaseMessage],
@ -224,6 +226,7 @@ class GenericFakeChatModel(BaseChatModel):
into message chunks.
"""
@override
def _generate(
self,
messages: list[BaseMessage],
@ -346,6 +349,7 @@ class ParrotFakeChatModel(BaseChatModel):
* Chat model should be usable in both sync and async tests
"""
@override
def _generate(
self,
messages: list[BaseMessage],

View File

@ -1399,6 +1399,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _llm_type(self) -> str:
"""Return type of llm."""
@override
def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)

View File

@ -59,7 +59,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
Returns:
Structured output.
"""
return await run_in_executor(None, self.parse_result, result)
return await run_in_executor(None, self.parse_result, result, partial=partial)
class BaseGenerationOutputParser(
@ -231,6 +231,7 @@ class BaseOutputParser(
run_type="parser",
)
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
@ -290,7 +291,11 @@ class BaseOutputParser(
return await run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
def parse_with_prompt(
self,
completion: str,
prompt: PromptValue, # noqa: ARG002
) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@ -23,6 +24,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
@ -251,6 +253,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
raise ValueError(msg)
return values
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
@ -287,6 +290,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.

View File

@ -9,6 +9,8 @@ from typing import (
Union,
)
from typing_extensions import override
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.output_parsers.base import BaseOutputParser, T
from langchain_core.outputs import (
@ -48,6 +50,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
None, self.parse_result, [Generation(text=chunk)]
)
@override
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
@ -68,6 +71,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
input, self._transform, config, run_type="parser"
)
@override
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],

View File

@ -127,6 +127,7 @@ class BasePromptTemplate(
"""Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete]
@override
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:

View File

@ -199,7 +199,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
def _get_ls_params(self, **_kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"):

View File

@ -326,7 +326,8 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_input_schema()
def get_input_schema(
self, config: Optional[RunnableConfig] = None
self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate input to the Runnable.
@ -398,7 +399,8 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_output_schema()
def get_output_schema(
self, config: Optional[RunnableConfig] = None
self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable.
@ -4751,11 +4753,6 @@ class RunnableLambda(Runnable[Input, Output]):
)
return cast("Output", output)
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
return ensure_config(config)
@override
def invoke(
self,
@ -4780,7 +4777,7 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(
self._invoke,
input,
self._config(config, self.func),
ensure_config(config),
**kwargs,
)
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
@ -4803,11 +4800,10 @@ class RunnableLambda(Runnable[Input, Output]):
Returns:
The output of this Runnable.
"""
the_func = self.afunc if hasattr(self, "afunc") else self.func
return await self._acall_with_config(
self._ainvoke,
input,
self._config(config, the_func),
ensure_config(config),
**kwargs,
)
@ -4884,7 +4880,7 @@ class RunnableLambda(Runnable[Input, Output]):
yield from self._transform_stream_with_config(
input,
self._transform,
self._config(config, self.func),
ensure_config(config),
**kwargs,
)
else:
@ -5012,7 +5008,7 @@ class RunnableLambda(Runnable[Input, Output]):
async for output in self._atransform_stream_with_config(
input,
self._atransform,
self._config(config, self.afunc if hasattr(self, "afunc") else self.func),
ensure_config(config),
**kwargs,
):
yield output

View File

@ -400,6 +400,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def OutputType(self) -> type[Output]:
return self._history_chain.OutputType
@override
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
@ -432,12 +433,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
module_name=self.__class__.__module__,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return False
async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return True
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> list[BaseMessage]:

View File

@ -133,6 +133,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_llm_start(llm_run)
return llm_run
@override
def on_llm_new_token(
self,
token: str,
@ -161,11 +162,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id=run_id,
chunk=chunk,
parent_run_id=parent_run_id,
**kwargs,
)
self._on_llm_new_token(llm_run, token, chunk)
return llm_run
@override
def on_retry(
self,
retry_state: RetryCallState,
@ -188,6 +189,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id=run_id,
)
@override
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run.
@ -235,6 +237,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_llm_error(llm_run)
return llm_run
@override
def on_chain_start(
self,
serialized: dict[str, Any],
@ -279,6 +282,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_chain_start(chain_run)
return chain_run
@override
def on_chain_end(
self,
outputs: dict[str, Any],
@ -302,12 +306,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
outputs=outputs,
run_id=run_id,
inputs=inputs,
**kwargs,
)
self._end_trace(chain_run)
self._on_chain_end(chain_run)
return chain_run
@override
def on_chain_error(
self,
error: BaseException,
@ -331,7 +335,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
error=error,
run_id=run_id,
inputs=inputs,
**kwargs,
)
self._end_trace(chain_run)
self._on_chain_error(chain_run)
@ -381,6 +384,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_tool_start(tool_run)
return tool_run
@override
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run.
@ -395,12 +399,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
tool_run = self._complete_tool_run(
output=output,
run_id=run_id,
**kwargs,
)
self._end_trace(tool_run)
self._on_tool_end(tool_run)
return tool_run
@override
def on_tool_error(
self,
error: BaseException,
@ -467,6 +471,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self._on_retriever_start(retrieval_run)
return retrieval_run
@override
def on_retriever_error(
self,
error: BaseException,
@ -487,12 +492,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
retrieval_run = self._errored_retrieval_run(
error=error,
run_id=run_id,
**kwargs,
)
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
return retrieval_run
@override
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
@ -509,7 +514,6 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
retrieval_run = self._complete_retrieval_run(
documents=documents,
run_id=run_id,
**kwargs,
)
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
@ -623,7 +627,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
run_id=run_id,
chunk=chunk,
parent_run_id=parent_run_id,
**kwargs,
)
await self._on_llm_new_token(llm_run, token, chunk)
@ -715,7 +718,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
outputs=outputs,
run_id=run_id,
inputs=inputs,
**kwargs,
)
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
await asyncio.gather(*tasks)
@ -733,7 +735,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
error=error,
inputs=inputs,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
await asyncio.gather(*tasks)
@ -776,7 +777,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
tool_run = self._complete_tool_run(
output=output,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
await asyncio.gather(*tasks)
@ -839,7 +839,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
retrieval_run = self._errored_retrieval_run(
error=error,
run_id=run_id,
**kwargs,
)
tasks = [
self._end_trace(retrieval_run),
@ -860,7 +859,6 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
retrieval_run = self._complete_retrieval_run(
documents=documents,
run_id=run_id,
**kwargs,
)
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
await asyncio.gather(*tasks)

View File

@ -41,7 +41,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
@contextmanager
def tracing_enabled(
session_name: str = "default",
session_name: str = "default", # noqa: ARG001
) -> Generator[TracerSessionV1, None, None]:
"""Throw an error because this has been replaced by tracing_v2_enabled."""
msg = (

View File

@ -231,8 +231,7 @@ class _TracerCore(ABC):
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
parent_run_id: Optional[UUID] = None, # noqa: ARG002
) -> Run:
"""Append token event to LLM run and return the run."""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
@ -252,7 +251,6 @@ class _TracerCore(ABC):
self,
retry_state: RetryCallState,
run_id: UUID,
**kwargs: Any,
) -> Run:
llm_run = self._get_run(run_id)
retry_d: dict[str, Any] = {
@ -369,7 +367,6 @@ class _TracerCore(ABC):
outputs: dict[str, Any],
run_id: UUID,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Update a chain run with outputs and end time."""
chain_run = self._get_run(run_id)
@ -385,7 +382,6 @@ class _TracerCore(ABC):
error: BaseException,
inputs: Optional[dict[str, Any]],
run_id: UUID,
**kwargs: Any,
) -> Run:
chain_run = self._get_run(run_id)
chain_run.error = self._get_stacktrace(error)
@ -439,7 +435,6 @@ class _TracerCore(ABC):
self,
output: dict[str, Any],
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a tool run with outputs and end time."""
tool_run = self._get_run(run_id, run_type="tool")
@ -452,7 +447,6 @@ class _TracerCore(ABC):
self,
error: BaseException,
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a tool run with error and end time."""
tool_run = self._get_run(run_id, run_type="tool")
@ -494,7 +488,6 @@ class _TracerCore(ABC):
self,
documents: Sequence[Document],
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Update a retrieval run with outputs and end time."""
retrieval_run = self._get_run(run_id, run_type="retriever")
@ -507,7 +500,6 @@ class _TracerCore(ABC):
self,
error: BaseException,
run_id: UUID,
**kwargs: Any,
) -> Run:
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = self._get_stacktrace(error)
@ -523,75 +515,75 @@ class _TracerCore(ABC):
"""Copy the tracer."""
return self
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""End a trace for a run."""
return None
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process a run upon creation."""
return None
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process a run upon update."""
return None
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run upon start."""
return None
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
run: Run, # noqa: ARG002
token: str, # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
return None
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run."""
return None
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the LLM Run upon error."""
return None
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run upon start."""
return None
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run."""
return None
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chain Run upon error."""
return None
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run upon start."""
return None
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run."""
return None
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Tool Run upon error."""
return None
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Chat Model Run upon start."""
return None
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run upon start."""
return None
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run."""
return None
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # noqa: ARG002
"""Process the Retriever Run upon error."""
return None

View File

@ -15,7 +15,7 @@ from typing import (
)
from uuid import UUID, uuid4
from typing_extensions import NotRequired, TypedDict
from typing_extensions import NotRequired, TypedDict, override
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
@ -293,6 +293,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.run_map[run_id] = info
self.parent_map[run_id] = parent_run_id
@override
async def on_chat_model_start(
self,
serialized: dict[str, Any],
@ -334,6 +335,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type,
)
@override
async def on_llm_start(
self,
serialized: dict[str, Any],
@ -377,6 +379,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type,
)
@override
async def on_custom_event(
self,
name: str,
@ -399,6 +402,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
)
self._send(event, name)
@override
async def on_llm_new_token(
self,
token: str,
@ -450,6 +454,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info["run_type"],
)
@override
async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
) -> None:
@ -552,6 +557,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type_,
)
@override
async def on_chain_end(
self,
outputs: dict[str, Any],
@ -586,6 +592,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type,
)
@override
async def on_tool_start(
self,
serialized: dict[str, Any],
@ -627,6 +634,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool",
)
@override
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
run_info = self.run_map.pop(run_id)
@ -654,6 +662,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool",
)
@override
async def on_retriever_start(
self,
serialized: dict[str, Any],
@ -697,6 +706,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type,
)
@override
async def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:

View File

@ -19,6 +19,7 @@ from tenacity import (
stop_after_attempt,
wait_exponential_jitter,
)
from typing_extensions import override
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
@ -252,13 +253,13 @@ class LangChainTracer(BaseTracer):
run.reference_example_id = self.example_id
self._persist_run_single(run)
@override
def _llm_run_with_token_event(
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""Append token event to LLM run and return the run."""
return super()._llm_run_with_token_event(
@ -267,7 +268,6 @@ class LangChainTracer(BaseTracer):
run_id,
chunk=None,
parent_run_id=parent_run_id,
**kwargs,
)
def _on_chat_model_start(self, run: Run) -> None:

View File

@ -6,7 +6,7 @@ Please use LangChainTracer instead.
from typing import Any
def get_headers(*args: Any, **kwargs: Any) -> Any:
def get_headers(*args: Any, **kwargs: Any) -> Any: # noqa: ARG001
"""Throw an error because this has been replaced by get_headers."""
msg = (
"get_headers for LangChainTracerV1 is no longer supported. "
@ -15,7 +15,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError(msg)
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802,ARG001
"""Throw an error because this has been replaced by LangChainTracer."""
msg = (
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."

View File

@ -16,8 +16,8 @@ from langchain_core._api.deprecation import deprecated
),
)
def try_load_from_hub(
*args: Any,
**kwargs: Any,
*args: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
) -> Any:
"""[DEPRECATED] Try to load from the old Hub."""
warnings.warn(

View File

@ -65,7 +65,11 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
return (literal, template)
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def l_sa_check(
template: str, # noqa: ARG001
literal: str,
is_standalone: bool,
) -> bool:
"""Do a preliminary check to see if a tag could be a standalone.
Args:

View File

@ -38,6 +38,7 @@ from pydantic.json_schema import (
JsonSchemaMode,
JsonSchemaValue,
)
from typing_extensions import override
if TYPE_CHECKING:
from pydantic_core import core_schema
@ -233,6 +234,7 @@ class _IgnoreUnserializable(GenerateJsonSchema):
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
"""
@override
def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue:

View File

@ -13,6 +13,7 @@ from typing import Any, Callable, Optional, Union, overload
from packaging.version import parse
from pydantic import SecretStr
from requests import HTTPError, Response
from typing_extensions import override
from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
@ -91,6 +92,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]:
"""Mock datetime.datetime.now() with a fixed datetime."""
@classmethod
@override
def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime":
# Create a copy of dt_value.
return MockDateTime(

View File

@ -36,7 +36,7 @@ from typing import (
)
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
@ -1070,6 +1070,7 @@ class VectorStoreRetriever(BaseRetriever):
return ls_params
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
@ -1090,6 +1091,7 @@ class VectorStoreRetriever(BaseRetriever):
raise ValueError(msg)
return docs
@override
async def _aget_relevant_documents(
self,
query: str,

View File

@ -285,7 +285,7 @@ class InMemoryVectorStore(VectorStore):
since="0.2.29",
removal="1.0",
)
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
def upsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse:
"""[DEPRECATED] Upsert documents into the store.
Args:
@ -319,7 +319,7 @@ class InMemoryVectorStore(VectorStore):
removal="1.0",
)
async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any
self, items: Sequence[Document], /, **_kwargs: Any
) -> UpsertResponse:
"""[DEPRECATED] Upsert documents into the store.
@ -364,7 +364,6 @@ class InMemoryVectorStore(VectorStore):
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> list[tuple[Document, float, list[float]]]:
# get all docs with fixed order in list
docs = list(self.store.values())
@ -404,7 +403,7 @@ class InMemoryVectorStore(VectorStore):
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
**_kwargs: Any,
) -> list[tuple[Document, float]]:
"""Search for the most similar documents to the given embedding.
@ -419,7 +418,7 @@ class InMemoryVectorStore(VectorStore):
return [
(doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
embedding=embedding, k=k, filter=filter
)
]
@ -490,12 +489,14 @@ class InMemoryVectorStore(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
*,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> list[Document]:
prefetch_hits = self._similarity_search_with_score_by_vector(
embedding=embedding,
k=fetch_k,
**kwargs,
filter=filter,
)
try:

View File

@ -98,7 +98,6 @@ ignore = [
# TODO rules
"A",
"ANN401",
"ARG",
"BLE",
"ERA",
"FBT001",
@ -132,5 +131,6 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini
"tests/unit_tests/prompts/test_chat.py" = [ "E501",]
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
"tests/unit_tests/runnables/test_graph.py" = [ "E501",]
"tests/unit_tests/test_tools.py" = [ "ARG",]
"tests/**" = [ "D", "S",]
"scripts/**" = [ "INP", "S",]

View File

@ -10,6 +10,8 @@ from contextlib import asynccontextmanager
from typing import Any, Optional
from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks import (
AsyncCallbackHandler,
AsyncCallbackManager,
@ -45,6 +47,7 @@ async def test_inline_handlers_share_parent_context() -> None:
"""Initialize the handler."""
self.run_inline = run_inline
@override
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
"""Update the callstack with the name of the callback."""
some_var.set("on_llm_start")

View File

@ -74,6 +74,7 @@ async def test_async_custom_event_implicit_config() -> None:
# a decorator for async functions
@RunnableLambda # type: ignore[arg-type]
async def foo(x: int, config: RunnableConfig) -> int:
assert "callbacks" in config
await adispatch_custom_event("event1", {"x": x})
await adispatch_custom_event("event2", {"x": x})
return x

View File

@ -3,6 +3,7 @@
from collections.abc import Iterator
import pytest
from typing_extensions import override
from langchain_core.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_core.documents import Document
@ -15,6 +16,7 @@ def test_base_blob_parser() -> None:
class MyParser(BaseBlobParser):
"""A simple parser that returns a single document."""
@override
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazy parsing interface."""
yield Document(

View File

@ -1,6 +1,8 @@
from collections.abc import Iterable
from typing import Any, Optional, cast
from typing_extensions import override
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.example_selectors import (
@ -21,6 +23,7 @@ class DummyVectorStore(VectorStore):
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
@override
def add_texts(
self,
texts: Iterable[str],
@ -32,6 +35,7 @@ class DummyVectorStore(VectorStore):
self.metadatas.extend(metadatas)
return ["dummy_id"]
@override
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
@ -41,6 +45,7 @@ class DummyVectorStore(VectorStore):
)
] * k
@override
def max_marginal_relevance_search(
self,
query: str,

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
@ -138,6 +139,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
@override
def on_llm_start(
self,
*args: Any,
@ -145,6 +147,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_start_common()
@override
def on_llm_new_token(
self,
*args: Any,
@ -152,6 +155,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_new_token_common()
@override
def on_llm_end(
self,
*args: Any,
@ -159,6 +163,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_end_common()
@override
def on_llm_error(
self,
*args: Any,
@ -166,6 +171,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_error_common(*args, **kwargs)
@override
def on_retry(
self,
*args: Any,
@ -173,6 +179,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retry_common()
@override
def on_chain_start(
self,
*args: Any,
@ -180,6 +187,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_start_common()
@override
def on_chain_end(
self,
*args: Any,
@ -187,6 +195,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_end_common()
@override
def on_chain_error(
self,
*args: Any,
@ -194,6 +203,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_error_common()
@override
def on_tool_start(
self,
*args: Any,
@ -201,6 +211,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_start_common()
@override
def on_tool_end(
self,
*args: Any,
@ -208,6 +219,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_end_common()
@override
def on_tool_error(
self,
*args: Any,
@ -215,6 +227,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_error_common()
@override
def on_agent_action(
self,
*args: Any,
@ -222,6 +235,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_action_common()
@override
def on_agent_finish(
self,
*args: Any,
@ -229,6 +243,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_finish_common()
@override
def on_text(
self,
*args: Any,
@ -236,6 +251,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_text_common()
@override
def on_retriever_start(
self,
*args: Any,
@ -243,6 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_start_common()
@override
def on_retriever_end(
self,
*args: Any,
@ -250,6 +267,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_end_common()
@override
def on_retriever_error(
self,
*args: Any,
@ -263,6 +281,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
@override
def on_chat_model_start(
self,
serialized: dict[str, Any],
@ -294,6 +313,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@override
async def on_retry(
self,
*args: Any,
@ -301,6 +321,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> Any:
self.on_retry_common()
@override
async def on_llm_start(
self,
*args: Any,
@ -308,6 +329,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_start_common()
@override
async def on_llm_new_token(
self,
*args: Any,
@ -315,6 +337,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_new_token_common()
@override
async def on_llm_end(
self,
*args: Any,
@ -322,6 +345,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_end_common()
@override
async def on_llm_error(
self,
*args: Any,
@ -329,6 +353,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_error_common(*args, **kwargs)
@override
async def on_chain_start(
self,
*args: Any,
@ -336,6 +361,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_start_common()
@override
async def on_chain_end(
self,
*args: Any,
@ -343,6 +369,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_end_common()
@override
async def on_chain_error(
self,
*args: Any,
@ -350,6 +377,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_error_common()
@override
async def on_tool_start(
self,
*args: Any,
@ -357,6 +385,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_start_common()
@override
async def on_tool_end(
self,
*args: Any,
@ -364,6 +393,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_end_common()
@override
async def on_tool_error(
self,
*args: Any,
@ -371,6 +401,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_error_common()
@override
async def on_agent_action(
self,
*args: Any,
@ -378,6 +409,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_action_common()
@override
async def on_agent_finish(
self,
*args: Any,
@ -385,6 +417,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_finish_common()
@override
async def on_text(
self,
*args: Any,

View File

@ -4,6 +4,8 @@ from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import (
FakeListChatModel,
@ -171,6 +173,7 @@ async def test_callback_handlers() -> None:
# Required to implement since this is an abstract method
pass
@override
async def on_llm_new_token(
self,
token: str,

View File

@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import pytest
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel
@ -138,6 +139,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseChatModel):
@override
def _generate(
self,
messages: list[BaseMessage],
@ -176,6 +178,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
def _stream(
self,
messages: list[BaseMessage],
@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
async def _astream( # type: ignore
self,
messages: list[BaseMessage],
@ -286,6 +290,7 @@ async def test_async_pass_run_id() -> None:
class NoStreamingModel(BaseChatModel):
@override
def _generate(
self,
messages: list[BaseMessage],
@ -301,6 +306,7 @@ class NoStreamingModel(BaseChatModel):
class StreamingModel(NoStreamingModel):
@override
def _stream(
self,
messages: list[BaseMessage],

View File

@ -3,6 +3,7 @@
from typing import Any, Optional
import pytest
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
@ -30,6 +31,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}

View File

@ -2,6 +2,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional
import pytest
from typing_extensions import override
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -106,6 +107,7 @@ async def test_error_callback() -> None:
"""Return type of llm."""
return "failing-llm"
@override
def _call(
self,
prompt: str,
@ -136,6 +138,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseLLM):
@override
def _generate(
self,
prompts: list[str],
@ -172,6 +175,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
def _stream(
self,
prompt: str,
@ -209,6 +213,7 @@ async def test_astream_implementation_uses_astream() -> None:
"""Top Level call."""
raise NotImplementedError
@override
async def _astream(
self,
prompt: str,

View File

@ -1,5 +1,7 @@
from typing import Any, Optional
from typing_extensions import override
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models import FakeListLLM
@ -20,6 +22,7 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
@ -74,6 +77,7 @@ class InMemoryCacheBad(BaseCache):
msg = "This code should not be triggered"
raise NotImplementedError(msg)
@override
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}

View File

@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pytest
from typing_extensions import override
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
@ -660,6 +661,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],

View File

@ -1,5 +1,7 @@
"""Module to test base parser implementations."""
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
@ -16,6 +18,7 @@ def test_base_generation_parser() -> None:
class StrInvertCase(BaseGenerationOutputParser[str]):
"""An example parser that inverts the case of the characters in the message."""
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> str:
@ -59,6 +62,7 @@ def test_base_transform_output_parser() -> None:
"""Parse a single string into a specific format."""
raise NotImplementedError
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> str:

View File

@ -5,6 +5,7 @@ from collections.abc import Sequence
from typing import Any
import pytest
from typing_extensions import override
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@ -383,6 +384,7 @@ class AsIsSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> Any:
raise NotImplementedError
@override
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples)
@ -481,6 +483,7 @@ class AsyncAsIsSelector(BaseExampleSelector):
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
raise NotImplementedError
@override
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
return list(self.examples)

View File

@ -14,7 +14,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
) -> Union[BaseModel, dict]:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)

View File

@ -2,7 +2,7 @@ from typing import Any, Optional
import pytest
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from typing_extensions import Self, override
from langchain_core.runnables import (
ConfigurableField,
@ -32,6 +32,7 @@ class MyRunnable(RunnableSerializable[str, str]):
self._my_hidden_property = self.my_property
return self
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@ -41,12 +42,15 @@ class MyRunnable(RunnableSerializable[str, str]):
return self.my_property
def my_custom_function_w_config(
self, config: Optional[RunnableConfig] = None
self,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str:
return self.my_property
def my_custom_function_w_kw_config(
self, *, config: Optional[RunnableConfig] = None
self,
*,
config: Optional[RunnableConfig] = None, # noqa: ARG002
) -> str:
return self.my_property
@ -54,6 +58,7 @@ class MyRunnable(RunnableSerializable[str, str]):
class MyOtherRunnable(RunnableSerializable[str, str]):
my_other_property: str
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@ -62,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def my_other_custom_function(self) -> str:
return self.my_other_property
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str:
def my_other_custom_function_w_config(self, config: RunnableConfig) -> str: # noqa: ARG002
return self.my_other_property

View File

@ -25,7 +25,7 @@ def seq_naive_rag() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
@ -53,7 +53,7 @@ def seq_naive_rag_alt() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])
@ -78,7 +78,7 @@ def seq_naive_rag_scoped() -> Runnable:
"What's your name?",
]
retriever = RunnableLambda(lambda x: context)
retriever = RunnableLambda(lambda _: context)
prompt = PromptTemplate.from_template("{context} {question}")
llm = FakeListLLM(responses=["hello"])

View File

@ -9,6 +9,7 @@ from typing import (
import pytest
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import (
@ -60,7 +61,7 @@ def chain() -> Runnable:
)
def _raise_error(inputs: dict) -> str:
def _raise_error(_: dict) -> str:
raise ValueError
@ -259,17 +260,17 @@ async def test_abatch() -> None:
_assert_potential_error(actual, expected)
def _generate(input: Iterator) -> Iterator[str]:
def _generate(_: Iterator) -> Iterator[str]:
yield from "foo bar"
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@ -288,18 +289,18 @@ def test_fallbacks_stream() -> None:
list(runnable.stream({}))
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
for c in "foo bar":
yield c
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@ -323,6 +324,7 @@ async def test_fallbacks_astream() -> None:
class FakeStructuredOutputModel(BaseChatModel):
foo: int
@override
def _generate(
self,
messages: list[BaseMessage],
@ -333,6 +335,7 @@ class FakeStructuredOutputModel(BaseChatModel):
"""Top Level call."""
return ChatResult(generations=[])
@override
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
@ -340,10 +343,11 @@ class FakeStructuredOutputModel(BaseChatModel):
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
@override
def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return RunnableLambda(lambda x: {"foo": self.foo})
return RunnableLambda(lambda _: {"foo": self.foo})
@property
def _llm_type(self) -> str:
@ -353,6 +357,7 @@ class FakeStructuredOutputModel(BaseChatModel):
class FakeModel(BaseChatModel):
bar: int
@override
def _generate(
self,
messages: list[BaseMessage],
@ -363,6 +368,7 @@ class FakeModel(BaseChatModel):
"""Top Level call."""
return ChatResult(generations=[])
@override
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union
import pytest
from packaging import version
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
@ -39,7 +40,7 @@ def _get_get_session_history(
chat_history_store = store if store is not None else {}
def get_session_history(
session_id: str, **kwargs: Any
session_id: str, **_kwargs: Any
) -> InMemoryChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = InMemoryChatMessageHistory()
@ -253,6 +254,7 @@ async def test_output_message_async() -> None:
class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in."""
@override
def _generate(
self,
messages: list[BaseMessage],
@ -856,7 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message)
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
@ -874,7 +876,7 @@ def test_get_output_messages_with_value_error() -> None:
with_history.bound.invoke([HumanMessage(content="hello")], config)
illegal_int_message = 123
runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_int_message)
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_int_message)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
with pytest.raises(

View File

@ -21,10 +21,11 @@ from packaging import version
from pydantic import BaseModel, Field
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict
from typing_extensions import TypedDict, override
from langchain_core.callbacks.manager import (
Callbacks,
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
atrace_as_chain_group,
trace_as_chain_group,
)
@ -184,6 +185,7 @@ class FakeTracer(BaseTracer):
class FakeRunnable(Runnable[str, int]):
@override
def invoke(
self,
input: str,
@ -196,6 +198,7 @@ class FakeRunnable(Runnable[str, int]):
class FakeRunnableSerializable(RunnableSerializable[str, int]):
hello: str = ""
@override
def invoke(
self,
input: str,
@ -206,25 +209,15 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
class FakeRetriever(BaseRetriever):
@override
def _get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@override
async def _aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@ -506,7 +499,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
foo_ = RunnableLambda(foo)
assert foo_.assign(bar=lambda x: "foo").get_output_schema().model_json_schema() == {
assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
"properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
"required": ["root", "bar"],
"title": "RunnableAssignOutput",
@ -1782,10 +1775,10 @@ def test_with_listener_propagation(mocker: MockerFixture) -> None:
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -1888,10 +1881,10 @@ def test_prompt_with_chat_model(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
async def test_prompt_with_chat_model_async(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2519,7 +2512,7 @@ async def test_stream_log_retriever() -> None:
@freeze_time("2023-01-01")
async def test_stream_log_lists() -> None:
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
for i in range(4):
yield AddableDict(alist=[str(i)])
@ -2631,10 +2624,10 @@ async def test_prompt_with_llm_and_async_lambda(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2672,10 +2665,9 @@ def test_prompt_with_chat_model_and_parser(
@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_combining_sequences(
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -3513,7 +3505,7 @@ def test_bind_bind() -> None:
def test_bind_with_lambda() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
@ -3523,7 +3515,7 @@ def test_bind_with_lambda() -> None:
async def test_bind_with_lambda_async() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
def my_function(_: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
@ -3858,7 +3850,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
return x
runnable = RunnableLambda(_simple_recursion)
@ -4008,7 +4000,7 @@ def test_runnable_lambda_stream() -> None:
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda x: llm).stream(""))
output = list(RunnableLambda(lambda _: llm).stream(""))
assert output == list(llm_res)
@ -4021,7 +4013,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
llm_res
)
@ -4029,7 +4021,7 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@ -4076,7 +4068,7 @@ async def test_runnable_lambda_astream() -> None:
_
async for _ in RunnableLambda(
func=id,
afunc=awrapper(lambda x: llm),
afunc=awrapper(lambda _: llm),
).astream("")
]
assert output == list(llm_res)
@ -4084,7 +4076,7 @@ async def test_runnable_lambda_astream() -> None:
output = [
chunk
async for chunk in cast(
"AsyncIterator[str]", RunnableLambda(lambda x: llm).astream("")
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
)
]
assert output == list(llm_res)
@ -4100,14 +4092,14 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
config: RunnableConfig = {"callbacks": [tracer]}
assert [
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
] == list(llm_res)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@ -4487,7 +4479,7 @@ def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
def test_runnable_branch_invoke() -> None:
# Test with single branch
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@ -4552,7 +4544,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
"""Verify that callbacks are correctly used in invoke."""
tracer = FakeTracer()
def raise_value_error(x: int) -> int:
def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@ -4580,7 +4572,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
"""Verify that callbacks are invoked correctly in ainvoke."""
tracer = FakeTracer()
async def raise_value_error(x: int) -> int:
async def raise_value_error(_: int) -> int:
"""Raise a value error."""
msg = "x is too large"
raise ValueError(msg)
@ -4755,13 +4747,13 @@ def test_representation_of_runnables() -> None:
runnable = RunnableLambda(lambda x: x * 2)
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
def f(x: int) -> int:
def f(_: int) -> int:
"""Return 2."""
return 2
assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"
async def af(x: int) -> int:
async def af(_: int) -> int:
"""Return 2."""
return 2
@ -4814,7 +4806,7 @@ async def test_tool_from_runnable() -> None:
def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable."""
def gen(input: Iterator[Any]) -> Iterator[int]:
def gen(_: Iterator[Any]) -> Iterator[int]:
yield 1
yield 2
yield 3
@ -4835,7 +4827,7 @@ def test_runnable_gen() -> None:
async def test_runnable_gen_async() -> None:
"""Test that a generator can be used as a runnable."""
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
@ -4847,7 +4839,7 @@ async def test_runnable_gen_async() -> None:
assert await arunnable.abatch([None, None]) == [6, 6]
class AsyncGen:
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
@ -4870,7 +4862,7 @@ def test_runnable_gen_context_config() -> None:
"""
fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]:
def gen(_: Iterator[Any]) -> Iterator[int]:
yield fake.invoke("a")
yield fake.invoke("aa")
yield fake.invoke("aaa")
@ -4944,7 +4936,7 @@ async def test_runnable_gen_context_config_async() -> None:
fake = RunnableLambda(len)
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
yield await fake.ainvoke("a")
yield await fake.ainvoke("aa")
yield await fake.ainvoke("aaa")
@ -5441,6 +5433,7 @@ def test_default_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@ -5462,6 +5455,7 @@ async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]):
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@ -5581,6 +5575,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
on_chain_end_triggered = False
class MyHandler(BaseCallbackHandler):
@override
def on_chain_error(
self,
error: BaseException,
@ -5594,6 +5589,7 @@ def test_closing_iterator_doesnt_raise_error() -> None:
nonlocal on_chain_error_triggered
on_chain_error_triggered = True
@override
def on_chain_end(
self,
outputs: dict[str, Any],

View File

@ -8,6 +8,7 @@ from typing import Any, cast
import pytest
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
@ -82,12 +83,12 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
def foo(_: int) -> dict:
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc."""
return [Document(page_content="hello")]
@ -434,7 +435,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(
@ -1021,7 +1022,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello"
@tool
def with_callbacks(callbacks: Callbacks) -> str:
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing."""
return "world"
@ -1031,7 +1032,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y}
@tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing."""
return {"x": x, "y": y}
@ -1180,6 +1181,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever):
documents: list[Document]
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
@ -1592,10 +1594,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool."""
def success(inputs: str) -> str:
def success(_: str) -> str:
return "success"
def fail(inputs: str) -> None:
def fail(_: str) -> None:
"""Simple func."""
msg = "fail"
raise ValueError(msg)

View File

@ -17,6 +17,7 @@ from typing import (
import pytest
from blockbuster import BlockBuster
from pydantic import BaseModel
from typing_extensions import override
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
@ -97,12 +98,12 @@ async def _collect_events(
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
def foo(x: int) -> dict: # noqa: ARG001
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
def get_docs(x: int) -> list[Document]: # noqa: ARG001
"""Hello Doc."""
return [Document(page_content="hello")]
@ -465,7 +466,7 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
as_lambdas = RunnableLambda(lambda _: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(
@ -1043,7 +1044,7 @@ async def test_event_streaming_with_tools() -> None:
return "hello"
@tool
def with_callbacks(callbacks: Callbacks) -> str:
def with_callbacks(callbacks: Callbacks) -> str: # noqa: ARG001
"""A tool that does nothing."""
return "world"
@ -1053,7 +1054,7 @@ async def test_event_streaming_with_tools() -> None:
return {"x": x, "y": y}
@tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: # noqa: ARG001
"""A tool that does nothing."""
return {"x": x, "y": y}
@ -1165,6 +1166,7 @@ async def test_event_streaming_with_tools() -> None:
class HardCodedRetriever(BaseRetriever):
documents: list[Document]
@override
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
@ -1553,10 +1555,10 @@ async def test_chain_ordering() -> None:
async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool."""
def success(inputs: str) -> str:
def success(_: str) -> str:
return "success"
def fail(inputs: str) -> None:
def fail(_: str) -> None:
"""Simple func."""
msg = "fail"
raise ValueError(msg)
@ -2069,6 +2071,7 @@ class StreamingRunnable(Runnable[Input, Output]):
"""Initialize the runnable."""
self.iterable = iterable
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@ -2084,6 +2087,7 @@ class StreamingRunnable(Runnable[Input, Output]):
) -> Iterator[Output]:
raise NotImplementedError
@override
async def astream(
self,
input: Input,
@ -2323,7 +2327,7 @@ async def test_bad_parent_ids() -> None:
async def test_runnable_generator() -> None:
"""Test async events from sync lambda."""
async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]:
async def generator(_: AsyncIterator[str]) -> AsyncIterator[str]:
yield "1"
yield "2"

View File

@ -375,7 +375,7 @@ class TestRunnableSequenceParallelTraceNesting:
def test_sync(
self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int]
) -> None:
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
def other_thing(_: int) -> Generator[int, None, None]: # type: ignore
yield 1
parent = self._create_parent(other_thing)
@ -407,7 +407,7 @@ class TestRunnableSequenceParallelTraceNesting:
[RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int]
],
) -> None:
async def other_thing(a: int) -> AsyncGenerator[int, None]:
async def other_thing(_: int) -> AsyncGenerator[int, None]:
yield 1
parent = self._create_parent(other_thing)

View File

@ -2301,7 +2301,7 @@ def test_injected_arg_with_complex_type() -> None:
self.value = "bar"
@tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
"""Tool that has an injected tool arg."""
return foo.value
@ -2477,7 +2477,7 @@ def test_simple_tool_args_schema_dict() -> None:
def test_empty_string_tool_call_id() -> None:
@tool
def foo(x: int) -> str:
def foo(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@ -2489,7 +2489,7 @@ def test_empty_string_tool_call_id() -> None:
def test_tool_decorator_description() -> None:
# test basic tool
@tool
def foo(x: int) -> str:
def foo(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@ -2501,7 +2501,7 @@ def test_tool_decorator_description() -> None:
# test basic tool with description
@tool(description="description")
def foo_description(x: int) -> str:
def foo_description(x: int) -> str: # noqa: ARG001
"""Foo."""
return "hi"
@ -2520,7 +2520,7 @@ def test_tool_decorator_description() -> None:
x: int
@tool(args_schema=ArgsSchema)
def foo_args_schema(x: int) -> str:
def foo_args_schema(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_schema.description == "Bar."
@ -2532,7 +2532,7 @@ def test_tool_decorator_description() -> None:
)
@tool(description="description", args_schema=ArgsSchema)
def foo_args_schema_description(x: int) -> str:
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_schema_description.description == "description"
@ -2554,11 +2554,11 @@ def test_tool_decorator_description() -> None:
}
@tool(args_schema=args_json_schema)
def foo_args_jsons_schema(x: int) -> str:
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
return "hi"
@tool(description="description", args_schema=args_json_schema)
def foo_args_jsons_schema_with_description(x: int) -> str:
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
return "hi"
assert foo_args_jsons_schema.description == "JSON Schema."
@ -2620,10 +2620,10 @@ def test_title_property_preserved() -> None:
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
def sync_no_op(foo: int) -> str:
def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good"
async def async_no_op(foo: int) -> str:
async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good"
tool = StructuredTool(
@ -2668,10 +2668,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str:
def sync_no_op(foo: int) -> str: # noqa: ARG001
return "good"
async def async_no_op(foo: int) -> str:
async def async_no_op(foo: int) -> str: # noqa: ARG001
return "good"
tool = StructuredTool(

View File

@ -121,7 +121,7 @@ def dummy_structured_tool() -> StructuredTool:
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return StructuredTool.from_function(
lambda x: None,
lambda _: None,
name="dummy_function",
description="Dummy function.",
args_schema=Schema,
@ -143,7 +143,7 @@ def dummy_structured_tool_args_schema_dict() -> StructuredTool:
"required": ["arg1", "arg2"],
}
return StructuredTool.from_function(
lambda x: None,
lambda _: None,
name="dummy_function",
description="Dummy function.",
args_schema=args_schema,

View File

@ -10,6 +10,7 @@ import uuid
from typing import TYPE_CHECKING, Any, Optional
import pytest
from typing_extensions import override
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings
@ -25,6 +26,7 @@ class CustomAddTextsVectorstore(VectorStore):
def __init__(self) -> None:
self.store: dict[str, Document] = {}
@override
def add_texts(
self,
texts: Iterable[str],
@ -51,6 +53,7 @@ class CustomAddTextsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store]
@classmethod
@override
def from_texts( # type: ignore
cls,
texts: list[str],
@ -74,6 +77,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
def __init__(self) -> None:
self.store: dict[str, Document] = {}
@override
def add_documents(
self,
documents: list[Document],
@ -95,6 +99,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
return [self.store[id] for id in ids if id in self.store]
@classmethod
@override
def from_texts( # type: ignore
cls,
texts: list[str],