This commit is contained in:
Christophe Bornet 2025-04-26 12:28:41 +02:00 committed by GitHub
commit 132e6ebab3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 442 additions and 381 deletions

View File

@ -50,7 +50,7 @@ class LangSmithLoader(BaseLoader):
offset: int = 0,
limit: Optional[int] = None,
metadata: Optional[dict] = None,
filter: Optional[str] = None,
filter: Optional[str] = None, # noqa: A002
content_key: str = "",
format_content: Optional[Callable[..., str]] = None,
client: Optional[LangSmithClient] = None,

View File

@ -341,15 +341,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Get the output type for this runnable."""
return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
if isinstance(input, str):
return StringPromptValue(text=input)
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
if isinstance(model_input, PromptValue):
return model_input
if isinstance(model_input, str):
return StringPromptValue(text=model_input)
if isinstance(model_input, Sequence):
return ChatPromptValue(messages=convert_to_messages(model_input))
msg = (
f"Invalid input type {type(input)}. "
f"Invalid input type {type(model_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004

View File

@ -325,15 +325,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"""Get the input type for this runnable."""
return str
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
if isinstance(input, str):
return StringPromptValue(text=input)
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
if isinstance(model_input, PromptValue):
return model_input
if isinstance(model_input, str):
return StringPromptValue(text=model_input)
if isinstance(model_input, Sequence):
return ChatPromptValue(messages=convert_to_messages(model_input))
msg = (
f"Invalid input type {type(input)}. "
f"Invalid input type {type(model_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
@ -438,7 +438,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if max_concurrency is None:
try:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
[self._convert_input(input_) for input_ in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
@ -484,7 +484,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if max_concurrency is None:
try:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
[self._convert_input(input_) for input_ in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],

View File

@ -417,10 +417,10 @@ def add_ai_message_chunks(
else:
usage_metadata = None
id = None
chunk_id = None
for id_ in [left.id] + [o.id for o in others]:
if id_:
id = id_
chunk_id = id_
break
return left.__class__(
example=left.example,
@ -429,7 +429,7 @@ def add_ai_message_chunks(
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=id,
id=chunk_id,
)

View File

@ -11,7 +11,11 @@ class RemoveMessage(BaseMessage):
type: Literal["remove"] = "remove"
"""The type of the message (used for serialization). Defaults to "remove"."""
def __init__(self, id: str, **kwargs: Any) -> None:
def __init__(
self,
id: str, # noqa: A002
**kwargs: Any,
) -> None:
"""Create a RemoveMessage.
Args:

View File

@ -208,7 +208,12 @@ class ToolCall(TypedDict):
type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
def tool_call(
*,
name: str,
args: dict[str, Any],
id: Optional[str], # noqa: A002
) -> ToolCall:
"""Create a tool call.
Args:
@ -254,7 +259,7 @@ def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
index: Optional[int] = None,
) -> ToolCallChunk:
"""Create a tool call chunk.
@ -292,7 +297,7 @@ def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
error: Optional[str] = None,
) -> InvalidToolCall:
"""Create an invalid tool call.

View File

@ -212,7 +212,7 @@ def _create_message_from_message_type(
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.

View File

@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import-untyped]
import pydantic
from pydantic import SkipValidation
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@ -47,6 +48,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""The Pydantic object to use for validation.
If None, no validation is performed."""
@override
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

View File

@ -20,7 +20,10 @@ if TYPE_CHECKING:
T = TypeVar("T")
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
def droplastn(
iter: Iterator[T], # noqa: A002
n: int,
) -> Iterator[T]:
"""Drop the last n elements of an iterator.
Args:
@ -66,6 +69,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
"""
raise NotImplementedError
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[list[str]]:
@ -99,6 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
for part in self.parse(buffer):
yield [part]
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[list[str]]:

View File

@ -72,6 +72,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _type(self) -> str:
return "json_functions"
@override
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

View File

