This commit is contained in:
Christophe Bornet 2025-04-26 12:28:41 +02:00 committed by GitHub
commit 132e6ebab3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 442 additions and 381 deletions

View File

@ -50,7 +50,7 @@ class LangSmithLoader(BaseLoader):
offset: int = 0, offset: int = 0,
limit: Optional[int] = None, limit: Optional[int] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
filter: Optional[str] = None, filter: Optional[str] = None, # noqa: A002
content_key: str = "", content_key: str = "",
format_content: Optional[Callable[..., str]] = None, format_content: Optional[Callable[..., str]] = None,
client: Optional[LangSmithClient] = None, client: Optional[LangSmithClient] = None,

View File

@ -341,15 +341,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Get the output type for this runnable.""" """Get the output type for this runnable."""
return AnyMessage return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(model_input, PromptValue):
return input return model_input
if isinstance(input, str): if isinstance(model_input, str):
return StringPromptValue(text=input) return StringPromptValue(text=model_input)
if isinstance(input, Sequence): if isinstance(model_input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input)) return ChatPromptValue(messages=convert_to_messages(model_input))
msg = ( msg = (
f"Invalid input type {type(input)}. " f"Invalid input type {type(model_input)}. "
"Must be a PromptValue, str, or list of BaseMessages." "Must be a PromptValue, str, or list of BaseMessages."
) )
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004

View File

@ -325,15 +325,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"""Get the input type for this runnable.""" """Get the input type for this runnable."""
return str return str
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(model_input, PromptValue):
return input return model_input
if isinstance(input, str): if isinstance(model_input, str):
return StringPromptValue(text=input) return StringPromptValue(text=model_input)
if isinstance(input, Sequence): if isinstance(model_input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input)) return ChatPromptValue(messages=convert_to_messages(model_input))
msg = ( msg = (
f"Invalid input type {type(input)}. " f"Invalid input type {type(model_input)}. "
"Must be a PromptValue, str, or list of BaseMessages." "Must be a PromptValue, str, or list of BaseMessages."
) )
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004
@ -438,7 +438,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if max_concurrency is None: if max_concurrency is None:
try: try:
llm_result = self.generate_prompt( llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs], [self._convert_input(input_) for input_ in inputs],
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],
@ -484,7 +484,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if max_concurrency is None: if max_concurrency is None:
try: try:
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs], [self._convert_input(input_) for input_ in inputs],
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],

View File

@ -417,10 +417,10 @@ def add_ai_message_chunks(
else: else:
usage_metadata = None usage_metadata = None
id = None chunk_id = None
for id_ in [left.id] + [o.id for o in others]: for id_ in [left.id] + [o.id for o in others]:
if id_: if id_:
id = id_ chunk_id = id_
break break
return left.__class__( return left.__class__(
example=left.example, example=left.example,
@ -429,7 +429,7 @@ def add_ai_message_chunks(
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata, response_metadata=response_metadata,
usage_metadata=usage_metadata, usage_metadata=usage_metadata,
id=id, id=chunk_id,
) )

View File

@ -11,7 +11,11 @@ class RemoveMessage(BaseMessage):
type: Literal["remove"] = "remove" type: Literal["remove"] = "remove"
"""The type of the message (used for serialization). Defaults to "remove".""" """The type of the message (used for serialization). Defaults to "remove"."""
def __init__(self, id: str, **kwargs: Any) -> None: def __init__(
self,
id: str, # noqa: A002
**kwargs: Any,
) -> None:
"""Create a RemoveMessage. """Create a RemoveMessage.
Args: Args:

View File

@ -208,7 +208,12 @@ class ToolCall(TypedDict):
type: NotRequired[Literal["tool_call"]] type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall: def tool_call(
*,
name: str,
args: dict[str, Any],
id: Optional[str], # noqa: A002
) -> ToolCall:
"""Create a tool call. """Create a tool call.
Args: Args:
@ -254,7 +259,7 @@ def tool_call_chunk(
*, *,
name: Optional[str] = None, name: Optional[str] = None,
args: Optional[str] = None, args: Optional[str] = None,
id: Optional[str] = None, id: Optional[str] = None, # noqa: A002
index: Optional[int] = None, index: Optional[int] = None,
) -> ToolCallChunk: ) -> ToolCallChunk:
"""Create a tool call chunk. """Create a tool call chunk.
@ -292,7 +297,7 @@ def invalid_tool_call(
*, *,
name: Optional[str] = None, name: Optional[str] = None,
args: Optional[str] = None, args: Optional[str] = None,
id: Optional[str] = None, id: Optional[str] = None, # noqa: A002
error: Optional[str] = None, error: Optional[str] = None,
) -> InvalidToolCall: ) -> InvalidToolCall:
"""Create an invalid tool call. """Create an invalid tool call.

View File

@ -212,7 +212,7 @@ def _create_message_from_message_type(
name: Optional[str] = None, name: Optional[str] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None, tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None, id: Optional[str] = None, # noqa: A002
**additional_kwargs: Any, **additional_kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Create a message from a message type and content string. """Create a message from a message type and content string.

View File

