mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-01 21:35:34 +00:00
Merge 1b6790af76
into 04a899ebe3
This commit is contained in:
commit
132e6ebab3
@ -50,7 +50,7 @@ class LangSmithLoader(BaseLoader):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
filter: Optional[str] = None,
|
filter: Optional[str] = None, # noqa: A002
|
||||||
content_key: str = "",
|
content_key: str = "",
|
||||||
format_content: Optional[Callable[..., str]] = None,
|
format_content: Optional[Callable[..., str]] = None,
|
||||||
client: Optional[LangSmithClient] = None,
|
client: Optional[LangSmithClient] = None,
|
||||||
|
@ -341,15 +341,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
"""Get the output type for this runnable."""
|
"""Get the output type for this runnable."""
|
||||||
return AnyMessage
|
return AnyMessage
|
||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(model_input, PromptValue):
|
||||||
return input
|
return model_input
|
||||||
if isinstance(input, str):
|
if isinstance(model_input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=model_input)
|
||||||
if isinstance(input, Sequence):
|
if isinstance(model_input, Sequence):
|
||||||
return ChatPromptValue(messages=convert_to_messages(input))
|
return ChatPromptValue(messages=convert_to_messages(model_input))
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(model_input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
@ -325,15 +325,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
"""Get the input type for this runnable."""
|
"""Get the input type for this runnable."""
|
||||||
return str
|
return str
|
||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, model_input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(model_input, PromptValue):
|
||||||
return input
|
return model_input
|
||||||
if isinstance(input, str):
|
if isinstance(model_input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=model_input)
|
||||||
if isinstance(input, Sequence):
|
if isinstance(model_input, Sequence):
|
||||||
return ChatPromptValue(messages=convert_to_messages(input))
|
return ChatPromptValue(messages=convert_to_messages(model_input))
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(model_input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
@ -438,7 +438,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
try:
|
try:
|
||||||
llm_result = self.generate_prompt(
|
llm_result = self.generate_prompt(
|
||||||
[self._convert_input(input) for input in inputs],
|
[self._convert_input(input_) for input_ in inputs],
|
||||||
callbacks=[c.get("callbacks") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
tags=[c.get("tags") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
metadata=[c.get("metadata") for c in config],
|
||||||
@ -484,7 +484,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
try:
|
try:
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input) for input in inputs],
|
[self._convert_input(input_) for input_ in inputs],
|
||||||
callbacks=[c.get("callbacks") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
tags=[c.get("tags") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
metadata=[c.get("metadata") for c in config],
|
||||||
|
@ -417,10 +417,10 @@ def add_ai_message_chunks(
|
|||||||
else:
|
else:
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
|
|
||||||
id = None
|
chunk_id = None
|
||||||
for id_ in [left.id] + [o.id for o in others]:
|
for id_ in [left.id] + [o.id for o in others]:
|
||||||
if id_:
|
if id_:
|
||||||
id = id_
|
chunk_id = id_
|
||||||
break
|
break
|
||||||
return left.__class__(
|
return left.__class__(
|
||||||
example=left.example,
|
example=left.example,
|
||||||
@ -429,7 +429,7 @@ def add_ai_message_chunks(
|
|||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
id=id,
|
id=chunk_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,11 @@ class RemoveMessage(BaseMessage):
|
|||||||
type: Literal["remove"] = "remove"
|
type: Literal["remove"] = "remove"
|
||||||
"""The type of the message (used for serialization). Defaults to "remove"."""
|
"""The type of the message (used for serialization). Defaults to "remove"."""
|
||||||
|
|
||||||
def __init__(self, id: str, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str, # noqa: A002
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
"""Create a RemoveMessage.
|
"""Create a RemoveMessage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -208,7 +208,12 @@ class ToolCall(TypedDict):
|
|||||||
type: NotRequired[Literal["tool_call"]]
|
type: NotRequired[Literal["tool_call"]]
|
||||||
|
|
||||||
|
|
||||||
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
|
def tool_call(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
id: Optional[str], # noqa: A002
|
||||||
|
) -> ToolCall:
|
||||||
"""Create a tool call.
|
"""Create a tool call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -254,7 +259,7 @@ def tool_call_chunk(
|
|||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
args: Optional[str] = None,
|
args: Optional[str] = None,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None, # noqa: A002
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
) -> ToolCallChunk:
|
) -> ToolCallChunk:
|
||||||
"""Create a tool call chunk.
|
"""Create a tool call chunk.
|
||||||
@ -292,7 +297,7 @@ def invalid_tool_call(
|
|||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
args: Optional[str] = None,
|
args: Optional[str] = None,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None, # noqa: A002
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
) -> InvalidToolCall:
|
) -> InvalidToolCall:
|
||||||
"""Create an invalid tool call.
|
"""Create an invalid tool call.
|
||||||
|
@ -212,7 +212,7 @@ def _create_message_from_message_type(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
tool_call_id: Optional[str] = None,
|
tool_call_id: Optional[str] = None,
|
||||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None, # noqa: A002
|
||||||
**additional_kwargs: Any,
|
**additional_kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
"""Create a message from a message type and content string.
|
"""Create a message from a message type and content string.
|
||||||
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Optional, TypeVar, Union
|
|||||||
import jsonpatch # type: ignore[import-untyped]
|
import jsonpatch # type: ignore[import-untyped]
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||||
@ -47,6 +48,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
"""The Pydantic object to use for validation.
|
"""The Pydantic object to use for validation.
|
||||||
If None, no validation is performed."""
|
If None, no validation is performed."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
|
@ -20,7 +20,10 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
def droplastn(
|
||||||
|
iter: Iterator[T], # noqa: A002
|
||||||
|
n: int,
|
||||||
|
) -> Iterator[T]:
|
||||||
"""Drop the last n elements of an iterator.
|
"""Drop the last n elements of an iterator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -66,6 +69,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def _transform(
|
def _transform(
|
||||||
self, input: Iterator[Union[str, BaseMessage]]
|
self, input: Iterator[Union[str, BaseMessage]]
|
||||||
) -> Iterator[list[str]]:
|
) -> Iterator[list[str]]:
|
||||||
@ -99,6 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
for part in self.parse(buffer):
|
for part in self.parse(buffer):
|
||||||
yield [part]
|
yield [part]
|
||||||
|
|
||||||
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
) -> AsyncIterator[list[str]]:
|
) -> AsyncIterator[list[str]]:
|
||||||
|
@ -72,6 +72,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
return "json_functions"
|
return "json_functions"
|
||||||
|
|
||||||
|
@override
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
|
@ -30,7 +30,10 @@ if TYPE_CHECKING:
|
|||||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||||
"""Base class for an output parser that can handle streaming input."""
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
def _transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Union[str, BaseMessage]], # noqa: A002
|
||||||
|
) -> Iterator[T]:
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
if isinstance(chunk, BaseMessage):
|
if isinstance(chunk, BaseMessage):
|
||||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
@ -38,7 +41,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
yield self.parse_result([Generation(text=chunk)])
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self,
|
||||||
|
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
if isinstance(chunk, BaseMessage):
|
if isinstance(chunk, BaseMessage):
|
||||||
@ -102,7 +106,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
parsed output, or just the current parsed output.
|
parsed output, or just the current parsed output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
def _diff(
|
||||||
|
self,
|
||||||
|
prev: Optional[T],
|
||||||
|
next: T, # noqa: A002
|
||||||
|
) -> T:
|
||||||
"""Convert parsed outputs into a diff format.
|
"""Convert parsed outputs into a diff format.
|
||||||
|
|
||||||
The semantics of this are up to the output parser.
|
The semantics of this are up to the output parser.
|
||||||
@ -116,6 +124,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||||
@ -140,6 +149,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
yield parsed
|
yield parsed
|
||||||
prev_parsed = parsed
|
prev_parsed = parsed
|
||||||
|
|
||||||
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
|
@ -8,6 +8,8 @@ from collections.abc import AsyncIterator, Iterator
|
|||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
from xml.etree.ElementTree import TreeBuilder
|
from xml.etree.ElementTree import TreeBuilder
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
@ -234,6 +236,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
|
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
|
||||||
raise OutputParserException(msg, llm_output=text) from e
|
raise OutputParserException(msg, llm_output=text) from e
|
||||||
|
|
||||||
|
@override
|
||||||
def _transform(
|
def _transform(
|
||||||
self, input: Iterator[Union[str, BaseMessage]]
|
self, input: Iterator[Union[str, BaseMessage]]
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
@ -242,6 +245,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
yield from streaming_parser.parse(chunk)
|
yield from streaming_parser.parse(chunk)
|
||||||
streaming_parser.close()
|
streaming_parser.close()
|
||||||
|
|
||||||
|
@override
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
|
@ -472,16 +472,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
|
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
|
||||||
input_variables = []
|
input_variables = []
|
||||||
if isinstance(img_template, str):
|
if isinstance(img_template, str):
|
||||||
vars = get_template_variables(img_template, template_format)
|
variables = get_template_variables(
|
||||||
if vars:
|
img_template, template_format
|
||||||
if len(vars) > 1:
|
)
|
||||||
|
if variables:
|
||||||
|
if len(variables) > 1:
|
||||||
msg = (
|
msg = (
|
||||||
"Only one format variable allowed per image"
|
"Only one format variable allowed per image"
|
||||||
f" template.\nGot: {vars}"
|
f" template.\nGot: {variables}"
|
||||||
f"\nFrom: {tmpl}"
|
f"\nFrom: {tmpl}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
input_variables = [vars[0]]
|
input_variables = [variables[0]]
|
||||||
img_template = {"url": img_template}
|
img_template = {"url": img_template}
|
||||||
img_template_obj = ImagePromptTemplate(
|
img_template_obj = ImagePromptTemplate(
|
||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# noqa:A005
|
||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -125,20 +126,20 @@ def mustache_template_vars(
|
|||||||
Returns:
|
Returns:
|
||||||
The variables from the template.
|
The variables from the template.
|
||||||
"""
|
"""
|
||||||
vars: set[str] = set()
|
variables: set[str] = set()
|
||||||
section_depth = 0
|
section_depth = 0
|
||||||
for type, key in mustache.tokenize(template):
|
for type_, key in mustache.tokenize(template):
|
||||||
if type == "end":
|
if type_ == "end":
|
||||||
section_depth -= 1
|
section_depth -= 1
|
||||||
elif (
|
elif (
|
||||||
type in ("variable", "section", "inverted section", "no escape")
|
type_ in ("variable", "section", "inverted section", "no escape")
|
||||||
and key != "."
|
and key != "."
|
||||||
and section_depth == 0
|
and section_depth == 0
|
||||||
):
|
):
|
||||||
vars.add(key.split(".")[0])
|
variables.add(key.split(".")[0])
|
||||||
if type in ("section", "inverted section"):
|
if type_ in ("section", "inverted section"):
|
||||||
section_depth += 1
|
section_depth += 1
|
||||||
return vars
|
return variables
|
||||||
|
|
||||||
|
|
||||||
Defs = dict[str, "Defs"]
|
Defs = dict[str, "Defs"]
|
||||||
@ -158,17 +159,17 @@ def mustache_schema(
|
|||||||
fields = {}
|
fields = {}
|
||||||
prefix: tuple[str, ...] = ()
|
prefix: tuple[str, ...] = ()
|
||||||
section_stack: list[tuple[str, ...]] = []
|
section_stack: list[tuple[str, ...]] = []
|
||||||
for type, key in mustache.tokenize(template):
|
for type_, key in mustache.tokenize(template):
|
||||||
if key == ".":
|
if key == ".":
|
||||||
continue
|
continue
|
||||||
if type == "end":
|
if type_ == "end":
|
||||||
if section_stack:
|
if section_stack:
|
||||||
prefix = section_stack.pop()
|
prefix = section_stack.pop()
|
||||||
elif type in ("section", "inverted section"):
|
elif type_ in ("section", "inverted section"):
|
||||||
section_stack.append(prefix)
|
section_stack.append(prefix)
|
||||||
prefix = prefix + tuple(key.split("."))
|
prefix = prefix + tuple(key.split("."))
|
||||||
fields[prefix] = False
|
fields[prefix] = False
|
||||||
elif type in ("variable", "no escape"):
|
elif type_ in ("variable", "no escape"):
|
||||||
fields[prefix + tuple(key.split("."))] = True
|
fields[prefix + tuple(key.split("."))] = True
|
||||||
defs: Defs = {} # None means leaf node
|
defs: Defs = {} # None means leaf node
|
||||||
while fields:
|
while fields:
|
||||||
|
@ -209,6 +209,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
|
|
||||||
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
@ -269,6 +270,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@override
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: str,
|
input: str,
|
||||||
|
@ -724,7 +724,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self,
|
||||||
|
input: Input, # noqa: A002
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
"""Transform a single input into an output.
|
"""Transform a single input into an output.
|
||||||
|
|
||||||
@ -741,7 +744,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self,
|
||||||
|
input: Input, # noqa: A002
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
"""Default implementation of ainvoke, calls invoke from a thread.
|
"""Default implementation of ainvoke, calls invoke from a thread.
|
||||||
|
|
||||||
@ -772,14 +778,14 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return self.invoke(input, config, **kwargs)
|
return self.invoke(input_, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return self.invoke(input, config, **kwargs)
|
return self.invoke(input_, config, **kwargs)
|
||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
@ -826,15 +832,17 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
i: int, input: Input, config: RunnableConfig
|
i: int, input_: Input, config: RunnableConfig
|
||||||
) -> tuple[int, Union[Output, Exception]]:
|
) -> tuple[int, Union[Output, Exception]]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
|
out: Union[Output, Exception] = self.invoke(
|
||||||
|
input_, config, **kwargs
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
out = e
|
out = e
|
||||||
else:
|
else:
|
||||||
out = self.invoke(input, config, **kwargs)
|
out = self.invoke(input_, config, **kwargs)
|
||||||
|
|
||||||
return (i, out)
|
return (i, out)
|
||||||
|
|
||||||
@ -844,8 +852,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
futures = {
|
futures = {
|
||||||
executor.submit(invoke, i, input, config)
|
executor.submit(invoke, i, input_, config)
|
||||||
for i, (input, config) in enumerate(zip(inputs, configs))
|
for i, (input_, config) in enumerate(zip(inputs, configs))
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -892,15 +900,15 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
input: Input, config: RunnableConfig
|
value: Input, config: RunnableConfig
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return await self.ainvoke(input, config, **kwargs)
|
return await self.ainvoke(value, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return await self.ainvoke(input, config, **kwargs)
|
return await self.ainvoke(value, config, **kwargs)
|
||||||
|
|
||||||
coros = map(ainvoke, inputs, configs)
|
coros = map(ainvoke, inputs, configs)
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
@ -960,24 +968,24 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||||
|
|
||||||
async def ainvoke_task(
|
async def ainvoke_task(
|
||||||
i: int, input: Input, config: RunnableConfig
|
i: int, input_: Input, config: RunnableConfig
|
||||||
) -> tuple[int, Union[Output, Exception]]:
|
) -> tuple[int, Union[Output, Exception]]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
out: Union[Output, Exception] = await self.ainvoke(
|
out: Union[Output, Exception] = await self.ainvoke(
|
||||||
input, config, **kwargs
|
input_, config, **kwargs
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
out = e
|
out = e
|
||||||
else:
|
else:
|
||||||
out = await self.ainvoke(input, config, **kwargs)
|
out = await self.ainvoke(input_, config, **kwargs)
|
||||||
return (i, out)
|
return (i, out)
|
||||||
|
|
||||||
coros = [
|
coros = [
|
||||||
gated_coro(semaphore, ainvoke_task(i, input, config))
|
gated_coro(semaphore, ainvoke_task(i, input_, config))
|
||||||
if semaphore
|
if semaphore
|
||||||
else ainvoke_task(i, input, config)
|
else ainvoke_task(i, input_, config)
|
||||||
for i, (input, config) in enumerate(zip(inputs, configs))
|
for i, (input_, config) in enumerate(zip(inputs, configs))
|
||||||
]
|
]
|
||||||
|
|
||||||
for coro in asyncio.as_completed(coros):
|
for coro in asyncio.as_completed(coros):
|
||||||
@ -985,7 +993,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input, # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
@ -1005,7 +1013,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input, # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
@ -1059,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def astream_log(
|
async def astream_log(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any, # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
diff: bool = True,
|
diff: bool = True,
|
||||||
@ -1130,7 +1138,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def astream_events(
|
async def astream_events(
|
||||||
self,
|
self,
|
||||||
input: Any,
|
input: Any, # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
version: Literal["v1", "v2"] = "v2",
|
version: Literal["v1", "v2"] = "v2",
|
||||||
@ -1396,7 +1404,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input], # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
@ -1438,7 +1446,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input], # noqa: A002
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
@ -1903,7 +1911,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||||
],
|
],
|
||||||
input: Input,
|
input_: Input,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
serialized: Optional[dict[str, Any]] = None,
|
serialized: Optional[dict[str, Any]] = None,
|
||||||
@ -1917,7 +1925,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
serialized,
|
serialized,
|
||||||
input,
|
input_,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
@ -1930,7 +1938,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
context.run(
|
context.run(
|
||||||
call_func_with_variable_args, # type: ignore[arg-type]
|
call_func_with_variable_args, # type: ignore[arg-type]
|
||||||
func,
|
func,
|
||||||
input,
|
input_,
|
||||||
config,
|
config,
|
||||||
run_manager,
|
run_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1953,7 +1961,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Awaitable[Output],
|
Awaitable[Output],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
input: Input,
|
input_: Input,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
serialized: Optional[dict[str, Any]] = None,
|
serialized: Optional[dict[str, Any]] = None,
|
||||||
@ -1967,7 +1975,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
serialized,
|
serialized,
|
||||||
input,
|
input_,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
@ -1976,7 +1984,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
with set_config_context(child_config) as context:
|
with set_config_context(child_config) as context:
|
||||||
coro = acall_func_with_variable_args(
|
coro = acall_func_with_variable_args(
|
||||||
func, input, config, run_manager, **kwargs
|
func, input_, config, run_manager, **kwargs
|
||||||
)
|
)
|
||||||
output: Output = await coro_with_context(coro, context)
|
output: Output = await coro_with_context(coro, context)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
@ -1999,7 +2007,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
list[Union[Exception, Output]],
|
list[Union[Exception, Output]],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
input: list[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
@ -2011,21 +2019,21 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Helper method to transform an Input value to an Output value,
|
Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses.
|
with callbacks. Use this method to implement invoke() in subclasses.
|
||||||
"""
|
"""
|
||||||
if not input:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(input))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||||
run_managers = [
|
run_managers = [
|
||||||
callback_manager.on_chain_start(
|
callback_manager.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input,
|
input_,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for callback_manager, input, config in zip(
|
for callback_manager, input_, config in zip(
|
||||||
callback_managers, input, configs
|
callback_managers, inputs, configs
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
@ -2036,12 +2044,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
]
|
]
|
||||||
if accepts_run_manager(func):
|
if accepts_run_manager(func):
|
||||||
kwargs["run_manager"] = run_managers
|
kwargs["run_manager"] = run_managers
|
||||||
output = func(input, **kwargs) # type: ignore[call-arg]
|
output = func(inputs, **kwargs) # type: ignore[call-arg]
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
for run_manager in run_managers:
|
for run_manager in run_managers:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in input])
|
return cast("list[Output]", [e for _ in inputs])
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -2072,7 +2080,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Awaitable[list[Union[Exception, Output]]],
|
Awaitable[list[Union[Exception, Output]]],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
input: list[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
@ -2085,22 +2093,22 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
with callbacks.
|
with callbacks.
|
||||||
Use this method to implement invoke() in subclasses.
|
Use this method to implement invoke() in subclasses.
|
||||||
"""
|
"""
|
||||||
if not input:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(input))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||||
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
callback_manager.on_chain_start(
|
callback_manager.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input,
|
input_,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for callback_manager, input, config in zip(
|
for callback_manager, input_, config in zip(
|
||||||
callback_managers, input, configs
|
callback_managers, inputs, configs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -2112,13 +2120,13 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
]
|
]
|
||||||
if accepts_run_manager(func):
|
if accepts_run_manager(func):
|
||||||
kwargs["run_manager"] = run_managers
|
kwargs["run_manager"] = run_managers
|
||||||
output = await func(input, **kwargs) # type: ignore[call-arg]
|
output = await func(inputs, **kwargs) # type: ignore[call-arg]
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||||
)
|
)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in input])
|
return cast("list[Output]", [e for _ in inputs])
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -2136,7 +2144,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
def _transform_stream_with_config(
|
def _transform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
inputs: Iterator[Input],
|
||||||
transformer: Union[
|
transformer: Union[
|
||||||
Callable[[Iterator[Input]], Iterator[Output]],
|
Callable[[Iterator[Input]], Iterator[Output]],
|
||||||
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
||||||
@ -2163,7 +2171,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = tee(input, 2)
|
input_for_tracing, input_for_transform = tee(inputs, 2)
|
||||||
# Start the input iterator to ensure the input Runnable starts before this one
|
# Start the input iterator to ensure the input Runnable starts before this one
|
||||||
final_input: Optional[Input] = next(input_for_tracing, None)
|
final_input: Optional[Input] = next(input_for_tracing, None)
|
||||||
final_input_supported = True
|
final_input_supported = True
|
||||||
@ -2237,7 +2245,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def _atransform_stream_with_config(
|
async def _atransform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
inputs: AsyncIterator[Input],
|
||||||
transformer: Union[
|
transformer: Union[
|
||||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||||
Callable[
|
Callable[
|
||||||
@ -2267,7 +2275,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = atee(input, 2)
|
input_for_tracing, input_for_transform = atee(inputs, 2)
|
||||||
# Start the input iterator to ensure the input Runnable starts before this one
|
# Start the input iterator to ensure the input Runnable starts before this one
|
||||||
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
||||||
final_input_supported = True
|
final_input_supported = True
|
||||||
@ -3019,6 +3027,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
|
input_ = input
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -3029,16 +3038,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
with set_config_context(config) as context:
|
with set_config_context(config) as context:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
input = context.run(step.invoke, input, config, **kwargs)
|
input_ = context.run(step.invoke, input_, config, **kwargs)
|
||||||
else:
|
else:
|
||||||
input = context.run(step.invoke, input, config)
|
input_ = context.run(step.invoke, input_, config)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(input)
|
run_manager.on_chain_end(input_)
|
||||||
return cast("Output", input)
|
return cast("Output", input_)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -3059,6 +3068,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
|
input_ = input
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -3069,17 +3079,17 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
with set_config_context(config) as context:
|
with set_config_context(config) as context:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
part = functools.partial(step.ainvoke, input_, config, **kwargs)
|
||||||
else:
|
else:
|
||||||
part = functools.partial(step.ainvoke, input, config)
|
part = functools.partial(step.ainvoke, input_, config)
|
||||||
input = await coro_with_context(part(), context, create_task=True)
|
input_ = await coro_with_context(part(), context, create_task=True)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(input)
|
await run_manager.on_chain_end(input_)
|
||||||
return cast("Output", input)
|
return cast("Output", input_)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def batch(
|
def batch(
|
||||||
@ -3117,11 +3127,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
run_managers = [
|
run_managers = [
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input,
|
input_,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
# invoke
|
# invoke
|
||||||
@ -3248,11 +3258,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
*(
|
*(
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input,
|
input_,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3346,7 +3356,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
inputs: Iterator[Input],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -3359,7 +3369,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# transform the input stream of each step with the next
|
# transform the input stream of each step with the next
|
||||||
# steps that don't natively support transforming an input stream will
|
# steps that don't natively support transforming an input stream will
|
||||||
# buffer input in memory until all available, and then start emitting output
|
# buffer input in memory until all available, and then start emitting output
|
||||||
final_pipeline = cast("Iterator[Output]", input)
|
final_pipeline = cast("Iterator[Output]", inputs)
|
||||||
for idx, step in enumerate(steps):
|
for idx, step in enumerate(steps):
|
||||||
config = patch_config(
|
config = patch_config(
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}")
|
config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}")
|
||||||
@ -3373,7 +3383,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
inputs: AsyncIterator[Input],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -3387,7 +3397,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
# transform the input stream of each step with the next
|
# transform the input stream of each step with the next
|
||||||
# steps that don't natively support transforming an input stream will
|
# steps that don't natively support transforming an input stream will
|
||||||
# buffer input in memory until all available, and then start emitting output
|
# buffer input in memory until all available, and then start emitting output
|
||||||
final_pipeline = cast("AsyncIterator[Output]", input)
|
final_pipeline = cast("AsyncIterator[Output]", inputs)
|
||||||
for idx, step in enumerate(steps):
|
for idx, step in enumerate(steps):
|
||||||
config = patch_config(
|
config = patch_config(
|
||||||
config,
|
config,
|
||||||
@ -3733,7 +3743,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _invoke_step(
|
def _invoke_step(
|
||||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
|
||||||
) -> Any:
|
) -> Any:
|
||||||
child_config = patch_config(
|
child_config = patch_config(
|
||||||
config,
|
config,
|
||||||
@ -3743,7 +3753,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
with set_config_context(child_config) as context:
|
with set_config_context(child_config) as context:
|
||||||
return context.run(
|
return context.run(
|
||||||
step.invoke,
|
step.invoke,
|
||||||
input,
|
input_,
|
||||||
child_config,
|
child_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3785,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _ainvoke_step(
|
async def _ainvoke_step(
|
||||||
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
step: Runnable[Input, Any], input_: Input, config: RunnableConfig, key: str
|
||||||
) -> Any:
|
) -> Any:
|
||||||
child_config = patch_config(
|
child_config = patch_config(
|
||||||
config,
|
config,
|
||||||
@ -3793,7 +3803,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
with set_config_context(child_config) as context:
|
with set_config_context(child_config) as context:
|
||||||
return await coro_with_context(
|
return await coro_with_context(
|
||||||
step.ainvoke(input, child_config), context, create_task=True
|
step.ainvoke(input_, child_config), context, create_task=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
@ -3823,7 +3833,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
inputs: Iterator[Input],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
@ -3831,7 +3841,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
steps = dict(self.steps__)
|
steps = dict(self.steps__)
|
||||||
# Each step gets a copy of the input iterator,
|
# Each step gets a copy of the input iterator,
|
||||||
# which is consumed in parallel in a separate thread.
|
# which is consumed in parallel in a separate thread.
|
||||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
input_copies = list(safetee(inputs, len(steps), lock=threading.Lock()))
|
||||||
with get_executor_for_config(config) as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
# Create the transform() generator for each step
|
# Create the transform() generator for each step
|
||||||
named_generators = [
|
named_generators = [
|
||||||
@ -3890,7 +3900,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
inputs: AsyncIterator[Input],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
@ -3898,7 +3908,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
steps = dict(self.steps__)
|
steps = dict(self.steps__)
|
||||||
# Each step gets a copy of the input iterator,
|
# Each step gets a copy of the input iterator,
|
||||||
# which is consumed in parallel in a separate thread.
|
# which is consumed in parallel in a separate thread.
|
||||||
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
input_copies = list(atee(inputs, len(steps), lock=asyncio.Lock()))
|
||||||
# Create the transform() generator for each step
|
# Create the transform() generator for each step
|
||||||
named_generators = [
|
named_generators = [
|
||||||
(
|
(
|
||||||
@ -4590,7 +4600,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input_: Input,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -4599,7 +4609,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
output: Optional[Output] = None
|
output: Optional[Output] = None
|
||||||
for chunk in call_func_with_variable_args(
|
for chunk in call_func_with_variable_args(
|
||||||
cast("Callable[[Input], Iterator[Output]]", self.func),
|
cast("Callable[[Input], Iterator[Output]]", self.func),
|
||||||
input,
|
input_,
|
||||||
config,
|
config,
|
||||||
run_manager,
|
run_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -4613,18 +4623,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
output = chunk
|
output = chunk
|
||||||
else:
|
else:
|
||||||
output = call_func_with_variable_args(
|
output = call_func_with_variable_args(
|
||||||
self.func, input, config, run_manager, **kwargs
|
self.func, input_, config, run_manager, **kwargs
|
||||||
)
|
)
|
||||||
# If the output is a Runnable, invoke it
|
# If the output is a Runnable, invoke it
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
recursion_limit = config["recursion_limit"]
|
recursion_limit = config["recursion_limit"]
|
||||||
if recursion_limit <= 0:
|
if recursion_limit <= 0:
|
||||||
msg = (
|
msg = (
|
||||||
f"Recursion limit reached when invoking {self} with input {input}."
|
f"Recursion limit reached when invoking {self} with input {input_}."
|
||||||
)
|
)
|
||||||
raise RecursionError(msg)
|
raise RecursionError(msg)
|
||||||
output = output.invoke(
|
output = output.invoke(
|
||||||
input,
|
input_,
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
callbacks=run_manager.get_child(),
|
callbacks=run_manager.get_child(),
|
||||||
@ -4635,7 +4645,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
value: Input,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -4646,7 +4656,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if inspect.isgeneratorfunction(self.func):
|
if inspect.isgeneratorfunction(self.func):
|
||||||
|
|
||||||
def func(
|
def func(
|
||||||
input: Input,
|
value: Input,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -4654,7 +4664,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
output: Optional[Output] = None
|
output: Optional[Output] = None
|
||||||
for chunk in call_func_with_variable_args(
|
for chunk in call_func_with_variable_args(
|
||||||
cast("Callable[[Input], Iterator[Output]]", self.func),
|
cast("Callable[[Input], Iterator[Output]]", self.func),
|
||||||
input,
|
value,
|
||||||
config,
|
config,
|
||||||
run_manager.get_sync(),
|
run_manager.get_sync(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -4671,13 +4681,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def func(
|
def func(
|
||||||
input: Input,
|
value: Input,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return call_func_with_variable_args(
|
return call_func_with_variable_args(
|
||||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
self.func, value, config, run_manager.get_sync(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@ -4693,7 +4703,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
"AsyncGenerator[Any, Any]",
|
"AsyncGenerator[Any, Any]",
|
||||||
acall_func_with_variable_args(
|
acall_func_with_variable_args(
|
||||||
cast("Callable", afunc),
|
cast("Callable", afunc),
|
||||||
input,
|
value,
|
||||||
config,
|
config,
|
||||||
run_manager,
|
run_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -4713,18 +4723,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
output = chunk
|
output = chunk
|
||||||
else:
|
else:
|
||||||
output = await acall_func_with_variable_args(
|
output = await acall_func_with_variable_args(
|
||||||
cast("Callable", afunc), input, config, run_manager, **kwargs
|
cast("Callable", afunc), value, config, run_manager, **kwargs
|
||||||
)
|
)
|
||||||
# If the output is a Runnable, invoke it
|
# If the output is a Runnable, invoke it
|
||||||
if isinstance(output, Runnable):
|
if isinstance(output, Runnable):
|
||||||
recursion_limit = config["recursion_limit"]
|
recursion_limit = config["recursion_limit"]
|
||||||
if recursion_limit <= 0:
|
if recursion_limit <= 0:
|
||||||
msg = (
|
msg = (
|
||||||
f"Recursion limit reached when invoking {self} with input {input}."
|
f"Recursion limit reached when invoking {self} with input {value}."
|
||||||
)
|
)
|
||||||
raise RecursionError(msg)
|
raise RecursionError(msg)
|
||||||
output = await output.ainvoke(
|
output = await output.ainvoke(
|
||||||
input,
|
value,
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
callbacks=run_manager.get_child(),
|
callbacks=run_manager.get_child(),
|
||||||
@ -4789,14 +4799,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
chunks: Iterator[Input],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
final: Input
|
final: Input
|
||||||
got_first_val = False
|
got_first_val = False
|
||||||
for ichunk in input:
|
for ichunk in chunks:
|
||||||
# By definitions, RunnableLambdas consume all input before emitting output.
|
# By definitions, RunnableLambdas consume all input before emitting output.
|
||||||
# If the input is not addable, then we'll assume that we can
|
# If the input is not addable, then we'll assume that we can
|
||||||
# only operate on the last chunk.
|
# only operate on the last chunk.
|
||||||
@ -4881,14 +4891,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
chunks: AsyncIterator[Input],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
final: Input
|
final: Input
|
||||||
got_first_val = False
|
got_first_val = False
|
||||||
async for ichunk in input:
|
async for ichunk in chunks:
|
||||||
# By definitions, RunnableLambdas consume all input before emitting output.
|
# By definitions, RunnableLambdas consume all input before emitting output.
|
||||||
# If the input is not addable, then we'll assume that we can
|
# If the input is not addable, then we'll assume that we can
|
||||||
# only operate on the last chunk.
|
# only operate on the last chunk.
|
||||||
@ -4913,13 +4923,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
def func(
|
def func(
|
||||||
input: Input,
|
input_: Input,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return call_func_with_variable_args(
|
return call_func_with_variable_args(
|
||||||
self.func, input, config, run_manager.get_sync(), **kwargs
|
self.func, input_, config, run_manager.get_sync(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
@ -196,6 +196,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
@ -254,6 +255,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
run_manager.on_chain_end(output)
|
run_manager.on_chain_end(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@override
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
@ -302,6 +304,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
await run_manager.on_chain_end(output)
|
await run_manager.on_chain_end(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@override
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -388,6 +391,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
raise
|
raise
|
||||||
run_manager.on_chain_end(final_output)
|
run_manager.on_chain_end(final_output)
|
||||||
|
|
||||||
|
@override
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
|
@ -401,7 +401,7 @@ def call_func_with_variable_args(
|
|||||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||||
],
|
],
|
||||||
input: Input,
|
input: Input, # noqa: A002
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -438,7 +438,7 @@ def acall_func_with_variable_args(
|
|||||||
Awaitable[Output],
|
Awaitable[Output],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
input: Input,
|
input: Input, # noqa: A002
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -178,16 +178,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||||
input: Input,
|
input_: Input,
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
bound, config = prepared
|
bound, config = prepared
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return bound.invoke(input, config, **kwargs)
|
return bound.invoke(input_, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return bound.invoke(input, config, **kwargs)
|
return bound.invoke(input_, config, **kwargs)
|
||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
@ -221,16 +221,16 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||||
input: Input,
|
input_: Input,
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
bound, config = prepared
|
bound, config = prepared
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return await bound.ainvoke(input, config, **kwargs)
|
return await bound.ainvoke(input_, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return await bound.ainvoke(input, config, **kwargs)
|
return await bound.ainvoke(input_, config, **kwargs)
|
||||||
|
|
||||||
coros = map(ainvoke, prepared, inputs)
|
coros = map(ainvoke, prepared, inputs)
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
@ -269,7 +269,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
isinstance(input, dict) for input in inputs
|
isinstance(input_, dict) for input_ in inputs
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||||
@ -298,11 +298,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
run_managers = [
|
run_managers = [
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input_ if isinstance(input_, dict) else {"input": input_},
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
to_return: dict[int, Any] = {}
|
to_return: dict[int, Any] = {}
|
||||||
@ -311,7 +311,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
first_to_raise = None
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
outputs = runnable.batch(
|
outputs = runnable.batch(
|
||||||
[input for _, input in sorted(run_again.items())],
|
[input_ for _, input_ in sorted(run_again.items())],
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||||
@ -320,7 +320,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
|
||||||
if isinstance(output, BaseException) and not isinstance(
|
if isinstance(output, BaseException) and not isinstance(
|
||||||
output, self.exceptions_to_handle
|
output, self.exceptions_to_handle
|
||||||
):
|
):
|
||||||
@ -331,7 +331,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
run_again.pop(i)
|
run_again.pop(i)
|
||||||
elif isinstance(output, self.exceptions_to_handle):
|
elif isinstance(output, self.exceptions_to_handle):
|
||||||
if self.exception_key:
|
if self.exception_key:
|
||||||
input[self.exception_key] = output # type: ignore[index]
|
input_[self.exception_key] = output # type: ignore[index]
|
||||||
handled_exceptions[i] = output
|
handled_exceptions[i] = output
|
||||||
else:
|
else:
|
||||||
run_managers[i].on_chain_end(output)
|
run_managers[i].on_chain_end(output)
|
||||||
@ -363,7 +363,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
isinstance(input, dict) for input in inputs
|
isinstance(input_, dict) for input_ in inputs
|
||||||
):
|
):
|
||||||
msg = (
|
msg = (
|
||||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||||
@ -393,11 +393,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
*(
|
*(
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
None,
|
None,
|
||||||
input,
|
input_,
|
||||||
name=config.get("run_name") or self.get_name(),
|
name=config.get("run_name") or self.get_name(),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input_, config in zip(callback_managers, inputs, configs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -407,7 +407,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
first_to_raise = None
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
outputs = await runnable.abatch(
|
outputs = await runnable.abatch(
|
||||||
[input for _, input in sorted(run_again.items())],
|
[input_ for _, input_ in sorted(run_again.items())],
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||||
@ -417,7 +417,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
for (i, input_), output in zip(sorted(run_again.copy().items()), outputs):
|
||||||
if isinstance(output, BaseException) and not isinstance(
|
if isinstance(output, BaseException) and not isinstance(
|
||||||
output, self.exceptions_to_handle
|
output, self.exceptions_to_handle
|
||||||
):
|
):
|
||||||
@ -428,7 +428,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
run_again.pop(i)
|
run_again.pop(i)
|
||||||
elif isinstance(output, self.exceptions_to_handle):
|
elif isinstance(output, self.exceptions_to_handle):
|
||||||
if self.exception_key:
|
if self.exception_key:
|
||||||
input[self.exception_key] = output # type: ignore[index]
|
input_[self.exception_key] = output # type: ignore[index]
|
||||||
handled_exceptions[i] = output
|
handled_exceptions[i] = output
|
||||||
else:
|
else:
|
||||||
to_return[i] = output
|
to_return[i] = output
|
||||||
|
@ -111,7 +111,12 @@ class Node(NamedTuple):
|
|||||||
data: Union[type[BaseModel], RunnableType, None]
|
data: Union[type[BaseModel], RunnableType, None]
|
||||||
metadata: Optional[dict[str, Any]]
|
metadata: Optional[dict[str, Any]]
|
||||||
|
|
||||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
def copy(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
id: Optional[str] = None, # noqa: A002
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
"""Return a copy of the node with optional new id and name.
|
"""Return a copy of the node with optional new id and name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -181,7 +186,10 @@ class MermaidDrawMethod(Enum):
|
|||||||
API = "api" # Uses Mermaid.INK API to render the graph
|
API = "api" # Uses Mermaid.INK API to render the graph
|
||||||
|
|
||||||
|
|
||||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
|
def node_data_str(
|
||||||
|
id: str, # noqa: A002
|
||||||
|
data: Union[type[BaseModel], RunnableType, None],
|
||||||
|
) -> str:
|
||||||
"""Convert the data of a node to a string.
|
"""Convert the data of a node to a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -320,7 +328,7 @@ class Graph:
|
|||||||
def add_node(
|
def add_node(
|
||||||
self,
|
self,
|
||||||
data: Union[type[BaseModel], RunnableType, None],
|
data: Union[type[BaseModel], RunnableType, None],
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None, # noqa: A002
|
||||||
*,
|
*,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
@ -340,8 +348,8 @@ class Graph:
|
|||||||
if id is not None and id in self.nodes:
|
if id is not None and id in self.nodes:
|
||||||
msg = f"Node with id {id} already exists"
|
msg = f"Node with id {id} already exists"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
id = id or self.next_id()
|
id_ = id or self.next_id()
|
||||||
node = Node(id=id, data=data, metadata=metadata, name=node_data_str(id, data))
|
node = Node(id=id_, data=data, metadata=metadata, name=node_data_str(id_, data))
|
||||||
self.nodes[node.id] = node
|
self.nodes[node.id] = node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -406,8 +414,8 @@ class Graph:
|
|||||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||||
prefix = ""
|
prefix = ""
|
||||||
|
|
||||||
def prefixed(id: str) -> str:
|
def prefixed(id_: str) -> str:
|
||||||
return f"{prefix}:{id}" if prefix else id
|
return f"{prefix}:{id_}" if prefix else id_
|
||||||
|
|
||||||
# prefix each node
|
# prefix each node
|
||||||
self.nodes.update(
|
self.nodes.update(
|
||||||
@ -450,8 +458,8 @@ class Graph:
|
|||||||
|
|
||||||
return Graph(
|
return Graph(
|
||||||
nodes={
|
nodes={
|
||||||
_get_node_id(id): node.copy(id=_get_node_id(id))
|
_get_node_id(id_): node.copy(id=_get_node_id(id_))
|
||||||
for id, node in self.nodes.items()
|
for id_, node in self.nodes.items()
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
edge.copy(
|
edge.copy(
|
||||||
|
@ -187,7 +187,7 @@ def _build_sugiyama_layout(
|
|||||||
# Y
|
# Y
|
||||||
#
|
#
|
||||||
|
|
||||||
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
|
vertices_ = {id_: Vertex(f" {data} ") for id_, data in vertices.items()}
|
||||||
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
|
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
|
||||||
vertices_list = vertices_.values()
|
vertices_list = vertices_.values()
|
||||||
graph = Graph(vertices_list, edges_)
|
graph = Graph(vertices_list, edges_)
|
||||||
|
@ -509,20 +509,20 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
def _enter_history(self, value: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||||
messages = hist.messages.copy()
|
messages = hist.messages.copy()
|
||||||
|
|
||||||
if not self.history_messages_key:
|
if not self.history_messages_key:
|
||||||
# return all messages
|
# return all messages
|
||||||
input_val = (
|
input_val = (
|
||||||
input if not self.input_messages_key else input[self.input_messages_key]
|
value if not self.input_messages_key else value[self.input_messages_key]
|
||||||
)
|
)
|
||||||
messages += self._get_input_messages(input_val)
|
messages += self._get_input_messages(input_val)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def _aenter_history(
|
async def _aenter_history(
|
||||||
self, input: dict[str, Any], config: RunnableConfig
|
self, value: dict[str, Any], config: RunnableConfig
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||||
messages = (await hist.aget_messages()).copy()
|
messages = (await hist.aget_messages()).copy()
|
||||||
@ -530,7 +530,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
if not self.history_messages_key:
|
if not self.history_messages_key:
|
||||||
# return all messages
|
# return all messages
|
||||||
input_val = (
|
input_val = (
|
||||||
input if not self.input_messages_key else input[self.input_messages_key]
|
value if not self.input_messages_key else value[self.input_messages_key]
|
||||||
)
|
)
|
||||||
messages += self._get_input_messages(input_val)
|
messages += self._get_input_messages(input_val)
|
||||||
return messages
|
return messages
|
||||||
|
@ -483,19 +483,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
input: dict[str, Any],
|
value: dict[str, Any],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if not isinstance(input, dict):
|
if not isinstance(value, dict):
|
||||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**input,
|
**value,
|
||||||
**self.mapper.invoke(
|
**self.mapper.invoke(
|
||||||
input,
|
value,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
@ -512,19 +512,19 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
input: dict[str, Any],
|
value: dict[str, Any],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if not isinstance(input, dict):
|
if not isinstance(value, dict):
|
||||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**input,
|
**value,
|
||||||
**await self.mapper.ainvoke(
|
**await self.mapper.ainvoke(
|
||||||
input,
|
value,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
@ -541,7 +541,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[dict[str, Any]],
|
values: Iterator[dict[str, Any]],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -549,7 +549,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
mapper_keys = set(self.mapper.steps__.keys())
|
mapper_keys = set(self.mapper.steps__.keys())
|
||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
for_passthrough, for_map = safetee(values, 2, lock=threading.Lock())
|
||||||
|
|
||||||
# create map output stream
|
# create map output stream
|
||||||
map_output = self.mapper.transform(
|
map_output = self.mapper.transform(
|
||||||
@ -598,7 +598,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[dict[str, Any]],
|
values: AsyncIterator[dict[str, Any]],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -606,7 +606,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
mapper_keys = set(self.mapper.steps__.keys())
|
mapper_keys = set(self.mapper.steps__.keys())
|
||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
for_passthrough, for_map = atee(values, 2, lock=asyncio.Lock())
|
||||||
# create map output stream
|
# create map output stream
|
||||||
map_output = self.mapper.atransform(
|
map_output = self.mapper.atransform(
|
||||||
for_map,
|
for_map,
|
||||||
@ -731,23 +731,23 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
return super().get_name(suffix, name=name)
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
def _pick(self, input: dict[str, Any]) -> Any:
|
def _pick(self, value: dict[str, Any]) -> Any:
|
||||||
if not isinstance(input, dict):
|
if not isinstance(value, dict):
|
||||||
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
msg = "The input to RunnablePassthrough.assign() must be a dict."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
if isinstance(self.keys, str):
|
if isinstance(self.keys, str):
|
||||||
return input.get(self.keys)
|
return value.get(self.keys)
|
||||||
picked = {k: input.get(k) for k in self.keys if k in input}
|
picked = {k: value.get(k) for k in self.keys if k in value}
|
||||||
if picked:
|
if picked:
|
||||||
return AddableDict(picked)
|
return AddableDict(picked)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
input: dict[str, Any],
|
value: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return self._pick(input)
|
return self._pick(value)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
@ -760,9 +760,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
input: dict[str, Any],
|
value: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return self._pick(input)
|
return self._pick(value)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -775,9 +775,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
self,
|
self,
|
||||||
input: Iterator[dict[str, Any]],
|
chunks: Iterator[dict[str, Any]],
|
||||||
) -> Iterator[dict[str, Any]]:
|
) -> Iterator[dict[str, Any]]:
|
||||||
for chunk in input:
|
for chunk in chunks:
|
||||||
picked = self._pick(chunk)
|
picked = self._pick(chunk)
|
||||||
if picked is not None:
|
if picked is not None:
|
||||||
yield picked
|
yield picked
|
||||||
@ -795,9 +795,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
async def _atransform(
|
async def _atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[dict[str, Any]],
|
chunks: AsyncIterator[dict[str, Any]],
|
||||||
) -> AsyncIterator[dict[str, Any]]:
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
async for chunk in input:
|
async for chunk in chunks:
|
||||||
picked = self._pick(chunk)
|
picked = self._pick(chunk)
|
||||||
if picked is not None:
|
if picked is not None:
|
||||||
yield picked
|
yield picked
|
||||||
|
@ -178,7 +178,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input_: Input,
|
||||||
run_manager: "CallbackManagerForChainRun",
|
run_manager: "CallbackManagerForChainRun",
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -186,7 +186,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
for attempt in self._sync_retrying(reraise=True):
|
for attempt in self._sync_retrying(reraise=True):
|
||||||
with attempt:
|
with attempt:
|
||||||
result = super().invoke(
|
result = super().invoke(
|
||||||
input,
|
input_,
|
||||||
self._patch_config(config, run_manager, attempt.retry_state),
|
self._patch_config(config, run_manager, attempt.retry_state),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -202,7 +202,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input_: Input,
|
||||||
run_manager: "AsyncCallbackManagerForChainRun",
|
run_manager: "AsyncCallbackManagerForChainRun",
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -210,7 +210,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
async for attempt in self._async_retrying(reraise=True):
|
async for attempt in self._async_retrying(reraise=True):
|
||||||
with attempt:
|
with attempt:
|
||||||
result = await super().ainvoke(
|
result = await super().ainvoke(
|
||||||
input,
|
input_,
|
||||||
self._patch_config(config, run_manager, attempt.retry_state),
|
self._patch_config(config, run_manager, attempt.retry_state),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -148,22 +148,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
|||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input_["key"] for input_ in inputs]
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
actual_inputs = [input_["input"] for input_ in inputs]
|
||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
msg = "One or more keys do not have a corresponding runnable"
|
msg = "One or more keys do not have a corresponding runnable"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
runnable: Runnable, input: Input, config: RunnableConfig
|
runnable: Runnable, input_: Input, config: RunnableConfig
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return runnable.invoke(input, config, **kwargs)
|
return runnable.invoke(input_, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return runnable.invoke(input, config, **kwargs)
|
return runnable.invoke(input_, config, **kwargs)
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
@ -185,22 +185,22 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
|||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input_["key"] for input_ in inputs]
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
actual_inputs = [input_["input"] for input_ in inputs]
|
||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
msg = "One or more keys do not have a corresponding runnable"
|
msg = "One or more keys do not have a corresponding runnable"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
runnable: Runnable, input: Input, config: RunnableConfig
|
runnable: Runnable, input_: Input, config: RunnableConfig
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return await runnable.ainvoke(input, config, **kwargs)
|
return await runnable.ainvoke(input_, config, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
else:
|
else:
|
||||||
return await runnable.ainvoke(input, config, **kwargs)
|
return await runnable.ainvoke(input_, config, **kwargs)
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
@ -75,7 +75,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
|||||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
||||||
|
|
||||||
|
|
||||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||||
"""Check if a callable accepts a run_manager argument.
|
"""Check if a callable accepts a run_manager argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,7 +90,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||||
"""Check if a callable accepts a config argument.
|
"""Check if a callable accepts a config argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -105,7 +105,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def accepts_context(callable: Callable[..., Any]) -> bool:
|
def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||||
"""Check if a callable accepts a context argument.
|
"""Check if a callable accepts a context argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -691,7 +691,7 @@ def get_unique_config_specs(
|
|||||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||||
)
|
)
|
||||||
unique: list[ConfigurableFieldSpec] = []
|
unique: list[ConfigurableFieldSpec] = []
|
||||||
for id, dupes in grouped:
|
for spec_id, dupes in grouped:
|
||||||
first = next(dupes)
|
first = next(dupes)
|
||||||
others = list(dupes)
|
others = list(dupes)
|
||||||
if len(others) == 0 or all(o == first for o in others):
|
if len(others) == 0 or all(o == first for o in others):
|
||||||
@ -699,7 +699,7 @@ def get_unique_config_specs(
|
|||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"RunnableSequence contains conflicting config specs"
|
"RunnableSequence contains conflicting config specs"
|
||||||
f"for {id}: {[first] + others}"
|
f"for {spec_id}: {[first] + others}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return unique
|
return unique
|
||||||
|
@ -184,7 +184,7 @@ class StructuredQuery(Expr):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
filter: Optional[FilterDirective],
|
filter: Optional[FilterDirective], # noqa: A002
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -943,17 +943,17 @@ def _handle_tool_error(
|
|||||||
|
|
||||||
|
|
||||||
def _prep_run_args(
|
def _prep_run_args(
|
||||||
input: Union[str, dict, ToolCall],
|
value: Union[str, dict, ToolCall],
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[Union[str, dict], dict]:
|
) -> tuple[Union[str, dict], dict]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if _is_tool_call(input):
|
if _is_tool_call(value):
|
||||||
tool_call_id: Optional[str] = cast("ToolCall", input)["id"]
|
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
|
||||||
tool_input: Union[str, dict] = cast("ToolCall", input)["args"].copy()
|
tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
|
||||||
else:
|
else:
|
||||||
tool_call_id = None
|
tool_call_id = None
|
||||||
tool_input = cast("Union[str, dict]", input)
|
tool_input = cast("Union[str, dict]", value)
|
||||||
return (
|
return (
|
||||||
tool_input,
|
tool_input,
|
||||||
dict(
|
dict(
|
||||||
|
@ -740,7 +740,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
|
|
||||||
async def _astream_events_implementation_v1(
|
async def _astream_events_implementation_v1(
|
||||||
runnable: Runnable[Input, Output],
|
runnable: Runnable[Input, Output],
|
||||||
input: Any,
|
value: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Optional[Sequence[str]] = None,
|
||||||
@ -789,7 +789,7 @@ async def _astream_events_implementation_v1(
|
|||||||
|
|
||||||
async for log in _astream_log_implementation(
|
async for log in _astream_log_implementation(
|
||||||
runnable,
|
runnable,
|
||||||
input,
|
value,
|
||||||
config=config,
|
config=config,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
diff=True,
|
diff=True,
|
||||||
@ -810,7 +810,7 @@ async def _astream_events_implementation_v1(
|
|||||||
tags=root_tags,
|
tags=root_tags,
|
||||||
metadata=root_metadata,
|
metadata=root_metadata,
|
||||||
data={
|
data={
|
||||||
"input": input,
|
"input": value,
|
||||||
},
|
},
|
||||||
parent_ids=[], # Not supported in v1
|
parent_ids=[], # Not supported in v1
|
||||||
)
|
)
|
||||||
@ -924,7 +924,7 @@ async def _astream_events_implementation_v1(
|
|||||||
|
|
||||||
async def _astream_events_implementation_v2(
|
async def _astream_events_implementation_v2(
|
||||||
runnable: Runnable[Input, Output],
|
runnable: Runnable[Input, Output],
|
||||||
input: Any,
|
value: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
include_names: Optional[Sequence[str]] = None,
|
include_names: Optional[Sequence[str]] = None,
|
||||||
@ -972,7 +972,7 @@ async def _astream_events_implementation_v2(
|
|||||||
async def consume_astream() -> None:
|
async def consume_astream() -> None:
|
||||||
try:
|
try:
|
||||||
# if astream also calls tap_output_aiter this will be a no-op
|
# if astream also calls tap_output_aiter this will be a no-op
|
||||||
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
|
async with aclosing(runnable.astream(value, config, **kwargs)) as stream:
|
||||||
async for _ in event_streamer.tap_output_aiter(run_id, stream):
|
async for _ in event_streamer.tap_output_aiter(run_id, stream):
|
||||||
# All the content will be picked up
|
# All the content will be picked up
|
||||||
pass
|
pass
|
||||||
@ -993,7 +993,7 @@ async def _astream_events_implementation_v2(
|
|||||||
# chain are not available until the entire input is consumed.
|
# chain are not available until the entire input is consumed.
|
||||||
# As a temporary solution, we'll modify the input to be the input
|
# As a temporary solution, we'll modify the input to be the input
|
||||||
# that was passed into the chain.
|
# that was passed into the chain.
|
||||||
event["data"]["input"] = input
|
event["data"]["input"] = value
|
||||||
first_event_run_id = event["run_id"]
|
first_event_run_id = event["run_id"]
|
||||||
yield event
|
yield event
|
||||||
continue
|
continue
|
||||||
|
@ -580,7 +580,7 @@ def _get_standardized_outputs(
|
|||||||
@overload
|
@overload
|
||||||
def _astream_log_implementation(
|
def _astream_log_implementation(
|
||||||
runnable: Runnable[Input, Output],
|
runnable: Runnable[Input, Output],
|
||||||
input: Any,
|
value: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stream: LogStreamCallbackHandler,
|
stream: LogStreamCallbackHandler,
|
||||||
@ -593,7 +593,7 @@ def _astream_log_implementation(
|
|||||||
@overload
|
@overload
|
||||||
def _astream_log_implementation(
|
def _astream_log_implementation(
|
||||||
runnable: Runnable[Input, Output],
|
runnable: Runnable[Input, Output],
|
||||||
input: Any,
|
value: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stream: LogStreamCallbackHandler,
|
stream: LogStreamCallbackHandler,
|
||||||
@ -605,7 +605,7 @@ def _astream_log_implementation(
|
|||||||
|
|
||||||
async def _astream_log_implementation(
|
async def _astream_log_implementation(
|
||||||
runnable: Runnable[Input, Output],
|
runnable: Runnable[Input, Output],
|
||||||
input: Any,
|
value: Any,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stream: LogStreamCallbackHandler,
|
stream: LogStreamCallbackHandler,
|
||||||
@ -651,7 +651,7 @@ async def _astream_log_implementation(
|
|||||||
prev_final_output: Optional[Output] = None
|
prev_final_output: Optional[Output] = None
|
||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
|
|
||||||
async for chunk in runnable.astream(input, config, **kwargs):
|
async for chunk in runnable.astream(value, config, **kwargs):
|
||||||
prev_final_output = final_output
|
prev_final_output = final_output
|
||||||
if final_output is None:
|
if final_output is None:
|
||||||
final_output = chunk
|
final_output = chunk
|
||||||
|
@ -596,7 +596,7 @@ def convert_to_json_schema(
|
|||||||
|
|
||||||
@beta()
|
@beta()
|
||||||
def tool_example_to_messages(
|
def tool_example_to_messages(
|
||||||
input: str,
|
input: str, # noqa: A002
|
||||||
tool_calls: list[BaseModel],
|
tool_calls: list[BaseModel],
|
||||||
tool_outputs: Optional[list[str]] = None,
|
tool_outputs: Optional[list[str]] = None,
|
||||||
*,
|
*,
|
||||||
|
@ -363,7 +363,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
filter: Optional[Callable[[Document], bool]] = None,
|
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
|
||||||
) -> list[tuple[Document, float, list[float]]]:
|
) -> list[tuple[Document, float, list[float]]]:
|
||||||
# get all docs with fixed order in list
|
# get all docs with fixed order in list
|
||||||
docs = list(self.store.values())
|
docs = list(self.store.values())
|
||||||
@ -402,7 +402,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
filter: Optional[Callable[[Document], bool]] = None,
|
filter: Optional[Callable[[Document], bool]] = None, # noqa: A002
|
||||||
**_kwargs: Any,
|
**_kwargs: Any,
|
||||||
) -> list[tuple[Document, float]]:
|
) -> list[tuple[Document, float]]:
|
||||||
"""Search for the most similar documents to the given embedding.
|
"""Search for the most similar documents to the given embedding.
|
||||||
|
@ -100,7 +100,6 @@ ignore = [
|
|||||||
"UP007", # Doesn't play well with Pydantic in Python 3.9
|
"UP007", # Doesn't play well with Pydantic in Python 3.9
|
||||||
|
|
||||||
# TODO rules
|
# TODO rules
|
||||||
"A",
|
|
||||||
"ANN401",
|
"ANN401",
|
||||||
"BLE",
|
"BLE",
|
||||||
"ERA",
|
"ERA",
|
||||||
|
@ -869,9 +869,9 @@ def create_image_data() -> str:
|
|||||||
return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501
|
return "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
def create_base64_image(format: str = "jpeg") -> str:
|
def create_base64_image(image_format: str = "jpeg") -> str:
|
||||||
data = create_image_data()
|
data = create_image_data()
|
||||||
return f"data:image/{format};base64,{data}"
|
return f"data:image/{image_format};base64,{data}"
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_openai_messages_string() -> None:
|
def test_convert_to_openai_messages_string() -> None:
|
||||||
|
@ -639,7 +639,7 @@ def test_parse_with_different_pydantic_1_proper() -> None:
|
|||||||
|
|
||||||
def test_max_tokens_error(caplog: Any) -> None:
|
def test_max_tokens_error(caplog: Any) -> None:
|
||||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||||
input = AIMessage(
|
message = AIMessage(
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
@ -651,7 +651,7 @@ def test_max_tokens_error(caplog: Any) -> None:
|
|||||||
response_metadata={"stop_reason": "max_tokens"},
|
response_metadata={"stop_reason": "max_tokens"},
|
||||||
)
|
)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
_ = parser.invoke(input)
|
_ = parser.invoke(message)
|
||||||
assert any(
|
assert any(
|
||||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||||
for record, msg in zip(caplog.records, caplog.messages)
|
for record, msg in zip(caplog.records, caplog.messages)
|
||||||
|
@ -15,11 +15,11 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_directory(dir: Path) -> Iterator:
|
def change_directory(dir_path: Path) -> Iterator:
|
||||||
"""Change the working directory to the right folder."""
|
"""Change the working directory to the right folder."""
|
||||||
origin = Path().absolute()
|
origin = Path().absolute()
|
||||||
try:
|
try:
|
||||||
os.chdir(dir)
|
os.chdir(dir_path)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
os.chdir(origin)
|
os.chdir(origin)
|
||||||
|
@ -220,8 +220,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
str_parser = StrOutputParser()
|
str_parser = StrOutputParser()
|
||||||
xml_parser = XMLOutputParser()
|
xml_parser = XMLOutputParser()
|
||||||
|
|
||||||
def conditional_str_parser(input: str) -> Runnable:
|
def conditional_str_parser(value: str) -> Runnable:
|
||||||
if input == "a":
|
if value == "a":
|
||||||
return str_parser
|
return str_parser
|
||||||
return xml_parser
|
return xml_parser
|
||||||
|
|
||||||
|
@ -111,9 +111,9 @@ async def test_input_messages_async() -> None:
|
|||||||
|
|
||||||
def test_input_dict() -> None:
|
def test_input_dict() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: "you said: "
|
lambda params: "you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
|
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
get_session_history = _get_get_session_history()
|
get_session_history = _get_get_session_history()
|
||||||
@ -131,9 +131,9 @@ def test_input_dict() -> None:
|
|||||||
|
|
||||||
async def test_input_dict_async() -> None:
|
async def test_input_dict_async() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: "you said: "
|
lambda params: "you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
|
str(m.content) for m in params["messages"] if isinstance(m, HumanMessage)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
get_session_history = _get_get_session_history()
|
get_session_history = _get_get_session_history()
|
||||||
@ -153,10 +153,10 @@ async def test_input_dict_async() -> None:
|
|||||||
|
|
||||||
def test_input_dict_with_history_key() -> None:
|
def test_input_dict_with_history_key() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: "you said: "
|
lambda params: "you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
|
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
get_session_history = _get_get_session_history()
|
get_session_history = _get_get_session_history()
|
||||||
@ -175,10 +175,10 @@ def test_input_dict_with_history_key() -> None:
|
|||||||
|
|
||||||
async def test_input_dict_with_history_key_async() -> None:
|
async def test_input_dict_with_history_key_async() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: "you said: "
|
lambda params: "you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
|
[str(m.content) for m in params["history"] if isinstance(m, HumanMessage)]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
get_session_history = _get_get_session_history()
|
get_session_history = _get_get_session_history()
|
||||||
@ -197,15 +197,15 @@ async def test_input_dict_with_history_key_async() -> None:
|
|||||||
|
|
||||||
def test_output_message() -> None:
|
def test_output_message() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: AIMessage(
|
lambda params: AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -225,15 +225,15 @@ def test_output_message() -> None:
|
|||||||
|
|
||||||
async def test_output_message_async() -> None:
|
async def test_output_message_async() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: AIMessage(
|
lambda params: AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -302,16 +302,16 @@ async def test_input_messages_output_message_async() -> None:
|
|||||||
|
|
||||||
def test_output_messages() -> None:
|
def test_output_messages() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: [
|
lambda params: [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -332,16 +332,16 @@ def test_output_messages() -> None:
|
|||||||
|
|
||||||
async def test_output_messages_async() -> None:
|
async def test_output_messages_async() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: [
|
lambda params: [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -362,17 +362,17 @@ async def test_output_messages_async() -> None:
|
|||||||
|
|
||||||
def test_output_dict() -> None:
|
def test_output_dict() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: {
|
lambda params: {
|
||||||
"output": [
|
"output": [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -395,17 +395,17 @@ def test_output_dict() -> None:
|
|||||||
|
|
||||||
async def test_output_dict_async() -> None:
|
async def test_output_dict_async() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: {
|
lambda params: {
|
||||||
"output": [
|
"output": [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -431,17 +431,17 @@ def test_get_input_schema_input_dict() -> None:
|
|||||||
input: Union[str, BaseMessage, Sequence[BaseMessage]]
|
input: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||||
|
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: {
|
lambda params: {
|
||||||
"output": [
|
"output": [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -463,17 +463,17 @@ def test_get_input_schema_input_dict() -> None:
|
|||||||
def test_get_output_schema() -> None:
|
def test_get_output_schema() -> None:
|
||||||
"""Test get output schema."""
|
"""Test get output schema."""
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: {
|
lambda params: {
|
||||||
"output": [
|
"output": [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
[
|
[
|
||||||
str(m.content)
|
str(m.content)
|
||||||
for m in input["history"]
|
for m in params["history"]
|
||||||
if isinstance(m, HumanMessage)
|
if isinstance(m, HumanMessage)
|
||||||
]
|
]
|
||||||
+ [input["input"]]
|
+ [params["input"]]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -531,8 +531,8 @@ def test_get_input_schema_input_messages() -> None:
|
|||||||
def test_using_custom_config_specs() -> None:
|
def test_using_custom_config_specs() -> None:
|
||||||
"""Test that we can configure which keys should be passed to the session factory."""
|
"""Test that we can configure which keys should be passed to the session factory."""
|
||||||
|
|
||||||
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
|
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
|
||||||
messages = input["messages"]
|
messages = params["messages"]
|
||||||
return [
|
return [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
@ -644,8 +644,8 @@ def test_using_custom_config_specs() -> None:
|
|||||||
async def test_using_custom_config_specs_async() -> None:
|
async def test_using_custom_config_specs_async() -> None:
|
||||||
"""Test that we can configure which keys should be passed to the session factory."""
|
"""Test that we can configure which keys should be passed to the session factory."""
|
||||||
|
|
||||||
def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]:
|
def _fake_llm(params: dict[str, Any]) -> list[BaseMessage]:
|
||||||
messages = input["messages"]
|
messages = params["messages"]
|
||||||
return [
|
return [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
@ -757,12 +757,12 @@ async def test_using_custom_config_specs_async() -> None:
|
|||||||
def test_ignore_session_id() -> None:
|
def test_ignore_session_id() -> None:
|
||||||
"""Test without config."""
|
"""Test without config."""
|
||||||
|
|
||||||
def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]:
|
def _fake_llm(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||||
return [
|
return [
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="you said: "
|
content="you said: "
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
str(m.content) for m in input if isinstance(m, HumanMessage)
|
str(m.content) for m in messages if isinstance(m, HumanMessage)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -564,8 +564,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"required": ["bye", "hello"],
|
"required": ["bye", "hello"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
def get_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||||
return input["variable_name"]
|
return value["variable_name"]
|
||||||
|
|
||||||
assert RunnableLambda(get_value).get_input_jsonschema() == {
|
assert RunnableLambda(get_value).get_input_jsonschema() == {
|
||||||
"title": "get_value_input",
|
"title": "get_value_input",
|
||||||
@ -574,8 +574,8 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"required": ["variable_name"],
|
"required": ["variable_name"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def aget_value(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
async def aget_value(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||||
return (input["variable_name"], input.get("another"))
|
return (value["variable_name"], value.get("another"))
|
||||||
|
|
||||||
assert RunnableLambda(aget_value).get_input_jsonschema() == {
|
assert RunnableLambda(aget_value).get_input_jsonschema() == {
|
||||||
"title": "aget_value_input",
|
"title": "aget_value_input",
|
||||||
@ -587,11 +587,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"required": ["another", "variable_name"],
|
"required": ["another", "variable_name"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def aget_values(input): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
async def aget_values(value): # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
|
||||||
return {
|
return {
|
||||||
"hello": input["variable_name"],
|
"hello": value["variable_name"],
|
||||||
"bye": input["variable_name"],
|
"bye": value["variable_name"],
|
||||||
"byebye": input["yo"],
|
"byebye": value["yo"],
|
||||||
}
|
}
|
||||||
|
|
||||||
assert RunnableLambda(aget_values).get_input_jsonschema() == {
|
assert RunnableLambda(aget_values).get_input_jsonschema() == {
|
||||||
@ -613,11 +613,11 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
bye: str
|
bye: str
|
||||||
byebye: int
|
byebye: int
|
||||||
|
|
||||||
async def aget_values_typed(input: InputType) -> OutputType:
|
async def aget_values_typed(value: InputType) -> OutputType:
|
||||||
return {
|
return {
|
||||||
"hello": input["variable_name"],
|
"hello": value["variable_name"],
|
||||||
"bye": input["variable_name"],
|
"bye": value["variable_name"],
|
||||||
"byebye": input["yo"],
|
"byebye": value["yo"],
|
||||||
}
|
}
|
||||||
|
|
||||||
assert _normalize_schema(
|
assert _normalize_schema(
|
||||||
@ -2592,8 +2592,8 @@ async def test_prompt_with_llm_and_async_lambda(
|
|||||||
)
|
)
|
||||||
llm = FakeListLLM(responses=["foo", "bar"])
|
llm = FakeListLLM(responses=["foo", "bar"])
|
||||||
|
|
||||||
async def passthrough(input: Any) -> Any:
|
async def passthrough(value: Any) -> Any:
|
||||||
return input
|
return value
|
||||||
|
|
||||||
chain = prompt | llm | passthrough
|
chain = prompt | llm | passthrough
|
||||||
|
|
||||||
@ -2946,12 +2946,12 @@ def test_higher_order_lambda_runnable(
|
|||||||
input={"question": lambda x: x["question"]},
|
input={"question": lambda x: x["question"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def router(input: dict[str, Any]) -> Runnable:
|
def router(params: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if params["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
if input["key"] == "english":
|
if params["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {params['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
chain: Runnable = input_map | router
|
chain: Runnable = input_map | router
|
||||||
@ -3002,12 +3002,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
|||||||
input={"question": lambda x: x["question"]},
|
input={"question": lambda x: x["question"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def router(input: dict[str, Any]) -> Runnable:
|
def router(value: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if value["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
if input["key"] == "english":
|
if value["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {value['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
chain: Runnable = input_map | router
|
chain: Runnable = input_map | router
|
||||||
@ -3024,12 +3024,12 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
|||||||
assert result2 == ["4", "2"]
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
# Test ainvoke
|
# Test ainvoke
|
||||||
async def arouter(input: dict[str, Any]) -> Runnable:
|
async def arouter(params: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if params["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
if input["key"] == "english":
|
if params["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {params['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
achain: Runnable = input_map | arouter
|
achain: Runnable = input_map | arouter
|
||||||
@ -4125,6 +4125,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
def __init__(self, fail_starts_with: str) -> None:
|
def __init__(self, fail_starts_with: str) -> None:
|
||||||
self.fail_starts_with = fail_starts_with
|
self.fail_starts_with = fail_starts_with
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -4135,15 +4136,15 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
inputs: list[str],
|
inputs: list[str],
|
||||||
) -> list:
|
) -> list:
|
||||||
outputs: list[Any] = []
|
outputs: list[Any] = []
|
||||||
for input in inputs:
|
for value in inputs:
|
||||||
if input.startswith(self.fail_starts_with):
|
if value.startswith(self.fail_starts_with):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
ValueError(
|
ValueError(
|
||||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs.append(input + "a")
|
outputs.append(value + "a")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
@ -4264,6 +4265,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
def __init__(self, fail_starts_with: str) -> None:
|
def __init__(self, fail_starts_with: str) -> None:
|
||||||
self.fail_starts_with = fail_starts_with
|
self.fail_starts_with = fail_starts_with
|
||||||
|
|
||||||
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -4274,15 +4276,15 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
inputs: list[str],
|
inputs: list[str],
|
||||||
) -> list:
|
) -> list:
|
||||||
outputs: list[Any] = []
|
outputs: list[Any] = []
|
||||||
for input in inputs:
|
for value in inputs:
|
||||||
if input.startswith(self.fail_starts_with):
|
if value.startswith(self.fail_starts_with):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
ValueError(
|
ValueError(
|
||||||
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {input}"
|
f"ControlledExceptionRunnable({self.fail_starts_with}) fail for {value}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs.append(input + "a")
|
outputs.append(value + "a")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
@ -5006,10 +5008,10 @@ def test_runnable_iter_context_config() -> None:
|
|||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
def gen(input: str) -> Iterator[int]:
|
def gen(value: str) -> Iterator[int]:
|
||||||
yield fake.invoke(input)
|
yield fake.invoke(value)
|
||||||
yield fake.invoke(input * 2)
|
yield fake.invoke(value * 2)
|
||||||
yield fake.invoke(input * 3)
|
yield fake.invoke(value * 3)
|
||||||
|
|
||||||
assert gen.get_input_jsonschema() == {
|
assert gen.get_input_jsonschema() == {
|
||||||
"title": "gen_input",
|
"title": "gen_input",
|
||||||
@ -5064,10 +5066,10 @@ async def test_runnable_iter_context_config_async() -> None:
|
|||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
async def agen(input: str) -> AsyncIterator[int]:
|
async def agen(value: str) -> AsyncIterator[int]:
|
||||||
yield await fake.ainvoke(input)
|
yield await fake.ainvoke(value)
|
||||||
yield await fake.ainvoke(input * 2)
|
yield await fake.ainvoke(value * 2)
|
||||||
yield await fake.ainvoke(input * 3)
|
yield await fake.ainvoke(value * 3)
|
||||||
|
|
||||||
assert agen.get_input_jsonschema() == {
|
assert agen.get_input_jsonschema() == {
|
||||||
"title": "agen_input",
|
"title": "agen_input",
|
||||||
@ -5130,10 +5132,10 @@ def test_runnable_lambda_context_config() -> None:
|
|||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
def fun(input: str) -> int:
|
def fun(value: str) -> int:
|
||||||
output = fake.invoke(input)
|
output = fake.invoke(value)
|
||||||
output += fake.invoke(input * 2)
|
output += fake.invoke(value * 2)
|
||||||
output += fake.invoke(input * 3)
|
output += fake.invoke(value * 3)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
|
assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
|
||||||
@ -5186,10 +5188,10 @@ async def test_runnable_lambda_context_config_async() -> None:
|
|||||||
fake = RunnableLambda(len)
|
fake = RunnableLambda(len)
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
async def afun(input: str) -> int:
|
async def afun(value: str) -> int:
|
||||||
output = await fake.ainvoke(input)
|
output = await fake.ainvoke(value)
|
||||||
output += await fake.ainvoke(input * 2)
|
output += await fake.ainvoke(value * 2)
|
||||||
output += await fake.ainvoke(input * 3)
|
output += await fake.ainvoke(value * 3)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
|
assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
|
||||||
@ -5242,12 +5244,12 @@ async def test_runnable_gen_transform() -> None:
|
|||||||
for i in range(length):
|
for i in range(length):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
def plus_one(input: Iterator[int]) -> Iterator[int]:
|
def plus_one(ints: Iterator[int]) -> Iterator[int]:
|
||||||
for i in input:
|
for i in ints:
|
||||||
yield i + 1
|
yield i + 1
|
||||||
|
|
||||||
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
|
async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||||
async for i in input:
|
async for i in ints:
|
||||||
yield i + 1
|
yield i + 1
|
||||||
|
|
||||||
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
|
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
|
||||||
|
@ -543,10 +543,10 @@ async def test_astream_events_from_model() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@RunnableLambda
|
@RunnableLambda
|
||||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return model.invoke(input)
|
return model.invoke(value)
|
||||||
return model.invoke(input, config)
|
return model.invoke(value, config)
|
||||||
|
|
||||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
||||||
_assert_events_equal_allow_superset_metadata(
|
_assert_events_equal_allow_superset_metadata(
|
||||||
@ -667,10 +667,10 @@ async def test_astream_events_from_model() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@RunnableLambda
|
@RunnableLambda
|
||||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return await model.ainvoke(input)
|
return await model.ainvoke(value)
|
||||||
return await model.ainvoke(input, config)
|
return await model.ainvoke(value, config)
|
||||||
|
|
||||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
||||||
_assert_events_equal_allow_superset_metadata(
|
_assert_events_equal_allow_superset_metadata(
|
||||||
|
@ -613,10 +613,10 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@RunnableLambda
|
@RunnableLambda
|
||||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
def i_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return model.invoke(input)
|
return model.invoke(value)
|
||||||
return model.invoke(input, config)
|
return model.invoke(value, config)
|
||||||
|
|
||||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
|
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
|
||||||
_assert_events_equal_allow_superset_metadata(
|
_assert_events_equal_allow_superset_metadata(
|
||||||
@ -721,10 +721,10 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@RunnableLambda
|
@RunnableLambda
|
||||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
async def ai_dont_stream(value: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return await model.ainvoke(input)
|
return await model.ainvoke(value)
|
||||||
return await model.ainvoke(input, config)
|
return await model.ainvoke(value, config)
|
||||||
|
|
||||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
|
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
|
||||||
_assert_events_equal_allow_superset_metadata(
|
_assert_events_equal_allow_superset_metadata(
|
||||||
@ -2079,6 +2079,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
|||||||
msg = "Server side error"
|
msg = "Server side error"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@override
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -2413,14 +2414,14 @@ async def test_break_astream_events() -> None:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
async def __call__(self, input: Any) -> Any:
|
async def __call__(self, value: Any) -> Any:
|
||||||
self.started = True
|
self.started = True
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.cancelled = True
|
self.cancelled = True
|
||||||
raise
|
raise
|
||||||
return input
|
return value
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.started = False
|
self.started = False
|
||||||
@ -2433,11 +2434,11 @@ async def test_break_astream_events() -> None:
|
|||||||
outer_cancelled = False
|
outer_cancelled = False
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
async def sequence(input: Any) -> Any:
|
async def sequence(value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
yield await alittlewhile(input)
|
yield await alittlewhile(value)
|
||||||
yield await awhile(input)
|
yield await awhile(value)
|
||||||
yield await anotherwhile(input)
|
yield await anotherwhile(value)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
nonlocal outer_cancelled
|
nonlocal outer_cancelled
|
||||||
outer_cancelled = True
|
outer_cancelled = True
|
||||||
@ -2478,14 +2479,14 @@ async def test_cancel_astream_events() -> None:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
async def __call__(self, input: Any) -> Any:
|
async def __call__(self, value: Any) -> Any:
|
||||||
self.started = True
|
self.started = True
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.cancelled = True
|
self.cancelled = True
|
||||||
raise
|
raise
|
||||||
return input
|
return value
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.started = False
|
self.started = False
|
||||||
@ -2498,11 +2499,11 @@ async def test_cancel_astream_events() -> None:
|
|||||||
outer_cancelled = False
|
outer_cancelled = False
|
||||||
|
|
||||||
@chain
|
@chain
|
||||||
async def sequence(input: Any) -> Any:
|
async def sequence(value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
yield await alittlewhile(input)
|
yield await alittlewhile(value)
|
||||||
yield await awhile(input)
|
yield await awhile(value)
|
||||||
yield await anotherwhile(input)
|
yield await anotherwhile(value)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
nonlocal outer_cancelled
|
nonlocal outer_cancelled
|
||||||
outer_cancelled = True
|
outer_cancelled = True
|
||||||
|
@ -47,23 +47,23 @@ global_agent = RunnableLambda(lambda x: x * 3)
|
|||||||
def test_nonlocals() -> None:
|
def test_nonlocals() -> None:
|
||||||
agent = RunnableLambda(lambda x: x * 2)
|
agent = RunnableLambda(lambda x: x * 2)
|
||||||
|
|
||||||
def my_func(input: str, agent: dict[str, str]) -> str:
|
def my_func(value: str, agent: dict[str, str]) -> str:
|
||||||
return agent.get("agent_name", input)
|
return agent.get("agent_name", value)
|
||||||
|
|
||||||
def my_func2(input: str) -> str:
|
def my_func2(value: str) -> str:
|
||||||
return agent.get("agent_name", input) # type: ignore[attr-defined]
|
return agent.get("agent_name", value) # type: ignore[attr-defined]
|
||||||
|
|
||||||
def my_func3(input: str) -> str:
|
def my_func3(value: str) -> str:
|
||||||
return agent.invoke(input)
|
return agent.invoke(value)
|
||||||
|
|
||||||
def my_func4(input: str) -> str:
|
def my_func4(value: str) -> str:
|
||||||
return global_agent.invoke(input)
|
return global_agent.invoke(value)
|
||||||
|
|
||||||
def my_func5() -> tuple[Callable[[str], str], RunnableLambda]:
|
def my_func5() -> tuple[Callable[[str], str], RunnableLambda]:
|
||||||
global_agent = RunnableLambda(lambda x: x * 3)
|
global_agent = RunnableLambda(lambda x: x * 3)
|
||||||
|
|
||||||
def my_func6(input: str) -> str:
|
def my_func6(value: str) -> str:
|
||||||
return global_agent.invoke(input)
|
return global_agent.invoke(value)
|
||||||
|
|
||||||
return my_func6, global_agent
|
return my_func6, global_agent
|
||||||
|
|
||||||
|
@ -2,17 +2,17 @@ from langchain_core.globals import get_debug, set_debug
|
|||||||
|
|
||||||
|
|
||||||
def test_debug_is_settable_via_setter() -> None:
|
def test_debug_is_settable_via_setter() -> None:
|
||||||
from langchain_core import globals
|
from langchain_core import globals as globals_
|
||||||
from langchain_core.callbacks.manager import _get_debug
|
from langchain_core.callbacks.manager import _get_debug
|
||||||
|
|
||||||
previous_value = globals._debug
|
previous_value = globals_._debug
|
||||||
previous_fn_reading = _get_debug()
|
previous_fn_reading = _get_debug()
|
||||||
assert previous_value == previous_fn_reading
|
assert previous_value == previous_fn_reading
|
||||||
|
|
||||||
# Flip the value of the flag.
|
# Flip the value of the flag.
|
||||||
set_debug(not previous_value)
|
set_debug(not previous_value)
|
||||||
|
|
||||||
new_value = globals._debug
|
new_value = globals_._debug
|
||||||
new_fn_reading = _get_debug()
|
new_fn_reading = _get_debug()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -50,7 +50,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
|||||||
return ids_
|
return ids_
|
||||||
|
|
||||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||||
return [self.store[id] for id in ids if id in self.store]
|
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@override
|
@override
|
||||||
@ -96,7 +96,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
|||||||
return ids_
|
return ids_
|
||||||
|
|
||||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||||
return [self.store[id] for id in ids if id in self.store]
|
return [self.store[id_] for id_ in ids if id_ in self.store]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@override
|
@override
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
version = 1
|
version = 1
|
||||||
|
revision = 1
|
||||||
requires-python = ">=3.9, <4.0"
|
requires-python = ">=3.9, <4.0"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.13'",
|
"python_full_version >= '3.13'",
|
||||||
|
Loading…
Reference in New Issue
Block a user