@ -30,7 +30,10 @@ if TYPE_CHECKING:
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
def _transform(
self,
input: Iterator[Union[str, BaseMessage]], # noqa: A002
) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
@ -38,7 +41,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self,
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
@ -102,7 +106,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
def _diff(
self,
prev: Optional[T],
next: T, # noqa: A002
) -> T:
"""Convert parsed outputs into a diff format.
The semantics of this are up to the output parser.
@ -116,6 +124,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""
raise NotImplementedError
@override
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
@ -140,6 +149,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
yield parsed
prev_parsed = parsed
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:

View File

@ -8,6 +8,8 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -234,6 +236,7 @@ class XMLOutputParser(BaseTransformOutputParser):
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
@ -242,6 +245,7 @@ class XMLOutputParser(BaseTransformOutputParser):
yield from streaming_parser.parse(chunk)
streaming_parser.close()
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:

View File

@ -472,16 +472,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
input_variables = []
if isinstance(img_template, str):
vars = get_template_variables(img_template, template_format)
if vars:
if len(vars) > 1:
variables = get_template_variables(
img_template, template_format
)
if variables:
if len(variables) > 1:
msg = (
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f" template.\nGot: {variables}"
f"\nFrom: {tmpl}"
)
raise ValueError(msg)
input_variables = [vars[0]]
input_variables = [variables[0]]
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables,

View File

@ -1,3 +1,4 @@
# noqa:A005
"""BasePrompt schema definition."""
from __future__ import annotations
@ -125,20 +126,20 @@ def mustache_template_vars(
Returns:
The variables from the template.
"""
vars: set[str] = set()
variables: set[str] = set()
section_depth = 0
for type, key in mustache.tokenize(template):
if type == "end":
for type_, key in mustache.tokenize(template):
if type_ == "end":
section_depth -= 1
elif (
type in ("variable", "section", "inverted section", "no escape")
type_ in ("variable", "section", "inverted section", "no escape")
and key != "."
and section_depth == 0
):
vars.add(key.split(".")[0])
if type in ("section", "inverted section"):
variables.add(key.split(".")[0])
if type_ in ("section", "inverted section"):
section_depth += 1
return vars
return variables
Defs = dict[str, "Defs"]
@ -158,17 +159,17 @@ def mustache_schema(
fields = {}
prefix: tuple[str, ...] = ()
section_stack: list[tuple[str, ...]] = []
for type, key in mustache.tokenize(template):
for type_, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
if type_ == "end":
if section_stack:
prefix = section_stack.pop()
elif type in ("section", "inverted section"):
elif type_ in ("section", "inverted section"):
section_stack.append(prefix)
prefix = prefix + tuple(key.split("."))
fields[prefix] = False
elif type in ("variable", "no escape"):
elif type_ in ("variable", "no escape"):
fields[prefix + tuple(key.split("."))] = True
defs: Defs = {} # None means leaf node
while fields:

View File

@ -209,6 +209,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
@override
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Document]:
@ -269,6 +270,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
)
return result
@override
async def ainvoke(
self,
input: str,

View File

@ -724,7 +724,10 @@ class Runnable(Generic[Input, Output], ABC):
@abstractmethod
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
self,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output:
"""Transform a single input into an output.
@ -741,7 +744,10 @@ class Runnable(Generic[Input, Output], ABC):
"""
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
self,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Output:
"""Default implementation of ainvoke, calls invoke from a thread.
@ -772,14 +778,14 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]:
if return_exceptions:
try:
return self.invoke(input, config, **kwargs)
return self.invoke(input_, config, **kwargs)
except Exception as e:
return e
else:
return self.invoke(input, config, **kwargs)
return self.invoke(input_, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
@ -826,15 +832,17 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs))
def invoke(
i: int, input: Input, config: RunnableConfig
i: int, input_: Input, config: RunnableConfig
) -> tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
out: Union[Output, Exception] = self.invoke(
input_, config, **kwargs
)
except Exception as e:
out = e
else:
out = self.invoke(input, config, **kwargs)
out = self.invoke(input_, config, **kwargs)
return (i, out)
@ -844,8 +852,8 @@ class Runnable(Generic[Input, Output], ABC):
with get_executor_for_config(configs[0]) as executor:
futures = {
executor.submit(invoke, i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
executor.submit(invoke, i, input_, config)
for i, (input_, config) in enumerate(zip(inputs, configs))
}
try:
@ -892,15 +900,15 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(inputs))
async def ainvoke(
input: Input, config: RunnableConfig
value: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await self.ainvoke(input, config, **kwargs)
return await self.ainvoke(value, config, **kwargs)
except Exception as e:
return e
else:
return await self.ainvoke(input, config, **kwargs)
return await self.ainvoke(value, config, **kwargs)
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@ -960,24 +968,24 @@ class Runnable(Generic[Input, Output], ABC):
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def ainvoke_task(
i: int, input: Input, config: RunnableConfig
i: int, input_: Input, config: RunnableConfig
) -> tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = await self.ainvoke(
input, config, **kwargs
input_, config, **kwargs
)
except Exception as e:
out = e
else:
out = await self.ainvoke(input, config, **kwargs)
out = await self.ainvoke(input_, config, **kwargs)
return (i, out)
coros = [
gated_coro(semaphore, ainvoke_task(i, input, config))
gated_coro(semaphore, ainvoke_task(i, input_, config))
if semaphore
else ainvoke_task(i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
else ainvoke_task(i, input_, config)
for i, (input_, config) in enumerate(zip(inputs, configs))
]
for coro in asyncio.as_completed(coros):
@ -985,7 +993,7 @@ class Runnable(Generic[Input, Output], ABC):
def stream(
self,
input: Input,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
@ -1005,7 +1013,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream(
self,
input: Input,
input: Input, # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
@ -1059,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream_log(
self,
input: Any,
input: Any, # noqa: A002
config: Optional[RunnableConfig] = None,
*,
diff: bool = True,
@ -1130,7 +1138,7 @@ class Runnable(Generic[Input, Output], ABC):
async def astream_events(
self,
input: Any,
input: Any, # noqa: A002
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"] = "v2",
@ -1396,7 +1404,7 @@ class Runnable(Generic[Input, Output], ABC):
def transform(
self,
input: Iterator[Input],
input: Iterator[Input], # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
@ -1438,7 +1446,7 @@ class Runnable(Generic[Input, Output], ABC):
async def atransform(
self,
input: AsyncIterator[Input],
input: AsyncIterator[Input], # noqa: A002
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
@ -1903,7 +1911,7 @@ class Runnable(Generic[Input, Output], ABC):
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
input_: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
serialized: Optional[dict[str, Any]] = None,
@ -1917,7 +1925,7 @@ class Runnable(Generic[Input, Output], ABC):
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
serialized,
input,
input_,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
@ -1930,7 +1938,7 @@ class Runnable(Generic[Input, Output], ABC):
context.run(
call_func_with_variable_args, # type: ignore[arg-type]
func,
input,
input_,
config,
run_manager,
**kwargs,
@ -1953,7 +1961,7 @@ class Runnable(Generic[Input, Output], ABC):
Awaitable[Output],
],
],
input: Input,
input_: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
serialized: Optional[dict[str, Any]] = None,
@ -1967,7 +1975,7 @@ class Runnable(Generic[Input, Output], ABC):
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
serialized,
input,
input_,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
@ -1976,7 +1984,7 @@ class Runnable(Generic[Input, Output], ABC):
child_config = patch_config(config, callbacks=run_manager.get_child())
with set_config_context(child_config) as context:
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
func, input_, config, run_manager, **kwargs
)
output: Output = await coro_with_context(coro, context)
except BaseException as e:
@ -1999,7 +2007,7 @@ class Runnable(Generic[Input, Output], ABC):
list[Union[Exception, Output]],
],
],
input: list[Input],
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
@ -2011,21 +2019,21 @@ class Runnable(Generic[Input, Output], ABC):
Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.
"""
if not input:
if not inputs:
return []
configs = get_config_list(config, len(input))
configs = get_config_list(config, len(inputs))
callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [
callback_manager.on_chain_start(
None,
input,
input_,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
for callback_manager, input_, config in zip(
callback_managers, inputs, configs
)
]
try:
@ -2036,12 +2044,12 @@ class Runnable(Generic[Input, Output], ABC):
]
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = func(input, **kwargs) # type: ignore[call-arg]
output = func(inputs, **kwargs) # type: ignore[call-arg]
except BaseException as e:
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
return cast("list[Output]", [e for _ in input])
return cast("list[Output]", [e for _ in inputs])
raise
else:
first_exception: Optional[Exception] = None
@ -2072,7 +2080,7 @@ class Runnable(Generic[Input, Output], ABC):
Awaitable[list[Union[Exception, Output]]],
],
],
input: list[Input],
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
@ -2085,22 +2093,22 @@ class Runnable(Generic[Input, Output], ABC):
with callbacks.
Use this method to implement invoke() in subclasses.
"""
if not input:
if not inputs:
return []
configs = get_config_list(config, len(input))
configs = get_config_list(config, len(inputs))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
callback_manager.on_chain_start(
None,
input,
input_,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
for callback_manager, input_, config in zip(
callback_managers, inputs, configs
)
)
)
@ -2112,13 +2120,13 @@ class Runnable(Generic[Input, Output], ABC):
]
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = await func(input, **kwargs) # type: ignore[call-arg]
output = await func(inputs, **kwargs) # type: ignore[call-arg]
except BaseException as e:
await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
if return_exceptions:
return cast("list[Output]", [e for _ in input])
return cast("list[Output]", [e for _ in inputs])
raise
else:
first_exception: Optional[Exception] = None
@ -2136,7 +2144,7 @@ class Runnable(Generic[Input, Output], ABC):
def _transform_stream_with_config(
self,
input: Iterator[Input],
inputs: Iterator[Input],
transformer: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
@ -2163,7 +2171,7 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
input_for_tracing, input_for_transform = tee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one
final_input: Optional[Input] = next(input_for_tracing, None)
final_input_supported = True
@ -2237,7 +2245,7 @@ class Runnable(Generic[Input, Output], ABC):
async def _atransform_stream_with_config(
self,
input: AsyncIterator[Input],
inputs: AsyncIterator[Input],
transformer: Union[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Callable[
@ -2267,7 +2275,7 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
input_for_tracing, input_for_transform = atee(inputs, 2)
# Start the input iterator to ensure the input Runnable starts before this one
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
final_input_supported = True
@ -3019,6 +3027,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
input_ = input
# invoke all steps in sequence
try:
@ -3029,16 +3038,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
with set_config_context(config) as context:
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
input_ = context.run(step.invoke, input_, config, **kwargs)
else:
input = context.run(step.invoke, input, config)
input_ = context.run(step.invoke, input_, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(input)
return cast("Output", input)
run_manager.on_chain_end(input_)
return cast("Output", input_)
@override
async def ainvoke(
@ -3059,6 +3068,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
input_ = input
# invoke all steps in sequence
try:
@ -3069,17 +3079,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
with set_config_context(config) as context:
if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs)
part = functools.partial(step.ainvoke, input_, config, **kwargs)
else:
part = functools.partial(step.ainvoke, input, config)
input = await coro_with_context(part(), context, create_task=True)
part = functools.partial(step.ainvoke, input_, config)
input_ = await coro_with_context(part(), context, create_task=True)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(input)
return cast("Output", input)
await run_manager.on_chain_end(input_)
return cast("Output", input_)
@override
def batch(
@ -3117,11 +3127,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_managers = [
cm.on_chain_start(
None,
input,
input_,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
for cm, input_, config in zip(callback_managers, inputs, configs)
]
# invoke
@ -3248,11 +3258,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
*(
cm.on_chain_start(
None,
input,
input_,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
for cm, input_, config in zip(callback_managers, inputs, configs)
)
)
@ -3346,7 +3356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def _transform(
self,
input: Iterator[Input],
inputs: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -3359,7 +3369,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast("Iterator[Output]", input)
final_pipeline = cast("Iterator[Output]", inputs)
for idx, step in enumerate(steps):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}")
@ -3373,7 +3383,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
async def _atransform(
self,
input: AsyncIterator[Input],
inputs: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -3387,7 +3397,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast("AsyncIterator[Output]", input)
final_pipeline = cast("AsyncIterator[Output]", inputs)
for idx, step in enumerate(steps):
config = patch_config(
config,
@ -3733,7 +3743,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
)
def _invoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
@ -3743,7 +3753,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
with set_config_context(child_config) as context:
return context.run(
step.invoke,
input,
input_,
child_config,
)
@ -3785,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
)
async def _ainvoke_step(
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
) -> Any:
child_config = patch_config(
config,
@ -3793,7 +3803,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
)
with set_config_context(child_config) as context:
return await coro_with_context(
step.ainvoke(input, child_config), context, create_task=True
step.ainvoke(input_, child_config), context, create_task=True
)
# gather results from all steps
@ -3823,7 +3833,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
def _transform(
self,
input: Iterator[Input],
inputs: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[AddableDict]:
@ -3831,7 +3841,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
steps = dict(self.steps__)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
input_copies = list(safetee(inputs, len(steps), lock=threading.Lock()))
with get_executor_for_config(config) as executor:
# Create the transform() generator for each step
named_generators = [
@ -3890,7 +3900,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
async def _atransform(
self,
input: AsyncIterator[Input],
inputs: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[AddableDict]:
@ -3898,7 +3908,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
steps = dict(self.steps__)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
input_copies = list(atee(inputs, len(steps), lock=asyncio.Lock()))
# Create the transform() generator for each step
named_generators = [
(
@ -4590,7 +4600,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _invoke(
self,
input: Input,
input_: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -4599,7 +4609,7 @@ class RunnableLambda(Runnable[Input, Output]):
output: Optional[Output] = None
for chunk in call_func_with_variable_args(
cast("Callable[[Input], Iterator[Output]]", self.func),
input,
input_,
config,
run_manager,
**kwargs,
@ -4613,18 +4623,18 @@ class RunnableLambda(Runnable[Input, Output]):
output = chunk
else:
output = call_func_with_variable_args(
self.func, input, config, run_manager, **kwargs
self.func, input_, config, run_manager, **kwargs
)
# If the output is a Runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
msg = (
f"Recursion limit reached when invoking {self} with input {input}."
f"Recursion limit reached when invoking {self} with input {input_}."
)
raise RecursionError(msg)
output = output.invoke(
input,
input_,
patch_config(
config,
callbacks=run_manager.get_child(),
@ -4635,7 +4645,7 @@ class RunnableLambda(Runnable[Input, Output]):
async def _ainvoke(
self,
input: Input,
value: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -4646,7 +4656,7 @@ class RunnableLambda(Runnable[Input, Output]):
if inspect.isgeneratorfunction(self.func):
def func(
input: Input,
value: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -4654,7 +4664,7 @@ class RunnableLambda(Runnable[Input, Output]):
output: Optional[Output] = None
for chunk in call_func_with_variable_args(
cast("Callable[[Input], Iterator[Output]]", self.func),
input,
value,
config,
run_manager.get_sync(),
**kwargs,
@ -4671,13 +4681,13 @@ class RunnableLambda(Runnable[Input, Output]):
else:
def func(
input: Input,
value: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
self.func, value, config, run_manager.get_sync(), **kwargs
)
@wraps(func)
@ -4693,7 +4703,7 @@ class RunnableLambda(Runnable[Input, Output]):
"AsyncGenerator[Any, Any]",
acall_func_with_variable_args(
cast("Callable", afunc),
input,
value,
config,
run_manager,
**kwargs,
@ -4713,18 +4723,18 @@ class RunnableLambda(Runnable[Input, Output]):
output = chunk
else:
output = await acall_func_with_variable_args(
cast("Callable", afunc), input, config, run_manager, **kwargs
cast("Callable", afunc), value, config, run_manager, **kwargs
)
# If the output is a Runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
msg = (
f"Recursion limit reached when invoking {self} with input {input}."
f"Recursion limit reached when invoking {self} with input {value}."
)
raise RecursionError(msg)
output = await output.ainvoke(
input,
value,
patch_config(
config,
callbacks=run_manager.get_child(),
@ -4789,14 +4799,14 @@ class RunnableLambda(Runnable[Input, Output]):
def _transform(
self,
input: Iterator[Input],
chunks: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
final: Input
got_first_val = False
for ichunk in input:
for ichunk in chunks:
# By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk.
@ -4881,14 +4891,14 @@ class RunnableLambda(Runnable[Input, Output]):
async def _atransform(
self,
input: AsyncIterator[Input],
chunks: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]:
final: Input
got_first_val = False
async for ichunk in input:
async for ichunk in chunks:
# By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk.
@ -4913,13 +4923,13 @@ class RunnableLambda(Runnable[Input, Output]):
raise TypeError(msg)
def func(
input: Input,
input_: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
self.func, input_, config, run_manager.get_sync(), **kwargs
)
@wraps(func)

View File

@ -196,6 +196,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
raise ValueError(msg)
return specs
@override
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@ -254,6 +255,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
run_manager.on_chain_end(output)
return output
@override
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
@ -302,6 +304,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
await run_manager.on_chain_end(output)
return output
@override
def stream(
self,
input: Input,
@ -388,6 +391,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
raise
run_manager.on_chain_end(final_output)
@override
async def astream(
self,
input: Input,

View File

@ -401,7 +401,7 @@ def call_func_with_variable_args(
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
input: Input, # noqa: A002
config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
@ -438,7 +438,7 @@ def acall_func_with_variable_args(
Awaitable[Output],
],
],
input: Input,
input: Input, # noqa: A002
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,

View File

@ -178,16 +178,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def invoke(
prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input,
input_: Input,
) -> Union[Output, Exception]:
bound, config = prepared
if return_exceptions:
try:
return bound.invoke(input, config, **kwargs)
return bound.invoke(input_, config, **kwargs)
except Exception as e:
return e
else:
return bound.invoke(input, config, **kwargs)
return bound.invoke(input_, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
@ -221,16 +221,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
async def ainvoke(
prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input,
input_: Input,
) -> Union[Output, Exception]:
bound, config = prepared
if return_exceptions:
try:
return await bound.ainvoke(input, config, **kwargs)
return await bound.ainvoke(input_, config, **kwargs)
except Exception as e:
return e
else:
return await bound.ainvoke(input, config, **kwargs)
return await bound.ainvoke(input_, config, **kwargs)
coros = map(ainvoke, prepared, inputs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)

View File

@ -269,7 +269,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
isinstance(input_, dict) for input_ in inputs
):
msg = (
"If 'exception_key' is specified then inputs must be dictionaries."
@ -298,11 +298,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_managers = [
cm.on_chain_start(
None,
input if isinstance(input, dict) else {"input": input},
input_ if isinstance(input_, dict) else {"input": input_},
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
for cm, input_, config in zip(callback_managers, inputs, configs)
]
to_return: dict[int, Any] = {}
@ -311,7 +311,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
first_to_raise = None
for runnable in self.runnables:
outputs = runnable.batch(
[input for _, input in sorted(run_again.items())],
[input_ for _, input_ in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
@ -320,7 +320,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return_exceptions=True,
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
@ -331,7 +331,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore[index]
input_[self.exception_key] = output # type: ignore[index]
handled_exceptions[i] = output
else:
run_managers[i].on_chain_end(output)
@ -363,7 +363,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
isinstance(input_, dict) for input_ in inputs
):
msg = (
"If 'exception_key' is specified then inputs must be dictionaries."
@ -393,11 +393,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
*(
cm.on_chain_start(
None,
input,
input_,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
for cm, input_, config in zip(callback_managers, inputs, configs)
)
)
@ -407,7 +407,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
first_to_raise = None
for runnable in self.runnables:
outputs = await runnable.abatch(
[input for _, input in sorted(run_again.items())],
[input_ for _, input_ in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
@ -417,7 +417,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
@ -428,7 +428,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore[index]
input_[self.exception_key] = output # type: ignore[index]
handled_exceptions[i] = output
else:
to_return[i] = output

View File

@ -111,7 +111,12 @@ class Node(NamedTuple):
data: Union[type[BaseModel], RunnableType, None]
metadata: Optional[dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
def copy(
self,
*,
id: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
) -> Node:
"""Return a copy of the node with optional new id and name.
Args:
@ -181,7 +186,10 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
def node_data_str(
id: str, # noqa: A002
data: Union[type[BaseModel], RunnableType, None],
) -> str:
"""Convert the data of a node to a string.
Args:
@ -320,7 +328,7 @@ class Graph:
def add_node(
self,
data: Union[type[BaseModel], RunnableType, None],
id: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
*,
metadata: Optional[dict[str, Any]] = None,
) -> Node:
@ -340,8 +348,8 @@ class Graph:
if id is not None and id in self.nodes:
msg = f"Node with id {id} already exists"
raise ValueError(msg)
id = id or self.next_id()
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
id_ = id or self.next_id()
node = Node(id=id_, data=data, metadata=metadata, name=node_data_str(id_, data))
self.nodes[node.id] = node
return node
@ -406,8 +414,8 @@ class Graph:
if all(is_uuid(node.id) for node in graph.nodes.values()):
prefix = ""
def prefixed(id: str) -> str:
return f"{prefix}:{id}" if prefix else id
def prefixed(id_: str) -> str:
return f"{prefix}:{id_}" if prefix else id_
# prefix each node
self.nodes.update(
@ -450,8 +458,8 @@ class Graph:
return Graph(
nodes={
_get_node_id(id): node.copy(id=_get_node_id(id))
for id, node in self.nodes.items()
_get_node_id(id_): node.copy(id=_get_node_id(id_))
for id_, node in self.nodes.items()
},
edges=[
edge.copy(

View File

@ -187,7 +187,7 @@ def _build_sugiyama_layout(
# Y
#
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
vertices_ = {id_: Vertex(f" {data} ") for id_, data in vertices.items()}
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
vertices_list = vertices_.values()
graph = Graph(vertices_list, edges_)

View File

@ -509,20 +509,20 @@ class RunnableWithMessageHistory(RunnableBindingBase):
)
raise ValueError(msg) # noqa: TRY004
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
def _enter_history(self, value: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
if not self.history_messages_key:
# return all messages
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
value if not self.input_messages_key else value[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages
async def _aenter_history(
self, input: dict[str, Any], config: RunnableConfig
self, value: dict[str, Any], config: RunnableConfig
) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = (await hist.aget_messages()).copy()
@ -530,7 +530,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if not self.history_messages_key:
# return all messages
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
value if not self.input_messages_key else value[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages

View File

@ -483,19 +483,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _invoke(
self,
input: dict[str, Any],
value: dict[str, Any],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> dict[str, Any]:
if not isinstance(input, dict):
if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004
return {
**input,
**value,
**self.mapper.invoke(
input,
value,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
@ -512,19 +512,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _ainvoke(
self,
input: dict[str, Any],
value: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> dict[str, Any]:
if not isinstance(input, dict):
if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004
return {
**input,
**value,
**await self.mapper.ainvoke(
input,
value,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
@ -541,7 +541,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _transform(
self,
input: Iterator[dict[str, Any]],
values: Iterator[dict[str, Any]],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -549,7 +549,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
for_passthrough, for_map = safetee(values, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(
@ -598,7 +598,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _atransform(
self,
input: AsyncIterator[dict[str, Any]],
values: AsyncIterator[dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
@ -606,7 +606,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
for_passthrough, for_map = atee(values, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(
for_map,
@ -731,23 +731,23 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
)
return super().get_name(suffix, name=name)
def _pick(self, input: dict[str, Any]) -> Any:
if not isinstance(input, dict):
def _pick(self, value: dict[str, Any]) -> Any:
if not isinstance(value, dict):
msg = "The input to RunnablePassthrough.assign() must be a dict."
raise ValueError(msg) # noqa: TRY004
if isinstance(self.keys, str):
return input.get(self.keys)
picked = {k: input.get(k) for k in self.keys if k in input}
return value.get(self.keys)
picked = {k: value.get(k) for k in self.keys if k in value}
if picked:
return AddableDict(picked)
return None
def _invoke(
self,
input: dict[str, Any],
value: dict[str, Any],
) -> dict[str, Any]:
return self._pick(input)
return self._pick(value)
@override
def invoke(
@ -760,9 +760,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _ainvoke(
self,
input: dict[str, Any],
value: dict[str, Any],
) -> dict[str, Any]:
return self._pick(input)
return self._pick(value)
@override
async def ainvoke(
@ -775,9 +775,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
def _transform(
self,
input: Iterator[dict[str, Any]],
chunks: Iterator[dict[str, Any]],
) -> Iterator[dict[str, Any]]:
for chunk in input:
for chunk in chunks:
picked = self._pick(chunk)
if picked is not None:
yield picked
@ -795,9 +795,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
async def _atransform(
self,
input: AsyncIterator[dict[str, Any]],
chunks: AsyncIterator[dict[str, Any]],
) -> AsyncIterator[dict[str, Any]]:
async for chunk in input:
async for chunk in chunks:
picked = self._pick(chunk)
if picked is not None:
yield picked

View File

@ -178,7 +178,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _invoke(
self,
input: Input,
input_: Input,
run_manager: "CallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
@ -186,7 +186,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
for attempt in self._sync_retrying(reraise=True):
with attempt:
result = super().invoke(
input,
input_,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)
@ -202,7 +202,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async def _ainvoke(
self,
input: Input,
input_: Input,
run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
@ -210,7 +210,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async for attempt in self._async_retrying(reraise=True):
with attempt:
result = await super().ainvoke(
input,
input_,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)

View File

@ -148,22 +148,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
keys = [input_["key"] for input_ in inputs]
actual_inputs = [input_["input"] for input_ in inputs]
if any(key not in self.runnables for key in keys):
msg = "One or more keys do not have a corresponding runnable"
raise ValueError(msg)
def invoke(
runnable: Runnable, input: Input, config: RunnableConfig
runnable: Runnable, input_: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return runnable.invoke(input, config, **kwargs)
return runnable.invoke(input_, config, **kwargs)
except Exception as e:
return e
else:
return runnable.invoke(input, config, **kwargs)
return runnable.invoke(input_, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
@ -185,22 +185,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
keys = [input_["key"] for input_ in inputs]
actual_inputs = [input_["input"] for input_ in inputs]
if any(key not in self.runnables for key in keys):
msg = "One or more keys do not have a corresponding runnable"
raise ValueError(msg)
async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig
runnable: Runnable, input_: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await runnable.ainvoke(input, config, **kwargs)
return await runnable.ainvoke(input_, config, **kwargs)
except Exception as e:
return e
else:
return await runnable.ainvoke(input, config, **kwargs)
return await runnable.ainvoke(input_, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))

View File

@ -75,7 +75,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a run_manager argument.
Args:
@ -90,7 +90,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
return False
def accepts_config(callable: Callable[..., Any]) -> bool:
def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a config argument.
Args:
@ -105,7 +105,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
return False
def accepts_context(callable: Callable[..., Any]) -> bool:
def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
"""Check if a callable accepts a context argument.
Args:
@ -691,7 +691,7 @@ def get_unique_config_specs(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
)
unique: list[ConfigurableFieldSpec] = []
for id, dupes in grouped:
for spec_id, dupes in grouped:
first = next(dupes)
others = list(dupes)
if len(others) == 0 or all(o == first for o in others):
@ -699,7 +699,7 @@ def get_unique_config_specs(
else:
msg = (
"RunnableSequence contains conflicting config specs"
f"for {id}: {[first] + others}"
f"for {spec_id}: {[first] + others}"
)
raise ValueError(msg)
return unique

View File

@ -184,7 +184,7 @@ class StructuredQuery(Expr):
def __init__(
self,
query: str,
filter: Optional[FilterDirective],
filter: Optional[FilterDirective], # noqa: A002
limit: Optional[int] = None,
**kwargs: Any,
) -> None:

View File

@ -943,17 +943,17 @@ def _handle_tool_error(
def _prep_run_args(
input: Union[str, dict, ToolCall],
value: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
) -> tuple[Union[str, dict], dict]:
config = ensure_config(config)
if _is_tool_call(input):
tool_call_id: Optional[str] = cast("ToolCall", input)["id"]
tool_input: Union[str, dict] = cast("ToolCall", input)["args"].copy()
if _is_tool_call(value):
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
else:
tool_call_id = None
tool_input = cast("Union[str, dict]", input)
tool_input = cast("Union[str, dict]", value)
return (
tool_input,
dict(

View File

@ -740,7 +740,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def _astream_events_implementation_v1(
runnable: Runnable[Input, Output],
input: Any,
value: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
@ -789,7 +789,7 @@ async def _astream_events_implementation_v1(
async for log in _astream_log_implementation(
runnable,
input,
value,
config=config,
stream=stream,
diff=True,
@ -810,7 +810,7 @@ async def _astream_events_implementation_v1(
tags=root_tags,
metadata=root_metadata,
data={
"input": input,
"input": value,
},
parent_ids=[], # Not supported in v1
)
@ -924,7 +924,7 @@ async def _astream_events_implementation_v1(
async def _astream_events_implementation_v2(
runnable: Runnable[Input, Output],
input: Any,
value: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
@ -972,7 +972,7 @@ async def _astream_events_implementation_v2(
async def consume_astream() -> None:
try:
# if astream also calls tap_output_aiter this will be a no-op
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
async with aclosing(runnable.astream(value, config, **kwargs)) as stream:
async for _ in event_streamer.tap_output_aiter(run_id, stream):
# All the content will be picked up
pass
@ -993,7 +993,7 @@ async def _astream_events_implementation_v2(
# chain are not available until the entire input is consumed.
# As a temporary solution, we'll modify the input to be the input
# that was passed into the chain.
event["data"]["input"] = input
event["data"]["input"] = value
first_event_run_id = event["run_id"]
yield event
continue

View File

@ -580,7 +580,7 @@ def _get_standardized_outputs(
@overload
def _astream_log_implementation(
runnable: Runnable[Input, Output],
input: Any,
value: Any,
config: Optional[RunnableConfig] = None,
*,
stream: LogStreamCallbackHandler,
@ -593,7 +593,7 @@ def _astream_log_implementation(
@overload
def _astream_log_implementation(
runnable: Runnable[Input, Output],
input: Any,
value: Any,
config: Optional[RunnableConfig] = None,
*,
stream: LogStreamCallbackHandler,
@ -605,7 +605,7 @@ def _astream_log_implementation(
async def _astream_log_implementation(
runnable: Runnable[Input, Output],
input: Any,
value: Any,
config: Optional[RunnableConfig] = None,
*,
stream: LogStreamCallbackHandler,
@ -651,7 +651,7 @@ async def _astream_log_implementation(
prev_final_output: Optional[Output] = None
final_output: Optional[Output] = None
async for chunk in runnable.astream(input, config, **kwargs):
async for chunk in runnable.astream(value, config, **kwargs):
prev_final_output = final_output
if final_output is None:
final_output = chunk

View File

@ -596,7 +596,7 @@ def convert_to_json_schema(
@beta()
def tool_example_to_messages(
input: str,
input: str, # noqa: A002
tool_calls: list[BaseModel],
tool_outputs: Optional[list[str]] = None,
*,

View File

@ -363,7 +363,7 @@ class InMemoryVectorStore(VectorStore):
self,
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
) -> list[tuple[Document, float, list[float]]]:
# get all docs with fixed order in list
docs = list(self.store.values())
@ -402,7 +402,7 @@ class InMemoryVectorStore(VectorStore):
self,
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
**_kwargs: Any,
) -> list[tuple[Document, float]]:
"""Search for the most similar documents to the given embedding.

View File

@ -100,7 +100,6 @@ ignore = [
"UP007", # Doesn't play well with Pydantic in Python 3.9
# TODO rules
"A",
"ANN401",
"BLE",
"ERA",

View File

@ -869,9 +869,9 @@ def create_image_data() -> str:
return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501
def create_base64_image(format: str = "jpeg") -> str:
def create_base64_image(image_format: str = "jpeg") -> str:
data = create_image_data()
return f"data:image/{format};base64,{data}"
return f"data:image/{image_format};base64,{data}"
def test_convert_to_openai_messages_string() -> None:

View File

@ -639,7 +639,7 @@ def test_parse_with_different_pydantic_1_proper() -> None:
def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
input = AIMessage(
message = AIMessage(
content="",
tool_calls=[
{
@ -651,7 +651,7 @@ def test_max_tokens_error(caplog: Any) -> None:
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(input)
_ = parser.invoke(message)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)

View File

@ -15,11 +15,11 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
@contextmanager
def change_directory(dir: Path) -> Iterator:
def change_directory(dir_path: Path) -> Iterator:
"""Change the working directory to the right folder."""
origin = Path().absolute()
try:
os.chdir(dir)
os.chdir(dir_path)
yield
finally:
os.chdir(origin)

View File

@ -220,8 +220,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
str_parser = StrOutputParser()
xml_parser = XMLOutputParser()
def conditional_str_parser(input: str) -> Runnable:
if input == "a":
def conditional_str_parser(value: str) -> Runnable:
if value == "a":
return str_parser
return xml_parser

View File

@ -111,9 +111,9 @@ async def test_input_messages_async() -> None:
def test_input_dict() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
lambda params: "you said: "
+ "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
)
)
get_session_history = _get_get_session_history()
@ -131,9 +131,9 @@ def test_input_dict() -> None:
async def test_input_dict_async() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
lambda params: "you said: "
+ "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
)
)
get_session_history = _get_get_session_history()
@ -153,10 +153,10 @@ async def test_input_dict_async() -> None:
def test_input_dict_with_history_key() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
lambda params: "you said: "
+ "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
+ [input["input"]]
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
+ [params["input"]]
)
)
get_session_history = _get_get_session_history()
@ -175,10 +175,10 @@ def test_input_dict_with_history_key() -> None:
async def test_input_dict_with_history_key_async() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
lambda params: "you said: "
+ "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
+ [input["input"]]
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
+ [params["input"]]
)
)
get_session_history = _get_get_session_history()
@ -197,15 +197,15 @@ async def test_input_dict_with_history_key_async() -> None:
def test_output_message() -> None:
runnable = RunnableLambda(
lambda input: AIMessage(
lambda params: AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
)
@ -225,15 +225,15 @@ def test_output_message() -> None:
async def test_output_message_async() -> None:
runnable = RunnableLambda(
lambda input: AIMessage(
lambda params: AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
)
@ -302,16 +302,16 @@ async def test_input_messages_output_message_async() -> None:
def test_output_messages() -> None:
runnable = RunnableLambda(
lambda input: [
lambda params: [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -332,16 +332,16 @@ def test_output_messages() -> None:
async def test_output_messages_async() -> None:
runnable = RunnableLambda(
lambda input: [
lambda params: [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -362,17 +362,17 @@ async def test_output_messages_async() -> None:
def test_output_dict() -> None:
runnable = RunnableLambda(
lambda input: {
lambda params: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -395,17 +395,17 @@ def test_output_dict() -> None:
async def test_output_dict_async() -> None:
runnable = RunnableLambda(
lambda input: {
lambda params: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -431,17 +431,17 @@ def test_get_input_schema_input_dict() -> None:
input: Union[str, BaseMessage, Sequence[BaseMessage]]
runnable = RunnableLambda(
lambda input: {
lambda params: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -463,17 +463,17 @@ def test_get_input_schema_input_dict() -> None:
def test_get_output_schema() -> None:
"""Test get output schema."""
runnable = RunnableLambda(
lambda input: {
lambda params: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
for m in params["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
+ [params["input"]]
)
)
]
@ -531,8 +531,8 @@ def test_get_input_schema_input_messages() -> None:
def test_using_custom_config_specs() -> None:
"""Test that we can configure which keys should be passed to the session factory."""
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
messages = input["messages"]
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
messages = params["messages"]
return [
AIMessage(
content="you said: "
@ -644,8 +644,8 @@ def test_using_custom_config_specs() -> None:
async def test_using_custom_config_specs_async() -> None:
"""Test that we can configure which keys should be passed to the session factory."""
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
messages = input["messages"]
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
messages = params["messages"]
return [
AIMessage(
content="you said: "
@ -757,12 +757,12 @@ async def test_using_custom_config_specs_async() -> None:
def test_ignore_session_id() -> None:
"""Test without config."""
def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]:
def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]:
return [
AIMessage(
content="you said: "
+ "\n".join(
str(m.content) for m in input if isinstance(m, HumanMessage)
str(m.content) for m in messages if isinstance(m, HumanMessage)
)
)
]

View File

@ -564,8 +564,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["bye", "hello"],
}
def get_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return input["variable_name"]
def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return value["variable_name"]
assert RunnableLambda(get_value).get_input_jsonschema() == {
"title": "get_value_input",
@ -574,8 +574,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["variable_name"],
}
async def aget_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return (input["variable_name"], input.get("another"))
async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return (value["variable_name"], value.get("another"))
assert RunnableLambda(aget_value).get_input_jsonschema() == {
"title": "aget_value_input",
@ -587,11 +587,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
"required": ["another", "variable_name"],
}
async def aget_values(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
"hello": value["variable_name"],
"bye": value["variable_name"],
"byebye": value["yo"],
}
assert RunnableLambda(aget_values).get_input_jsonschema() == {
@ -613,11 +613,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
bye: str
byebye: int
async def aget_values_typed(input: InputType) -> OutputType:
async def aget_values_typed(value: InputType) -> OutputType:
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
"hello": value["variable_name"],
"bye": value["variable_name"],
"byebye": value["yo"],
}
assert _normalize_schema(
@ -2592,8 +2592,8 @@ async def test_prompt_with_llm_and_async_lambda(
)
llm = FakeListLLM(responses=["foo", "bar"])
async def passthrough(input: Any) -> Any:
return input
async def passthrough(value: Any) -> Any:
return value
chain = prompt | llm | passthrough
@ -2946,12 +2946,12 @@ def test_higher_order_lambda_runnable(
input={"question": lambda x: x["question"]},
)
def router(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
def router(params: dict[str, Any]) -> Runnable:
if params["key"] == "math":
return itemgetter("input") | math_chain
if input["key"] == "english":
if params["key"] == "english":
return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}"
msg = f"Unknown key: {params['key']}"
raise ValueError(msg)
chain: Runnable = input_map | router
@ -3002,12 +3002,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
input={"question": lambda x: x["question"]},
)
def router(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
def router(value: dict[str, Any]) -> Runnable:
if value["key"] == "math":
return itemgetter("input") | math_chain
if input["key"] == "english":
if value["key"] == "english":
return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}"
msg = f"Unknown key: {value['key']}"
raise ValueError(msg)
chain: Runnable = input_map | router
@ -3024,12 +3024,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
assert result2 == ["4", "2"]
# Test ainvoke
async def arouter(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
async def arouter(params: dict[str, Any]) -> Runnable:
if params["key"] == "math":
return itemgetter("input") | math_chain
if input["key"] == "english":
if params["key"] == "english":
return itemgetter("input") | english_chain
msg = f"Unknown key: {input['key']}"
msg = f"Unknown key: {params['key']}"
raise ValueError(msg)
achain: Runnable = input_map | arouter
@ -4125,6 +4125,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
@override
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@ -4135,15 +4136,15 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
inputs: list[str],
) -> list:
outputs: list[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
for value in inputs:
if value.startswith(self.fail_starts_with):
outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
)
)
else:
outputs.append(input + "a")
outputs.append(value + "a")
return outputs
def batch(
@ -4264,6 +4265,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
@override
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
@ -4274,15 +4276,15 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
inputs: list[str],
) -> list:
outputs: list[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
for value in inputs:
if value.startswith(self.fail_starts_with):
outputs.append(
ValueError(
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
)
)
else:
outputs.append(input + "a")
outputs.append(value + "a")
return outputs
async def abatch(
@ -5006,10 +5008,10 @@ def test_runnable_iter_context_config() -> None:
fake = RunnableLambda(len)
@chain
def gen(input: str) -> Iterator[int]:
yield fake.invoke(input)
yield fake.invoke(input * 2)
yield fake.invoke(input * 3)
def gen(value: str) -> Iterator[int]:
yield fake.invoke(value)
yield fake.invoke(value * 2)
yield fake.invoke(value * 3)
assert gen.get_input_jsonschema() == {
"title": "gen_input",
@ -5064,10 +5066,10 @@ async def test_runnable_iter_context_config_async() -> None:
fake = RunnableLambda(len)
@chain
async def agen(input: str) -> AsyncIterator[int]:
yield await fake.ainvoke(input)
yield await fake.ainvoke(input * 2)
yield await fake.ainvoke(input * 3)
async def agen(value: str) -> AsyncIterator[int]:
yield await fake.ainvoke(value)
yield await fake.ainvoke(value * 2)
yield await fake.ainvoke(value * 3)
assert agen.get_input_jsonschema() == {
"title": "agen_input",
@ -5130,10 +5132,10 @@ def test_runnable_lambda_context_config() -> None:
fake = RunnableLambda(len)
@chain
def fun(input: str) -> int:
output = fake.invoke(input)
output += fake.invoke(input * 2)
output += fake.invoke(input * 3)
def fun(value: str) -> int:
output = fake.invoke(value)
output += fake.invoke(value * 2)
output += fake.invoke(value * 3)
return output
assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
@ -5186,10 +5188,10 @@ async def test_runnable_lambda_context_config_async() -> None:
fake = RunnableLambda(len)
@chain
async def afun(input: str) -> int:
output = await fake.ainvoke(input)
output += await fake.ainvoke(input * 2)
output += await fake.ainvoke(input * 3)
async def afun(value: str) -> int:
output = await fake.ainvoke(value)
output += await fake.ainvoke(value * 2)
output += await fake.ainvoke(value * 3)
return output
assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
@ -5242,12 +5244,12 @@ async def test_runnable_gen_transform() -> None:
for i in range(length):
yield i
def plus_one(input: Iterator[int]) -> Iterator[int]:
for i in input:
def plus_one(ints: Iterator[int]) -> Iterator[int]:
for i in ints:
yield i + 1
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in input:
async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in ints:
yield i + 1
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one

View File

@ -543,10 +543,10 @@ async def test_astream_events_from_model() -> None:
)
@RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return model.invoke(input)
return model.invoke(input, config)
return model.invoke(value)
return model.invoke(value, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata(
@ -667,10 +667,10 @@ async def test_astream_events_from_model() -> None:
)
@RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return await model.ainvoke(input)
return await model.ainvoke(input, config)
return await model.ainvoke(value)
return await model.ainvoke(value, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata(

View File

@ -613,10 +613,10 @@ async def test_astream_with_model_in_chain() -> None:
)
@RunnableLambda
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return model.invoke(input)
return model.invoke(input, config)
return model.invoke(value)
return model.invoke(value, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata(
@ -721,10 +721,10 @@ async def test_astream_with_model_in_chain() -> None:
)
@RunnableLambda
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return await model.ainvoke(input)
return await model.ainvoke(input, config)
return await model.ainvoke(value)
return await model.ainvoke(value, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata(
@ -2079,6 +2079,7 @@ class StreamingRunnable(Runnable[Input, Output]):
msg = "Server side error"
raise ValueError(msg)
@override
def stream(
self,
input: Input,
@ -2413,14 +2414,14 @@ async def test_break_astream_events() -> None:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
async def __call__(self, value: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
except asyncio.CancelledError:
self.cancelled = True
raise
return input
return value
def reset(self) -> None:
self.started = False
@ -2433,11 +2434,11 @@ async def test_break_astream_events() -> None:
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
async def sequence(value: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
yield await alittlewhile(value)
yield await awhile(value)
yield await anotherwhile(value)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
@ -2478,14 +2479,14 @@ async def test_cancel_astream_events() -> None:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
async def __call__(self, value: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
except asyncio.CancelledError:
self.cancelled = True
raise
return input
return value
def reset(self) -> None:
self.started = False
@ -2498,11 +2499,11 @@ async def test_cancel_astream_events() -> None:
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
async def sequence(value: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
yield await alittlewhile(value)
yield await awhile(value)
yield await anotherwhile(value)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True

View File

@ -47,23 +47,23 @@ global_agent = RunnableLambda(lambda x: x * 3)
def test_nonlocals() -> None:
agent = RunnableLambda(lambda x: x * 2)
def my_func(input: str, agent: dict[str, str]) -> str:
return agent.get("agent_name", input)
def my_func(value: str, agent: dict[str, str]) -> str:
return agent.get("agent_name", value)
def my_func2(input: str) -> str:
return agent.get("agent_name", input) # type: ignore[attr-defined]
def my_func2(value: str) -> str:
return agent.get("agent_name", value) # type: ignore[attr-defined]
def my_func3(input: str) -> str:
return agent.invoke(input)
def my_func3(value: str) -> str:
return agent.invoke(value)
def my_func4(input: str) -> str:
return global_agent.invoke(input)
def my_func4(value: str) -> str:
return global_agent.invoke(value)
def my_func5() -> tuple[Callable[[str], str], RunnableLambda]:
global_agent = RunnableLambda(lambda x: x * 3)
def my_func6(input: str) -> str:
return global_agent.invoke(input)
def my_func6(value: str) -> str:
return global_agent.invoke(value)
return my_func6, global_agent

View File

@ -2,17 +2,17 @@ from langchain_core.globals import get_debug, set_debug
def test_debug_is_settable_via_setter() -> None:
from langchain_core import globals
from langchain_core import globals as globals_
from langchain_core.callbacks.manager import _get_debug
previous_value = globals._debug
previous_value = globals_._debug
previous_fn_reading = _get_debug()
assert previous_value == previous_fn_reading
# Flip the value of the flag.
set_debug(not previous_value)
new_value = globals._debug
new_value = globals_._debug
new_fn_reading = _get_debug()
try:

View File

@ -50,7 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
return [self.store[id_] for id_ in ids if id_ in self.store]
@classmethod
@override
@ -96,7 +96,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
return [self.store[id_] for id_ in ids if id_ in self.store]
@classmethod
@override

View File

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.9, <4.0"
resolution-markers = [
"python_full_version >= '3.13'",