@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import-untyped] import jsonpatch # type: ignore[import-untyped]
import pydantic import pydantic
from pydantic import SkipValidation from pydantic import SkipValidation
from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@ -47,6 +48,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""The Pydantic object to use for validation. """The Pydantic object to use for validation.
If None, no validation is performed.""" If None, no validation is performed."""
@override
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch

View File

@ -20,7 +20,10 @@ if TYPE_CHECKING:
T = TypeVar("T") T = TypeVar("T")
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: def droplastn(
iter: Iterator[T], # noqa: A002
n: int,
) -> Iterator[T]:
"""Drop the last n elements of an iterator. """Drop the last n elements of an iterator.
Args: Args:
@ -66,6 +69,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
""" """
raise NotImplementedError raise NotImplementedError
@override
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[list[str]]: ) -> Iterator[list[str]]:
@ -99,6 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
for part in self.parse(buffer): for part in self.parse(buffer):
yield [part] yield [part]
@override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[list[str]]: ) -> AsyncIterator[list[str]]:

View File

@ -72,6 +72,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _type(self) -> str: def _type(self) -> str:
return "json_functions" return "json_functions"
@override
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch

View File

@ -30,7 +30,10 @@ if TYPE_CHECKING:
class BaseTransformOutputParser(BaseOutputParser[T]): class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input.""" """Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]: def _transform(
self,
input: Iterator[Union[str, BaseMessage]], # noqa: A002
) -> Iterator[T]:
for chunk in input: for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)]) yield self.parse_result([ChatGeneration(message=chunk)])
@ -38,7 +41,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
yield self.parse_result([Generation(text=chunk)]) yield self.parse_result([Generation(text=chunk)])
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self,
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
@ -102,7 +106,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
parsed output, or just the current parsed output. parsed output, or just the current parsed output.
""" """
def _diff(self, prev: Optional[T], next: T) -> T: def _diff(
self,
prev: Optional[T],
next: T, # noqa: A002
) -> T:
"""Convert parsed outputs into a diff format. """Convert parsed outputs into a diff format.
The semantics of this are up to the output parser. The semantics of this are up to the output parser.
@ -116,6 +124,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
""" """
raise NotImplementedError raise NotImplementedError
@override
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
@ -140,6 +149,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
yield parsed yield parsed
prev_parsed = parsed prev_parsed = parsed
@override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:

View File

@ -8,6 +8,8 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder from xml.etree.ElementTree import TreeBuilder
from typing_extensions import override
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -234,6 +236,7 @@ class XMLOutputParser(BaseTransformOutputParser):
msg = f"Failed to parse XML format from completion {text}. Got: {e}" msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e raise OutputParserException(msg, llm_output=text) from e
@override
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
@ -242,6 +245,7 @@ class XMLOutputParser(BaseTransformOutputParser):
yield from streaming_parser.parse(chunk) yield from streaming_parser.parse(chunk)
streaming_parser.close() streaming_parser.close()
@override
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:

View File

@ -472,16 +472,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
img_template = cast("_ImageTemplateParam", tmpl)["image_url"] img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
input_variables = [] input_variables = []
if isinstance(img_template, str): if isinstance(img_template, str):
vars = get_template_variables(img_template, template_format) variables = get_template_variables(
if vars: img_template, template_format
if len(vars) > 1: )
if variables:
if len(variables) > 1:
msg = ( msg = (
"Only one format variable allowed per image" "Only one format variable allowed per image"
f" template.\nGot: {vars}" f" template.\nGot: {variables}"
f"\nFrom: {tmpl}" f"\nFrom: {tmpl}"
) )
raise ValueError(msg) raise ValueError(msg)
input_variables = [vars[0]] input_variables = [variables[0]]
img_template = {"url": img_template} img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate( img_template_obj = ImagePromptTemplate(
input_variables=input_variables, input_variables=input_variables,

View File

@ -1,3 +1,4 @@
# noqa:A005
"""BasePrompt schema definition.""" """BasePrompt schema definition."""
from __future__ import annotations from __future__ import annotations
@ -125,20 +126,20 @@ def mustache_template_vars(
Returns: Returns:
The variables from the template. The variables from the template.
""" """
vars: set[str] = set() variables: set[str] = set()
section_depth = 0 section_depth = 0
for type, key in mustache.tokenize(template): for type_, key in mustache.tokenize(template):
if type == "end": if type_ == "end":
section_depth -= 1 section_depth -= 1
elif ( elif (
type in ("variable", "section", "inverted section", "no escape") type_ in ("variable", "section", "inverted section", "no escape")
and key != "." and key != "."
and section_depth == 0 and section_depth == 0
): ):
vars.add(key.split(".")[0]) variables.add(key.split(".")[0])
if type in ("section", "inverted section"): if type_ in ("section", "inverted section"):
section_depth += 1 section_depth += 1
return vars return variables
Defs = dict[str, "Defs"] Defs = dict[str, "Defs"]
@ -158,17 +159,17 @@ def mustache_schema(
fields = {} fields = {}
prefix: tuple[str, ...] = () prefix: tuple[str, ...] = ()
section_stack: list[tuple[str, ...]] = [] section_stack: list[tuple[str, ...]] = []
for type, key in mustache.tokenize(template): for type_, key in mustache.tokenize(template):
if key == ".": if key == ".":
continue continue
if type == "end": if type_ == "end":
if section_stack: if section_stack:
prefix = section_stack.pop() prefix = section_stack.pop()
elif type in ("section", "inverted section"): elif type_ in ("section", "inverted section"):
section_stack.append(prefix) section_stack.append(prefix)
prefix = prefix + tuple(key.split(".")) prefix = prefix + tuple(key.split("."))
fields[prefix] = False fields[prefix] = False
elif type in ("variable", "no escape"): elif type_ in ("variable", "no escape"):
fields[prefix + tuple(key.split("."))] = True fields[prefix + tuple(key.split("."))] = True
defs: Defs = {} # None means leaf node defs: Defs = {} # None means leaf node
while fields: while fields:

View File

@ -209,6 +209,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
@override
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Document]: ) -> list[Document]:
@ -269,6 +270,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
) )
return result return result
@override
async def ainvoke( async def ainvoke(
self, self,
input: str, input: str,

View File

@ -724,7 +724,10 @@ class Runnable(Generic[Input, Output], ABC):
@abstractmethod @abstractmethod
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output: ) -> Output:
"""Transform a single input into an output. """Transform a single input into an output.
@ -741,7 +744,10 @@ class Runnable(Generic[Input, Output], ABC):
""" """
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output: ) -> Output:
"""Default implementation of ainvoke, calls invoke from a thread. """Default implementation of ainvoke, calls invoke from a thread.
@ -772,14 +778,14 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]:
if return_exceptions: if return_exceptions:
try: try:
return self.invoke(input, config, **kwargs) return self.invoke(input_, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return self.invoke(input, config, **kwargs) return self.invoke(input_, config, **kwargs)
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
@ -826,15 +832,17 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
def invoke( def invoke(
i: int, input: Input, config: RunnableConfig i: int, input_: Input, config: RunnableConfig
) -> tuple[int, Union[Output, Exception]]: ) -> tuple[int, Union[Output, Exception]]:
if return_exceptions: if return_exceptions:
try: try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs) out: Union[Output, Exception] = self.invoke(
input_, config, **kwargs
)
except Exception as e: except Exception as e:
out = e out = e
else: else:
out = self.invoke(input, config, **kwargs) out = self.invoke(input_, config, **kwargs)
return (i, out) return (i, out)
@ -844,8 +852,8 @@ class Runnable(Generic[Input, Output], ABC):
with get_executor_for_config(configs[0]) as executor: with get_executor_for_config(configs[0]) as executor:
futures = { futures = {
executor.submit(invoke, i, input, config) executor.submit(invoke, i, input_, config)
for i, (input, config) in enumerate(zip(inputs, configs)) for i, (input_, config) in enumerate(zip(inputs, configs))
} }
try: try:
@ -892,15 +900,15 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
async def ainvoke( async def ainvoke(
input: Input, config: RunnableConfig value: Input, config: RunnableConfig
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
if return_exceptions: if return_exceptions:
try: try:
return await self.ainvoke(input, config, **kwargs) return await self.ainvoke(value, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return await self.ainvoke(input, config, **kwargs) return await self.ainvoke(value, config, **kwargs)
coros = map(ainvoke, inputs, configs) coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@ -960,24 +968,24 @@ class Runnable(Generic[Input, Output], ABC):
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def ainvoke_task( async def ainvoke_task(
i: int, input: Input, config: RunnableConfig i: int, input_: Input, config: RunnableConfig
) -> tuple[int, Union[Output, Exception]]: ) -> tuple[int, Union[Output, Exception]]:
if return_exceptions: if return_exceptions:
try: try:
out: Union[Output, Exception] = await self.ainvoke( out: Union[Output, Exception] = await self.ainvoke(
input, config, **kwargs input_, config, **kwargs
) )
except Exception as e: except Exception as e:
out = e out = e
else: else:
out = await self.ainvoke(input, config, **kwargs) out = await self.ainvoke(input_, config, **kwargs)
return (i, out) return (i, out)
coros = [ coros = [
gated_coro(semaphore, ainvoke_task(i, input, config)) gated_coro(semaphore, ainvoke_task(i, input_, config))
if semaphore if semaphore
else ainvoke_task(i, input, config) else ainvoke_task(i, input_, config)
for i, (input, config) in enumerate(zip(inputs, configs)) for i, (input_, config) in enumerate(zip(inputs, configs))
] ]
for coro in asyncio.as_completed(coros): for coro in asyncio.as_completed(coros):
@ -985,7 +993,7 @@ class Runnable(Generic[Input, Output], ABC):
def stream( def stream(
self, self,
input: Input, input: Input, # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
@ -1005,7 +1013,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream( async def astream(
self, self,
input: Input, input: Input, # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
@ -1059,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream_log( async def astream_log(
self, self,
input: Any, input: Any, # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
diff: bool = True, diff: bool = True,
@ -1130,7 +1138,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream_events( async def astream_events(
self, self,
input: Any, input: Any, # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
version: Literal["v1", "v2"] = "v2", version: Literal["v1", "v2"] = "v2",
@ -1396,7 +1404,7 @@ class Runnable(Generic[Input, Output], ABC):
def transform( def transform(
self, self,
input: Iterator[Input], input: Iterator[Input], # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
@ -1438,7 +1446,7 @@ class Runnable(Generic[Input, Output], ABC):
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Input], input: AsyncIterator[Input], # noqa: A002
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
@ -1903,7 +1911,7 @@ class Runnable(Generic[Input, Output], ABC):
Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
], ],
input: Input, input_: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
serialized: Optional[dict[str, Any]] = None, serialized: Optional[dict[str, Any]] = None,
@ -1917,7 +1925,7 @@ class Runnable(Generic[Input, Output], ABC):
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
serialized, serialized,
input, input_,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
@ -1930,7 +1938,7 @@ class Runnable(Generic[Input, Output], ABC):
context.run( context.run(
call_func_with_variable_args, # type: ignore[arg-type] call_func_with_variable_args, # type: ignore[arg-type]
func, func,
input, input_,
config, config,
run_manager, run_manager,
**kwargs, **kwargs,
@ -1953,7 +1961,7 @@ class Runnable(Generic[Input, Output], ABC):
Awaitable[Output], Awaitable[Output],
], ],
], ],
input: Input, input_: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
serialized: Optional[dict[str, Any]] = None, serialized: Optional[dict[str, Any]] = None,
@ -1967,7 +1975,7 @@ class Runnable(Generic[Input, Output], ABC):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
serialized, serialized,
input, input_,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
@ -1976,7 +1984,7 @@ class Runnable(Generic[Input, Output], ABC):
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
with set_config_context(child_config) as context: with set_config_context(child_config) as context:
coro = acall_func_with_variable_args( coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs func, input_, config, run_manager, **kwargs
) )
output: Output = await coro_with_context(coro, context) output: Output = await coro_with_context(coro, context)
except BaseException as e: except BaseException as e:
@ -1999,7 +2007,7 @@ class Runnable(Generic[Input, Output], ABC):
list[Union[Exception, Output]], list[Union[Exception, Output]],
], ],
], ],
input: list[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
@ -2011,21 +2019,21 @@ class Runnable(Generic[Input, Output], ABC):
Helper method to transform an Input value to an Output value, Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses. with callbacks. Use this method to implement invoke() in subclasses.
""" """
if not input: if not inputs:
return [] return []
configs = get_config_list(config, len(input)) configs = get_config_list(config, len(inputs))
callback_managers = [get_callback_manager_for_config(c) for c in configs] callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [ run_managers = [
callback_manager.on_chain_start( callback_manager.on_chain_start(
None, None,
input, input_,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for callback_manager, input, config in zip( for callback_manager, input_, config in zip(
callback_managers, input, configs callback_managers, inputs, configs
) )
] ]
try: try:
@ -2036,12 +2044,12 @@ class Runnable(Generic[Input, Output], ABC):
] ]
if accepts_run_manager(func): if accepts_run_manager(func):
kwargs["run_manager"] = run_managers kwargs["run_manager"] = run_managers
output = func(input, **kwargs) # type: ignore[call-arg] output = func(inputs, **kwargs) # type: ignore[call-arg]
except BaseException as e: except BaseException as e:
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
if return_exceptions: if return_exceptions:
return cast("list[Output]", [e for _ in input]) return cast("list[Output]", [e for _ in inputs])
raise raise
else: else:
first_exception: Optional[Exception] = None first_exception: Optional[Exception] = None
@ -2072,7 +2080,7 @@ class Runnable(Generic[Input, Output], ABC):
Awaitable[list[Union[Exception, Output]]], Awaitable[list[Union[Exception, Output]]],
], ],
], ],
input: list[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
@ -2085,22 +2093,22 @@ class Runnable(Generic[Input, Output], ABC):
with callbacks. with callbacks.
Use this method to implement invoke() in subclasses. Use this method to implement invoke() in subclasses.
""" """
if not input: if not inputs:
return [] return []
configs = get_config_list(config, len(input)) configs = get_config_list(config, len(inputs))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs] callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
callback_manager.on_chain_start( callback_manager.on_chain_start(
None, None,
input, input_,
run_type=run_type, run_type=run_type,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for callback_manager, input, config in zip( for callback_manager, input_, config in zip(
callback_managers, input, configs callback_managers, inputs, configs
) )
) )
) )
@ -2112,13 +2120,13 @@ class Runnable(Generic[Input, Output], ABC):
] ]
if accepts_run_manager(func): if accepts_run_manager(func):
kwargs["run_manager"] = run_managers kwargs["run_manager"] = run_managers
output = await func(input, **kwargs) # type: ignore[call-arg] output = await func(inputs, **kwargs) # type: ignore[call-arg]
except BaseException as e: except BaseException as e:
await asyncio.gather( await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers) *(run_manager.on_chain_error(e) for run_manager in run_managers)
) )
if return_exceptions: if return_exceptions:
return cast("list[Output]", [e for _ in input]) return cast("list[Output]", [e for _ in inputs])
raise raise
else: else:
first_exception: Optional[Exception] = None first_exception: Optional[Exception] = None
@ -2136,7 +2144,7 @@ class Runnable(Generic[Input, Output], ABC):
def _transform_stream_with_config( def _transform_stream_with_config(
self, self,
input: Iterator[Input], inputs: Iterator[Input],
transformer: Union[ transformer: Union[
Callable[[Iterator[Input]], Iterator[Output]], Callable[[Iterator[Input]], Iterator[Output]],
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
@ -2163,7 +2171,7 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2) input_for_tracing, input_for_transform = tee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one # Start the input iterator to ensure the input Runnable starts before this one
final_input: Optional[Input] = next(input_for_tracing, None) final_input: Optional[Input] = next(input_for_tracing, None)
final_input_supported = True final_input_supported = True
@ -2237,7 +2245,7 @@ class Runnable(Generic[Input, Output], ABC):
async def _atransform_stream_with_config( async def _atransform_stream_with_config(
self, self,
input: AsyncIterator[Input], inputs: AsyncIterator[Input],
transformer: Union[ transformer: Union[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]], Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Callable[ Callable[
@ -2267,7 +2275,7 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2) input_for_tracing, input_for_transform = atee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one # Start the input iterator to ensure the input Runnable starts before this one
final_input: Optional[Input] = await py_anext(input_for_tracing, None) final_input: Optional[Input] = await py_anext(input_for_tracing, None)
final_input_supported = True final_input_supported = True
@ -3019,6 +3027,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
input_ = input
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -3029,16 +3038,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
with set_config_context(config) as context: with set_config_context(config) as context:
if i == 0: if i == 0:
input = context.run(step.invoke, input, config, **kwargs) input_ = context.run(step.invoke, input_, config, **kwargs)
else: else:
input = context.run(step.invoke, input, config) input_ = context.run(step.invoke, input_, config)
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
else: else:
run_manager.on_chain_end(input) run_manager.on_chain_end(input_)
return cast("Output", input) return cast("Output", input_)
@override @override
async def ainvoke( async def ainvoke(
@ -3059,6 +3068,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
input_ = input
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -3069,17 +3079,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
with set_config_context(config) as context: with set_config_context(config) as context:
if i == 0: if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs) part = functools.partial(step.ainvoke, input_, config, **kwargs)
else: else:
part = functools.partial(step.ainvoke, input, config) part = functools.partial(step.ainvoke, input_, config)
input = await coro_with_context(part(), context, create_task=True) input_ = await coro_with_context(part(), context, create_task=True)
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
else: else:
await run_manager.on_chain_end(input) await run_manager.on_chain_end(input_)
return cast("Output", input) return cast("Output", input_)
@override @override
def batch( def batch(
@ -3117,11 +3127,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_managers = [ run_managers = [
cm.on_chain_start( cm.on_chain_start(
None, None,
input, input_,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input_, config in zip(callback_managers, inputs, configs)
] ]
# invoke # invoke
@ -3248,11 +3258,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
*( *(
cm.on_chain_start( cm.on_chain_start(
None, None,
input, input_,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input_, config in zip(callback_managers, inputs, configs)
) )
) )
@ -3346,7 +3356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def _transform( def _transform(
self, self,
input: Iterator[Input], inputs: Iterator[Input],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -3359,7 +3369,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# transform the input stream of each step with the next # transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will # steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output # buffer input in memory until all available, and then start emitting output
final_pipeline = cast("Iterator[Output]", input) final_pipeline = cast("Iterator[Output]", inputs)
for idx, step in enumerate(steps): for idx, step in enumerate(steps):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}") config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}")
@ -3373,7 +3383,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Input], inputs: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -3387,7 +3397,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# transform the input stream of each step with the next # transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will # steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output # buffer input in memory until all available, and then start emitting output
final_pipeline = cast("AsyncIterator[Output]", input) final_pipeline = cast("AsyncIterator[Output]", inputs)
for idx, step in enumerate(steps): for idx, step in enumerate(steps):
config = patch_config( config = patch_config(
config, config,
@ -3733,7 +3743,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
) )
def _invoke_step( def _invoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
) -> Any: ) -> Any:
child_config = patch_config( child_config = patch_config(
config, config,
@ -3743,7 +3753,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
with set_config_context(child_config) as context: with set_config_context(child_config) as context:
return context.run( return context.run(
step.invoke, step.invoke,
input, input_,
child_config, child_config,
) )
@ -3785,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
) )
async def _ainvoke_step( async def _ainvoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
) -> Any: ) -> Any:
child_config = patch_config( child_config = patch_config(
config, config,
@ -3793,7 +3803,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
) )
with set_config_context(child_config) as context: with set_config_context(child_config) as context:
return await coro_with_context( return await coro_with_context(
step.ainvoke(input, child_config), context, create_task=True step.ainvoke(input_, child_config), context, create_task=True
) )
# gather results from all steps # gather results from all steps
@ -3823,7 +3833,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
def _transform( def _transform(
self, self,
input: Iterator[Input], inputs: Iterator[Input],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
@ -3831,7 +3841,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
steps = dict(self.steps__) steps = dict(self.steps__)
# Each step gets a copy of the input iterator, # Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread. # which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock())) input_copies = list(safetee(inputs, len(steps), lock=threading.Lock()))
with get_executor_for_config(config) as executor: with get_executor_for_config(config) as executor:
# Create the transform() generator for each step # Create the transform() generator for each step
named_generators = [ named_generators = [
@ -3890,7 +3900,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Input], inputs: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
@ -3898,7 +3908,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
steps = dict(self.steps__) steps = dict(self.steps__)
# Each step gets a copy of the input iterator, # Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread. # which is consumed in parallel in a separate thread.
input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) input_copies = list(atee(inputs, len(steps), lock=asyncio.Lock()))
# Create the transform() generator for each step # Create the transform() generator for each step
named_generators = [ named_generators = [
( (
@ -4590,7 +4600,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _invoke( def _invoke(
self, self,
input: Input, input_: Input,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -4599,7 +4609,7 @@ class RunnableLambda(Runnable[Input, Output]):
output: Optional[Output] = None output: Optional[Output] = None
for chunk in call_func_with_variable_args( for chunk in call_func_with_variable_args(
cast("Callable[[Input], Iterator[Output]]", self.func), cast("Callable[[Input], Iterator[Output]]", self.func),
input, input_,
config, config,
run_manager, run_manager,
**kwargs, **kwargs,
@ -4613,18 +4623,18 @@ class RunnableLambda(Runnable[Input, Output]):
output = chunk output = chunk
else: else:
output = call_func_with_variable_args( output = call_func_with_variable_args(
self.func, input, config, run_manager, **kwargs self.func, input_, config, run_manager, **kwargs
) )
# If the output is a Runnable, invoke it # If the output is a Runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
if recursion_limit <= 0: if recursion_limit <= 0:
msg = ( msg = (
f"Recursion limit reached when invoking {self} with input {input}." f"Recursion limit reached when invoking {self} with input {input_}."
) )
raise RecursionError(msg) raise RecursionError(msg)
output = output.invoke( output = output.invoke(
input, input_,
patch_config( patch_config(
config, config,
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(),
@ -4635,7 +4645,7 @@ class RunnableLambda(Runnable[Input, Output]):
async def _ainvoke( async def _ainvoke(
self, self,
input: Input, value: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -4646,7 +4656,7 @@ class RunnableLambda(Runnable[Input, Output]):
if inspect.isgeneratorfunction(self.func): if inspect.isgeneratorfunction(self.func):
def func( def func(
input: Input, value: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -4654,7 +4664,7 @@ class RunnableLambda(Runnable[Input, Output]):
output: Optional[Output] = None output: Optional[Output] = None
for chunk in call_func_with_variable_args( for chunk in call_func_with_variable_args(
cast("Callable[[Input], Iterator[Output]]", self.func), cast("Callable[[Input], Iterator[Output]]", self.func),
input, value,
config, config,
run_manager.get_sync(), run_manager.get_sync(),
**kwargs, **kwargs,
@ -4671,13 +4681,13 @@ class RunnableLambda(Runnable[Input, Output]):
else: else:
def func( def func(
input: Input, value: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Output: ) -> Output:
return call_func_with_variable_args( return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs self.func, value, config, run_manager.get_sync(), **kwargs
) )
@wraps(func) @wraps(func)
@ -4693,7 +4703,7 @@ class RunnableLambda(Runnable[Input, Output]):
"AsyncGenerator[Any, Any]", "AsyncGenerator[Any, Any]",
acall_func_with_variable_args( acall_func_with_variable_args(
cast("Callable", afunc), cast("Callable", afunc),
input, value,
config, config,
run_manager, run_manager,
**kwargs, **kwargs,
@ -4713,18 +4723,18 @@ class RunnableLambda(Runnable[Input, Output]):
output = chunk output = chunk
else: else:
output = await acall_func_with_variable_args( output = await acall_func_with_variable_args(
cast("Callable", afunc), input, config, run_manager, **kwargs cast("Callable", afunc), value, config, run_manager, **kwargs
) )
# If the output is a Runnable, invoke it # If the output is a Runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
if recursion_limit <= 0: if recursion_limit <= 0:
msg = ( msg = (
f"Recursion limit reached when invoking {self} with input {input}." f"Recursion limit reached when invoking {self} with input {value}."
) )
raise RecursionError(msg) raise RecursionError(msg)
output = await output.ainvoke( output = await output.ainvoke(
input, value,
patch_config( patch_config(
config, config,
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(),
@ -4789,14 +4799,14 @@ class RunnableLambda(Runnable[Input, Output]):
def _transform( def _transform(
self, self,
input: Iterator[Input], chunks: Iterator[Input],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
final: Input final: Input
got_first_val = False got_first_val = False
for ichunk in input: for ichunk in chunks:
# By definitions, RunnableLambdas consume all input before emitting output. # By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can # If the input is not addable, then we'll assume that we can
# only operate on the last chunk. # only operate on the last chunk.
@ -4881,14 +4891,14 @@ class RunnableLambda(Runnable[Input, Output]):
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Input], chunks: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
final: Input final: Input
got_first_val = False got_first_val = False
async for ichunk in input: async for ichunk in chunks:
# By definitions, RunnableLambdas consume all input before emitting output. # By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can # If the input is not addable, then we'll assume that we can
# only operate on the last chunk. # only operate on the last chunk.
@ -4913,13 +4923,13 @@ class RunnableLambda(Runnable[Input, Output]):
raise TypeError(msg) raise TypeError(msg)
def func( def func(
input: Input, input_: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Output: ) -> Output:
return call_func_with_variable_args( return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs self.func, input_, config, run_manager.get_sync(), **kwargs
) )
@wraps(func) @wraps(func)

View File

@ -196,6 +196,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
raise ValueError(msg) raise ValueError(msg)
return specs return specs
@override
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
@ -254,6 +255,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
run_manager.on_chain_end(output) run_manager.on_chain_end(output)
return output return output
@override
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
@ -302,6 +304,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
await run_manager.on_chain_end(output) await run_manager.on_chain_end(output)
return output return output
@override
def stream( def stream(
self, self,
input: Input, input: Input,
@ -388,6 +391,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
raise raise
run_manager.on_chain_end(final_output) run_manager.on_chain_end(final_output)
@override
async def astream( async def astream(
self, self,
input: Input, input: Input,

View File

@ -401,7 +401,7 @@ def call_func_with_variable_args(
Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
], ],
input: Input, input: Input, # noqa: A002
config: RunnableConfig, config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any, **kwargs: Any,
@ -438,7 +438,7 @@ def acall_func_with_variable_args(
Awaitable[Output], Awaitable[Output],
], ],
], ],
input: Input, input: Input, # noqa: A002
config: RunnableConfig, config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any, **kwargs: Any,

View File

@ -178,16 +178,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def invoke( def invoke(
prepared: tuple[Runnable[Input, Output], RunnableConfig], prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input_: Input,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared bound, config = prepared
if return_exceptions: if return_exceptions:
try: try:
return bound.invoke(input, config, **kwargs) return bound.invoke(input_, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return bound.invoke(input, config, **kwargs) return bound.invoke(input_, config, **kwargs)
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
@ -221,16 +221,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
async def ainvoke( async def ainvoke(
prepared: tuple[Runnable[Input, Output], RunnableConfig], prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input_: Input,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared bound, config = prepared
if return_exceptions: if return_exceptions:
try: try:
return await bound.ainvoke(input, config, **kwargs) return await bound.ainvoke(input_, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return await bound.ainvoke(input, config, **kwargs) return await bound.ainvoke(input_, config, **kwargs)
coros = map(ainvoke, prepared, inputs) coros = map(ainvoke, prepared, inputs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)

View File

@ -269,7 +269,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs isinstance(input_, dict) for input_ in inputs
): ):
msg = ( msg = (
"If 'exception_key' is specified then inputs must be dictionaries." "If 'exception_key' is specified then inputs must be dictionaries."
@ -298,11 +298,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_managers = [ run_managers = [
cm.on_chain_start( cm.on_chain_start(
None, None,
input if isinstance(input, dict) else {"input": input}, input_ if isinstance(input_, dict) else {"input": input_},
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input_, config in zip(callback_managers, inputs, configs)
] ]
to_return: dict[int, Any] = {} to_return: dict[int, Any] = {}
@ -311,7 +311,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:
outputs = runnable.batch( outputs = runnable.batch(
[input for _, input in sorted(run_again.items())], [input_ for _, input_ in sorted(run_again.items())],
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child()) patch_config(configs[i], callbacks=run_managers[i].get_child())
@ -320,7 +320,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return_exceptions=True, return_exceptions=True,
**kwargs, **kwargs,
) )
for (i, input), output in zip(sorted(run_again.copy().items()), outputs): for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance( if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle output, self.exceptions_to_handle
): ):
@ -331,7 +331,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_again.pop(i) run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle): elif isinstance(output, self.exceptions_to_handle):
if self.exception_key: if self.exception_key:
input[self.exception_key] = output # type: ignore[index] input_[self.exception_key] = output # type: ignore[index]
handled_exceptions[i] = output handled_exceptions[i] = output
else: else:
run_managers[i].on_chain_end(output) run_managers[i].on_chain_end(output)
@ -363,7 +363,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
from langchain_core.callbacks.manager import AsyncCallbackManager from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs isinstance(input_, dict) for input_ in inputs
): ):
msg = ( msg = (
"If 'exception_key' is specified then inputs must be dictionaries." "If 'exception_key' is specified then inputs must be dictionaries."
@ -393,11 +393,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
*( *(
cm.on_chain_start( cm.on_chain_start(
None, None,
input, input_,
name=config.get("run_name") or self.get_name(), name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input_, config in zip(callback_managers, inputs, configs)
) )
) )
@ -407,7 +407,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:
outputs = await runnable.abatch( outputs = await runnable.abatch(
[input for _, input in sorted(run_again.items())], [input_ for _, input_ in sorted(run_again.items())],
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child()) patch_config(configs[i], callbacks=run_managers[i].get_child())
@ -417,7 +417,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
**kwargs, **kwargs,
) )
for (i, input), output in zip(sorted(run_again.copy().items()), outputs): for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance( if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle output, self.exceptions_to_handle
): ):
@ -428,7 +428,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_again.pop(i) run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle): elif isinstance(output, self.exceptions_to_handle):
if self.exception_key: if self.exception_key:
input[self.exception_key] = output # type: ignore[index] input_[self.exception_key] = output # type: ignore[index]
handled_exceptions[i] = output handled_exceptions[i] = output
else: else:
to_return[i] = output to_return[i] = output

View File

@ -111,7 +111,12 @@ class Node(NamedTuple):
data: Union[type[BaseModel], RunnableType, None] data: Union[type[BaseModel], RunnableType, None]
metadata: Optional[dict[str, Any]] metadata: Optional[dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node: def copy(
self,
*,
id: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
) -> Node:
"""Return a copy of the node with optional new id and name. """Return a copy of the node with optional new id and name.
Args: Args:
@ -181,7 +186,10 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str: def node_data_str(
id: str, # noqa: A002
data: Union[type[BaseModel], RunnableType, None],
) -> str:
"""Convert the data of a node to a string. """Convert the data of a node to a string.
Args: Args:
@ -320,7 +328,7 @@ class Graph:
def add_node( def add_node(
self, self,
data: Union[type[BaseModel], RunnableType, None], data: Union[type[BaseModel], RunnableType, None],
id: Optional[str] = None, id: Optional[str] = None, # noqa: A002
*, *,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
) -> Node: ) -> Node:
@ -340,8 +348,8 @@ class Graph:
if id is not None and id in self.nodes: if id is not None and id in self.nodes:
msg = f"Node with id {id} already exists" msg = f"Node with id {id} already exists"
raise ValueError(msg) raise ValueError(msg)
id = id or self.next_id() id_ = id or self.next_id()
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data)) node = Node(id=id_, data=data, metadata=metadata, name=node_data_str(id_, data))
self.nodes[node.id] = node self.nodes[node.id] = node
return node return node
@ -406,8 +414,8 @@ class Graph:
if all(is_uuid(node.id) for node in graph.nodes.values()): if all(is_uuid(node.id) for node in graph.nodes.values()):
prefix = "" prefix = ""
def prefixed(id: str) -> str: def prefixed(id_: str) -> str:
return f"{prefix}:{id}" if prefix else id return f"{prefix}:{id_}" if prefix else id_
# prefix each node # prefix each node
self.nodes.update( self.nodes.update(
@ -450,8 +458,8 @@ class Graph:
return Graph( return Graph(
nodes={ nodes={
_get_node_id(id): node.copy(id=_get_node_id(id)) _get_node_id(id_): node.copy(id=_get_node_id(id_))
for id, node in self.nodes.items() for id_, node in self.nodes.items()
}, },
edges=[ edges=[
edge.copy( edge.copy(

View File

@ -187,7 +187,7 @@ def _build_sugiyama_layout(
# Y # Y
# #
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()} vertices_ = {id_: Vertex(f" {data} ") for id_, data in vertices.items()}
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges] edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
vertices_list = vertices_.values() vertices_list = vertices_.values()
graph = Graph(vertices_list, edges_) graph = Graph(vertices_list, edges_)

View File

@ -509,20 +509,20 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) )
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]: def _enter_history(self, value: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy() messages = hist.messages.copy()
if not self.history_messages_key: if not self.history_messages_key:
# return all messages # return all messages
input_val = ( input_val = (
input if not self.input_messages_key else input[self.input_messages_key] value if not self.input_messages_key else value[self.input_messages_key]
) )
messages += self._get_input_messages(input_val) messages += self._get_input_messages(input_val)
return messages return messages
async def _aenter_history( async def _aenter_history(
self, input: dict[str, Any], config: RunnableConfig self, value: dict[str, Any], config: RunnableConfig
) -> list[BaseMessage]: ) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = (await hist.aget_messages()).copy() messages = (await hist.aget_messages()).copy()
@ -530,7 +530,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if not self.history_messages_key: if not self.history_messages_key:
# return all messages # return all messages
input_val = ( input_val = (
input if not self.input_messages_key else input[self.input_messages_key] value if not self.input_messages_key else value[self.input_messages_key]
) )
messages += self._get_input_messages(input_val) messages += self._get_input_messages(input_val)
return messages return messages

View File

@ -483,19 +483,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _invoke( def _invoke(
self, self,
input: dict[str, Any], value: dict[str, Any],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
if not isinstance(input, dict): if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict." msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004
return { return {
**input, **value,
**self.mapper.invoke( **self.mapper.invoke(
input, value,
patch_config(config, callbacks=run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
**kwargs, **kwargs,
), ),
@ -512,19 +512,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _ainvoke( async def _ainvoke(
self, self,
input: dict[str, Any], value: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
if not isinstance(input, dict): if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict." msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004
return { return {
**input, **value,
**await self.mapper.ainvoke( **await self.mapper.ainvoke(
input, value,
patch_config(config, callbacks=run_manager.get_child()), patch_config(config, callbacks=run_manager.get_child()),
**kwargs, **kwargs,
), ),
@ -541,7 +541,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _transform( def _transform(
self, self,
input: Iterator[dict[str, Any]], values: Iterator[dict[str, Any]],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -549,7 +549,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
# collect mapper keys # collect mapper keys
mapper_keys = set(self.mapper.steps__.keys()) mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough # create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) for_passthrough, for_map = safetee(values, 2, lock=threading.Lock())
# create map output stream # create map output stream
map_output = self.mapper.transform( map_output = self.mapper.transform(
@ -598,7 +598,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[dict[str, Any]], values: AsyncIterator[dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -606,7 +606,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
# collect mapper keys # collect mapper keys
mapper_keys = set(self.mapper.steps__.keys()) mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough # create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) for_passthrough, for_map = atee(values, 2, lock=asyncio.Lock())
# create map output stream # create map output stream
map_output = self.mapper.atransform( map_output = self.mapper.atransform(
for_map, for_map,
@ -731,23 +731,23 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
) )
return super().get_name(suffix, name=name) return super().get_name(suffix, name=name)
def _pick(self, input: dict[str, Any]) -> Any: def _pick(self, value: dict[str, Any]) -> Any:
if not isinstance(input, dict): if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict." msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004 raise ValueError(msg) # noqa: TRY004
if isinstance(self.keys, str): if isinstance(self.keys, str):
return input.get(self.keys) return value.get(self.keys)
picked = {k: input.get(k) for k in self.keys if k in input} picked = {k: value.get(k) for k in self.keys if k in value}
if picked: if picked:
return AddableDict(picked) return AddableDict(picked)
return None return None
def _invoke( def _invoke(
self, self,
input: dict[str, Any], value: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return self._pick(input) return self._pick(value)
@override @override
def invoke( def invoke(
@ -760,9 +760,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _ainvoke( async def _ainvoke(
self, self,
input: dict[str, Any], value: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return self._pick(input) return self._pick(value)
@override @override
async def ainvoke( async def ainvoke(
@ -775,9 +775,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _transform( def _transform(
self, self,
input: Iterator[dict[str, Any]], chunks: Iterator[dict[str, Any]],
) -> Iterator[dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
for chunk in input: for chunk in chunks:
picked = self._pick(chunk) picked = self._pick(chunk)
if picked is not None: if picked is not None:
yield picked yield picked
@ -795,9 +795,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[dict[str, Any]], chunks: AsyncIterator[dict[str, Any]],
) -> AsyncIterator[dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async for chunk in input: async for chunk in chunks:
picked = self._pick(chunk) picked = self._pick(chunk)
if picked is not None: if picked is not None:
yield picked yield picked

View File

@ -178,7 +178,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _invoke( def _invoke(
self, self,
input: Input, input_: Input,
run_manager: "CallbackManagerForChainRun", run_manager: "CallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -186,7 +186,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
for attempt in self._sync_retrying(reraise=True): for attempt in self._sync_retrying(reraise=True):
with attempt: with attempt:
result = super().invoke( result = super().invoke(
input, input_,
self._patch_config(config, run_manager, attempt.retry_state), self._patch_config(config, run_manager, attempt.retry_state),
**kwargs, **kwargs,
) )
@ -202,7 +202,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async def _ainvoke( async def _ainvoke(
self, self,
input: Input, input_: Input,
run_manager: "AsyncCallbackManagerForChainRun", run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
@ -210,7 +210,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async for attempt in self._async_retrying(reraise=True): async for attempt in self._async_retrying(reraise=True):
with attempt: with attempt:
result = await super().ainvoke( result = await super().ainvoke(
input, input_,
self._patch_config(config, run_manager, attempt.retry_state), self._patch_config(config, run_manager, attempt.retry_state),
**kwargs, **kwargs,
) )

View File

@ -148,22 +148,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
if not inputs: if not inputs:
return [] return []
keys = [input["key"] for input in inputs] keys = [input_["key"] for input_ in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input_["input"] for input_ in inputs]
if any(key not in self.runnables for key in keys): if any(key not in self.runnables for key in keys):
msg = "One or more keys do not have a corresponding runnable" msg = "One or more keys do not have a corresponding runnable"
raise ValueError(msg) raise ValueError(msg)
def invoke( def invoke(
runnable: Runnable, input: Input, config: RunnableConfig runnable: Runnable, input_: Input, config: RunnableConfig
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
if return_exceptions: if return_exceptions:
try: try:
return runnable.invoke(input, config, **kwargs) return runnable.invoke(input_, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return runnable.invoke(input, config, **kwargs) return runnable.invoke(input_, config, **kwargs)
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
@ -185,22 +185,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
if not inputs: if not inputs:
return [] return []
keys = [input["key"] for input in inputs] keys = [input_["key"] for input_ in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input_["input"] for input_ in inputs]
if any(key not in self.runnables for key in keys): if any(key not in self.runnables for key in keys):
msg = "One or more keys do not have a corresponding runnable" msg = "One or more keys do not have a corresponding runnable"
raise ValueError(msg) raise ValueError(msg)
async def ainvoke( async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig runnable: Runnable, input_: Input, config: RunnableConfig
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
if return_exceptions: if return_exceptions:
try: try:
return await runnable.ainvoke(input, config, **kwargs) return await runnable.ainvoke(input_, config, **kwargs)
except Exception as e: except Exception as e:
return e return e
else: else:
return await runnable.ainvoke(input, config, **kwargs) return await runnable.ainvoke(input_, config, **kwargs)
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))

View File

@ -75,7 +75,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros)) return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
def accepts_run_manager(callable: Callable[..., Any]) -> bool: def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a run_manager argument. """Check if a callable accepts a run_manager argument.
Args: Args:
@ -90,7 +90,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
return False return False
def accepts_config(callable: Callable[..., Any]) -> bool: def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a config argument. """Check if a callable accepts a config argument.
Args: Args:
@ -105,7 +105,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
return False return False
def accepts_context(callable: Callable[..., Any]) -> bool: def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a context argument. """Check if a callable accepts a context argument.
Args: Args:
@ -691,7 +691,7 @@ def get_unique_config_specs(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
) )
unique: list[ConfigurableFieldSpec] = [] unique: list[ConfigurableFieldSpec] = []
for id, dupes in grouped: for spec_id, dupes in grouped:
first = next(dupes) first = next(dupes)
others = list(dupes) others = list(dupes)
if len(others) == 0 or all(o == first for o in others): if len(others) == 0 or all(o == first for o in others):
@ -699,7 +699,7 @@ def get_unique_config_specs(
else: else:
msg = ( msg = (
"RunnableSequence contains conflicting config specs" "RunnableSequence contains conflicting config specs"
f"for {id}: {[first] + others}" f"for {spec_id}: {[first] + others}"
) )
raise ValueError(msg) raise ValueError(msg)
return unique return unique

View File

@ -184,7 +184,7 @@ class StructuredQuery(Expr):
def __init__( def __init__(
self, self,
query: str, query: str,
filter: Optional[FilterDirective], filter: Optional[FilterDirective], # noqa: A002
limit: Optional[int] = None, limit: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:

View File

@ -943,17 +943,17 @@ def _handle_tool_error(
def _prep_run_args( def _prep_run_args(
input: Union[str, dict, ToolCall], value: Union[str, dict, ToolCall],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
**kwargs: Any, **kwargs: Any,
) -> tuple[Union[str, dict], dict]: ) -> tuple[Union[str, dict], dict]:
config = ensure_config(config) config = ensure_config(config)
if _is_tool_call(input): if _is_tool_call(value):
tool_call_id: Optional[str] = cast("ToolCall", input)["id"] tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
tool_input: Union[str, dict] = cast("ToolCall", input)["args"].copy() tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
else: else:
tool_call_id = None tool_call_id = None
tool_input = cast("Union[str, dict]", input) tool_input = cast("Union[str, dict]", value)
return ( return (
tool_input, tool_input,
dict( dict(

View File

@ -740,7 +740,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def _astream_events_implementation_v1( async def _astream_events_implementation_v1(
runnable: Runnable[Input, Output], runnable: Runnable[Input, Output],
input: Any, value: Any,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
@ -789,7 +789,7 @@ async def _astream_events_implementation_v1(
async for log in _astream_log_implementation( async for log in _astream_log_implementation(
runnable, runnable,
input, value,
config=config, config=config,
stream=stream, stream=stream,
diff=True, diff=True,
@ -810,7 +810,7 @@ async def _astream_events_implementation_v1(
tags=root_tags, tags=root_tags,
metadata=root_metadata, metadata=root_metadata,
data={ data={
"input": input, "input": value,
}, },
parent_ids=[], # Not supported in v1 parent_ids=[], # Not supported in v1
) )
@ -924,7 +924,7 @@ async def _astream_events_implementation_v1(
async def _astream_events_implementation_v2( async def _astream_events_implementation_v2(
runnable: Runnable[Input, Output], runnable: Runnable[Input, Output],
input: Any, value: Any,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
@ -972,7 +972,7 @@ async def _astream_events_implementation_v2(
async def consume_astream() -> None: async def consume_astream() -> None:
try: try:
# if astream also calls tap_output_aiter this will be a no-op # if astream also calls tap_output_aiter this will be a no-op
async with aclosing(runnable.astream(input, config, **kwargs)) as stream: async with aclosing(runnable.astream(value, config, **kwargs)) as stream:
async for _ in event_streamer.tap_output_aiter(run_id, stream): async for _ in event_streamer.tap_output_aiter(run_id, stream):
# All the content will be picked up # All the content will be picked up
pass pass
@ -993,7 +993,7 @@ async def _astream_events_implementation_v2(
# chain are not available until the entire input is consumed. # chain are not available until the entire input is consumed.
# As a temporary solution, we'll modify the input to be the input # As a temporary solution, we'll modify the input to be the input
# that was passed into the chain. # that was passed into the chain.
event["data"]["input"] = input event["data"]["input"] = value
first_event_run_id = event["run_id"] first_event_run_id = event["run_id"]
yield event yield event
continue continue

View File

@ -580,7 +580,7 @@ def _get_standardized_outputs(
@overload @overload
def _astream_log_implementation( def _astream_log_implementation(
runnable: Runnable[Input, Output], runnable: Runnable[Input, Output],
input: Any, value: Any,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stream: LogStreamCallbackHandler, stream: LogStreamCallbackHandler,
@ -593,7 +593,7 @@ def _astream_log_implementation(
@overload @overload
def _astream_log_implementation( def _astream_log_implementation(
runnable: Runnable[Input, Output], runnable: Runnable[Input, Output],
input: Any, value: Any,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stream: LogStreamCallbackHandler, stream: LogStreamCallbackHandler,
@ -605,7 +605,7 @@ def _astream_log_implementation(
async def _astream_log_implementation( async def _astream_log_implementation(
runnable: Runnable[Input, Output], runnable: Runnable[Input, Output],
input: Any, value: Any,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stream: LogStreamCallbackHandler, stream: LogStreamCallbackHandler,
@ -651,7 +651,7 @@ async def _astream_log_implementation(
prev_final_output: Optional[Output] = None prev_final_output: Optional[Output] = None
final_output: Optional[Output] = None final_output: Optional[Output] = None
async for chunk in runnable.astream(input, config, **kwargs): async for chunk in runnable.astream(value, config, **kwargs):
prev_final_output = final_output prev_final_output = final_output
if final_output is None: if final_output is None:
final_output = chunk final_output = chunk

View File

@ -596,7 +596,7 @@ def convert_to_json_schema(
@beta() @beta()
def tool_example_to_messages( def tool_example_to_messages(
input: str, input: str, # noqa: A002
tool_calls: list[BaseModel], tool_calls: list[BaseModel],
tool_outputs: Optional[list[str]] = None, tool_outputs: Optional[list[str]] = None,
*, *,

View File

@ -363,7 +363,7 @@ class InMemoryVectorStore(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
) -> list[tuple[Document, float, list[float]]]: ) -> list[tuple[Document, float, list[float]]]:
# get all docs with fixed order in list # get all docs with fixed order in list
docs = list(self.store.values()) docs = list(self.store.values())
@ -402,7 +402,7 @@ class InMemoryVectorStore(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
**_kwargs: Any, **_kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Search for the most similar documents to the given embedding. """Search for the most similar documents to the given embedding.

View File

@ -100,7 +100,6 @@ ignore = [
"UP007", # Doesn't play well with Pydantic in Python 3.9 "UP007", # Doesn't play well with Pydantic in Python 3.9
# TODO rules # TODO rules
"A",
"ANN401", "ANN401",
"BLE", "BLE",
"ERA", "ERA",

View File

@ -869,9 +869,9 @@ def create_image_data() -> str:
return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501 return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501
def create_base64_image(format: str = "jpeg") -> str: def create_base64_image(image_format: str = "jpeg") -> str:
data = create_image_data() data = create_image_data()
return f"data:image/{format};base64,{data}" return f"data:image/{image_format};base64,{data}"
def test_convert_to_openai_messages_string() -> None: def test_convert_to_openai_messages_string() -> None:

View File

@ -639,7 +639,7 @@ def test_parse_with_different_pydantic_1_proper() -> None:
def test_max_tokens_error(caplog: Any) -> None: def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True) parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
input = AIMessage( message = AIMessage(
content="", content="",
tool_calls=[ tool_calls=[
{ {
@ -651,7 +651,7 @@ def test_max_tokens_error(caplog: Any) -> None:
response_metadata={"stop_reason": "max_tokens"}, response_metadata={"stop_reason": "max_tokens"},
) )
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
_ = parser.invoke(input) _ = parser.invoke(message)
assert any( assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR" "`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages) for record, msg in zip(caplog.records, caplog.messages)

View File

@ -15,11 +15,11 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
@contextmanager @contextmanager
def change_directory(dir: Path) -> Iterator: def change_directory(dir_path: Path) -> Iterator:
"""Change the working directory to the right folder.""" """Change the working directory to the right folder."""
origin = Path().absolute() origin = Path().absolute()
try: try:
os.chdir(dir) os.chdir(dir_path)
yield yield
finally: finally:
os.chdir(origin) os.chdir(origin)

View File

@ -220,8 +220,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
str_parser = StrOutputParser() str_parser = StrOutputParser()
xml_parser = XMLOutputParser() xml_parser = XMLOutputParser()
def conditional_str_parser(input: str) -> Runnable: def conditional_str_parser(value: str) -> Runnable:
if input == "a": if value == "a":
return str_parser return str_parser
return xml_parser return xml_parser

View File

@ -111,9 +111,9 @@ async def test_input_messages_async() -> None:
def test_input_dict() -> None: def test_input_dict() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda params: "you said: "
+ "\n".join( + "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage) str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
) )
) )
get_session_history = _get_get_session_history() get_session_history = _get_get_session_history()
@ -131,9 +131,9 @@ def test_input_dict() -> None:
async def test_input_dict_async() -> None: async def test_input_dict_async() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda params: "you said: "
+ "\n".join( + "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage) str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
) )
) )
get_session_history = _get_get_session_history() get_session_history = _get_get_session_history()
@ -153,10 +153,10 @@ async def test_input_dict_async() -> None:
def test_input_dict_with_history_key() -> None: def test_input_dict_with_history_key() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda params: "you said: "
+ "\n".join( + "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)] [str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
+ [input["input"]] + [params["input"]]
) )
) )
get_session_history = _get_get_session_history() get_session_history = _get_get_session_history()
@ -175,10 +175,10 @@ def test_input_dict_with_history_key() -> None:
async def test_input_dict_with_history_key_async() -> None: async def test_input_dict_with_history_key_async() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda params: "you said: "
+ "\n".join( + "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)] [str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
+ [input["input"]] + [params["input"]]
) )
) )
get_session_history = _get_get_session_history() get_session_history = _get_get_session_history()
@ -197,15 +197,15 @@ async def test_input_dict_with_history_key_async() -> None:
def test_output_message() -> None: def test_output_message() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: AIMessage( lambda params: AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
) )
@ -225,15 +225,15 @@ def test_output_message() -> None:
async def test_output_message_async() -> None: async def test_output_message_async() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: AIMessage( lambda params: AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
) )
@ -302,16 +302,16 @@ async def test_input_messages_output_message_async() -> None:
def test_output_messages() -> None: def test_output_messages() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: [ lambda params: [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -332,16 +332,16 @@ def test_output_messages() -> None:
async def test_output_messages_async() -> None: async def test_output_messages_async() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: [ lambda params: [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -362,17 +362,17 @@ async def test_output_messages_async() -> None:
def test_output_dict() -> None: def test_output_dict() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: { lambda params: {
"output": [ "output": [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -395,17 +395,17 @@ def test_output_dict() -> None:
async def test_output_dict_async() -> None: async def test_output_dict_async() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: { lambda params: {
"output": [ "output": [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -431,17 +431,17 @@ def test_get_input_schema_input_dict() -> None:
input: Union[str, BaseMessage, Sequence[BaseMessage]] input: Union[str, BaseMessage, Sequence[BaseMessage]]
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: { lambda params: {
"output": [ "output": [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -463,17 +463,17 @@ def test_get_input_schema_input_dict() -> None:
def test_get_output_schema() -> None: def test_get_output_schema() -> None:
"""Test get output schema.""" """Test get output schema."""
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: { lambda params: {
"output": [ "output": [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
[ [
str(m.content) str(m.content)
for m in input["history"] for m in params["history"]
if isinstance(m, HumanMessage) if isinstance(m, HumanMessage)
] ]
+ [input["input"]] + [params["input"]]
) )
) )
] ]
@ -531,8 +531,8 @@ def test_get_input_schema_input_messages() -> None:
def test_using_custom_config_specs() -> None: def test_using_custom_config_specs() -> None:
"""Test that we can configure which keys should be passed to the session factory.""" """Test that we can configure which keys should be passed to the session factory."""
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
messages = input["messages"] messages = params["messages"]
return [ return [
AIMessage( AIMessage(
content="you said: " content="you said: "
@ -644,8 +644,8 @@ def test_using_custom_config_specs() -> None:
async def test_using_custom_config_specs_async() -> None: async def test_using_custom_config_specs_async() -> None:
"""Test that we can configure which keys should be passed to the session factory.""" """Test that we can configure which keys should be passed to the session factory."""
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
messages = input["messages"] messages = params["messages"]
return [ return [
AIMessage( AIMessage(
content="you said: " content="you said: "
@ -757,12 +757,12 @@ async def test_using_custom_config_specs_async() -> None:
def test_ignore_session_id() -> None: def test_ignore_session_id() -> None:
"""Test without config.""" """Test without config."""
def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]: def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]:
return [ return [
AIMessage( AIMessage(
content="you said: " content="you said: "
+ "\n".join( + "\n".join(
str(m.content) for m in input if isinstance(m, HumanMessage) str(m.content) for m in messages if isinstance(m, HumanMessage)
) )
) )
] ]

View File

@ -564,8 +564,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["bye", "hello"], "required": ["bye", "hello"],
} }
def get_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return input["variable_name"] return value["variable_name"]
assert RunnableLambda(get_value).get_input_jsonschema() == { assert RunnableLambda(get_value).get_input_jsonschema() == {
"title": "get_value_input", "title": "get_value_input",
@ -574,8 +574,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["variable_name"], "required": ["variable_name"],
} }
async def aget_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return (input["variable_name"], input.get("another")) return (value["variable_name"], value.get("another"))
assert RunnableLambda(aget_value).get_input_jsonschema() == { assert RunnableLambda(aget_value).get_input_jsonschema() == {
"title": "aget_value_input", "title": "aget_value_input",
@ -587,11 +587,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["another", "variable_name"], "required": ["another", "variable_name"],
} }
async def aget_values(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202 async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return { return {
"hello": input["variable_name"], "hello": value["variable_name"],
"bye": input["variable_name"], "bye": value["variable_name"],
"byebye": input["yo"], "byebye": value["yo"],
} }
assert RunnableLambda(aget_values).get_input_jsonschema() == { assert RunnableLambda(aget_values).get_input_jsonschema() == {
@ -613,11 +613,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
bye: str bye: str
byebye: int byebye: int
async def aget_values_typed(input: InputType) -> OutputType: async def aget_values_typed(value: InputType) -> OutputType:
return { return {
"hello": input["variable_name"], "hello": value["variable_name"],
"bye": input["variable_name"], "bye": value["variable_name"],
"byebye": input["yo"], "byebye": value["yo"],
} }
assert _normalize_schema( assert _normalize_schema(
@ -2592,8 +2592,8 @@ async def test_prompt_with_llm_and_async_lambda(
) )
llm = FakeListLLM(responses=["foo", "bar"]) llm = FakeListLLM(responses=["foo", "bar"])
async def passthrough(input: Any) -> Any: async def passthrough(value: Any) -> Any:
return input return value
chain = prompt | llm | passthrough chain = prompt | llm | passthrough
@ -2946,12 +2946,12 @@ def test_higher_order_lambda_runnable(
input={"question": lambda x: x["question"]}, input={"question": lambda x: x["question"]},
) )
def router(input: dict[str, Any]) -> Runnable: def router(params: dict[str, Any]) -> Runnable:
if input["key"] == "math": if params["key"] == "math":
return itemgetter("input") | math_chain return itemgetter("input") | math_chain
if input["key"] == "english": if params["key"] == "english":
return itemgetter("input") | english_chain return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}" msg = f"Unknown key: {params['key']}"
raise ValueError(msg) raise ValueError(msg)
chain: Runnable = input_map | router chain: Runnable = input_map | router
@ -3002,12 +3002,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
input={"question": lambda x: x["question"]}, input={"question": lambda x: x["question"]},
) )
def router(input: dict[str, Any]) -> Runnable: def router(value: dict[str, Any]) -> Runnable:
if input["key"] == "math": if value["key"] == "math":
return itemgetter("input") | math_chain return itemgetter("input") | math_chain
if input["key"] == "english": if value["key"] == "english":
return itemgetter("input") | english_chain return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}" msg = f"Unknown key: {value['key']}"
raise ValueError(msg) raise ValueError(msg)
chain: Runnable = input_map | router chain: Runnable = input_map | router
@ -3024,12 +3024,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
assert result2 == ["4", "2"] assert result2 == ["4", "2"]
# Test ainvoke # Test ainvoke
async def arouter(input: dict[str, Any]) -> Runnable: async def arouter(params: dict[str, Any]) -> Runnable:
if input["key"] == "math": if params["key"] == "math":
return itemgetter("input") | math_chain return itemgetter("input") | math_chain
if input["key"] == "english": if params["key"] == "english":
return itemgetter("input") | english_chain return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}" msg = f"Unknown key: {params['key']}"
raise ValueError(msg) raise ValueError(msg)
achain: Runnable = input_map | arouter achain: Runnable = input_map | arouter
@ -4125,6 +4125,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None: def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with self.fail_starts_with = fail_starts_with
@override
def invoke( def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -4135,15 +4136,15 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
inputs: list[str], inputs: list[str],
) -> list: ) -> list:
outputs: list[Any] = [] outputs: list[Any] = []
for input in inputs: for value in inputs:
if input.startswith(self.fail_starts_with): if value.startswith(self.fail_starts_with):
outputs.append( outputs.append(
ValueError( ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}" f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
) )
) )
else: else:
outputs.append(input + "a") outputs.append(value + "a")
return outputs return outputs
def batch( def batch(
@ -4264,6 +4265,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None: def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with self.fail_starts_with = fail_starts_with
@override
def invoke( def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
@ -4274,15 +4276,15 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
inputs: list[str], inputs: list[str],
) -> list: ) -> list:
outputs: list[Any] = [] outputs: list[Any] = []
for input in inputs: for value in inputs:
if input.startswith(self.fail_starts_with): if value.startswith(self.fail_starts_with):
outputs.append( outputs.append(
ValueError( ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}" f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
) )
) )
else: else:
outputs.append(input + "a") outputs.append(value + "a")
return outputs return outputs
async def abatch( async def abatch(
@ -5006,10 +5008,10 @@ def test_runnable_iter_context_config() -> None:
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
def gen(input: str) -> Iterator[int]: def gen(value: str) -> Iterator[int]:
yield fake.invoke(input) yield fake.invoke(value)
yield fake.invoke(input * 2) yield fake.invoke(value * 2)
yield fake.invoke(input * 3) yield fake.invoke(value * 3)
assert gen.get_input_jsonschema() == { assert gen.get_input_jsonschema() == {
"title": "gen_input", "title": "gen_input",
@ -5064,10 +5066,10 @@ async def test_runnable_iter_context_config_async() -> None:
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
async def agen(input: str) -> AsyncIterator[int]: async def agen(value: str) -> AsyncIterator[int]:
yield await fake.ainvoke(input) yield await fake.ainvoke(value)
yield await fake.ainvoke(input * 2) yield await fake.ainvoke(value * 2)
yield await fake.ainvoke(input * 3) yield await fake.ainvoke(value * 3)
assert agen.get_input_jsonschema() == { assert agen.get_input_jsonschema() == {
"title": "agen_input", "title": "agen_input",
@ -5130,10 +5132,10 @@ def test_runnable_lambda_context_config() -> None:
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
def fun(input: str) -> int: def fun(value: str) -> int:
output = fake.invoke(input) output = fake.invoke(value)
output += fake.invoke(input * 2) output += fake.invoke(value * 2)
output += fake.invoke(input * 3) output += fake.invoke(value * 3)
return output return output
assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"} assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
@ -5186,10 +5188,10 @@ async def test_runnable_lambda_context_config_async() -> None:
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
async def afun(input: str) -> int: async def afun(value: str) -> int:
output = await fake.ainvoke(input) output = await fake.ainvoke(value)
output += await fake.ainvoke(input * 2) output += await fake.ainvoke(value * 2)
output += await fake.ainvoke(input * 3) output += await fake.ainvoke(value * 3)
return output return output
assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"} assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
@ -5242,12 +5244,12 @@ async def test_runnable_gen_transform() -> None:
for i in range(length): for i in range(length):
yield i yield i
def plus_one(input: Iterator[int]) -> Iterator[int]: def plus_one(ints: Iterator[int]) -> Iterator[int]:
for i in input: for i in ints:
yield i + 1 yield i + 1
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in input: async for i in ints:
yield i + 1 yield i + 1
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one

View File

@ -543,10 +543,10 @@ async def test_astream_events_from_model() -> None:
) )
@RunnableLambda @RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any: def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
return model.invoke(input) return model.invoke(value)
return model.invoke(input, config) return model.invoke(value, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1")) events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata( _assert_events_equal_allow_superset_metadata(
@ -667,10 +667,10 @@ async def test_astream_events_from_model() -> None:
) )
@RunnableLambda @RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
return await model.ainvoke(input) return await model.ainvoke(value)
return await model.ainvoke(input, config) return await model.ainvoke(value, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1")) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata( _assert_events_equal_allow_superset_metadata(

View File

@ -613,10 +613,10 @@ async def test_astream_with_model_in_chain() -> None:
) )
@RunnableLambda @RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any: def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
return model.invoke(input) return model.invoke(value)
return model.invoke(input, config) return model.invoke(value, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2")) events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata( _assert_events_equal_allow_superset_metadata(
@ -721,10 +721,10 @@ async def test_astream_with_model_in_chain() -> None:
) )
@RunnableLambda @RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
return await model.ainvoke(input) return await model.ainvoke(value)
return await model.ainvoke(input, config) return await model.ainvoke(value, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2")) events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata( _assert_events_equal_allow_superset_metadata(
@ -2079,6 +2079,7 @@ class StreamingRunnable(Runnable[Input, Output]):
msg = "Server side error" msg = "Server side error"
raise ValueError(msg) raise ValueError(msg)
@override
def stream( def stream(
self, self,
input: Input, input: Input,
@ -2413,14 +2414,14 @@ async def test_break_astream_events() -> None:
def __init__(self) -> None: def __init__(self) -> None:
self.reset() self.reset()
async def __call__(self, input: Any) -> Any: async def __call__(self, value: Any) -> Any:
self.started = True self.started = True
try: try:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
except asyncio.CancelledError: except asyncio.CancelledError:
self.cancelled = True self.cancelled = True
raise raise
return input return value
def reset(self) -> None: def reset(self) -> None:
self.started = False self.started = False
@ -2433,11 +2434,11 @@ async def test_break_astream_events() -> None:
outer_cancelled = False outer_cancelled = False
@chain @chain
async def sequence(input: Any) -> Any: async def sequence(value: Any) -> Any:
try: try:
yield await alittlewhile(input) yield await alittlewhile(value)
yield await awhile(input) yield await awhile(value)
yield await anotherwhile(input) yield await anotherwhile(value)
except asyncio.CancelledError: except asyncio.CancelledError:
nonlocal outer_cancelled nonlocal outer_cancelled
outer_cancelled = True outer_cancelled = True
@ -2478,14 +2479,14 @@ async def test_cancel_astream_events() -> None:
def __init__(self) -> None: def __init__(self) -> None:
self.reset() self.reset()
async def __call__(self, input: Any) -> Any: async def __call__(self, value: Any) -> Any:
self.started = True self.started = True
try: try:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
except asyncio.CancelledError: except asyncio.CancelledError:
self.cancelled = True self.cancelled = True
raise raise
return input return value
def reset(self) -> None: def reset(self) -> None:
self.started = False self.started = False
@ -2498,11 +2499,11 @@ async def test_cancel_astream_events() -> None:
outer_cancelled = False outer_cancelled = False
@chain @chain
async def sequence(input: Any) -> Any: async def sequence(value: Any) -> Any:
try: try:
yield await alittlewhile(input) yield await alittlewhile(value)
yield await awhile(input) yield await awhile(value)
yield await anotherwhile(input) yield await anotherwhile(value)
except asyncio.CancelledError: except asyncio.CancelledError:
nonlocal outer_cancelled nonlocal outer_cancelled
outer_cancelled = True outer_cancelled = True

View File

@ -47,23 +47,23 @@ global_agent = RunnableLambda(lambda x: x * 3)
def test_nonlocals() -> None: def test_nonlocals() -> None:
agent = RunnableLambda(lambda x: x * 2) agent = RunnableLambda(lambda x: x * 2)
def my_func(input: str, agent: dict[str, str]) -> str: def my_func(value: str, agent: dict[str, str]) -> str:
return agent.get("agent_name", input) return agent.get("agent_name", value)
def my_func2(input: str) -> str: def my_func2(value: str) -> str:
return agent.get("agent_name", input) # type: ignore[attr-defined] return agent.get("agent_name", value) # type: ignore[attr-defined]
def my_func3(input: str) -> str: def my_func3(value: str) -> str:
return agent.invoke(input) return agent.invoke(value)
def my_func4(input: str) -> str: def my_func4(value: str) -> str:
return global_agent.invoke(input) return global_agent.invoke(value)
def my_func5() -> tuple[Callable[[str], str], RunnableLambda]: def my_func5() -> tuple[Callable[[str], str], RunnableLambda]:
global_agent = RunnableLambda(lambda x: x * 3) global_agent = RunnableLambda(lambda x: x * 3)
def my_func6(input: str) -> str: def my_func6(value: str) -> str:
return global_agent.invoke(input) return global_agent.invoke(value)
return my_func6, global_agent return my_func6, global_agent

View File

@ -2,17 +2,17 @@ from langchain_core.globals import get_debug, set_debug
def test_debug_is_settable_via_setter() -> None: def test_debug_is_settable_via_setter() -> None:
from langchain_core import globals from langchain_core import globals as globals_
from langchain_core.callbacks.manager import _get_debug from langchain_core.callbacks.manager import _get_debug
previous_value = globals._debug previous_value = globals_._debug
previous_fn_reading = _get_debug() previous_fn_reading = _get_debug()
assert previous_value == previous_fn_reading assert previous_value == previous_fn_reading
# Flip the value of the flag. # Flip the value of the flag.
set_debug(not previous_value) set_debug(not previous_value)
new_value = globals._debug new_value = globals_._debug
new_fn_reading = _get_debug() new_fn_reading = _get_debug()
try: try:

View File

@ -50,7 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
return ids_ return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store] return [self.store[id_] for id_ in ids if id_ in self.store]
@classmethod @classmethod
@override @override
@ -96,7 +96,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
return ids_ return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store] return [self.store[id_] for id_ in ids if id_ in self.store]
@classmethod @classmethod
@override @override

View File

@ -1,4 +1,5 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.9, <4.0" requires-python = ">=3.9, <4.0"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",