diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 259260f718b..57cac1347c5 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -50,7 +50,7 @@ class LangSmithLoader(BaseLoader): offset: int = 0, limit: Optional[int] = None, metadata: Optional[dict] = None, - filter: Optional[str] = None, + filter: Optional[str] = None, # noqa: A002 content_key: str = "", format_content: Optional[Callable[..., str]] = None, client: Optional[LangSmithClient] = None, diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index 942b6ebf890..8b2a6776ab0 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -92,13 +92,13 @@ def _convert_openai_format_to_data_block(block: dict) -> dict: if block["type"] == "input_audio": data = block["input_audio"].get("data") - format = block["input_audio"].get("format") - if data and format: + audio_format = block["input_audio"].get("format") + if data and audio_format: return { "type": "audio", "source_type": "base64", "data": data, - "mime_type": f"audio/{format}", + "mime_type": f"audio/{audio_format}", } return block diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 714ff42ccd4..a0e975271f6 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -343,15 +343,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): """Get the output type for this runnable.""" return AnyMessage - def _convert_input(self, input: LanguageModelInput) -> PromptValue: - if isinstance(input, PromptValue): - return input - if isinstance(input, str): - return StringPromptValue(text=input) - if isinstance(input, Sequence): - return ChatPromptValue(messages=convert_to_messages(input)) + def _convert_input(self, model_input: LanguageModelInput) -> PromptValue: + if isinstance(model_input, PromptValue): + return model_input + if isinstance(model_input, str): + return StringPromptValue(text=model_input) + if isinstance(model_input, Sequence): + return ChatPromptValue(messages=convert_to_messages(model_input)) msg = ( - f"Invalid input type {type(input)}. " + f"Invalid input type {type(model_input)}. " "Must be a PromptValue, str, or list of BaseMessages." ) raise ValueError(msg) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 353cd917707..070314e0ba6 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -327,15 +327,15 @@ class BaseLLM(BaseLanguageModel[str], ABC): """Get the input type for this runnable.""" return str - def _convert_input(self, input: LanguageModelInput) -> PromptValue: - if isinstance(input, PromptValue): - return input - if isinstance(input, str): - return StringPromptValue(text=input) - if isinstance(input, Sequence): - return ChatPromptValue(messages=convert_to_messages(input)) + def _convert_input(self, model_input: LanguageModelInput) -> PromptValue: + if isinstance(model_input, PromptValue): + return model_input + if isinstance(model_input, str): + return StringPromptValue(text=model_input) + if isinstance(model_input, Sequence): + return ChatPromptValue(messages=convert_to_messages(model_input)) msg = ( - f"Invalid input type {type(input)}. " + f"Invalid input type {type(model_input)}. " "Must be a PromptValue, str, or list of BaseMessages." ) raise ValueError(msg) @@ -440,7 +440,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): if max_concurrency is None: try: llm_result = self.generate_prompt( - [self._convert_input(input) for input in inputs], + [self._convert_input(input_) for input_ in inputs], callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], @@ -486,7 +486,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): if max_concurrency is None: try: llm_result = await self.agenerate_prompt( - [self._convert_input(input) for input in inputs], + [self._convert_input(input_) for input_ in inputs], callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 8b944929fb1..8641f976a6e 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -421,18 +421,18 @@ def add_ai_message_chunks( else: usage_metadata = None - id = None + chunk_id = None candidates = [left.id] + [o.id for o in others] # first pass: pick the first non-run-* id for id_ in candidates: if id_ and not id_.startswith(_LC_ID_PREFIX): - id = id_ + chunk_id = id_ break else: # second pass: no provider-assigned id found, just take the first non-null for id_ in candidates: if id_: - id = id_ + chunk_id = id_ break return left.__class__( @@ -442,7 +442,7 @@ def add_ai_message_chunks( tool_call_chunks=tool_call_chunks, response_metadata=response_metadata, usage_metadata=usage_metadata, - id=id, + id=chunk_id, ) diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py index cbd9bfa2f00..83a66fb123a 100644 --- a/libs/core/langchain_core/messages/content_blocks.py +++ b/libs/core/langchain_core/messages/content_blocks.py @@ -140,10 +140,10 @@ def convert_to_openai_data_block(block: dict) -> dict: elif block["type"] == "audio": if block["source_type"] == "base64": - format = block["mime_type"].split("/")[-1] + audio_format = block["mime_type"].split("/")[-1] formatted_block = { "type": "input_audio", - "input_audio": {"data": block["data"], "format": format}, + "input_audio": {"data": block["data"], "format": audio_format}, } else: error_msg = "source_type base64 is required for audio blocks." diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py index 9cfc89c5760..08b7e79b69c 100644 --- a/libs/core/langchain_core/messages/modifier.py +++ b/libs/core/langchain_core/messages/modifier.py @@ -11,7 +11,11 @@ class RemoveMessage(BaseMessage): type: Literal["remove"] = "remove" """The type of the message (used for serialization). Defaults to "remove".""" - def __init__(self, id: str, **kwargs: Any) -> None: + def __init__( + self, + id: str, # noqa: A002 + **kwargs: Any, + ) -> None: """Create a RemoveMessage. Args: diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 42c9018e151..553006e06e1 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -208,7 +208,12 @@ class ToolCall(TypedDict): type: NotRequired[Literal["tool_call"]] -def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall: +def tool_call( + *, + name: str, + args: dict[str, Any], + id: Optional[str], # noqa: A002 +) -> ToolCall: """Create a tool call. Args: @@ -254,7 +259,7 @@ def tool_call_chunk( *, name: Optional[str] = None, args: Optional[str] = None, - id: Optional[str] = None, + id: Optional[str] = None, # noqa: A002 index: Optional[int] = None, ) -> ToolCallChunk: """Create a tool call chunk. @@ -292,7 +297,7 @@ def invalid_tool_call( *, name: Optional[str] = None, args: Optional[str] = None, - id: Optional[str] = None, + id: Optional[str] = None, # noqa: A002 error: Optional[str] = None, ) -> InvalidToolCall: """Create an invalid tool call. diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index d99c0f164d3..adad9c9b6ac 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -212,7 +212,7 @@ def _create_message_from_message_type( name: Optional[str] = None, tool_call_id: Optional[str] = None, tool_calls: Optional[list[dict[str, Any]]] = None, - id: Optional[str] = None, + id: Optional[str] = None, # noqa: A002 **additional_kwargs: Any, ) -> BaseMessage: """Create a message from a message type and content string. diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index b8756489bb9..021c7a03bd0 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union import jsonpatch # type: ignore[import-untyped] import pydantic from pydantic import SkipValidation +from typing_extensions import override from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS @@ -47,6 +48,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): """The Pydantic object to use for validation. If None, no validation is performed.""" + @override def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 3dba7e3af5c..d60afe3554d 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -20,7 +20,10 @@ if TYPE_CHECKING: T = TypeVar("T") -def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: +def droplastn( + iter: Iterator[T], # noqa: A002 + n: int, +) -> Iterator[T]: """Drop the last n elements of an iterator. Args: @@ -66,6 +69,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): """ raise NotImplementedError + @override def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[list[str]]: @@ -99,6 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): for part in self.parse(buffer): yield [part] + @override async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[list[str]]: diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index f71b2c8044b..c656fc5cbb7 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -72,6 +72,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): def _type(self) -> str: return "json_functions" + @override def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 783abedf116..96111f9a4c2 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -30,7 +30,10 @@ if TYPE_CHECKING: class BaseTransformOutputParser(BaseOutputParser[T]): """Base class for an output parser that can handle streaming input.""" - def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]: + def _transform( + self, + input: Iterator[Union[str, BaseMessage]], # noqa: A002 + ) -> Iterator[T]: for chunk in input: if isinstance(chunk, BaseMessage): yield self.parse_result([ChatGeneration(message=chunk)]) @@ -38,7 +41,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]): yield self.parse_result([Generation(text=chunk)]) async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] + self, + input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002 ) -> AsyncIterator[T]: async for chunk in input: if isinstance(chunk, BaseMessage): @@ -102,7 +106,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): parsed output, or just the current parsed output. """ - def _diff(self, prev: Optional[T], next: T) -> T: + def _diff( + self, + prev: Optional[T], + next: T, # noqa: A002 + ) -> T: """Convert parsed outputs into a diff format. The semantics of this are up to the output parser. @@ -116,6 +124,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """ raise NotImplementedError + @override def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: prev_parsed = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None @@ -140,6 +149,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): yield parsed prev_parsed = parsed + @override async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[T]: diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 9bbe2ef0814..c5a9527eb04 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -8,6 +8,8 @@ from collections.abc import AsyncIterator, Iterator from typing import Any, Literal, Optional, Union from xml.etree.ElementTree import TreeBuilder +from typing_extensions import override + from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser @@ -234,6 +236,7 @@ class XMLOutputParser(BaseTransformOutputParser): msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e + @override def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: @@ -242,6 +245,7 @@ class XMLOutputParser(BaseTransformOutputParser): yield from streaming_parser.parse(chunk) streaming_parser.close() + @override async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 46759942832..633465bb500 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -471,16 +471,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): img_template = cast("_ImageTemplateParam", tmpl)["image_url"] input_variables = [] if isinstance(img_template, str): - vars = get_template_variables(img_template, template_format) - if vars: - if len(vars) > 1: + variables = get_template_variables( + img_template, template_format + ) + if variables: + if len(variables) > 1: msg = ( "Only one format variable allowed per image" - f" template.\nGot: {vars}" + f" template.\nGot: {variables}" f"\nFrom: {tmpl}" ) raise ValueError(msg) - input_variables = [vars[0]] + input_variables = [variables[0]] img_template = {"url": img_template} img_template_obj = ImagePromptTemplate( input_variables=input_variables, diff --git a/libs/core/langchain_core/prompts/dict.py b/libs/core/langchain_core/prompts/dict.py index 0ccdf7a64a0..79882fc47f5 100644 --- a/libs/core/langchain_core/prompts/dict.py +++ b/libs/core/langchain_core/prompts/dict.py @@ -4,6 +4,8 @@ import warnings from functools import cached_property from typing import Any, Literal, Optional +from typing_extensions import override + from langchain_core.load import dumpd from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, @@ -36,6 +38,7 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]): """Format the prompt with the inputs.""" return self.format(**kwargs) + @override def invoke( self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> dict: diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index e63d54dbbde..084dd09f1f5 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -125,20 +125,20 @@ def mustache_template_vars( Returns: The variables from the template. """ - vars: set[str] = set() + variables: set[str] = set() section_depth = 0 - for type, key in mustache.tokenize(template): - if type == "end": + for type_, key in mustache.tokenize(template): + if type_ == "end": section_depth -= 1 elif ( - type in ("variable", "section", "inverted section", "no escape") + type_ in ("variable", "section", "inverted section", "no escape") and key != "." and section_depth == 0 ): - vars.add(key.split(".")[0]) - if type in ("section", "inverted section"): + variables.add(key.split(".")[0]) + if type_ in ("section", "inverted section"): section_depth += 1 - return vars + return variables Defs = dict[str, "Defs"] @@ -158,17 +158,17 @@ def mustache_schema( fields = {} prefix: tuple[str, ...] = () section_stack: list[tuple[str, ...]] = [] - for type, key in mustache.tokenize(template): + for type_, key in mustache.tokenize(template): if key == ".": continue - if type == "end": + if type_ == "end": if section_stack: prefix = section_stack.pop() - elif type in ("section", "inverted section"): + elif type_ in ("section", "inverted section"): section_stack.append(prefix) prefix = prefix + tuple(key.split(".")) fields[prefix] = False - elif type in ("variable", "no escape"): + elif type_ in ("variable", "no escape"): fields[prefix + tuple(key.split("."))] = True defs: Defs = {} # None means leaf node while fields: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 3456cc3e02c..557599a5ac4 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -209,6 +209,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) + @override def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> list[Document]: @@ -269,6 +270,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): ) return result + @override async def ainvoke( self, input: str, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c4a176d1843..de19c5f5915 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -724,7 +724,10 @@ class Runnable(Generic[Input, Output], ABC): @abstractmethod def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: Input, # noqa: A002 + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Output: """Transform a single input into an output. @@ -741,7 +744,10 @@ class Runnable(Generic[Input, Output], ABC): """ async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: Input, # noqa: A002 + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> Output: """Default implementation of ainvoke, calls invoke from a thread. @@ -772,14 +778,14 @@ class Runnable(Generic[Input, Output], ABC): configs = get_config_list(config, len(inputs)) - def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: + def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]: if return_exceptions: try: - return self.invoke(input, config, **kwargs) + return self.invoke(input_, config, **kwargs) except Exception as e: return e else: - return self.invoke(input, config, **kwargs) + return self.invoke(input_, config, **kwargs) # If there's only one input, don't bother with the executor if len(inputs) == 1: @@ -826,15 +832,17 @@ class Runnable(Generic[Input, Output], ABC): configs = get_config_list(config, len(inputs)) def invoke( - i: int, input: Input, config: RunnableConfig + i: int, input_: Input, config: RunnableConfig ) -> tuple[int, Union[Output, Exception]]: if return_exceptions: try: - out: Union[Output, Exception] = self.invoke(input, config, **kwargs) + out: Union[Output, Exception] = self.invoke( + input_, config, **kwargs + ) except Exception as e: out = e else: - out = self.invoke(input, config, **kwargs) + out = self.invoke(input_, config, **kwargs) return (i, out) @@ -844,8 +852,8 @@ class Runnable(Generic[Input, Output], ABC): with get_executor_for_config(configs[0]) as executor: futures = { - executor.submit(invoke, i, input, config) - for i, (input, config) in enumerate(zip(inputs, configs)) + executor.submit(invoke, i, input_, config) + for i, (input_, config) in enumerate(zip(inputs, configs)) } try: @@ -892,15 +900,15 @@ class Runnable(Generic[Input, Output], ABC): configs = get_config_list(config, len(inputs)) async def ainvoke( - input: Input, config: RunnableConfig + value: Input, config: RunnableConfig ) -> Union[Output, Exception]: if return_exceptions: try: - return await self.ainvoke(input, config, **kwargs) + return await self.ainvoke(value, config, **kwargs) except Exception as e: return e else: - return await self.ainvoke(input, config, **kwargs) + return await self.ainvoke(value, config, **kwargs) coros = map(ainvoke, inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) @@ -960,24 +968,24 @@ class Runnable(Generic[Input, Output], ABC): semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def ainvoke_task( - i: int, input: Input, config: RunnableConfig + i: int, input_: Input, config: RunnableConfig ) -> tuple[int, Union[Output, Exception]]: if return_exceptions: try: out: Union[Output, Exception] = await self.ainvoke( - input, config, **kwargs + input_, config, **kwargs ) except Exception as e: out = e else: - out = await self.ainvoke(input, config, **kwargs) + out = await self.ainvoke(input_, config, **kwargs) return (i, out) coros = [ - gated_coro(semaphore, ainvoke_task(i, input, config)) + gated_coro(semaphore, ainvoke_task(i, input_, config)) if semaphore - else ainvoke_task(i, input, config) - for i, (input, config) in enumerate(zip(inputs, configs)) + else ainvoke_task(i, input_, config) + for i, (input_, config) in enumerate(zip(inputs, configs)) ] for coro in asyncio.as_completed(coros): @@ -985,7 +993,7 @@ class Runnable(Generic[Input, Output], ABC): def stream( self, - input: Input, + input: Input, # noqa: A002 config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: @@ -1005,7 +1013,7 @@ class Runnable(Generic[Input, Output], ABC): async def astream( self, - input: Input, + input: Input, # noqa: A002 config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: @@ -1059,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC): async def astream_log( self, - input: Any, + input: Any, # noqa: A002 config: Optional[RunnableConfig] = None, *, diff: bool = True, @@ -1130,7 +1138,7 @@ class Runnable(Generic[Input, Output], ABC): async def astream_events( self, - input: Any, + input: Any, # noqa: A002 config: Optional[RunnableConfig] = None, *, version: Literal["v1", "v2"] = "v2", @@ -1396,7 +1404,7 @@ class Runnable(Generic[Input, Output], ABC): def transform( self, - input: Iterator[Input], + input: Iterator[Input], # noqa: A002 config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: @@ -1438,7 +1446,7 @@ class Runnable(Generic[Input, Output], ABC): async def atransform( self, - input: AsyncIterator[Input], + input: AsyncIterator[Input], # noqa: A002 config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: @@ -1903,7 +1911,7 @@ class Runnable(Generic[Input, Output], ABC): Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], ], - input: Input, + input_: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, serialized: Optional[dict[str, Any]] = None, @@ -1917,7 +1925,7 @@ class Runnable(Generic[Input, Output], ABC): callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( serialized, - input, + input_, run_type=run_type, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), @@ -1930,7 +1938,7 @@ class Runnable(Generic[Input, Output], ABC): context.run( call_func_with_variable_args, # type: ignore[arg-type] func, - input, + input_, config, run_manager, **kwargs, @@ -1953,7 +1961,7 @@ class Runnable(Generic[Input, Output], ABC): Awaitable[Output], ], ], - input: Input, + input_: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, serialized: Optional[dict[str, Any]] = None, @@ -1967,7 +1975,7 @@ class Runnable(Generic[Input, Output], ABC): callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( serialized, - input, + input_, run_type=run_type, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), @@ -1976,7 +1984,7 @@ class Runnable(Generic[Input, Output], ABC): child_config = patch_config(config, callbacks=run_manager.get_child()) with set_config_context(child_config) as context: coro = acall_func_with_variable_args( - func, input, config, run_manager, **kwargs + func, input_, config, run_manager, **kwargs ) output: Output = await coro_with_context(coro, context) except BaseException as e: @@ -1999,7 +2007,7 @@ class Runnable(Generic[Input, Output], ABC): list[Union[Exception, Output]], ], ], - input: list[Input], + inputs: list[Input], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, @@ -2011,21 +2019,21 @@ class Runnable(Generic[Input, Output], ABC): Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses. """ - if not input: + if not inputs: return [] - configs = get_config_list(config, len(input)) + configs = get_config_list(config, len(inputs)) callback_managers = [get_callback_manager_for_config(c) for c in configs] run_managers = [ callback_manager.on_chain_start( None, - input, + input_, run_type=run_type, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for callback_manager, input, config in zip( - callback_managers, input, configs + for callback_manager, input_, config in zip( + callback_managers, inputs, configs ) ] try: @@ -2036,12 +2044,12 @@ class Runnable(Generic[Input, Output], ABC): ] if accepts_run_manager(func): kwargs["run_manager"] = run_managers - output = func(input, **kwargs) # type: ignore[call-arg] + output = func(inputs, **kwargs) # type: ignore[call-arg] except BaseException as e: for run_manager in run_managers: run_manager.on_chain_error(e) if return_exceptions: - return cast("list[Output]", [e for _ in input]) + return cast("list[Output]", [e for _ in inputs]) raise else: first_exception: Optional[Exception] = None @@ -2072,7 +2080,7 @@ class Runnable(Generic[Input, Output], ABC): Awaitable[list[Union[Exception, Output]]], ], ], - input: list[Input], + inputs: list[Input], config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, @@ -2085,22 +2093,22 @@ class Runnable(Generic[Input, Output], ABC): with callbacks. Use this method to implement invoke() in subclasses. """ - if not input: + if not inputs: return [] - configs = get_config_list(config, len(input)) + configs = get_config_list(config, len(inputs)) callback_managers = [get_async_callback_manager_for_config(c) for c in configs] run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( callback_manager.on_chain_start( None, - input, + input_, run_type=run_type, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for callback_manager, input, config in zip( - callback_managers, input, configs + for callback_manager, input_, config in zip( + callback_managers, inputs, configs ) ) ) @@ -2112,13 +2120,13 @@ class Runnable(Generic[Input, Output], ABC): ] if accepts_run_manager(func): kwargs["run_manager"] = run_managers - output = await func(input, **kwargs) # type: ignore[call-arg] + output = await func(inputs, **kwargs) # type: ignore[call-arg] except BaseException as e: await asyncio.gather( *(run_manager.on_chain_error(e) for run_manager in run_managers) ) if return_exceptions: - return cast("list[Output]", [e for _ in input]) + return cast("list[Output]", [e for _ in inputs]) raise else: first_exception: Optional[Exception] = None @@ -2136,7 +2144,7 @@ class Runnable(Generic[Input, Output], ABC): def _transform_stream_with_config( self, - input: Iterator[Input], + inputs: Iterator[Input], transformer: Union[ Callable[[Iterator[Input]], Iterator[Output]], Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], @@ -2163,7 +2171,7 @@ class Runnable(Generic[Input, Output], ABC): from langchain_core.tracers._streaming import _StreamingCallbackHandler # tee the input so we can iterate over it twice - input_for_tracing, input_for_transform = tee(input, 2) + input_for_tracing, input_for_transform = tee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one final_input: Optional[Input] = next(input_for_tracing, None) final_input_supported = True @@ -2237,7 +2245,7 @@ class Runnable(Generic[Input, Output], ABC): async def _atransform_stream_with_config( self, - input: AsyncIterator[Input], + inputs: AsyncIterator[Input], transformer: Union[ Callable[[AsyncIterator[Input]], AsyncIterator[Output]], Callable[ @@ -2267,7 +2275,7 @@ class Runnable(Generic[Input, Output], ABC): from langchain_core.tracers._streaming import _StreamingCallbackHandler # tee the input so we can iterate over it twice - input_for_tracing, input_for_transform = atee(input, 2) + input_for_tracing, input_for_transform = atee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one final_input: Optional[Input] = await py_anext(input_for_tracing, None) final_input_supported = True @@ -3019,6 +3027,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) + input_ = input # invoke all steps in sequence try: @@ -3029,16 +3038,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) with set_config_context(config) as context: if i == 0: - input = context.run(step.invoke, input, config, **kwargs) + input_ = context.run(step.invoke, input_, config, **kwargs) else: - input = context.run(step.invoke, input, config) + input_ = context.run(step.invoke, input_, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) raise else: - run_manager.on_chain_end(input) - return cast("Output", input) + run_manager.on_chain_end(input_) + return cast("Output", input_) @override async def ainvoke( @@ -3059,6 +3068,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) + input_ = input # invoke all steps in sequence try: @@ -3069,17 +3079,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) with set_config_context(config) as context: if i == 0: - part = functools.partial(step.ainvoke, input, config, **kwargs) + part = functools.partial(step.ainvoke, input_, config, **kwargs) else: - part = functools.partial(step.ainvoke, input, config) - input = await coro_with_context(part(), context, create_task=True) + part = functools.partial(step.ainvoke, input_, config) + input_ = await coro_with_context(part(), context, create_task=True) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) raise else: - await run_manager.on_chain_end(input) - return cast("Output", input) + await run_manager.on_chain_end(input_) + return cast("Output", input_) @override def batch( @@ -3117,11 +3127,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]): run_managers = [ cm.on_chain_start( None, - input, + input_, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip(callback_managers, inputs, configs) ] # invoke @@ -3248,11 +3258,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]): *( cm.on_chain_start( None, - input, + input_, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip(callback_managers, inputs, configs) ) ) @@ -3346,7 +3356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def _transform( self, - input: Iterator[Input], + inputs: Iterator[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -3359,7 +3369,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output - final_pipeline = cast("Iterator[Output]", input) + final_pipeline = cast("Iterator[Output]", inputs) for idx, step in enumerate(steps): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}") @@ -3373,7 +3383,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def _atransform( self, - input: AsyncIterator[Input], + inputs: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -3387,7 +3397,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output - final_pipeline = cast("AsyncIterator[Output]", input) + final_pipeline = cast("AsyncIterator[Output]", inputs) for idx, step in enumerate(steps): config = patch_config( config, @@ -3733,7 +3743,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ) def _invoke_step( - step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str + step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str ) -> Any: child_config = patch_config( config, @@ -3743,7 +3753,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): with set_config_context(child_config) as context: return context.run( step.invoke, - input, + input_, child_config, ) @@ -3785,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ) async def _ainvoke_step( - step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str + step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str ) -> Any: child_config = patch_config( config, @@ -3793,7 +3803,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ) with set_config_context(child_config) as context: return await coro_with_context( - step.ainvoke(input, child_config), context, create_task=True + step.ainvoke(input_, child_config), context, create_task=True ) # gather results from all steps @@ -3823,7 +3833,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): def _transform( self, - input: Iterator[Input], + inputs: Iterator[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, ) -> Iterator[AddableDict]: @@ -3831,7 +3841,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): steps = dict(self.steps__) # Each step gets a copy of the input iterator, # which is consumed in parallel in a separate thread. - input_copies = list(safetee(input, len(steps), lock=threading.Lock())) + input_copies = list(safetee(inputs, len(steps), lock=threading.Lock())) with get_executor_for_config(config) as executor: # Create the transform() generator for each step named_generators = [ @@ -3890,7 +3900,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): async def _atransform( self, - input: AsyncIterator[Input], + inputs: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, ) -> AsyncIterator[AddableDict]: @@ -3898,7 +3908,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): steps = dict(self.steps__) # Each step gets a copy of the input iterator, # which is consumed in parallel in a separate thread. - input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) + input_copies = list(atee(inputs, len(steps), lock=asyncio.Lock())) # Create the transform() generator for each step named_generators = [ ( @@ -4590,7 +4600,7 @@ class RunnableLambda(Runnable[Input, Output]): def _invoke( self, - input: Input, + input_: Input, run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -4599,7 +4609,7 @@ class RunnableLambda(Runnable[Input, Output]): output: Optional[Output] = None for chunk in call_func_with_variable_args( cast("Callable[[Input], Iterator[Output]]", self.func), - input, + input_, config, run_manager, **kwargs, @@ -4613,18 +4623,18 @@ class RunnableLambda(Runnable[Input, Output]): output = chunk else: output = call_func_with_variable_args( - self.func, input, config, run_manager, **kwargs + self.func, input_, config, run_manager, **kwargs ) # If the output is a Runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] if recursion_limit <= 0: msg = ( - f"Recursion limit reached when invoking {self} with input {input}." + f"Recursion limit reached when invoking {self} with input {input_}." ) raise RecursionError(msg) output = output.invoke( - input, + input_, patch_config( config, callbacks=run_manager.get_child(), @@ -4635,7 +4645,7 @@ class RunnableLambda(Runnable[Input, Output]): async def _ainvoke( self, - input: Input, + value: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -4646,7 +4656,7 @@ class RunnableLambda(Runnable[Input, Output]): if inspect.isgeneratorfunction(self.func): def func( - input: Input, + value: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -4654,7 +4664,7 @@ class RunnableLambda(Runnable[Input, Output]): output: Optional[Output] = None for chunk in call_func_with_variable_args( cast("Callable[[Input], Iterator[Output]]", self.func), - input, + value, config, run_manager.get_sync(), **kwargs, @@ -4671,13 +4681,13 @@ class RunnableLambda(Runnable[Input, Output]): else: def func( - input: Input, + value: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> Output: return call_func_with_variable_args( - self.func, input, config, run_manager.get_sync(), **kwargs + self.func, value, config, run_manager.get_sync(), **kwargs ) @wraps(func) @@ -4693,7 +4703,7 @@ class RunnableLambda(Runnable[Input, Output]): "AsyncGenerator[Any, Any]", acall_func_with_variable_args( cast("Callable", afunc), - input, + value, config, run_manager, **kwargs, @@ -4713,18 +4723,18 @@ class RunnableLambda(Runnable[Input, Output]): output = chunk else: output = await acall_func_with_variable_args( - cast("Callable", afunc), input, config, run_manager, **kwargs + cast("Callable", afunc), value, config, run_manager, **kwargs ) # If the output is a Runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] if recursion_limit <= 0: msg = ( - f"Recursion limit reached when invoking {self} with input {input}." + f"Recursion limit reached when invoking {self} with input {value}." ) raise RecursionError(msg) output = await output.ainvoke( - input, + value, patch_config( config, callbacks=run_manager.get_child(), @@ -4789,14 +4799,14 @@ class RunnableLambda(Runnable[Input, Output]): def _transform( self, - input: Iterator[Input], + chunks: Iterator[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: final: Input got_first_val = False - for ichunk in input: + for ichunk in chunks: # By definitions, RunnableLambdas consume all input before emitting output. # If the input is not addable, then we'll assume that we can # only operate on the last chunk. @@ -4881,14 +4891,14 @@ class RunnableLambda(Runnable[Input, Output]): async def _atransform( self, - input: AsyncIterator[Input], + chunks: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: final: Input got_first_val = False - async for ichunk in input: + async for ichunk in chunks: # By definitions, RunnableLambdas consume all input before emitting output. # If the input is not addable, then we'll assume that we can # only operate on the last chunk. @@ -4913,13 +4923,13 @@ class RunnableLambda(Runnable[Input, Output]): raise TypeError(msg) def func( - input: Input, + input_: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> Output: return call_func_with_variable_args( - self.func, input, config, run_manager.get_sync(), **kwargs + self.func, input_, config, run_manager.get_sync(), **kwargs ) @wraps(func) diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 686426e9e79..b34f312e0a9 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -196,6 +196,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): raise ValueError(msg) return specs + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -254,6 +255,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): run_manager.on_chain_end(output) return output + @override async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -302,6 +304,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): await run_manager.on_chain_end(output) return output + @override def stream( self, input: Input, @@ -388,6 +391,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): raise run_manager.on_chain_end(final_output) + @override async def astream( self, input: Input, diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 6eae4f23b06..f1bf2f65006 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -401,7 +401,7 @@ def call_func_with_variable_args( Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], ], - input: Input, + input: Input, # noqa: A002 config: RunnableConfig, run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs: Any, @@ -438,7 +438,7 @@ def acall_func_with_variable_args( Awaitable[Output], ], ], - input: Input, + input: Input, # noqa: A002 config: RunnableConfig, run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs: Any, diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 61386a391ca..4f2963353f8 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -178,16 +178,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def invoke( prepared: tuple[Runnable[Input, Output], RunnableConfig], - input: Input, + input_: Input, ) -> Union[Output, Exception]: bound, config = prepared if return_exceptions: try: - return bound.invoke(input, config, **kwargs) + return bound.invoke(input_, config, **kwargs) except Exception as e: return e else: - return bound.invoke(input, config, **kwargs) + return bound.invoke(input_, config, **kwargs) # If there's only one input, don't bother with the executor if len(inputs) == 1: @@ -221,16 +221,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def ainvoke( prepared: tuple[Runnable[Input, Output], RunnableConfig], - input: Input, + input_: Input, ) -> Union[Output, Exception]: bound, config = prepared if return_exceptions: try: - return await bound.ainvoke(input, config, **kwargs) + return await bound.ainvoke(input_, config, **kwargs) except Exception as e: return e else: - return await bound.ainvoke(input, config, **kwargs) + return await bound.ainvoke(input_, config, **kwargs) coros = map(ainvoke, prepared, inputs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 2150b00beba..5ea6bcc40ce 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -269,7 +269,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): from langchain_core.callbacks.manager import CallbackManager if self.exception_key is not None and not all( - isinstance(input, dict) for input in inputs + isinstance(input_, dict) for input_ in inputs ): msg = ( "If 'exception_key' is specified then inputs must be dictionaries." @@ -298,11 +298,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): run_managers = [ cm.on_chain_start( None, - input if isinstance(input, dict) else {"input": input}, + input_ if isinstance(input_, dict) else {"input": input_}, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip(callback_managers, inputs, configs) ] to_return: dict[int, Any] = {} @@ -311,7 +311,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): first_to_raise = None for runnable in self.runnables: outputs = runnable.batch( - [input for _, input in sorted(run_again.items())], + [input_ for _, input_ in sorted(run_again.items())], [ # each step a child run of the corresponding root run patch_config(configs[i], callbacks=run_managers[i].get_child()) @@ -320,7 +320,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return_exceptions=True, **kwargs, ) - for (i, input), output in zip(sorted(run_again.copy().items()), outputs): + for (i, input_), output in zip(sorted(run_again.copy().items()), outputs): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): @@ -331,7 +331,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): run_again.pop(i) elif isinstance(output, self.exceptions_to_handle): if self.exception_key: - input[self.exception_key] = output # type: ignore[index] + input_[self.exception_key] = output # type: ignore[index] handled_exceptions[i] = output else: run_managers[i].on_chain_end(output) @@ -363,7 +363,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): from langchain_core.callbacks.manager import AsyncCallbackManager if self.exception_key is not None and not all( - isinstance(input, dict) for input in inputs + isinstance(input_, dict) for input_ in inputs ): msg = ( "If 'exception_key' is specified then inputs must be dictionaries." @@ -393,11 +393,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): *( cm.on_chain_start( None, - input, + input_, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip(callback_managers, inputs, configs) ) ) @@ -407,7 +407,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): first_to_raise = None for runnable in self.runnables: outputs = await runnable.abatch( - [input for _, input in sorted(run_again.items())], + [input_ for _, input_ in sorted(run_again.items())], [ # each step a child run of the corresponding root run patch_config(configs[i], callbacks=run_managers[i].get_child()) @@ -417,7 +417,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): **kwargs, ) - for (i, input), output in zip(sorted(run_again.copy().items()), outputs): + for (i, input_), output in zip(sorted(run_again.copy().items()), outputs): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): @@ -428,7 +428,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): run_again.pop(i) elif isinstance(output, self.exceptions_to_handle): if self.exception_key: - input[self.exception_key] = output # type: ignore[index] + input_[self.exception_key] = output # type: ignore[index] handled_exceptions[i] = output else: to_return[i] = output diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 323dcdd2d48..daee67b516a 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -111,7 +111,12 @@ class Node(NamedTuple): data: Union[type[BaseModel], RunnableType, None] metadata: Optional[dict[str, Any]] - def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node: + def copy( + self, + *, + id: Optional[str] = None, # noqa: A002 + name: Optional[str] = None, + ) -> Node: """Return a copy of the node with optional new id and name. Args: @@ -181,7 +186,10 @@ class MermaidDrawMethod(Enum): API = "api" # Uses Mermaid.INK API to render the graph -def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str: +def node_data_str( + id: str, # noqa: A002 + data: Union[type[BaseModel], RunnableType, None], +) -> str: """Convert the data of a node to a string. Args: @@ -320,7 +328,7 @@ class Graph: def add_node( self, data: Union[type[BaseModel], RunnableType, None], - id: Optional[str] = None, + id: Optional[str] = None, # noqa: A002 *, metadata: Optional[dict[str, Any]] = None, ) -> Node: @@ -340,8 +348,8 @@ class Graph: if id is not None and id in self.nodes: msg = f"Node with id {id} already exists" raise ValueError(msg) - id = id or self.next_id() - node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data)) + id_ = id or self.next_id() + node = Node(id=id_, data=data, metadata=metadata, name=node_data_str(id_, data)) self.nodes[node.id] = node return node @@ -406,8 +414,8 @@ class Graph: if all(is_uuid(node.id) for node in graph.nodes.values()): prefix = "" - def prefixed(id: str) -> str: - return f"{prefix}:{id}" if prefix else id + def prefixed(id_: str) -> str: + return f"{prefix}:{id_}" if prefix else id_ # prefix each node self.nodes.update( @@ -450,8 +458,8 @@ class Graph: return Graph( nodes={ - _get_node_id(id): node.copy(id=_get_node_id(id)) - for id, node in self.nodes.items() + _get_node_id(id_): node.copy(id=_get_node_id(id_)) + for id_, node in self.nodes.items() }, edges=[ edge.copy( diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index c224e4c3846..162fcd40bd6 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -192,7 +192,7 @@ def _build_sugiyama_layout( # Y # - vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()} + vertices_ = {id_: Vertex(f" {data} ") for id_, data in vertices.items()} edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges] vertices_list = vertices_.values() graph = Graph(vertices_list, edges_) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index d261e574464..3a467811b4c 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -509,20 +509,20 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) raise ValueError(msg) - def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]: + def _enter_history(self, value: Any, config: RunnableConfig) -> list[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = hist.messages.copy() if not self.history_messages_key: # return all messages input_val = ( - input if not self.input_messages_key else input[self.input_messages_key] + value if not self.input_messages_key else value[self.input_messages_key] ) messages += self._get_input_messages(input_val) return messages async def _aenter_history( - self, input: dict[str, Any], config: RunnableConfig + self, value: dict[str, Any], config: RunnableConfig ) -> list[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = (await hist.aget_messages()).copy() @@ -530,7 +530,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): if not self.history_messages_key: # return all messages input_val = ( - input if not self.input_messages_key else input[self.input_messages_key] + value if not self.input_messages_key else value[self.input_messages_key] ) messages += self._get_input_messages(input_val) return messages diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 098ce791af9..530a80e2c64 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -483,19 +483,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): def _invoke( self, - input: dict[str, Any], + value: dict[str, Any], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> dict[str, Any]: - if not isinstance(input, dict): + if not isinstance(value, dict): msg = "The input to RunnablePassthrough.assign() must be a dict." raise ValueError(msg) # noqa: TRY004 return { - **input, + **value, **self.mapper.invoke( - input, + value, patch_config(config, callbacks=run_manager.get_child()), **kwargs, ), @@ -512,19 +512,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def _ainvoke( self, - input: dict[str, Any], + value: dict[str, Any], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> dict[str, Any]: - if not isinstance(input, dict): + if not isinstance(value, dict): msg = "The input to RunnablePassthrough.assign() must be a dict." raise ValueError(msg) # noqa: TRY004 return { - **input, + **value, **await self.mapper.ainvoke( - input, + value, patch_config(config, callbacks=run_manager.get_child()), **kwargs, ), @@ -541,7 +541,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): def _transform( self, - input: Iterator[dict[str, Any]], + values: Iterator[dict[str, Any]], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -549,7 +549,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): # collect mapper keys mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough - for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) + for_passthrough, for_map = safetee(values, 2, lock=threading.Lock()) # create map output stream map_output = self.mapper.transform( @@ -598,7 +598,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def _atransform( self, - input: AsyncIterator[dict[str, Any]], + values: AsyncIterator[dict[str, Any]], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, @@ -606,7 +606,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): # collect mapper keys mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough - for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) + for_passthrough, for_map = atee(values, 2, lock=asyncio.Lock()) # create map output stream map_output = self.mapper.atransform( for_map, @@ -731,23 +731,23 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): ) return super().get_name(suffix, name=name) - def _pick(self, input: dict[str, Any]) -> Any: - if not isinstance(input, dict): + def _pick(self, value: dict[str, Any]) -> Any: + if not isinstance(value, dict): msg = "The input to RunnablePassthrough.assign() must be a dict." raise ValueError(msg) # noqa: TRY004 if isinstance(self.keys, str): - return input.get(self.keys) - picked = {k: input.get(k) for k in self.keys if k in input} + return value.get(self.keys) + picked = {k: value.get(k) for k in self.keys if k in value} if picked: return AddableDict(picked) return None def _invoke( self, - input: dict[str, Any], + value: dict[str, Any], ) -> dict[str, Any]: - return self._pick(input) + return self._pick(value) @override def invoke( @@ -760,9 +760,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def _ainvoke( self, - input: dict[str, Any], + value: dict[str, Any], ) -> dict[str, Any]: - return self._pick(input) + return self._pick(value) @override async def ainvoke( @@ -775,9 +775,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): def _transform( self, - input: Iterator[dict[str, Any]], + chunks: Iterator[dict[str, Any]], ) -> Iterator[dict[str, Any]]: - for chunk in input: + for chunk in chunks: picked = self._pick(chunk) if picked is not None: yield picked @@ -795,9 +795,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def _atransform( self, - input: AsyncIterator[dict[str, Any]], + chunks: AsyncIterator[dict[str, Any]], ) -> AsyncIterator[dict[str, Any]]: - async for chunk in input: + async for chunk in chunks: picked = self._pick(chunk) if picked is not None: yield picked diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 3ceb396ff28..9a72f749d72 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -178,7 +178,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): def _invoke( self, - input: Input, + input_: Input, run_manager: "CallbackManagerForChainRun", config: RunnableConfig, **kwargs: Any, @@ -186,7 +186,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): for attempt in self._sync_retrying(reraise=True): with attempt: result = super().invoke( - input, + input_, self._patch_config(config, run_manager, attempt.retry_state), **kwargs, ) @@ -202,7 +202,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): async def _ainvoke( self, - input: Input, + input_: Input, run_manager: "AsyncCallbackManagerForChainRun", config: RunnableConfig, **kwargs: Any, @@ -210,7 +210,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): async for attempt in self._async_retrying(reraise=True): with attempt: result = await super().ainvoke( - input, + input_, self._patch_config(config, run_manager, attempt.retry_state), **kwargs, ) diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 192c6239cef..931bac392b6 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -148,22 +148,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): if not inputs: return [] - keys = [input["key"] for input in inputs] - actual_inputs = [input["input"] for input in inputs] + keys = [input_["key"] for input_ in inputs] + actual_inputs = [input_["input"] for input_ in inputs] if any(key not in self.runnables for key in keys): msg = "One or more keys do not have a corresponding runnable" raise ValueError(msg) def invoke( - runnable: Runnable, input: Input, config: RunnableConfig + runnable: Runnable, input_: Input, config: RunnableConfig ) -> Union[Output, Exception]: if return_exceptions: try: - return runnable.invoke(input, config, **kwargs) + return runnable.invoke(input_, config, **kwargs) except Exception as e: return e else: - return runnable.invoke(input, config, **kwargs) + return runnable.invoke(input_, config, **kwargs) runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) @@ -185,22 +185,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): if not inputs: return [] - keys = [input["key"] for input in inputs] - actual_inputs = [input["input"] for input in inputs] + keys = [input_["key"] for input_ in inputs] + actual_inputs = [input_["input"] for input_ in inputs] if any(key not in self.runnables for key in keys): msg = "One or more keys do not have a corresponding runnable" raise ValueError(msg) async def ainvoke( - runnable: Runnable, input: Input, config: RunnableConfig + runnable: Runnable, input_: Input, config: RunnableConfig ) -> Union[Output, Exception]: if return_exceptions: try: - return await runnable.ainvoke(input, config, **kwargs) + return await runnable.ainvoke(input_, config, **kwargs) except Exception as e: return e else: - return await runnable.ainvoke(input, config, **kwargs) + return await runnable.ainvoke(input_, config, **kwargs) runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 98fe236dcbf..7e5b68d01d0 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -75,7 +75,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros)) -def accepts_run_manager(callable: Callable[..., Any]) -> bool: +def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002 """Check if a callable accepts a run_manager argument. Args: @@ -90,7 +90,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool: return False -def accepts_config(callable: Callable[..., Any]) -> bool: +def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002 """Check if a callable accepts a config argument. Args: @@ -105,7 +105,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool: return False -def accepts_context(callable: Callable[..., Any]) -> bool: +def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002 """Check if a callable accepts a context argument. Args: @@ -691,7 +691,7 @@ def get_unique_config_specs( sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id ) unique: list[ConfigurableFieldSpec] = [] - for id, dupes in grouped: + for spec_id, dupes in grouped: first = next(dupes) others = list(dupes) if len(others) == 0 or all(o == first for o in others): @@ -699,7 +699,7 @@ def get_unique_config_specs( else: msg = ( "RunnableSequence contains conflicting config specs" - f"for {id}: {[first, *others]}" + f"for {spec_id}: {[first, *others]}" ) raise ValueError(msg) return unique diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index 746174e5d31..5a1a1eb9667 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -184,7 +184,7 @@ class StructuredQuery(Expr): def __init__( self, query: str, - filter: Optional[FilterDirective], + filter: Optional[FilterDirective], # noqa: A002 limit: Optional[int] = None, **kwargs: Any, ) -> None: diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 554181d1b72..f30ef85fdb5 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -943,17 +943,17 @@ def _handle_tool_error( def _prep_run_args( - input: Union[str, dict, ToolCall], + value: Union[str, dict, ToolCall], config: Optional[RunnableConfig], **kwargs: Any, ) -> tuple[Union[str, dict], dict]: config = ensure_config(config) - if _is_tool_call(input): - tool_call_id: Optional[str] = cast("ToolCall", input)["id"] - tool_input: Union[str, dict] = cast("ToolCall", input)["args"].copy() + if _is_tool_call(value): + tool_call_id: Optional[str] = cast("ToolCall", value)["id"] + tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy() else: tool_call_id = None - tool_input = cast("Union[str, dict]", input) + tool_input = cast("Union[str, dict]", value) return ( tool_input, dict( diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 1b908df5027..dc81c9d01cd 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -740,7 +740,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def _astream_events_implementation_v1( runnable: Runnable[Input, Output], - input: Any, + value: Any, config: Optional[RunnableConfig] = None, *, include_names: Optional[Sequence[str]] = None, @@ -789,7 +789,7 @@ async def _astream_events_implementation_v1( async for log in _astream_log_implementation( runnable, - input, + value, config=config, stream=stream, diff=True, @@ -810,7 +810,7 @@ async def _astream_events_implementation_v1( tags=root_tags, metadata=root_metadata, data={ - "input": input, + "input": value, }, parent_ids=[], # Not supported in v1 ) @@ -924,7 +924,7 @@ async def _astream_events_implementation_v1( async def _astream_events_implementation_v2( runnable: Runnable[Input, Output], - input: Any, + value: Any, config: Optional[RunnableConfig] = None, *, include_names: Optional[Sequence[str]] = None, @@ -972,7 +972,7 @@ async def _astream_events_implementation_v2( async def consume_astream() -> None: try: # if astream also calls tap_output_aiter this will be a no-op - async with aclosing(runnable.astream(input, config, **kwargs)) as stream: + async with aclosing(runnable.astream(value, config, **kwargs)) as stream: async for _ in event_streamer.tap_output_aiter(run_id, stream): # All the content will be picked up pass @@ -993,7 +993,7 @@ async def _astream_events_implementation_v2( # chain are not available until the entire input is consumed. # As a temporary solution, we'll modify the input to be the input # that was passed into the chain. - event["data"]["input"] = input + event["data"]["input"] = value first_event_run_id = event["run_id"] yield event continue diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 6f6aad7e9d7..259f266330c 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -580,7 +580,7 @@ def _get_standardized_outputs( @overload def _astream_log_implementation( runnable: Runnable[Input, Output], - input: Any, + value: Any, config: Optional[RunnableConfig] = None, *, stream: LogStreamCallbackHandler, @@ -593,7 +593,7 @@ def _astream_log_implementation( @overload def _astream_log_implementation( runnable: Runnable[Input, Output], - input: Any, + value: Any, config: Optional[RunnableConfig] = None, *, stream: LogStreamCallbackHandler, @@ -605,7 +605,7 @@ def _astream_log_implementation( async def _astream_log_implementation( runnable: Runnable[Input, Output], - input: Any, + value: Any, config: Optional[RunnableConfig] = None, *, stream: LogStreamCallbackHandler, @@ -651,7 +651,7 @@ async def _astream_log_implementation( prev_final_output: Optional[Output] = None final_output: Optional[Output] = None - async for chunk in runnable.astream(input, config, **kwargs): + async for chunk in runnable.astream(value, config, **kwargs): prev_final_output = final_output if final_output is None: final_output = chunk diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index bfc70df1715..c664ec03e3d 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -596,7 +596,7 @@ def convert_to_json_schema( @beta() def tool_example_to_messages( - input: str, + input: str, # noqa: A002 tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None, *, diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index edf62effb43..383ac1d0a79 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -363,7 +363,7 @@ class InMemoryVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[Callable[[Document], bool]] = None, + filter: Optional[Callable[[Document], bool]] = None, # noqa: A002 ) -> list[tuple[Document, float, list[float]]]: # get all docs with fixed order in list docs = list(self.store.values()) @@ -402,7 +402,7 @@ class InMemoryVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[Callable[[Document], bool]] = None, + filter: Optional[Callable[[Document], bool]] = None, # noqa: A002 **_kwargs: Any, ) -> list[tuple[Document, float]]: """Search for the most similar documents to the given embedding. diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index d31e641d5fe..8f243bb1178 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -100,7 +100,6 @@ ignore = [ "UP007", # Doesn't play well with Pydantic in Python 3.9 # TODO rules - "A", "ANN401", "BLE", "ERA", diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 218f94dc2e8..f7201ca48a9 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -869,9 +869,9 @@ def create_image_data() -> str: return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501 -def create_base64_image(format: str = "jpeg") -> str: +def create_base64_image(image_format: str = "jpeg") -> str: data = create_image_data() - return f"data:image/{format};base64,{data}" + return f"data:image/{image_format};base64,{data}" def test_convert_to_openai_messages_string() -> None: diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 992e5c48c55..dbb905e7f6b 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -639,7 +639,7 @@ def test_parse_with_different_pydantic_1_proper() -> None: def test_max_tokens_error(caplog: Any) -> None: parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True) - input = AIMessage( + message = AIMessage( content="", tool_calls=[ { @@ -651,7 +651,7 @@ def test_max_tokens_error(caplog: Any) -> None: response_metadata={"stop_reason": "max_tokens"}, ) with pytest.raises(ValidationError): - _ = parser.invoke(input) + _ = parser.invoke(message) assert any( "`max_tokens` stop reason" in msg and record.levelname == "ERROR" for record, msg in zip(caplog.records, caplog.messages) diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 2a98a1e95ce..76d75314775 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -15,11 +15,11 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute() @contextmanager -def change_directory(dir: Path) -> Iterator: +def change_directory(dir_path: Path) -> Iterator: """Change the working directory to the right folder.""" origin = Path().absolute() try: - os.chdir(dir) + os.chdir(dir_path) yield finally: os.chdir(origin) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 387644b8b07..45e1cb1b7f9 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -220,8 +220,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: str_parser = StrOutputParser() xml_parser = XMLOutputParser() - def conditional_str_parser(input: str) -> Runnable: - if input == "a": + def conditional_str_parser(value: str) -> Runnable: + if value == "a": return str_parser return xml_parser diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 885a52b428f..1807d8b3d5c 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -111,9 +111,9 @@ async def test_input_messages_async() -> None: def test_input_dict() -> None: runnable = RunnableLambda( - lambda input: "you said: " + lambda params: "you said: " + "\n".join( - str(m.content) for m in input["messages"] if isinstance(m, HumanMessage) + str(m.content) for m in params["messages"] if isinstance(m, HumanMessage) ) ) get_session_history = _get_get_session_history() @@ -131,9 +131,9 @@ def test_input_dict() -> None: async def test_input_dict_async() -> None: runnable = RunnableLambda( - lambda input: "you said: " + lambda params: "you said: " + "\n".join( - str(m.content) for m in input["messages"] if isinstance(m, HumanMessage) + str(m.content) for m in params["messages"] if isinstance(m, HumanMessage) ) ) get_session_history = _get_get_session_history() @@ -153,10 +153,10 @@ async def test_input_dict_async() -> None: def test_input_dict_with_history_key() -> None: runnable = RunnableLambda( - lambda input: "you said: " + lambda params: "you said: " + "\n".join( - [str(m.content) for m in input["history"] if isinstance(m, HumanMessage)] - + [input["input"]] + [str(m.content) for m in params["history"] if isinstance(m, HumanMessage)] + + [params["input"]] ) ) get_session_history = _get_get_session_history() @@ -175,10 +175,10 @@ def test_input_dict_with_history_key() -> None: async def test_input_dict_with_history_key_async() -> None: runnable = RunnableLambda( - lambda input: "you said: " + lambda params: "you said: " + "\n".join( - [str(m.content) for m in input["history"] if isinstance(m, HumanMessage)] - + [input["input"]] + [str(m.content) for m in params["history"] if isinstance(m, HumanMessage)] + + [params["input"]] ) ) get_session_history = _get_get_session_history() @@ -197,15 +197,15 @@ async def test_input_dict_with_history_key_async() -> None: def test_output_message() -> None: runnable = RunnableLambda( - lambda input: AIMessage( + lambda params: AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ) @@ -225,15 +225,15 @@ def test_output_message() -> None: async def test_output_message_async() -> None: runnable = RunnableLambda( - lambda input: AIMessage( + lambda params: AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ) @@ -302,16 +302,16 @@ async def test_input_messages_output_message_async() -> None: def test_output_messages() -> None: runnable = RunnableLambda( - lambda input: [ + lambda params: [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -332,16 +332,16 @@ def test_output_messages() -> None: async def test_output_messages_async() -> None: runnable = RunnableLambda( - lambda input: [ + lambda params: [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -362,17 +362,17 @@ async def test_output_messages_async() -> None: def test_output_dict() -> None: runnable = RunnableLambda( - lambda input: { + lambda params: { "output": [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -395,17 +395,17 @@ def test_output_dict() -> None: async def test_output_dict_async() -> None: runnable = RunnableLambda( - lambda input: { + lambda params: { "output": [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -431,17 +431,17 @@ def test_get_input_schema_input_dict() -> None: input: Union[str, BaseMessage, Sequence[BaseMessage]] runnable = RunnableLambda( - lambda input: { + lambda params: { "output": [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -463,17 +463,17 @@ def test_get_input_schema_input_dict() -> None: def test_get_output_schema() -> None: """Test get output schema.""" runnable = RunnableLambda( - lambda input: { + lambda params: { "output": [ AIMessage( content="you said: " + "\n".join( [ str(m.content) - for m in input["history"] + for m in params["history"] if isinstance(m, HumanMessage) ] - + [input["input"]] + + [params["input"]] ) ) ] @@ -531,8 +531,8 @@ def test_get_input_schema_input_messages() -> None: def test_using_custom_config_specs() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: - messages = input["messages"] + def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]: + messages = params["messages"] return [ AIMessage( content="you said: " @@ -644,8 +644,8 @@ def test_using_custom_config_specs() -> None: async def test_using_custom_config_specs_async() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: - messages = input["messages"] + def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]: + messages = params["messages"] return [ AIMessage( content="you said: " @@ -757,12 +757,12 @@ async def test_using_custom_config_specs_async() -> None: def test_ignore_session_id() -> None: """Test without config.""" - def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]: + def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]: return [ AIMessage( content="you said: " + "\n".join( - str(m.content) for m in input if isinstance(m, HumanMessage) + str(m.content) for m in messages if isinstance(m, HumanMessage) ) ) ] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b8d4e46c1de..a4b616964ed 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -564,8 +564,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: "required": ["bye", "hello"], } - def get_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 - return input["variable_name"] + def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 + return value["variable_name"] assert RunnableLambda(get_value).get_input_jsonschema() == { "title": "get_value_input", @@ -574,8 +574,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: "required": ["variable_name"], } - async def aget_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 - return (input["variable_name"], input.get("another")) + async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 + return (value["variable_name"], value.get("another")) assert RunnableLambda(aget_value).get_input_jsonschema() == { "title": "aget_value_input", @@ -587,11 +587,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: "required": ["another", "variable_name"], } - async def aget_values(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 + async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 return { - "hello": input["variable_name"], - "bye": input["variable_name"], - "byebye": input["yo"], + "hello": value["variable_name"], + "bye": value["variable_name"], + "byebye": value["yo"], } assert RunnableLambda(aget_values).get_input_jsonschema() == { @@ -613,11 +613,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None: bye: str byebye: int - async def aget_values_typed(input: InputType) -> OutputType: + async def aget_values_typed(value: InputType) -> OutputType: return { - "hello": input["variable_name"], - "bye": input["variable_name"], - "byebye": input["yo"], + "hello": value["variable_name"], + "bye": value["variable_name"], + "byebye": value["yo"], } assert _normalize_schema( @@ -2592,8 +2592,8 @@ async def test_prompt_with_llm_and_async_lambda( ) llm = FakeListLLM(responses=["foo", "bar"]) - async def passthrough(input: Any) -> Any: - return input + async def passthrough(value: Any) -> Any: + return value chain = prompt | llm | passthrough @@ -2946,12 +2946,12 @@ def test_higher_order_lambda_runnable( input={"question": lambda x: x["question"]}, ) - def router(input: dict[str, Any]) -> Runnable: - if input["key"] == "math": + def router(params: dict[str, Any]) -> Runnable: + if params["key"] == "math": return itemgetter("input") | math_chain - if input["key"] == "english": + if params["key"] == "english": return itemgetter("input") | english_chain - msg = f"Unknown key: {input['key']}" + msg = f"Unknown key: {params['key']}" raise ValueError(msg) chain: Runnable = input_map | router @@ -3002,12 +3002,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None input={"question": lambda x: x["question"]}, ) - def router(input: dict[str, Any]) -> Runnable: - if input["key"] == "math": + def router(value: dict[str, Any]) -> Runnable: + if value["key"] == "math": return itemgetter("input") | math_chain - if input["key"] == "english": + if value["key"] == "english": return itemgetter("input") | english_chain - msg = f"Unknown key: {input['key']}" + msg = f"Unknown key: {value['key']}" raise ValueError(msg) chain: Runnable = input_map | router @@ -3024,12 +3024,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None assert result2 == ["4", "2"] # Test ainvoke - async def arouter(input: dict[str, Any]) -> Runnable: - if input["key"] == "math": + async def arouter(params: dict[str, Any]) -> Runnable: + if params["key"] == "math": return itemgetter("input") | math_chain - if input["key"] == "english": + if params["key"] == "english": return itemgetter("input") | english_chain - msg = f"Unknown key: {input['key']}" + msg = f"Unknown key: {params['key']}" raise ValueError(msg) achain: Runnable = input_map | arouter @@ -4125,6 +4125,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: def __init__(self, fail_starts_with: str) -> None: self.fail_starts_with = fail_starts_with + @override def invoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -4135,15 +4136,15 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: inputs: list[str], ) -> list: outputs: list[Any] = [] - for input in inputs: - if input.startswith(self.fail_starts_with): + for value in inputs: + if value.startswith(self.fail_starts_with): outputs.append( ValueError( - f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}" + f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}" ) ) else: - outputs.append(input + "a") + outputs.append(value + "a") return outputs def batch( @@ -4264,6 +4265,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: def __init__(self, fail_starts_with: str) -> None: self.fail_starts_with = fail_starts_with + @override def invoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -4274,15 +4276,15 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: inputs: list[str], ) -> list: outputs: list[Any] = [] - for input in inputs: - if input.startswith(self.fail_starts_with): + for value in inputs: + if value.startswith(self.fail_starts_with): outputs.append( ValueError( - f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}" + f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}" ) ) else: - outputs.append(input + "a") + outputs.append(value + "a") return outputs async def abatch( @@ -5006,10 +5008,10 @@ def test_runnable_iter_context_config() -> None: fake = RunnableLambda(len) @chain - def gen(input: str) -> Iterator[int]: - yield fake.invoke(input) - yield fake.invoke(input * 2) - yield fake.invoke(input * 3) + def gen(value: str) -> Iterator[int]: + yield fake.invoke(value) + yield fake.invoke(value * 2) + yield fake.invoke(value * 3) assert gen.get_input_jsonschema() == { "title": "gen_input", @@ -5064,10 +5066,10 @@ async def test_runnable_iter_context_config_async() -> None: fake = RunnableLambda(len) @chain - async def agen(input: str) -> AsyncIterator[int]: - yield await fake.ainvoke(input) - yield await fake.ainvoke(input * 2) - yield await fake.ainvoke(input * 3) + async def agen(value: str) -> AsyncIterator[int]: + yield await fake.ainvoke(value) + yield await fake.ainvoke(value * 2) + yield await fake.ainvoke(value * 3) assert agen.get_input_jsonschema() == { "title": "agen_input", @@ -5130,10 +5132,10 @@ def test_runnable_lambda_context_config() -> None: fake = RunnableLambda(len) @chain - def fun(input: str) -> int: - output = fake.invoke(input) - output += fake.invoke(input * 2) - output += fake.invoke(input * 3) + def fun(value: str) -> int: + output = fake.invoke(value) + output += fake.invoke(value * 2) + output += fake.invoke(value * 3) return output assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"} @@ -5186,10 +5188,10 @@ async def test_runnable_lambda_context_config_async() -> None: fake = RunnableLambda(len) @chain - async def afun(input: str) -> int: - output = await fake.ainvoke(input) - output += await fake.ainvoke(input * 2) - output += await fake.ainvoke(input * 3) + async def afun(value: str) -> int: + output = await fake.ainvoke(value) + output += await fake.ainvoke(value * 2) + output += await fake.ainvoke(value * 3) return output assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"} @@ -5242,12 +5244,12 @@ async def test_runnable_gen_transform() -> None: for i in range(length): yield i - def plus_one(input: Iterator[int]) -> Iterator[int]: - for i in input: + def plus_one(ints: Iterator[int]) -> Iterator[int]: + for i in ints: yield i + 1 - async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: - async for i in input: + async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]: + async for i in ints: yield i + 1 chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index a046070e7eb..327f96224b3 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -543,10 +543,10 @@ async def test_astream_events_from_model() -> None: ) @RunnableLambda - def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + def i_dont_stream(value: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): - return model.invoke(input) - return model.invoke(input, config) + return model.invoke(value) + return model.invoke(value, config) events = await _collect_events(i_dont_stream.astream_events("hello", version="v1")) _assert_events_equal_allow_superset_metadata( @@ -667,10 +667,10 @@ async def test_astream_events_from_model() -> None: ) @RunnableLambda - async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): - return await model.ainvoke(input) - return await model.ainvoke(input, config) + return await model.ainvoke(value) + return await model.ainvoke(value, config) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1")) _assert_events_equal_allow_superset_metadata( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 000bad74f8d..be384668edb 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -613,10 +613,10 @@ async def test_astream_with_model_in_chain() -> None: ) @RunnableLambda - def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + def i_dont_stream(value: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): - return model.invoke(input) - return model.invoke(input, config) + return model.invoke(value) + return model.invoke(value, config) events = await _collect_events(i_dont_stream.astream_events("hello", version="v2")) _assert_events_equal_allow_superset_metadata( @@ -721,10 +721,10 @@ async def test_astream_with_model_in_chain() -> None: ) @RunnableLambda - async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any: if sys.version_info >= (3, 11): - return await model.ainvoke(input) - return await model.ainvoke(input, config) + return await model.ainvoke(value) + return await model.ainvoke(value, config) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2")) _assert_events_equal_allow_superset_metadata( @@ -2079,6 +2079,7 @@ class StreamingRunnable(Runnable[Input, Output]): msg = "Server side error" raise ValueError(msg) + @override def stream( self, input: Input, @@ -2413,14 +2414,14 @@ async def test_break_astream_events() -> None: def __init__(self) -> None: self.reset() - async def __call__(self, input: Any) -> Any: + async def __call__(self, value: Any) -> Any: self.started = True try: await asyncio.sleep(0.5) except asyncio.CancelledError: self.cancelled = True raise - return input + return value def reset(self) -> None: self.started = False @@ -2433,11 +2434,11 @@ async def test_break_astream_events() -> None: outer_cancelled = False @chain - async def sequence(input: Any) -> Any: + async def sequence(value: Any) -> Any: try: - yield await alittlewhile(input) - yield await awhile(input) - yield await anotherwhile(input) + yield await alittlewhile(value) + yield await awhile(value) + yield await anotherwhile(value) except asyncio.CancelledError: nonlocal outer_cancelled outer_cancelled = True @@ -2478,14 +2479,14 @@ async def test_cancel_astream_events() -> None: def __init__(self) -> None: self.reset() - async def __call__(self, input: Any) -> Any: + async def __call__(self, value: Any) -> Any: self.started = True try: await asyncio.sleep(0.5) except asyncio.CancelledError: self.cancelled = True raise - return input + return value def reset(self) -> None: self.started = False @@ -2498,11 +2499,11 @@ async def test_cancel_astream_events() -> None: outer_cancelled = False @chain - async def sequence(input: Any) -> Any: + async def sequence(value: Any) -> Any: try: - yield await alittlewhile(input) - yield await awhile(input) - yield await anotherwhile(input) + yield await alittlewhile(value) + yield await awhile(value) + yield await anotherwhile(value) except asyncio.CancelledError: nonlocal outer_cancelled outer_cancelled = True diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index d58a535c272..af2b603ec35 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -47,23 +47,23 @@ global_agent = RunnableLambda(lambda x: x * 3) def test_nonlocals() -> None: agent = RunnableLambda(lambda x: x * 2) - def my_func(input: str, agent: dict[str, str]) -> str: - return agent.get("agent_name", input) + def my_func(value: str, agent: dict[str, str]) -> str: + return agent.get("agent_name", value) - def my_func2(input: str) -> str: - return agent.get("agent_name", input) # type: ignore[attr-defined] + def my_func2(value: str) -> str: + return agent.get("agent_name", value) # type: ignore[attr-defined] - def my_func3(input: str) -> str: - return agent.invoke(input) + def my_func3(value: str) -> str: + return agent.invoke(value) - def my_func4(input: str) -> str: - return global_agent.invoke(input) + def my_func4(value: str) -> str: + return global_agent.invoke(value) def my_func5() -> tuple[Callable[[str], str], RunnableLambda]: global_agent = RunnableLambda(lambda x: x * 3) - def my_func6(input: str) -> str: - return global_agent.invoke(input) + def my_func6(value: str) -> str: + return global_agent.invoke(value) return my_func6, global_agent diff --git a/libs/core/tests/unit_tests/test_globals.py b/libs/core/tests/unit_tests/test_globals.py index 760b9d63c59..a0b2457bb48 100644 --- a/libs/core/tests/unit_tests/test_globals.py +++ b/libs/core/tests/unit_tests/test_globals.py @@ -2,17 +2,17 @@ from langchain_core.globals import get_debug, set_debug def test_debug_is_settable_via_setter() -> None: - from langchain_core import globals + from langchain_core import globals as globals_ from langchain_core.callbacks.manager import _get_debug - previous_value = globals._debug + previous_value = globals_._debug previous_fn_reading = _get_debug() assert previous_value == previous_fn_reading # Flip the value of the flag. set_debug(not previous_value) - new_value = globals._debug + new_value = globals_._debug new_fn_reading = _get_debug() try: diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index a0269331124..172e441055f 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -50,7 +50,7 @@ class CustomAddTextsVectorstore(VectorStore): return ids_ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: - return [self.store[id] for id in ids if id in self.store] + return [self.store[id_] for id_ in ids if id_ in self.store] @classmethod @override @@ -96,7 +96,7 @@ class CustomAddDocumentsVectorstore(VectorStore): return ids_ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: - return [self.store[id] for id in ids if id in self.store] + return [self.store[id_] for id_ in ids if id_ in self.store] @classmethod @override