mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 20:05:58 +00:00
Merge 1b6790af76
into 04a899ebe3
This commit is contained in:
commit
132e6ebab3
@ -50,7 +50,7 @@ class LangSmithLoader(BaseLoader):
|
||||
offset: int = 0,
|
||||
limit: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
filter: Optional[str] = None,
|
||||
filter: Optional[str] = None, # noqa: A002
|
||||
content_key: str = "",
|
||||
format_content: Optional[Callable[..., str]] = None,
|
||||
client: Optional[LangSmithClient] = None,
|
||||
|
@ -341,15 +341,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Get the output type for this runnable."""
|
||||
return AnyMessage
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
if isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
if isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(model_input, PromptValue):
|
||||
return model_input
|
||||
if isinstance(model_input, str):
|
||||
return StringPromptValue(text=model_input)
|
||||
if isinstance(model_input, Sequence):
|
||||
return ChatPromptValue(messages=convert_to_messages(model_input))
|
||||
msg = (
|
||||
f"Invalid input type {type(input)}. "
|
||||
f"Invalid input type {type(model_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
@ -325,15 +325,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
if isinstance(input, str):
|
||||
return StringPromptValue(text=input)
|
||||
if isinstance(input, Sequence):
|
||||
return ChatPromptValue(messages=convert_to_messages(input))
|
||||
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(model_input, PromptValue):
|
||||
return model_input
|
||||
if isinstance(model_input, str):
|
||||
return StringPromptValue(text=model_input)
|
||||
if isinstance(model_input, Sequence):
|
||||
return ChatPromptValue(messages=convert_to_messages(model_input))
|
||||
msg = (
|
||||
f"Invalid input type {type(input)}. "
|
||||
f"Invalid input type {type(model_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
@ -438,7 +438,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
if max_concurrency is None:
|
||||
try:
|
||||
llm_result = self.generate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
[self._convert_input(input_) for input_ in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
@ -484,7 +484,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
if max_concurrency is None:
|
||||
try:
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
[self._convert_input(input_) for input_ in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
|
@ -417,10 +417,10 @@ def add_ai_message_chunks(
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
id = None
|
||||
chunk_id = None
|
||||
for id_ in [left.id] + [o.id for o in others]:
|
||||
if id_:
|
||||
id = id_
|
||||
chunk_id = id_
|
||||
break
|
||||
return left.__class__(
|
||||
example=left.example,
|
||||
@ -429,7 +429,7 @@ def add_ai_message_chunks(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
response_metadata=response_metadata,
|
||||
usage_metadata=usage_metadata,
|
||||
id=id,
|
||||
id=chunk_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -11,7 +11,11 @@ class RemoveMessage(BaseMessage):
|
||||
type: Literal["remove"] = "remove"
|
||||
"""The type of the message (used for serialization). Defaults to "remove"."""
|
||||
|
||||
def __init__(self, id: str, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
id: str, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a RemoveMessage.
|
||||
|
||||
Args:
|
||||
|
@ -208,7 +208,12 @@ class ToolCall(TypedDict):
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
def tool_call(
|
||||
*,
|
||||
name: str,
|
||||
args: dict[str, Any],
|
||||
id: Optional[str], # noqa: A002
|
||||
) -> ToolCall:
|
||||
"""Create a tool call.
|
||||
|
||||
Args:
|
||||
@ -254,7 +259,7 @@ def tool_call_chunk(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
index: Optional[int] = None,
|
||||
) -> ToolCallChunk:
|
||||
"""Create a tool call chunk.
|
||||
@ -292,7 +297,7 @@ def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
error: Optional[str] = None,
|
||||
) -> InvalidToolCall:
|
||||
"""Create an invalid tool call.
|
||||
|
@ -212,7 +212,7 @@ def _create_message_from_message_type(
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
id: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Create a message from a message type and content string.
|
||||
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
|
||||
import jsonpatch # type: ignore[import-untyped]
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
@ -47,6 +48,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
@override
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
|
@ -20,7 +20,10 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
||||
def droplastn(
|
||||
iter: Iterator[T], # noqa: A002
|
||||
n: int,
|
||||
) -> Iterator[T]:
|
||||
"""Drop the last n elements of an iterator.
|
||||
|
||||
Args:
|
||||
@ -66,6 +69,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[list[str]]:
|
||||
@ -99,6 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
for part in self.parse(buffer):
|
||||
yield [part]
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[list[str]]:
|
||||
|
@ -72,6 +72,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
def _type(self) -> str:
|
||||
return "json_functions"
|
||||
|
||||
@override
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
|
@ -30,7 +30,10 @@ if TYPE_CHECKING:
|
||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]], # noqa: A002
|
||||
) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
@ -38,7 +41,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
@ -102,7 +106,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
parsed output, or just the current parsed output.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
def _diff(
|
||||
self,
|
||||
prev: Optional[T],
|
||||
next: T, # noqa: A002
|
||||
) -> T:
|
||||
"""Convert parsed outputs into a diff format.
|
||||
|
||||
The semantics of this are up to the output parser.
|
||||
@ -116,6 +124,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
@ -140,6 +149,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
|
@ -8,6 +8,8 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from xml.etree.ElementTree import TreeBuilder
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@ -234,6 +236,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
|
||||
raise OutputParserException(msg, llm_output=text) from e
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
@ -242,6 +245,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
yield from streaming_parser.parse(chunk)
|
||||
streaming_parser.close()
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
|
@ -472,16 +472,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
|
||||
input_variables = []
|
||||
if isinstance(img_template, str):
|
||||
vars = get_template_variables(img_template, template_format)
|
||||
if vars:
|
||||
if len(vars) > 1:
|
||||
variables = get_template_variables(
|
||||
img_template, template_format
|
||||
)
|
||||
if variables:
|
||||
if len(variables) > 1:
|
||||
msg = (
|
||||
"Only one format variable allowed per image"
|
||||
f" template.\nGot: {vars}"
|
||||
f" template.\nGot: {variables}"
|
||||
f"\nFrom: {tmpl}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
input_variables = [vars[0]]
|
||||
input_variables = [variables[0]]
|
||||
img_template = {"url": img_template}
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables,
|
||||
|
@ -1,3 +1,4 @@
|
||||
# noqa:A005
|
||||
"""BasePrompt schema definition."""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -125,20 +126,20 @@ def mustache_template_vars(
|
||||
Returns:
|
||||
The variables from the template.
|
||||
"""
|
||||
vars: set[str] = set()
|
||||
variables: set[str] = set()
|
||||
section_depth = 0
|
||||
for type, key in mustache.tokenize(template):
|
||||
if type == "end":
|
||||
for type_, key in mustache.tokenize(template):
|
||||
if type_ == "end":
|
||||
section_depth -= 1
|
||||
elif (
|
||||
type in ("variable", "section", "inverted section", "no escape")
|
||||
type_ in ("variable", "section", "inverted section", "no escape")
|
||||
and key != "."
|
||||
and section_depth == 0
|
||||
):
|
||||
vars.add(key.split(".")[0])
|
||||
if type in ("section", "inverted section"):
|
||||
variables.add(key.split(".")[0])
|
||||
if type_ in ("section", "inverted section"):
|
||||
section_depth += 1
|
||||
return vars
|
||||
return variables
|
||||
|
||||
|
||||
Defs = dict[str, "Defs"]
|
||||
@ -158,17 +159,17 @@ def mustache_schema(
|
||||
fields = {}
|
||||
prefix: tuple[str, ...] = ()
|
||||
section_stack: list[tuple[str, ...]] = []
|
||||
for type, key in mustache.tokenize(template):
|
||||
for type_, key in mustache.tokenize(template):
|
||||
if key == ".":
|
||||
continue
|
||||
if type == "end":
|
||||
if type_ == "end":
|
||||
if section_stack:
|
||||
prefix = section_stack.pop()
|
||||
elif type in ("section", "inverted section"):
|
||||
elif type_ in ("section", "inverted section"):
|
||||
section_stack.append(prefix)
|
||||
prefix = prefix + tuple(key.split("."))
|
||||
fields[prefix] = False
|
||||
elif type in ("variable", "no escape"):
|
||||
elif type_ in ("variable", "no escape"):
|
||||
fields[prefix + tuple(key.split("."))] = True
|
||||
defs: Defs = {} # None means leaf node
|
||||
while fields:
|
||||
|
@ -209,6 +209,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
@ -269,6 +270,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str,
|
||||
|
@ -724,7 +724,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
@abstractmethod
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
self,
|
||||
input: Input, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Transform a single input into an output.
|
||||
|
||||
@ -741,7 +744,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
self,
|
||||
input: Input, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Default implementation of ainvoke, calls invoke from a thread.
|
||||
|
||||
@ -772,14 +778,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||
def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return self.invoke(input, config, **kwargs)
|
||||
return self.invoke(input_, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return self.invoke(input, config, **kwargs)
|
||||
return self.invoke(input_, config, **kwargs)
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
@ -826,15 +832,17 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(
|
||||
i: int, input: Input, config: RunnableConfig
|
||||
i: int, input_: Input, config: RunnableConfig
|
||||
) -> tuple[int, Union[Output, Exception]]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
|
||||
out: Union[Output, Exception] = self.invoke(
|
||||
input_, config, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
out = e
|
||||
else:
|
||||
out = self.invoke(input, config, **kwargs)
|
||||
out = self.invoke(input_, config, **kwargs)
|
||||
|
||||
return (i, out)
|
||||
|
||||
@ -844,8 +852,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
futures = {
|
||||
executor.submit(invoke, i, input, config)
|
||||
for i, (input, config) in enumerate(zip(inputs, configs))
|
||||
executor.submit(invoke, i, input_, config)
|
||||
for i, (input_, config) in enumerate(zip(inputs, configs))
|
||||
}
|
||||
|
||||
try:
|
||||
@ -892,15 +900,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
async def ainvoke(
|
||||
input: Input, config: RunnableConfig
|
||||
value: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await self.ainvoke(input, config, **kwargs)
|
||||
return await self.ainvoke(value, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await self.ainvoke(input, config, **kwargs)
|
||||
return await self.ainvoke(value, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, inputs, configs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
@ -960,24 +968,24 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
|
||||
async def ainvoke_task(
|
||||
i: int, input: Input, config: RunnableConfig
|
||||
i: int, input_: Input, config: RunnableConfig
|
||||
) -> tuple[int, Union[Output, Exception]]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
out: Union[Output, Exception] = await self.ainvoke(
|
||||
input, config, **kwargs
|
||||
input_, config, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
out = e
|
||||
else:
|
||||
out = await self.ainvoke(input, config, **kwargs)
|
||||
out = await self.ainvoke(input_, config, **kwargs)
|
||||
return (i, out)
|
||||
|
||||
coros = [
|
||||
gated_coro(semaphore, ainvoke_task(i, input, config))
|
||||
gated_coro(semaphore, ainvoke_task(i, input_, config))
|
||||
if semaphore
|
||||
else ainvoke_task(i, input, config)
|
||||
for i, (input, config) in enumerate(zip(inputs, configs))
|
||||
else ainvoke_task(i, input_, config)
|
||||
for i, (input_, config) in enumerate(zip(inputs, configs))
|
||||
]
|
||||
|
||||
for coro in asyncio.as_completed(coros):
|
||||
@ -985,7 +993,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Input, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
@ -1005,7 +1013,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Input, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
@ -1059,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
input: Any, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
@ -1130,7 +1138,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
input: Any, # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"] = "v2",
|
||||
@ -1396,7 +1404,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
input: Iterator[Input], # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
@ -1438,7 +1446,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
input: AsyncIterator[Input], # noqa: A002
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
@ -1903,7 +1911,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input,
|
||||
input_: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
serialized: Optional[dict[str, Any]] = None,
|
||||
@ -1917,7 +1925,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
serialized,
|
||||
input,
|
||||
input_,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
@ -1930,7 +1938,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
context.run(
|
||||
call_func_with_variable_args, # type: ignore[arg-type]
|
||||
func,
|
||||
input,
|
||||
input_,
|
||||
config,
|
||||
run_manager,
|
||||
**kwargs,
|
||||
@ -1953,7 +1961,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input,
|
||||
input_: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
serialized: Optional[dict[str, Any]] = None,
|
||||
@ -1967,7 +1975,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
serialized,
|
||||
input,
|
||||
input_,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
@ -1976,7 +1984,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
with set_config_context(child_config) as context:
|
||||
coro = acall_func_with_variable_args(
|
||||
func, input, config, run_manager, **kwargs
|
||||
func, input_, config, run_manager, **kwargs
|
||||
)
|
||||
output: Output = await coro_with_context(coro, context)
|
||||
except BaseException as e:
|
||||
@ -1999,7 +2007,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
list[Union[Exception, Output]],
|
||||
],
|
||||
],
|
||||
input: list[Input],
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
@ -2011,21 +2019,21 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses.
|
||||
"""
|
||||
if not input:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(input))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||
run_managers = [
|
||||
callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
input_,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
for callback_manager, input_, config in zip(
|
||||
callback_managers, inputs, configs
|
||||
)
|
||||
]
|
||||
try:
|
||||
@ -2036,12 +2044,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
]
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_managers
|
||||
output = func(input, **kwargs) # type: ignore[call-arg]
|
||||
output = func(inputs, **kwargs) # type: ignore[call-arg]
|
||||
except BaseException as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_chain_error(e)
|
||||
if return_exceptions:
|
||||
return cast("list[Output]", [e for _ in input])
|
||||
return cast("list[Output]", [e for _ in inputs])
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
@ -2072,7 +2080,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Awaitable[list[Union[Exception, Output]]],
|
||||
],
|
||||
],
|
||||
input: list[Input],
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
@ -2085,22 +2093,22 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
with callbacks.
|
||||
Use this method to implement invoke() in subclasses.
|
||||
"""
|
||||
if not input:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(input))
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
input_,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
for callback_manager, input_, config in zip(
|
||||
callback_managers, inputs, configs
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -2112,13 +2120,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
]
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_managers
|
||||
output = await func(input, **kwargs) # type: ignore[call-arg]
|
||||
output = await func(inputs, **kwargs) # type: ignore[call-arg]
|
||||
except BaseException as e:
|
||||
await asyncio.gather(
|
||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||
)
|
||||
if return_exceptions:
|
||||
return cast("list[Output]", [e for _ in input])
|
||||
return cast("list[Output]", [e for _ in inputs])
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
@ -2136,7 +2144,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def _transform_stream_with_config(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
inputs: Iterator[Input],
|
||||
transformer: Union[
|
||||
Callable[[Iterator[Input]], Iterator[Output]],
|
||||
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
||||
@ -2163,7 +2171,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = tee(input, 2)
|
||||
input_for_tracing, input_for_transform = tee(inputs, 2)
|
||||
# Start the input iterator to ensure the input Runnable starts before this one
|
||||
final_input: Optional[Input] = next(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
@ -2237,7 +2245,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def _atransform_stream_with_config(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
inputs: AsyncIterator[Input],
|
||||
transformer: Union[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
Callable[
|
||||
@ -2267,7 +2275,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = atee(input, 2)
|
||||
input_for_tracing, input_for_transform = atee(inputs, 2)
|
||||
# Start the input iterator to ensure the input Runnable starts before this one
|
||||
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
@ -3019,6 +3027,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
input_ = input
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -3029,16 +3038,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
with set_config_context(config) as context:
|
||||
if i == 0:
|
||||
input = context.run(step.invoke, input, config, **kwargs)
|
||||
input_ = context.run(step.invoke, input_, config, **kwargs)
|
||||
else:
|
||||
input = context.run(step.invoke, input, config)
|
||||
input_ = context.run(step.invoke, input_, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(input)
|
||||
return cast("Output", input)
|
||||
run_manager.on_chain_end(input_)
|
||||
return cast("Output", input_)
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
@ -3059,6 +3068,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
input_ = input
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
@ -3069,17 +3079,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
with set_config_context(config) as context:
|
||||
if i == 0:
|
||||
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
||||
part = functools.partial(step.ainvoke, input_, config, **kwargs)
|
||||
else:
|
||||
part = functools.partial(step.ainvoke, input, config)
|
||||
input = await coro_with_context(part(), context, create_task=True)
|
||||
part = functools.partial(step.ainvoke, input_, config)
|
||||
input_ = await coro_with_context(part(), context, create_task=True)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(input)
|
||||
return cast("Output", input)
|
||||
await run_manager.on_chain_end(input_)
|
||||
return cast("Output", input_)
|
||||
|
||||
@override
|
||||
def batch(
|
||||
@ -3117,11 +3127,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
input_,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
# invoke
|
||||
@ -3248,11 +3258,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
input_,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
@ -3346,7 +3356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
inputs: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -3359,7 +3369,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast("Iterator[Output]", input)
|
||||
final_pipeline = cast("Iterator[Output]", inputs)
|
||||
for idx, step in enumerate(steps):
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}")
|
||||
@ -3373,7 +3383,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
inputs: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -3387,7 +3397,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
final_pipeline = cast("AsyncIterator[Output]", input)
|
||||
final_pipeline = cast("AsyncIterator[Output]", inputs)
|
||||
for idx, step in enumerate(steps):
|
||||
config = patch_config(
|
||||
config,
|
||||
@ -3733,7 +3743,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
)
|
||||
|
||||
def _invoke_step(
|
||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
|
||||
) -> Any:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
@ -3743,7 +3753,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
with set_config_context(child_config) as context:
|
||||
return context.run(
|
||||
step.invoke,
|
||||
input,
|
||||
input_,
|
||||
child_config,
|
||||
)
|
||||
|
||||
@ -3785,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
)
|
||||
|
||||
async def _ainvoke_step(
|
||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
|
||||
) -> Any:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
@ -3793,7 +3803,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
)
|
||||
with set_config_context(child_config) as context:
|
||||
return await coro_with_context(
|
||||
step.ainvoke(input, child_config), context, create_task=True
|
||||
step.ainvoke(input_, child_config), context, create_task=True
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
@ -3823,7 +3833,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
inputs: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Iterator[AddableDict]:
|
||||
@ -3831,7 +3841,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
steps = dict(self.steps__)
|
||||
# Each step gets a copy of the input iterator,
|
||||
# which is consumed in parallel in a separate thread.
|
||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||
input_copies = list(safetee(inputs, len(steps), lock=threading.Lock()))
|
||||
with get_executor_for_config(config) as executor:
|
||||
# Create the transform() generator for each step
|
||||
named_generators = [
|
||||
@ -3890,7 +3900,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
inputs: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
@ -3898,7 +3908,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
steps = dict(self.steps__)
|
||||
# Each step gets a copy of the input iterator,
|
||||
# which is consumed in parallel in a separate thread.
|
||||
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
||||
input_copies = list(atee(inputs, len(steps), lock=asyncio.Lock()))
|
||||
# Create the transform() generator for each step
|
||||
named_generators = [
|
||||
(
|
||||
@ -4590,7 +4600,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
input_: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -4599,7 +4609,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
output: Optional[Output] = None
|
||||
for chunk in call_func_with_variable_args(
|
||||
cast("Callable[[Input], Iterator[Output]]", self.func),
|
||||
input,
|
||||
input_,
|
||||
config,
|
||||
run_manager,
|
||||
**kwargs,
|
||||
@ -4613,18 +4623,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
output = chunk
|
||||
else:
|
||||
output = call_func_with_variable_args(
|
||||
self.func, input, config, run_manager, **kwargs
|
||||
self.func, input_, config, run_manager, **kwargs
|
||||
)
|
||||
# If the output is a Runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
msg = (
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
f"Recursion limit reached when invoking {self} with input {input_}."
|
||||
)
|
||||
raise RecursionError(msg)
|
||||
output = output.invoke(
|
||||
input,
|
||||
input_,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
@ -4635,7 +4645,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
value: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -4646,7 +4656,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if inspect.isgeneratorfunction(self.func):
|
||||
|
||||
def func(
|
||||
input: Input,
|
||||
value: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -4654,7 +4664,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
output: Optional[Output] = None
|
||||
for chunk in call_func_with_variable_args(
|
||||
cast("Callable[[Input], Iterator[Output]]", self.func),
|
||||
input,
|
||||
value,
|
||||
config,
|
||||
run_manager.get_sync(),
|
||||
**kwargs,
|
||||
@ -4671,13 +4681,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
else:
|
||||
|
||||
def func(
|
||||
input: Input,
|
||||
value: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
return call_func_with_variable_args(
|
||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||
self.func, value, config, run_manager.get_sync(), **kwargs
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
@ -4693,7 +4703,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
"AsyncGenerator[Any, Any]",
|
||||
acall_func_with_variable_args(
|
||||
cast("Callable", afunc),
|
||||
input,
|
||||
value,
|
||||
config,
|
||||
run_manager,
|
||||
**kwargs,
|
||||
@ -4713,18 +4723,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
output = chunk
|
||||
else:
|
||||
output = await acall_func_with_variable_args(
|
||||
cast("Callable", afunc), input, config, run_manager, **kwargs
|
||||
cast("Callable", afunc), value, config, run_manager, **kwargs
|
||||
)
|
||||
# If the output is a Runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
msg = (
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
f"Recursion limit reached when invoking {self} with input {value}."
|
||||
)
|
||||
raise RecursionError(msg)
|
||||
output = await output.ainvoke(
|
||||
input,
|
||||
value,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
@ -4789,14 +4799,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
chunks: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
final: Input
|
||||
got_first_val = False
|
||||
for ichunk in input:
|
||||
for ichunk in chunks:
|
||||
# By definitions, RunnableLambdas consume all input before emitting output.
|
||||
# If the input is not addable, then we'll assume that we can
|
||||
# only operate on the last chunk.
|
||||
@ -4881,14 +4891,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
chunks: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
final: Input
|
||||
got_first_val = False
|
||||
async for ichunk in input:
|
||||
async for ichunk in chunks:
|
||||
# By definitions, RunnableLambdas consume all input before emitting output.
|
||||
# If the input is not addable, then we'll assume that we can
|
||||
# only operate on the last chunk.
|
||||
@ -4913,13 +4923,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
raise TypeError(msg)
|
||||
|
||||
def func(
|
||||
input: Input,
|
||||
input_: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
return call_func_with_variable_args(
|
||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
||||
self.func, input_, config, run_manager.get_sync(), **kwargs
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
|
@ -196,6 +196,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
raise ValueError(msg)
|
||||
return specs
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
@ -254,6 +255,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
@ -302,6 +304,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
@ -388,6 +391,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
raise
|
||||
run_manager.on_chain_end(final_output)
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
|
@ -401,7 +401,7 @@ def call_func_with_variable_args(
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input,
|
||||
input: Input, # noqa: A002
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
@ -438,7 +438,7 @@ def acall_func_with_variable_args(
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input,
|
||||
input: Input, # noqa: A002
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
|
@ -178,16 +178,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
def invoke(
|
||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
input_: Input,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
if return_exceptions:
|
||||
try:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
return bound.invoke(input_, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
return bound.invoke(input_, config, **kwargs)
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
@ -221,16 +221,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def ainvoke(
|
||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
input_: Input,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
return await bound.ainvoke(input_, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
return await bound.ainvoke(input_, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, prepared, inputs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
@ -269,7 +269,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if self.exception_key is not None and not all(
|
||||
isinstance(input, dict) for input in inputs
|
||||
isinstance(input_, dict) for input_ in inputs
|
||||
):
|
||||
msg = (
|
||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||
@ -298,11 +298,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
input_ if isinstance(input_, dict) else {"input": input_},
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
to_return: dict[int, Any] = {}
|
||||
@ -311,7 +311,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
outputs = runnable.batch(
|
||||
[input for _, input in sorted(run_again.items())],
|
||||
[input_ for _, input_ in sorted(run_again.items())],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||
@ -320,7 +320,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
if isinstance(output, BaseException) and not isinstance(
|
||||
output, self.exceptions_to_handle
|
||||
):
|
||||
@ -331,7 +331,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
run_again.pop(i)
|
||||
elif isinstance(output, self.exceptions_to_handle):
|
||||
if self.exception_key:
|
||||
input[self.exception_key] = output # type: ignore[index]
|
||||
input_[self.exception_key] = output # type: ignore[index]
|
||||
handled_exceptions[i] = output
|
||||
else:
|
||||
run_managers[i].on_chain_end(output)
|
||||
@ -363,7 +363,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if self.exception_key is not None and not all(
|
||||
isinstance(input, dict) for input in inputs
|
||||
isinstance(input_, dict) for input_ in inputs
|
||||
):
|
||||
msg = (
|
||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||
@ -393,11 +393,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
input_,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
@ -407,7 +407,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
outputs = await runnable.abatch(
|
||||
[input for _, input in sorted(run_again.items())],
|
||||
[input_ for _, input_ in sorted(run_again.items())],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||
@ -417,7 +417,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
if isinstance(output, BaseException) and not isinstance(
|
||||
output, self.exceptions_to_handle
|
||||
):
|
||||
@ -428,7 +428,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
run_again.pop(i)
|
||||
elif isinstance(output, self.exceptions_to_handle):
|
||||
if self.exception_key:
|
||||
input[self.exception_key] = output # type: ignore[index]
|
||||
input_[self.exception_key] = output # type: ignore[index]
|
||||
handled_exceptions[i] = output
|
||||
else:
|
||||
to_return[i] = output
|
||||
|
@ -111,7 +111,12 @@ class Node(NamedTuple):
|
||||
data: Union[type[BaseModel], RunnableType, None]
|
||||
metadata: Optional[dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a copy of the node with optional new id and name.
|
||||
|
||||
Args:
|
||||
@ -181,7 +186,10 @@ class MermaidDrawMethod(Enum):
|
||||
API = "api" # Uses Mermaid.INK API to render the graph
|
||||
|
||||
|
||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
|
||||
def node_data_str(
|
||||
id: str, # noqa: A002
|
||||
data: Union[type[BaseModel], RunnableType, None],
|
||||
) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
|
||||
Args:
|
||||
@ -320,7 +328,7 @@ class Graph:
|
||||
def add_node(
|
||||
self,
|
||||
data: Union[type[BaseModel], RunnableType, None],
|
||||
id: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
*,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
@ -340,8 +348,8 @@ class Graph:
|
||||
if id is not None and id in self.nodes:
|
||||
msg = f"Node with id {id} already exists"
|
||||
raise ValueError(msg)
|
||||
id = id or self.next_id()
|
||||
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
|
||||
id_ = id or self.next_id()
|
||||
node = Node(id=id_, data=data, metadata=metadata, name=node_data_str(id_, data))
|
||||
self.nodes[node.id] = node
|
||||
return node
|
||||
|
||||
@ -406,8 +414,8 @@ class Graph:
|
||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||
prefix = ""
|
||||
|
||||
def prefixed(id: str) -> str:
|
||||
return f"{prefix}:{id}" if prefix else id
|
||||
def prefixed(id_: str) -> str:
|
||||
return f"{prefix}:{id_}" if prefix else id_
|
||||
|
||||
# prefix each node
|
||||
self.nodes.update(
|
||||
@ -450,8 +458,8 @@ class Graph:
|
||||
|
||||
return Graph(
|
||||
nodes={
|
||||
_get_node_id(id): node.copy(id=_get_node_id(id))
|
||||
for id, node in self.nodes.items()
|
||||
_get_node_id(id_): node.copy(id=_get_node_id(id_))
|
||||
for id_, node in self.nodes.items()
|
||||
},
|
||||
edges=[
|
||||
edge.copy(
|
||||
|
@ -187,7 +187,7 @@ def _build_sugiyama_layout(
|
||||
# Y
|
||||
#
|
||||
|
||||
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
|
||||
vertices_ = {id_: Vertex(f" {data} ") for id_, data in vertices.items()}
|
||||
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
|
||||
vertices_list = vertices_.values()
|
||||
graph = Graph(vertices_list, edges_)
|
||||
|
@ -509,20 +509,20 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||
def _enter_history(self, value: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
messages = hist.messages.copy()
|
||||
|
||||
if not self.history_messages_key:
|
||||
# return all messages
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
value if not self.input_messages_key else value[self.input_messages_key]
|
||||
)
|
||||
messages += self._get_input_messages(input_val)
|
||||
return messages
|
||||
|
||||
async def _aenter_history(
|
||||
self, input: dict[str, Any], config: RunnableConfig
|
||||
self, value: dict[str, Any], config: RunnableConfig
|
||||
) -> list[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
messages = (await hist.aget_messages()).copy()
|
||||
@ -530,7 +530,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
if not self.history_messages_key:
|
||||
# return all messages
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
value if not self.input_messages_key else value[self.input_messages_key]
|
||||
)
|
||||
messages += self._get_input_messages(input_val)
|
||||
return messages
|
||||
|
@ -483,19 +483,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
value: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
if not isinstance(input, dict):
|
||||
if not isinstance(value, dict):
|
||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
return {
|
||||
**input,
|
||||
**value,
|
||||
**self.mapper.invoke(
|
||||
input,
|
||||
value,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
),
|
||||
@ -512,19 +512,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
value: dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
if not isinstance(input, dict):
|
||||
if not isinstance(value, dict):
|
||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
return {
|
||||
**input,
|
||||
**value,
|
||||
**await self.mapper.ainvoke(
|
||||
input,
|
||||
value,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
),
|
||||
@ -541,7 +541,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[dict[str, Any]],
|
||||
values: Iterator[dict[str, Any]],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -549,7 +549,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps__.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||
for_passthrough, for_map = safetee(values, 2, lock=threading.Lock())
|
||||
|
||||
# create map output stream
|
||||
map_output = self.mapper.transform(
|
||||
@ -598,7 +598,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
values: AsyncIterator[dict[str, Any]],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -606,7 +606,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps__.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||
for_passthrough, for_map = atee(values, 2, lock=asyncio.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.atransform(
|
||||
for_map,
|
||||
@ -731,23 +731,23 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
)
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def _pick(self, input: dict[str, Any]) -> Any:
|
||||
if not isinstance(input, dict):
|
||||
def _pick(self, value: dict[str, Any]) -> Any:
|
||||
if not isinstance(value, dict):
|
||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
if isinstance(self.keys, str):
|
||||
return input.get(self.keys)
|
||||
picked = {k: input.get(k) for k in self.keys if k in input}
|
||||
return value.get(self.keys)
|
||||
picked = {k: value.get(k) for k in self.keys if k in value}
|
||||
if picked:
|
||||
return AddableDict(picked)
|
||||
return None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
value: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return self._pick(input)
|
||||
return self._pick(value)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
@ -760,9 +760,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: dict[str, Any],
|
||||
value: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return self._pick(input)
|
||||
return self._pick(value)
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
@ -775,9 +775,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[dict[str, Any]],
|
||||
chunks: Iterator[dict[str, Any]],
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
for chunk in input:
|
||||
for chunk in chunks:
|
||||
picked = self._pick(chunk)
|
||||
if picked is not None:
|
||||
yield picked
|
||||
@ -795,9 +795,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
chunks: AsyncIterator[dict[str, Any]],
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for chunk in input:
|
||||
async for chunk in chunks:
|
||||
picked = self._pick(chunk)
|
||||
if picked is not None:
|
||||
yield picked
|
||||
|
@ -178,7 +178,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
input_: Input,
|
||||
run_manager: "CallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -186,7 +186,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
for attempt in self._sync_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input,
|
||||
input_,
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
**kwargs,
|
||||
)
|
||||
@ -202,7 +202,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
input_: Input,
|
||||
run_manager: "AsyncCallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
@ -210,7 +210,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
async for attempt in self._async_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input,
|
||||
input_,
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -148,22 +148,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
keys = [input_["key"] for input_ in inputs]
|
||||
actual_inputs = [input_["input"] for input_ in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
msg = "One or more keys do not have a corresponding runnable"
|
||||
raise ValueError(msg)
|
||||
|
||||
def invoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
runnable: Runnable, input_: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
return runnable.invoke(input_, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
return runnable.invoke(input_, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
@ -185,22 +185,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
keys = [input_["key"] for input_ in inputs]
|
||||
actual_inputs = [input_["input"] for input_ in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
msg = "One or more keys do not have a corresponding runnable"
|
||||
raise ValueError(msg)
|
||||
|
||||
async def ainvoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
runnable: Runnable, input_: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
return await runnable.ainvoke(input_, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
return await runnable.ainvoke(input_, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
@ -75,7 +75,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
"""Check if a callable accepts a run_manager argument.
|
||||
|
||||
Args:
|
||||
@ -90,7 +90,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
"""Check if a callable accepts a config argument.
|
||||
|
||||
Args:
|
||||
@ -105,7 +105,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||
def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
"""Check if a callable accepts a context argument.
|
||||
|
||||
Args:
|
||||
@ -691,7 +691,7 @@ def get_unique_config_specs(
|
||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||
)
|
||||
unique: list[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
for spec_id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
others = list(dupes)
|
||||
if len(others) == 0 or all(o == first for o in others):
|
||||
@ -699,7 +699,7 @@ def get_unique_config_specs(
|
||||
else:
|
||||
msg = (
|
||||
"RunnableSequence contains conflicting config specs"
|
||||
f"for {id}: {[first] + others}"
|
||||
f"for {spec_id}: {[first] + others}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return unique
|
||||
|
@ -184,7 +184,7 @@ class StructuredQuery(Expr):
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
filter: Optional[FilterDirective],
|
||||
filter: Optional[FilterDirective], # noqa: A002
|
||||
limit: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -943,17 +943,17 @@ def _handle_tool_error(
|
||||
|
||||
|
||||
def _prep_run_args(
|
||||
input: Union[str, dict, ToolCall],
|
||||
value: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> tuple[Union[str, dict], dict]:
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(input):
|
||||
tool_call_id: Optional[str] = cast("ToolCall", input)["id"]
|
||||
tool_input: Union[str, dict] = cast("ToolCall", input)["args"].copy()
|
||||
if _is_tool_call(value):
|
||||
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
|
||||
tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
|
||||
else:
|
||||
tool_call_id = None
|
||||
tool_input = cast("Union[str, dict]", input)
|
||||
tool_input = cast("Union[str, dict]", value)
|
||||
return (
|
||||
tool_input,
|
||||
dict(
|
||||
|
@ -740,7 +740,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def _astream_events_implementation_v1(
|
||||
runnable: Runnable[Input, Output],
|
||||
input: Any,
|
||||
value: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
@ -789,7 +789,7 @@ async def _astream_events_implementation_v1(
|
||||
|
||||
async for log in _astream_log_implementation(
|
||||
runnable,
|
||||
input,
|
||||
value,
|
||||
config=config,
|
||||
stream=stream,
|
||||
diff=True,
|
||||
@ -810,7 +810,7 @@ async def _astream_events_implementation_v1(
|
||||
tags=root_tags,
|
||||
metadata=root_metadata,
|
||||
data={
|
||||
"input": input,
|
||||
"input": value,
|
||||
},
|
||||
parent_ids=[], # Not supported in v1
|
||||
)
|
||||
@ -924,7 +924,7 @@ async def _astream_events_implementation_v1(
|
||||
|
||||
async def _astream_events_implementation_v2(
|
||||
runnable: Runnable[Input, Output],
|
||||
input: Any,
|
||||
value: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
@ -972,7 +972,7 @@ async def _astream_events_implementation_v2(
|
||||
async def consume_astream() -> None:
|
||||
try:
|
||||
# if astream also calls tap_output_aiter this will be a no-op
|
||||
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
|
||||
async with aclosing(runnable.astream(value, config, **kwargs)) as stream:
|
||||
async for _ in event_streamer.tap_output_aiter(run_id, stream):
|
||||
# All the content will be picked up
|
||||
pass
|
||||
@ -993,7 +993,7 @@ async def _astream_events_implementation_v2(
|
||||
# chain are not available until the entire input is consumed.
|
||||
# As a temporary solution, we'll modify the input to be the input
|
||||
# that was passed into the chain.
|
||||
event["data"]["input"] = input
|
||||
event["data"]["input"] = value
|
||||
first_event_run_id = event["run_id"]
|
||||
yield event
|
||||
continue
|
||||
|
@ -580,7 +580,7 @@ def _get_standardized_outputs(
|
||||
@overload
|
||||
def _astream_log_implementation(
|
||||
runnable: Runnable[Input, Output],
|
||||
input: Any,
|
||||
value: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stream: LogStreamCallbackHandler,
|
||||
@ -593,7 +593,7 @@ def _astream_log_implementation(
|
||||
@overload
|
||||
def _astream_log_implementation(
|
||||
runnable: Runnable[Input, Output],
|
||||
input: Any,
|
||||
value: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stream: LogStreamCallbackHandler,
|
||||
@ -605,7 +605,7 @@ def _astream_log_implementation(
|
||||
|
||||
async def _astream_log_implementation(
|
||||
runnable: Runnable[Input, Output],
|
||||
input: Any,
|
||||
value: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stream: LogStreamCallbackHandler,
|
||||
@ -651,7 +651,7 @@ async def _astream_log_implementation(
|
||||
prev_final_output: Optional[Output] = None
|
||||
final_output: Optional[Output] = None
|
||||
|
||||
async for chunk in runnable.astream(input, config, **kwargs):
|
||||
async for chunk in runnable.astream(value, config, **kwargs):
|
||||
prev_final_output = final_output
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
|
@ -596,7 +596,7 @@ def convert_to_json_schema(
|
||||
|
||||
@beta()
|
||||
def tool_example_to_messages(
|
||||
input: str,
|
||||
input: str, # noqa: A002
|
||||
tool_calls: list[BaseModel],
|
||||
tool_outputs: Optional[list[str]] = None,
|
||||
*,
|
||||
|
@ -363,7 +363,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Callable[[Document], bool]] = None,
|
||||
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
|
||||
) -> list[tuple[Document, float, list[float]]]:
|
||||
# get all docs with fixed order in list
|
||||
docs = list(self.store.values())
|
||||
@ -402,7 +402,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Callable[[Document], bool]] = None,
|
||||
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
|
||||
**_kwargs: Any,
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Search for the most similar documents to the given embedding.
|
||||
|
@ -100,7 +100,6 @@ ignore = [
|
||||
"UP007", # Doesn't play well with Pydantic in Python 3.9
|
||||
|
||||
# TODO rules
|
||||
"A",
|
||||
"ANN401",
|
||||
"BLE",
|
||||
"ERA",
|
||||
|
@ -869,9 +869,9 @@ def create_image_data() -> str:
|
||||
return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501
|
||||
|
||||
|
||||
def create_base64_image(format: str = "jpeg") -> str:
|
||||
def create_base64_image(image_format: str = "jpeg") -> str:
|
||||
data = create_image_data()
|
||||
return f"data:image/{format};base64,{data}"
|
||||
return f"data:image/{image_format};base64,{data}"
|
||||
|
||||
|
||||
def test_convert_to_openai_messages_string() -> None:
|
||||
|
@ -639,7 +639,7 @@ def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
input = AIMessage(
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
@ -651,7 +651,7 @@ def test_max_tokens_error(caplog: Any) -> None:
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(input)
|
||||
_ = parser.invoke(message)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
|
@ -15,11 +15,11 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_directory(dir: Path) -> Iterator:
|
||||
def change_directory(dir_path: Path) -> Iterator:
|
||||
"""Change the working directory to the right folder."""
|
||||
origin = Path().absolute()
|
||||
try:
|
||||
os.chdir(dir)
|
||||
os.chdir(dir_path)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(origin)
|
||||
|
@ -220,8 +220,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
str_parser = StrOutputParser()
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
def conditional_str_parser(input: str) -> Runnable:
|
||||
if input == "a":
|
||||
def conditional_str_parser(value: str) -> Runnable:
|
||||
if value == "a":
|
||||
return str_parser
|
||||
return xml_parser
|
||||
|
||||
|
@ -111,9 +111,9 @@ async def test_input_messages_async() -> None:
|
||||
|
||||
def test_input_dict() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
lambda params: "you said: "
|
||||
+ "\n".join(
|
||||
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
|
||||
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
@ -131,9 +131,9 @@ def test_input_dict() -> None:
|
||||
|
||||
async def test_input_dict_async() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
lambda params: "you said: "
|
||||
+ "\n".join(
|
||||
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
|
||||
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
@ -153,10 +153,10 @@ async def test_input_dict_async() -> None:
|
||||
|
||||
def test_input_dict_with_history_key() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
lambda params: "you said: "
|
||||
+ "\n".join(
|
||||
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
|
||||
+ [input["input"]]
|
||||
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
@ -175,10 +175,10 @@ def test_input_dict_with_history_key() -> None:
|
||||
|
||||
async def test_input_dict_with_history_key_async() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: "you said: "
|
||||
lambda params: "you said: "
|
||||
+ "\n".join(
|
||||
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
|
||||
+ [input["input"]]
|
||||
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
@ -197,15 +197,15 @@ async def test_input_dict_with_history_key_async() -> None:
|
||||
|
||||
def test_output_message() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: AIMessage(
|
||||
lambda params: AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -225,15 +225,15 @@ def test_output_message() -> None:
|
||||
|
||||
async def test_output_message_async() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: AIMessage(
|
||||
lambda params: AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -302,16 +302,16 @@ async def test_input_messages_output_message_async() -> None:
|
||||
|
||||
def test_output_messages() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: [
|
||||
lambda params: [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -332,16 +332,16 @@ def test_output_messages() -> None:
|
||||
|
||||
async def test_output_messages_async() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: [
|
||||
lambda params: [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -362,17 +362,17 @@ async def test_output_messages_async() -> None:
|
||||
|
||||
def test_output_dict() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
lambda params: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -395,17 +395,17 @@ def test_output_dict() -> None:
|
||||
|
||||
async def test_output_dict_async() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
lambda params: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -431,17 +431,17 @@ def test_get_input_schema_input_dict() -> None:
|
||||
input: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
lambda params: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -463,17 +463,17 @@ def test_get_input_schema_input_dict() -> None:
|
||||
def test_get_output_schema() -> None:
|
||||
"""Test get output schema."""
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
lambda params: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
for m in params["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
+ [params["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -531,8 +531,8 @@ def test_get_input_schema_input_messages() -> None:
|
||||
def test_using_custom_config_specs() -> None:
|
||||
"""Test that we can configure which keys should be passed to the session factory."""
|
||||
|
||||
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
|
||||
messages = input["messages"]
|
||||
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
|
||||
messages = params["messages"]
|
||||
return [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
@ -644,8 +644,8 @@ def test_using_custom_config_specs() -> None:
|
||||
async def test_using_custom_config_specs_async() -> None:
|
||||
"""Test that we can configure which keys should be passed to the session factory."""
|
||||
|
||||
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
|
||||
messages = input["messages"]
|
||||
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
|
||||
messages = params["messages"]
|
||||
return [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
@ -757,12 +757,12 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
def test_ignore_session_id() -> None:
|
||||
"""Test without config."""
|
||||
|
||||
def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]:
|
||||
def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
return [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
str(m.content) for m in input if isinstance(m, HumanMessage)
|
||||
str(m.content) for m in messages if isinstance(m, HumanMessage)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
@ -564,8 +564,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"required": ["bye", "hello"],
|
||||
}
|
||||
|
||||
def get_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
return input["variable_name"]
|
||||
def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
return value["variable_name"]
|
||||
|
||||
assert RunnableLambda(get_value).get_input_jsonschema() == {
|
||||
"title": "get_value_input",
|
||||
@ -574,8 +574,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"required": ["variable_name"],
|
||||
}
|
||||
|
||||
async def aget_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
return (input["variable_name"], input.get("another"))
|
||||
async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
return (value["variable_name"], value.get("another"))
|
||||
|
||||
assert RunnableLambda(aget_value).get_input_jsonschema() == {
|
||||
"title": "aget_value_input",
|
||||
@ -587,11 +587,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"required": ["another", "variable_name"],
|
||||
}
|
||||
|
||||
async def aget_values(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||
return {
|
||||
"hello": input["variable_name"],
|
||||
"bye": input["variable_name"],
|
||||
"byebye": input["yo"],
|
||||
"hello": value["variable_name"],
|
||||
"bye": value["variable_name"],
|
||||
"byebye": value["yo"],
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values).get_input_jsonschema() == {
|
||||
@ -613,11 +613,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
bye: str
|
||||
byebye: int
|
||||
|
||||
async def aget_values_typed(input: InputType) -> OutputType:
|
||||
async def aget_values_typed(value: InputType) -> OutputType:
|
||||
return {
|
||||
"hello": input["variable_name"],
|
||||
"bye": input["variable_name"],
|
||||
"byebye": input["yo"],
|
||||
"hello": value["variable_name"],
|
||||
"bye": value["variable_name"],
|
||||
"byebye": value["yo"],
|
||||
}
|
||||
|
||||
assert _normalize_schema(
|
||||
@ -2592,8 +2592,8 @@ async def test_prompt_with_llm_and_async_lambda(
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
async def passthrough(input: Any) -> Any:
|
||||
return input
|
||||
async def passthrough(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
chain = prompt | llm | passthrough
|
||||
|
||||
@ -2946,12 +2946,12 @@ def test_higher_order_lambda_runnable(
|
||||
input={"question": lambda x: x["question"]},
|
||||
)
|
||||
|
||||
def router(input: dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
def router(params: dict[str, Any]) -> Runnable:
|
||||
if params["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
if input["key"] == "english":
|
||||
if params["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
msg = f"Unknown key: {input['key']}"
|
||||
msg = f"Unknown key: {params['key']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
chain: Runnable = input_map | router
|
||||
@ -3002,12 +3002,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
||||
input={"question": lambda x: x["question"]},
|
||||
)
|
||||
|
||||
def router(input: dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
def router(value: dict[str, Any]) -> Runnable:
|
||||
if value["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
if input["key"] == "english":
|
||||
if value["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
msg = f"Unknown key: {input['key']}"
|
||||
msg = f"Unknown key: {value['key']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
chain: Runnable = input_map | router
|
||||
@ -3024,12 +3024,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test ainvoke
|
||||
async def arouter(input: dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
async def arouter(params: dict[str, Any]) -> Runnable:
|
||||
if params["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
if input["key"] == "english":
|
||||
if params["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
msg = f"Unknown key: {input['key']}"
|
||||
msg = f"Unknown key: {params['key']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
achain: Runnable = input_map | arouter
|
||||
@ -4125,6 +4125,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@ -4135,15 +4136,15 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
inputs: list[str],
|
||||
) -> list:
|
||||
outputs: list[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
for value in inputs:
|
||||
if value.startswith(self.fail_starts_with):
|
||||
outputs.append(
|
||||
ValueError(
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
outputs.append(value + "a")
|
||||
return outputs
|
||||
|
||||
def batch(
|
||||
@ -4264,6 +4265,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@ -4274,15 +4276,15 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
inputs: list[str],
|
||||
) -> list:
|
||||
outputs: list[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
for value in inputs:
|
||||
if value.startswith(self.fail_starts_with):
|
||||
outputs.append(
|
||||
ValueError(
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
outputs.append(value + "a")
|
||||
return outputs
|
||||
|
||||
async def abatch(
|
||||
@ -5006,10 +5008,10 @@ def test_runnable_iter_context_config() -> None:
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
def gen(input: str) -> Iterator[int]:
|
||||
yield fake.invoke(input)
|
||||
yield fake.invoke(input * 2)
|
||||
yield fake.invoke(input * 3)
|
||||
def gen(value: str) -> Iterator[int]:
|
||||
yield fake.invoke(value)
|
||||
yield fake.invoke(value * 2)
|
||||
yield fake.invoke(value * 3)
|
||||
|
||||
assert gen.get_input_jsonschema() == {
|
||||
"title": "gen_input",
|
||||
@ -5064,10 +5066,10 @@ async def test_runnable_iter_context_config_async() -> None:
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
async def agen(input: str) -> AsyncIterator[int]:
|
||||
yield await fake.ainvoke(input)
|
||||
yield await fake.ainvoke(input * 2)
|
||||
yield await fake.ainvoke(input * 3)
|
||||
async def agen(value: str) -> AsyncIterator[int]:
|
||||
yield await fake.ainvoke(value)
|
||||
yield await fake.ainvoke(value * 2)
|
||||
yield await fake.ainvoke(value * 3)
|
||||
|
||||
assert agen.get_input_jsonschema() == {
|
||||
"title": "agen_input",
|
||||
@ -5130,10 +5132,10 @@ def test_runnable_lambda_context_config() -> None:
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
def fun(input: str) -> int:
|
||||
output = fake.invoke(input)
|
||||
output += fake.invoke(input * 2)
|
||||
output += fake.invoke(input * 3)
|
||||
def fun(value: str) -> int:
|
||||
output = fake.invoke(value)
|
||||
output += fake.invoke(value * 2)
|
||||
output += fake.invoke(value * 3)
|
||||
return output
|
||||
|
||||
assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
|
||||
@ -5186,10 +5188,10 @@ async def test_runnable_lambda_context_config_async() -> None:
|
||||
fake = RunnableLambda(len)
|
||||
|
||||
@chain
|
||||
async def afun(input: str) -> int:
|
||||
output = await fake.ainvoke(input)
|
||||
output += await fake.ainvoke(input * 2)
|
||||
output += await fake.ainvoke(input * 3)
|
||||
async def afun(value: str) -> int:
|
||||
output = await fake.ainvoke(value)
|
||||
output += await fake.ainvoke(value * 2)
|
||||
output += await fake.ainvoke(value * 3)
|
||||
return output
|
||||
|
||||
assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
|
||||
@ -5242,12 +5244,12 @@ async def test_runnable_gen_transform() -> None:
|
||||
for i in range(length):
|
||||
yield i
|
||||
|
||||
def plus_one(input: Iterator[int]) -> Iterator[int]:
|
||||
for i in input:
|
||||
def plus_one(ints: Iterator[int]) -> Iterator[int]:
|
||||
for i in ints:
|
||||
yield i + 1
|
||||
|
||||
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||
async for i in input:
|
||||
async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||
async for i in ints:
|
||||
yield i + 1
|
||||
|
||||
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
|
||||
|
@ -543,10 +543,10 @@ async def test_astream_events_from_model() -> None:
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return model.invoke(input)
|
||||
return model.invoke(input, config)
|
||||
return model.invoke(value)
|
||||
return model.invoke(value, config)
|
||||
|
||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
||||
_assert_events_equal_allow_superset_metadata(
|
||||
@ -667,10 +667,10 @@ async def test_astream_events_from_model() -> None:
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return await model.ainvoke(input)
|
||||
return await model.ainvoke(input, config)
|
||||
return await model.ainvoke(value)
|
||||
return await model.ainvoke(value, config)
|
||||
|
||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
||||
_assert_events_equal_allow_superset_metadata(
|
||||
|
@ -613,10 +613,10 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return model.invoke(input)
|
||||
return model.invoke(input, config)
|
||||
return model.invoke(value)
|
||||
return model.invoke(value, config)
|
||||
|
||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
|
||||
_assert_events_equal_allow_superset_metadata(
|
||||
@ -721,10 +721,10 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return await model.ainvoke(input)
|
||||
return await model.ainvoke(input, config)
|
||||
return await model.ainvoke(value)
|
||||
return await model.ainvoke(value, config)
|
||||
|
||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
|
||||
_assert_events_equal_allow_superset_metadata(
|
||||
@ -2079,6 +2079,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
msg = "Server side error"
|
||||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
@ -2413,14 +2414,14 @@ async def test_break_astream_events() -> None:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
async def __call__(self, value: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
return input
|
||||
return value
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
@ -2433,11 +2434,11 @@ async def test_break_astream_events() -> None:
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
async def sequence(value: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
yield await alittlewhile(value)
|
||||
yield await awhile(value)
|
||||
yield await anotherwhile(value)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
@ -2478,14 +2479,14 @@ async def test_cancel_astream_events() -> None:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
async def __call__(self, value: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
return input
|
||||
return value
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
@ -2498,11 +2499,11 @@ async def test_cancel_astream_events() -> None:
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
async def sequence(value: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
yield await alittlewhile(value)
|
||||
yield await awhile(value)
|
||||
yield await anotherwhile(value)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
|
@ -47,23 +47,23 @@ global_agent = RunnableLambda(lambda x: x * 3)
|
||||
def test_nonlocals() -> None:
|
||||
agent = RunnableLambda(lambda x: x * 2)
|
||||
|
||||
def my_func(input: str, agent: dict[str, str]) -> str:
|
||||
return agent.get("agent_name", input)
|
||||
def my_func(value: str, agent: dict[str, str]) -> str:
|
||||
return agent.get("agent_name", value)
|
||||
|
||||
def my_func2(input: str) -> str:
|
||||
return agent.get("agent_name", input) # type: ignore[attr-defined]
|
||||
def my_func2(value: str) -> str:
|
||||
return agent.get("agent_name", value) # type: ignore[attr-defined]
|
||||
|
||||
def my_func3(input: str) -> str:
|
||||
return agent.invoke(input)
|
||||
def my_func3(value: str) -> str:
|
||||
return agent.invoke(value)
|
||||
|
||||
def my_func4(input: str) -> str:
|
||||
return global_agent.invoke(input)
|
||||
def my_func4(value: str) -> str:
|
||||
return global_agent.invoke(value)
|
||||
|
||||
def my_func5() -> tuple[Callable[[str], str], RunnableLambda]:
|
||||
global_agent = RunnableLambda(lambda x: x * 3)
|
||||
|
||||
def my_func6(input: str) -> str:
|
||||
return global_agent.invoke(input)
|
||||
def my_func6(value: str) -> str:
|
||||
return global_agent.invoke(value)
|
||||
|
||||
return my_func6, global_agent
|
||||
|
||||
|
@ -2,17 +2,17 @@ from langchain_core.globals import get_debug, set_debug
|
||||
|
||||
|
||||
def test_debug_is_settable_via_setter() -> None:
|
||||
from langchain_core import globals
|
||||
from langchain_core import globals as globals_
|
||||
from langchain_core.callbacks.manager import _get_debug
|
||||
|
||||
previous_value = globals._debug
|
||||
previous_value = globals_._debug
|
||||
previous_fn_reading = _get_debug()
|
||||
assert previous_value == previous_fn_reading
|
||||
|
||||
# Flip the value of the flag.
|
||||
set_debug(not previous_value)
|
||||
|
||||
new_value = globals._debug
|
||||
new_value = globals_._debug
|
||||
new_fn_reading = _get_debug()
|
||||
|
||||
try:
|
||||
|
@ -50,7 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -96,7 +96,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
|
@ -1,4 +1,5 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.9, <4.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13'",
|
||||
|
Loading…
Reference in New Issue
Block a user