mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
core[patch]: docstrings runnables
update (#24161)
Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
parent
14ba1d4b45
commit
aa3e3cfa40
File diff suppressed because it is too large
Load Diff
@ -48,6 +48,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
If no condition evaluates to True, the default branch is run on the input.
|
If no condition evaluates to True, the default branch is run on the input.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
branches: A list of (condition, Runnable) pairs.
|
||||||
|
default: A Runnable to run if no condition is met.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -82,7 +86,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
RunnableLike, # To accommodate the default branch
|
RunnableLike, # To accommodate the default branch
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A Runnable that runs one of two branches based on a condition."""
|
"""A Runnable that runs one of two branches based on a condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*branches: A list of (condition, Runnable) pairs.
|
||||||
|
Defaults a Runnable to run if no condition is met.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of branches is less than 2.
|
||||||
|
TypeError: If the default branch is not Runnable, Callable or Mapping.
|
||||||
|
TypeError: If a branch is not a tuple or list.
|
||||||
|
ValueError: If a branch is not of length 2.
|
||||||
|
"""
|
||||||
if len(branches) < 2:
|
if len(branches) < 2:
|
||||||
raise ValueError("RunnableBranch requires at least two branches")
|
raise ValueError("RunnableBranch requires at least two branches")
|
||||||
|
|
||||||
@ -93,7 +108,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
(Runnable, Callable, Mapping), # type: ignore[arg-type]
|
(Runnable, Callable, Mapping), # type: ignore[arg-type]
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"RunnableBranch default must be runnable, callable or mapping."
|
"RunnableBranch default must be Runnable, callable or mapping."
|
||||||
)
|
)
|
||||||
|
|
||||||
default_ = cast(
|
default_ = cast(
|
||||||
@ -176,7 +191,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
"""First evaluates the condition, then delegate to true or false branch."""
|
"""First evaluates the condition, then delegate to true or false branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the Runnable.
|
||||||
|
config: The configuration for the Runnable. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the branch that was run.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
|
||||||
|
"""
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -277,7 +304,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
"""First evaluates the condition,
|
"""First evaluates the condition,
|
||||||
then delegate to true or false branch."""
|
then delegate to true or false branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the Runnable.
|
||||||
|
config: The configuration for the Runnable. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The output of the branch that was run.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BaseException: If an error occurs during the execution of the Runnable.
|
||||||
|
"""
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -352,7 +391,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
"""First evaluates the condition,
|
"""First evaluates the condition,
|
||||||
then delegate to true or false branch."""
|
then delegate to true or false branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the Runnable.
|
||||||
|
config: The configuration for the Runnable. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The output of the branch that was run.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BaseException: If an error occurs during the execution of the Runnable.
|
||||||
|
"""
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
@ -111,7 +111,7 @@ var_child_runnable_config = ContextVar(
|
|||||||
|
|
||||||
|
|
||||||
def _set_config_context(config: RunnableConfig) -> None:
|
def _set_config_context(config: RunnableConfig) -> None:
|
||||||
"""Set the child runnable config + tracing context
|
"""Set the child Runnable config + tracing context
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (RunnableConfig): The config to set.
|
config (RunnableConfig): The config to set.
|
||||||
@ -216,7 +216,6 @@ def patch_config(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Optional[RunnableConfig]): The config to patch.
|
config (Optional[RunnableConfig]): The config to patch.
|
||||||
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
|
|
||||||
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
recursion_limit (Optional[int], optional): The recursion limit to set.
|
recursion_limit (Optional[int], optional): The recursion limit to set.
|
||||||
@ -362,9 +361,9 @@ def call_func_with_variable_args(
|
|||||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
||||||
The function to call.
|
The function to call.
|
||||||
input (Input): The input to the function.
|
input (Input): The input to the function.
|
||||||
run_manager (CallbackManagerForChainRun): The run manager to
|
|
||||||
pass to the function.
|
|
||||||
config (RunnableConfig): The config to pass to the function.
|
config (RunnableConfig): The config to pass to the function.
|
||||||
|
run_manager (CallbackManagerForChainRun): The run manager to
|
||||||
|
pass to the function. Defaults to None.
|
||||||
**kwargs (Any): The keyword arguments to pass to the function.
|
**kwargs (Any): The keyword arguments to pass to the function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -395,7 +394,7 @@ def acall_func_with_variable_args(
|
|||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Awaitable[Output]:
|
) -> Awaitable[Output]:
|
||||||
"""Call function that may optionally accept a run_manager and/or config.
|
"""Async call function that may optionally accept a run_manager and/or config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
||||||
@ -403,9 +402,9 @@ def acall_func_with_variable_args(
|
|||||||
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
||||||
The function to call.
|
The function to call.
|
||||||
input (Input): The input to the function.
|
input (Input): The input to the function.
|
||||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
|
||||||
to pass to the function.
|
|
||||||
config (RunnableConfig): The config to pass to the function.
|
config (RunnableConfig): The config to pass to the function.
|
||||||
|
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||||
|
to pass to the function. Defaults to None.
|
||||||
**kwargs (Any): The keyword arguments to pass to the function.
|
**kwargs (Any): The keyword arguments to pass to the function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -493,6 +492,18 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
|||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
chunksize: int = 1,
|
chunksize: int = 1,
|
||||||
) -> Iterator[T]:
|
) -> Iterator[T]:
|
||||||
|
"""Map a function to multiple iterables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (Callable[..., T]): The function to map.
|
||||||
|
*iterables (Iterable[Any]): The iterables to map over.
|
||||||
|
timeout (float | None, optional): The timeout for the map.
|
||||||
|
Defaults to None.
|
||||||
|
chunksize (int, optional): The chunksize for the map. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[T]: The iterator for the mapped function.
|
||||||
|
"""
|
||||||
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||||
|
|
||||||
def _wrapped_fn(*args: Any) -> T:
|
def _wrapped_fn(*args: Any) -> T:
|
||||||
@ -534,13 +545,16 @@ async def run_in_executor(
|
|||||||
"""Run a function in an executor.
|
"""Run a function in an executor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
executor (Executor): The executor.
|
executor_or_config: The executor or config to run in.
|
||||||
func (Callable[P, Output]): The function.
|
func (Callable[P, Output]): The function.
|
||||||
*args (Any): The positional arguments to the function.
|
*args (Any): The positional arguments to the function.
|
||||||
**kwargs (Any): The keyword arguments to the function.
|
**kwargs (Any): The keyword arguments to the function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output: The output of the function.
|
Output: The output of the function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the function raises a StopIteration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper() -> T:
|
def wrapper() -> T:
|
||||||
|
@ -44,7 +44,15 @@ from langchain_core.runnables.utils import (
|
|||||||
|
|
||||||
|
|
||||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||||
"""Serializable Runnable that can be dynamically configured."""
|
"""Serializable Runnable that can be dynamically configured.
|
||||||
|
|
||||||
|
A DynamicRunnable should be initiated using the `configurable_fields` or
|
||||||
|
`configurable_alternatives` method of a Runnable.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
default: The default Runnable to use.
|
||||||
|
config: The configuration to use.
|
||||||
|
"""
|
||||||
|
|
||||||
default: RunnableSerializable[Input, Output]
|
default: RunnableSerializable[Input, Output]
|
||||||
|
|
||||||
@ -99,6 +107,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
|
"""Prepare the Runnable for invocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration to use. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Runnable[Input, Output], RunnableConfig]: The prepared Runnable and
|
||||||
|
configuration.
|
||||||
|
"""
|
||||||
runnable: Runnable[Input, Output] = self
|
runnable: Runnable[Input, Output] = self
|
||||||
while isinstance(runnable, DynamicRunnable):
|
while isinstance(runnable, DynamicRunnable):
|
||||||
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
|
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
|
||||||
@ -284,6 +301,9 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
A RunnableConfigurableFields should be initiated using the
|
A RunnableConfigurableFields should be initiated using the
|
||||||
`configurable_fields` method of a Runnable.
|
`configurable_fields` method of a Runnable.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fields: The configurable fields to use.
|
||||||
|
|
||||||
Here is an example of using a RunnableConfigurableFields with LLMs:
|
Here is an example of using a RunnableConfigurableFields with LLMs:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -348,6 +368,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
|
"""Get the configuration specs for the RunnableConfigurableFields.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ConfigurableFieldSpec]: The configuration specs.
|
||||||
|
"""
|
||||||
return get_unique_config_specs(
|
return get_unique_config_specs(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
@ -374,6 +399,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
def configurable_fields(
|
def configurable_fields(
|
||||||
self, **kwargs: AnyConfigurableField
|
self, **kwargs: AnyConfigurableField
|
||||||
) -> RunnableSerializable[Input, Output]:
|
) -> RunnableSerializable[Input, Output]:
|
||||||
|
"""Get a new RunnableConfigurableFields with the specified
|
||||||
|
configurable fields."""
|
||||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||||
|
|
||||||
def _prepare(
|
def _prepare(
|
||||||
@ -493,11 +520,13 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
which: ConfigurableField
|
which: ConfigurableField
|
||||||
|
"""The ConfigurableField to use to choose between alternatives."""
|
||||||
|
|
||||||
alternatives: Dict[
|
alternatives: Dict[
|
||||||
str,
|
str,
|
||||||
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||||
]
|
]
|
||||||
|
"""The alternatives to choose from."""
|
||||||
|
|
||||||
default_key: str = "default"
|
default_key: str = "default"
|
||||||
"""The enum value to use for the default option. Defaults to "default"."""
|
"""The enum value to use for the default option. Defaults to "default"."""
|
||||||
@ -619,7 +648,7 @@ def prefix_config_spec(
|
|||||||
prefix: The prefix to add.
|
prefix: The prefix to add.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
ConfigurableFieldSpec: The prefixed ConfigurableFieldSpec.
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
ConfigurableFieldSpec(
|
ConfigurableFieldSpec(
|
||||||
@ -641,6 +670,13 @@ def make_options_spec(
|
|||||||
) -> ConfigurableFieldSpec:
|
) -> ConfigurableFieldSpec:
|
||||||
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
|
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
|
||||||
ConfigurableFieldMultiOption.
|
ConfigurableFieldMultiOption.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: The ConfigurableFieldSingleOption or ConfigurableFieldMultiOption.
|
||||||
|
description: The description to use if the spec does not have one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ConfigurableFieldSpec.
|
||||||
"""
|
"""
|
||||||
with _enums_for_spec_lock:
|
with _enums_for_spec_lock:
|
||||||
if enum := _enums_for_spec.get(spec):
|
if enum := _enums_for_spec.get(spec):
|
||||||
|
@ -91,7 +91,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
runnable: Runnable[Input, Output]
|
runnable: Runnable[Input, Output]
|
||||||
"""The runnable to run first."""
|
"""The Runnable to run first."""
|
||||||
fallbacks: Sequence[Runnable[Input, Output]]
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
"""A sequence of fallbacks to try."""
|
"""A sequence of fallbacks to try."""
|
||||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||||
@ -102,7 +102,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
exception_key: Optional[str] = None
|
exception_key: Optional[str] = None
|
||||||
"""If string is specified then handled exceptions will be passed to fallbacks as
|
"""If string is specified then handled exceptions will be passed to fallbacks as
|
||||||
part of the input under the specified key. If None, exceptions
|
part of the input under the specified key. If None, exceptions
|
||||||
will not be passed to fallbacks. If used, the base runnable and its fallbacks
|
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
|
||||||
must accept a dictionary as input."""
|
must accept a dictionary as input."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -554,7 +554,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
await run_manager.on_chain_end(output)
|
await run_manager.on_chain_end(output)
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
"""Get an attribute from the wrapped runnable and its fallbacks.
|
"""Get an attribute from the wrapped Runnable and its fallbacks.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the attribute is anything other than a method that outputs a Runnable,
|
If the attribute is anything other than a method that outputs a Runnable,
|
||||||
|
@ -57,7 +57,14 @@ def is_uuid(value: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class Edge(NamedTuple):
|
class Edge(NamedTuple):
|
||||||
"""Edge in a graph."""
|
"""Edge in a graph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
source: The source node id.
|
||||||
|
target: The target node id.
|
||||||
|
data: Optional data associated with the edge. Defaults to None.
|
||||||
|
conditional: Whether the edge is conditional. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
@ -67,6 +74,15 @@ class Edge(NamedTuple):
|
|||||||
def copy(
|
def copy(
|
||||||
self, *, source: Optional[str] = None, target: Optional[str] = None
|
self, *, source: Optional[str] = None, target: Optional[str] = None
|
||||||
) -> Edge:
|
) -> Edge:
|
||||||
|
"""Return a copy of the edge with optional new source and target nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The new source node id. Defaults to None.
|
||||||
|
target: The new target node id. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A copy of the edge with the new source and target nodes.
|
||||||
|
"""
|
||||||
return Edge(
|
return Edge(
|
||||||
source=source or self.source,
|
source=source or self.source,
|
||||||
target=target or self.target,
|
target=target or self.target,
|
||||||
@ -76,7 +92,14 @@ class Edge(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Node(NamedTuple):
|
class Node(NamedTuple):
|
||||||
"""Node in a graph."""
|
"""Node in a graph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
id: The unique identifier of the node.
|
||||||
|
name: The name of the node.
|
||||||
|
data: The data of the node.
|
||||||
|
metadata: Optional metadata for the node. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@ -84,6 +107,15 @@ class Node(NamedTuple):
|
|||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||||
|
"""Return a copy of the node with optional new id and name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The new node id. Defaults to None.
|
||||||
|
name: The new node name. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A copy of the node with the new id and name.
|
||||||
|
"""
|
||||||
return Node(
|
return Node(
|
||||||
id=id or self.id,
|
id=id or self.id,
|
||||||
name=name or self.name,
|
name=name or self.name,
|
||||||
@ -93,7 +125,13 @@ class Node(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Branch(NamedTuple):
|
class Branch(NamedTuple):
|
||||||
"""Branch in a graph."""
|
"""Branch in a graph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
condition: A callable that returns a string representation of the condition.
|
||||||
|
ends: Optional dictionary of end node ids for the branches. Defaults
|
||||||
|
to None.
|
||||||
|
"""
|
||||||
|
|
||||||
condition: Callable[..., str]
|
condition: Callable[..., str]
|
||||||
ends: Optional[dict[str, str]]
|
ends: Optional[dict[str, str]]
|
||||||
@ -118,7 +156,13 @@ class CurveStyle(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NodeStyles:
|
class NodeStyles:
|
||||||
"""Schema for Hexadecimal color codes for different node types"""
|
"""Schema for Hexadecimal color codes for different node types.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
default: The default color code. Defaults to "fill:#f2f0ff,line-height:1.2".
|
||||||
|
first: The color code for the first node. Defaults to "fill-opacity:0".
|
||||||
|
last: The color code for the last node. Defaults to "fill:#bfb6fc".
|
||||||
|
"""
|
||||||
|
|
||||||
default: str = "fill:#f2f0ff,line-height:1.2"
|
default: str = "fill:#f2f0ff,line-height:1.2"
|
||||||
first: str = "fill-opacity:0"
|
first: str = "fill-opacity:0"
|
||||||
@ -161,7 +205,7 @@ def node_data_json(
|
|||||||
Args:
|
Args:
|
||||||
node: The node to convert.
|
node: The node to convert.
|
||||||
with_schemas: Whether to include the schema of the data if
|
with_schemas: Whether to include the schema of the data if
|
||||||
it is a Pydantic model.
|
it is a Pydantic model. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary with the type of the data and the data itself.
|
A dictionary with the type of the data and the data itself.
|
||||||
@ -209,13 +253,26 @@ def node_data_json(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Graph:
|
class Graph:
|
||||||
"""Graph of nodes and edges."""
|
"""Graph of nodes and edges.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
nodes: Dictionary of nodes in the graph. Defaults to an empty dictionary.
|
||||||
|
edges: List of edges in the graph. Defaults to an empty list.
|
||||||
|
"""
|
||||||
|
|
||||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||||
edges: List[Edge] = field(default_factory=list)
|
edges: List[Edge] = field(default_factory=list)
|
||||||
|
|
||||||
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""Convert the graph to a JSON-serializable format."""
|
"""Convert the graph to a JSON-serializable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_schemas: Whether to include the schemas of the nodes if they are
|
||||||
|
Pydantic models. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the nodes and edges of the graph.
|
||||||
|
"""
|
||||||
stable_node_ids = {
|
stable_node_ids = {
|
||||||
node.id: i if is_uuid(node.id) else node.id
|
node.id: i if is_uuid(node.id) else node.id
|
||||||
for i, node in enumerate(self.nodes.values())
|
for i, node in enumerate(self.nodes.values())
|
||||||
@ -247,6 +304,8 @@ class Graph:
|
|||||||
return bool(self.nodes)
|
return bool(self.nodes)
|
||||||
|
|
||||||
def next_id(self) -> str:
|
def next_id(self) -> str:
|
||||||
|
"""Return a new unique node
|
||||||
|
identifier that can be used to add a node to the graph."""
|
||||||
return uuid4().hex
|
return uuid4().hex
|
||||||
|
|
||||||
def add_node(
|
def add_node(
|
||||||
@ -256,7 +315,19 @@ class Graph:
|
|||||||
*,
|
*,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Add a node to the graph and return it."""
|
"""Add a node to the graph and return it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data of the node.
|
||||||
|
id: The id of the node. Defaults to None.
|
||||||
|
metadata: Optional metadata for the node. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The node that was added to the graph.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a node with the same id already exists.
|
||||||
|
"""
|
||||||
if id is not None and id in self.nodes:
|
if id is not None and id in self.nodes:
|
||||||
raise ValueError(f"Node with id {id} already exists")
|
raise ValueError(f"Node with id {id} already exists")
|
||||||
id = id or self.next_id()
|
id = id or self.next_id()
|
||||||
@ -265,7 +336,11 @@ class Graph:
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
def remove_node(self, node: Node) -> None:
|
def remove_node(self, node: Node) -> None:
|
||||||
"""Remove a node from the graph and all edges connected to it."""
|
"""Remove a node from the graph and all edges connected to it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to remove.
|
||||||
|
"""
|
||||||
self.nodes.pop(node.id)
|
self.nodes.pop(node.id)
|
||||||
self.edges = [
|
self.edges = [
|
||||||
edge
|
edge
|
||||||
@ -280,7 +355,20 @@ class Graph:
|
|||||||
data: Optional[Stringifiable] = None,
|
data: Optional[Stringifiable] = None,
|
||||||
conditional: bool = False,
|
conditional: bool = False,
|
||||||
) -> Edge:
|
) -> Edge:
|
||||||
"""Add an edge to the graph and return it."""
|
"""Add an edge to the graph and return it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: The source node of the edge.
|
||||||
|
target: The target node of the edge.
|
||||||
|
data: Optional data associated with the edge. Defaults to None.
|
||||||
|
conditional: Whether the edge is conditional. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The edge that was added to the graph.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the source or target node is not in the graph.
|
||||||
|
"""
|
||||||
if source.id not in self.nodes:
|
if source.id not in self.nodes:
|
||||||
raise ValueError(f"Source node {source.id} not in graph")
|
raise ValueError(f"Source node {source.id} not in graph")
|
||||||
if target.id not in self.nodes:
|
if target.id not in self.nodes:
|
||||||
@ -295,7 +383,15 @@ class Graph:
|
|||||||
self, graph: Graph, *, prefix: str = ""
|
self, graph: Graph, *, prefix: str = ""
|
||||||
) -> Tuple[Optional[Node], Optional[Node]]:
|
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||||
"""Add all nodes and edges from another graph.
|
"""Add all nodes and edges from another graph.
|
||||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
Note this doesn't check for duplicates, nor does it connect the graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The graph to add.
|
||||||
|
prefix: The prefix to add to the node ids. Defaults to "".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the first and last nodes of the subgraph.
|
||||||
|
"""
|
||||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||||
prefix = ""
|
prefix = ""
|
||||||
|
|
||||||
@ -350,7 +446,7 @@ class Graph:
|
|||||||
def first_node(self) -> Optional[Node]:
|
def first_node(self) -> Optional[Node]:
|
||||||
"""Find the single node that is not a target of any edge.
|
"""Find the single node that is not a target of any edge.
|
||||||
If there is no such node, or there are multiple, return None.
|
If there is no such node, or there are multiple, return None.
|
||||||
When drawing the graph this node would be the origin."""
|
When drawing the graph, this node would be the origin."""
|
||||||
targets = {edge.target for edge in self.edges}
|
targets = {edge.target for edge in self.edges}
|
||||||
found: List[Node] = []
|
found: List[Node] = []
|
||||||
for node in self.nodes.values():
|
for node in self.nodes.values():
|
||||||
@ -361,7 +457,7 @@ class Graph:
|
|||||||
def last_node(self) -> Optional[Node]:
|
def last_node(self) -> Optional[Node]:
|
||||||
"""Find the single node that is not a source of any edge.
|
"""Find the single node that is not a source of any edge.
|
||||||
If there is no such node, or there are multiple, return None.
|
If there is no such node, or there are multiple, return None.
|
||||||
When drawing the graph this node would be the destination.
|
When drawing the graph, this node would be the destination.
|
||||||
"""
|
"""
|
||||||
sources = {edge.source for edge in self.edges}
|
sources = {edge.source for edge in self.edges}
|
||||||
found: List[Node] = []
|
found: List[Node] = []
|
||||||
@ -372,7 +468,7 @@ class Graph:
|
|||||||
|
|
||||||
def trim_first_node(self) -> None:
|
def trim_first_node(self) -> None:
|
||||||
"""Remove the first node if it exists and has a single outgoing edge,
|
"""Remove the first node if it exists and has a single outgoing edge,
|
||||||
ie. if removing it would not leave the graph without a "first" node."""
|
i.e., if removing it would not leave the graph without a "first" node."""
|
||||||
first_node = self.first_node()
|
first_node = self.first_node()
|
||||||
if first_node:
|
if first_node:
|
||||||
if (
|
if (
|
||||||
@ -384,7 +480,7 @@ class Graph:
|
|||||||
|
|
||||||
def trim_last_node(self) -> None:
|
def trim_last_node(self) -> None:
|
||||||
"""Remove the last node if it exists and has a single incoming edge,
|
"""Remove the last node if it exists and has a single incoming edge,
|
||||||
ie. if removing it would not leave the graph without a "last" node."""
|
i.e., if removing it would not leave the graph without a "last" node."""
|
||||||
last_node = self.last_node()
|
last_node = self.last_node()
|
||||||
if last_node:
|
if last_node:
|
||||||
if (
|
if (
|
||||||
@ -395,6 +491,7 @@ class Graph:
|
|||||||
self.remove_node(last_node)
|
self.remove_node(last_node)
|
||||||
|
|
||||||
def draw_ascii(self) -> str:
|
def draw_ascii(self) -> str:
|
||||||
|
"""Draw the graph as an ASCII art string."""
|
||||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||||
|
|
||||||
return draw_ascii(
|
return draw_ascii(
|
||||||
@ -403,6 +500,7 @@ class Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def print_ascii(self) -> None:
|
def print_ascii(self) -> None:
|
||||||
|
"""Print the graph as an ASCII art string."""
|
||||||
print(self.draw_ascii()) # noqa: T201
|
print(self.draw_ascii()) # noqa: T201
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -427,6 +525,17 @@ class Graph:
|
|||||||
fontname: Optional[str] = None,
|
fontname: Optional[str] = None,
|
||||||
labels: Optional[LabelsDict] = None,
|
labels: Optional[LabelsDict] = None,
|
||||||
) -> Union[bytes, None]:
|
) -> Union[bytes, None]:
|
||||||
|
"""Draw the graph as a PNG image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file_path: The path to save the image to. If None, the image
|
||||||
|
is not saved. Defaults to None.
|
||||||
|
fontname: The name of the font to use. Defaults to None.
|
||||||
|
labels: Optional labels for nodes and edges in the graph. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The PNG image as bytes if output_file_path is None, None otherwise.
|
||||||
|
"""
|
||||||
from langchain_core.runnables.graph_png import PngDrawer
|
from langchain_core.runnables.graph_png import PngDrawer
|
||||||
|
|
||||||
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||||
@ -450,6 +559,18 @@ class Graph:
|
|||||||
node_colors: NodeStyles = NodeStyles(),
|
node_colors: NodeStyles = NodeStyles(),
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Draw the graph as a Mermaid syntax string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_styles: Whether to include styles in the syntax. Defaults to True.
|
||||||
|
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||||
|
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||||
|
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||||
|
Defaults to 9.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Mermaid syntax string.
|
||||||
|
"""
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||||
|
|
||||||
graph = self.reid()
|
graph = self.reid()
|
||||||
@ -478,6 +599,23 @@ class Graph:
|
|||||||
background_color: str = "white",
|
background_color: str = "white",
|
||||||
padding: int = 10,
|
padding: int = 10,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
|
"""Draw the graph as a PNG image using Mermaid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||||
|
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||||
|
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||||
|
Defaults to 9.
|
||||||
|
output_file_path: The path to save the image to. If None, the image
|
||||||
|
is not saved. Defaults to None.
|
||||||
|
draw_method: The method to use to draw the graph.
|
||||||
|
Defaults to MermaidDrawMethod.API.
|
||||||
|
background_color: The color of the background. Defaults to "white".
|
||||||
|
padding: The padding around the graph. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The PNG image as bytes.
|
||||||
|
"""
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
|
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
|
||||||
|
|
||||||
mermaid_syntax = self.draw_mermaid(
|
mermaid_syntax = self.draw_mermaid(
|
||||||
|
@ -17,6 +17,7 @@ class VertexViewer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
HEIGHT = 3 # top and bottom box edges + text
|
HEIGHT = 3 # top and bottom box edges + text
|
||||||
|
"""Height of the box."""
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
def __init__(self, name: str) -> None:
|
||||||
self._h = self.HEIGHT # top and bottom box edges + text
|
self._h = self.HEIGHT # top and bottom box edges + text
|
||||||
|
@ -23,18 +23,25 @@ def draw_mermaid(
|
|||||||
node_styles: NodeStyles = NodeStyles(),
|
node_styles: NodeStyles = NodeStyles(),
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Draws a Mermaid graph using the provided graph data
|
"""Draws a Mermaid graph using the provided graph data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes (dict[str, str]): List of node ids
|
nodes (dict[str, str]): List of node ids.
|
||||||
edges (List[Edge]): List of edges, object with source,
|
edges (List[Edge]): List of edges, object with a source,
|
||||||
target and data.
|
target and data.
|
||||||
|
first_node (str, optional): Id of the first node. Defaults to None.
|
||||||
|
last_node (str, optional): Id of the last node. Defaults to None.
|
||||||
|
with_styles (bool, optional): Whether to include styles in the graph.
|
||||||
|
Defaults to True.
|
||||||
curve_style (CurveStyle, optional): Curve style for the edges.
|
curve_style (CurveStyle, optional): Curve style for the edges.
|
||||||
node_colors (NodeColors, optional): Node colors for different types.
|
Defaults to CurveStyle.LINEAR.
|
||||||
|
node_styles (NodeStyles, optional): Node colors for different types.
|
||||||
|
Defaults to NodeStyles().
|
||||||
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
||||||
|
Defaults to 9.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Mermaid graph syntax
|
str: Mermaid graph syntax.
|
||||||
"""
|
"""
|
||||||
# Initialize Mermaid graph configuration
|
# Initialize Mermaid graph configuration
|
||||||
mermaid_graph = (
|
mermaid_graph = (
|
||||||
@ -139,7 +146,24 @@ def draw_mermaid_png(
|
|||||||
background_color: Optional[str] = "white",
|
background_color: Optional[str] = "white",
|
||||||
padding: int = 10,
|
padding: int = 10,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Draws a Mermaid graph as PNG using provided syntax."""
|
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mermaid_syntax (str): Mermaid graph syntax.
|
||||||
|
output_file_path (str, optional): Path to save the PNG image.
|
||||||
|
Defaults to None.
|
||||||
|
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
|
||||||
|
Defaults to MermaidDrawMethod.API.
|
||||||
|
background_color (str, optional): Background color of the image.
|
||||||
|
Defaults to "white".
|
||||||
|
padding (int, optional): Padding around the image. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: PNG image bytes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an invalid draw method is provided.
|
||||||
|
"""
|
||||||
if draw_method == MermaidDrawMethod.PYPPETEER:
|
if draw_method == MermaidDrawMethod.PYPPETEER:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from langchain_core.runnables.graph import Graph, LabelsDict
|
|||||||
class PngDrawer:
|
class PngDrawer:
|
||||||
"""Helper class to draw a state graph into a PNG file.
|
"""Helper class to draw a state graph into a PNG file.
|
||||||
|
|
||||||
It requires graphviz and pygraphviz to be installed.
|
It requires `graphviz` and `pygraphviz` to be installed.
|
||||||
:param fontname: The font to use for the labels
|
:param fontname: The font to use for the labels
|
||||||
:param labels: A dictionary of label overrides. The dictionary
|
:param labels: A dictionary of label overrides. The dictionary
|
||||||
should have the following format:
|
should have the following format:
|
||||||
@ -33,7 +33,7 @@ class PngDrawer:
|
|||||||
"""Initializes the PNG drawer.
|
"""Initializes the PNG drawer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fontname: The font to use for the labels
|
fontname: The font to use for the labels. Defaults to "arial".
|
||||||
labels: A dictionary of label overrides. The dictionary
|
labels: A dictionary of label overrides. The dictionary
|
||||||
should have the following format:
|
should have the following format:
|
||||||
{
|
{
|
||||||
@ -48,6 +48,7 @@ class PngDrawer:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
The keys are the original labels, and the values are the new labels.
|
The keys are the original labels, and the values are the new labels.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.fontname = fontname or "arial"
|
self.fontname = fontname or "arial"
|
||||||
self.labels = labels or LabelsDict(nodes={}, edges={})
|
self.labels = labels or LabelsDict(nodes={}, edges={})
|
||||||
@ -56,7 +57,7 @@ class PngDrawer:
|
|||||||
"""Returns the label to use for a node.
|
"""Returns the label to use for a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label: The original label
|
label: The original label.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The new label.
|
The new label.
|
||||||
@ -68,7 +69,7 @@ class PngDrawer:
|
|||||||
"""Returns the label to use for an edge.
|
"""Returns the label to use for an edge.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label: The original label
|
label: The original label.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The new label.
|
The new label.
|
||||||
@ -80,8 +81,8 @@ class PngDrawer:
|
|||||||
"""Adds a node to the graph.
|
"""Adds a node to the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
viz: The graphviz object
|
viz: The graphviz object.
|
||||||
node: The node to add
|
node: The node to add.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
@ -106,9 +107,9 @@ class PngDrawer:
|
|||||||
"""Adds an edge to the graph.
|
"""Adds an edge to the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
viz: The graphviz object
|
viz: The graphviz object.
|
||||||
source: The source node
|
source: The source node.
|
||||||
target: The target node
|
target: The target node.
|
||||||
label: The label for the edge. Defaults to None.
|
label: The label for the edge. Defaults to None.
|
||||||
conditional: Whether the edge is conditional. Defaults to False.
|
conditional: Whether the edge is conditional. Defaults to False.
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ class PngDrawer:
|
|||||||
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
||||||
"""Draw the given state graph into a PNG file.
|
"""Draw the given state graph into a PNG file.
|
||||||
|
|
||||||
Requires graphviz and pygraphviz to be installed.
|
Requires `graphviz` and `pygraphviz` to be installed.
|
||||||
:param graph: The graph to draw
|
:param graph: The graph to draw
|
||||||
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
|
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
|
||||||
"""
|
"""
|
||||||
@ -156,14 +157,32 @@ class PngDrawer:
|
|||||||
viz.close()
|
viz.close()
|
||||||
|
|
||||||
def add_nodes(self, viz: Any, graph: Graph) -> None:
|
def add_nodes(self, viz: Any, graph: Graph) -> None:
|
||||||
|
"""Add nodes to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
viz: The graphviz object.
|
||||||
|
graph: The graph to draw.
|
||||||
|
"""
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
self.add_node(viz, node)
|
self.add_node(viz, node)
|
||||||
|
|
||||||
def add_edges(self, viz: Any, graph: Graph) -> None:
|
def add_edges(self, viz: Any, graph: Graph) -> None:
|
||||||
|
"""Add edges to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
viz: The graphviz object.
|
||||||
|
graph: The graph to draw.
|
||||||
|
"""
|
||||||
for start, end, data, cond in graph.edges:
|
for start, end, data, cond in graph.edges:
|
||||||
self.add_edge(viz, start, end, str(data), cond)
|
self.add_edge(viz, start, end, str(data), cond)
|
||||||
|
|
||||||
def update_styles(self, viz: Any, graph: Graph) -> None:
|
def update_styles(self, viz: Any, graph: Graph) -> None:
|
||||||
|
"""Update the styles of the entrypoint and END nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
viz: The graphviz object.
|
||||||
|
graph: The graph to draw.
|
||||||
|
"""
|
||||||
if first := graph.first_node():
|
if first := graph.first_node():
|
||||||
viz.get_node(first.id).attr.update(fillcolor="lightblue")
|
viz.get_node(first.id).attr.update(fillcolor="lightblue")
|
||||||
if last := graph.last_node():
|
if last := graph.last_node():
|
||||||
|
@ -45,13 +45,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
history for it; it is responsible for reading and updating the chat message
|
history for it; it is responsible for reading and updating the chat message
|
||||||
history.
|
history.
|
||||||
|
|
||||||
The formats supports for the inputs and outputs of the wrapped Runnable
|
The formats supported for the inputs and outputs of the wrapped Runnable
|
||||||
are described below.
|
are described below.
|
||||||
|
|
||||||
RunnableWithMessageHistory must always be called with a config that contains
|
RunnableWithMessageHistory must always be called with a config that contains
|
||||||
the appropriate parameters for the chat message history factory.
|
the appropriate parameters for the chat message history factory.
|
||||||
|
|
||||||
By default the Runnable is expected to take a single configuration parameter
|
By default, the Runnable is expected to take a single configuration parameter
|
||||||
called `session_id` which is a string. This parameter is used to create a new
|
called `session_id` which is a string. This parameter is used to create a new
|
||||||
or look up an existing chat message history that matches the given session_id.
|
or look up an existing chat message history that matches the given session_id.
|
||||||
|
|
||||||
@ -70,6 +70,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
For production use cases, you will want to use a persistent implementation
|
For production use cases, you will want to use a persistent implementation
|
||||||
of chat message history, such as ``RedisChatMessageHistory``.
|
of chat message history, such as ``RedisChatMessageHistory``.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
get_session_history: Function that returns a new BaseChatMessageHistory.
|
||||||
|
This function should either take a single positional argument
|
||||||
|
`session_id` of type string and return a corresponding
|
||||||
|
chat message history instance.
|
||||||
|
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||||
|
as input. The key in the input dict that contains the messages.
|
||||||
|
output_messages_key: Must be specified if the base Runnable returns a dict
|
||||||
|
as output. The key in the output dict that contains the messages.
|
||||||
|
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||||
|
as input and expects a separate key for historical messages.
|
||||||
|
history_factory_config: Configure fields that should be passed to the
|
||||||
|
chat history factory. See ``ConfigurableFieldSpec`` for more details.
|
||||||
|
|
||||||
Example: Chat message history with an in-memory implementation for testing.
|
Example: Chat message history with an in-memory implementation for testing.
|
||||||
|
|
||||||
@ -287,9 +300,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
...
|
...
|
||||||
|
|
||||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||||
as input.
|
as input. Default is None.
|
||||||
output_messages_key: Must be specified if the base runnable returns a dict
|
output_messages_key: Must be specified if the base runnable returns a dict
|
||||||
as output.
|
as output. Default is None.
|
||||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||||
as input and expects a separate key for historical messages.
|
as input and expects a separate key for historical messages.
|
||||||
history_factory_config: Configure fields that should be passed to the
|
history_factory_config: Configure fields that should be passed to the
|
||||||
@ -347,6 +360,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
|
"""Get the configuration specs for the RunnableWithMessageHistory."""
|
||||||
return get_unique_config_specs(
|
return get_unique_config_specs(
|
||||||
super().config_specs + list(self.history_factory_config)
|
super().config_specs + list(self.history_factory_config)
|
||||||
)
|
)
|
||||||
|
@ -53,19 +53,33 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def identity(x: Other) -> Other:
|
def identity(x: Other) -> Other:
|
||||||
"""Identity function"""
|
"""Identity function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Other): input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Other: output.
|
||||||
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
async def aidentity(x: Other) -> Other:
|
async def aidentity(x: Other) -> Other:
|
||||||
"""Async identity function"""
|
"""Async identity function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Other): input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Other: output.
|
||||||
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||||
"""Runnable to passthrough inputs unchanged or with additional keys.
|
"""Runnable to passthrough inputs unchanged or with additional keys.
|
||||||
|
|
||||||
This runnable behaves almost like the identity function, except that it
|
This Runnable behaves almost like the identity function, except that it
|
||||||
can be configured to add additional keys to the output, if the input is a
|
can be configured to add additional keys to the output, if the input is a
|
||||||
dict.
|
dict.
|
||||||
|
|
||||||
@ -73,6 +87,13 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
chains. The chains rely on simple lambdas to make the examples easy to execute
|
chains. The chains rely on simple lambdas to make the examples easy to execute
|
||||||
and experiment with.
|
and experiment with.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
func (Callable[[Other], None], optional): Function to be called with the input.
|
||||||
|
afunc (Callable[[Other], Awaitable[None]], optional): Async function to
|
||||||
|
be called with the input.
|
||||||
|
input_type (Optional[Type[Other]], optional): Type of the input.
|
||||||
|
**kwargs (Any): Additional keyword arguments.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -199,10 +220,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
"""Merge the Dict input with the output produced by the mapping argument.
|
"""Merge the Dict input with the output produced by the mapping argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mapping: A mapping from keys to runnables or callables.
|
**kwargs: Runnable, Callable or a Mapping from keys to Runnables
|
||||||
|
or Callables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A runnable that merges the Dict input with the output produced by the
|
A Runnable that merges the Dict input with the output produced by the
|
||||||
mapping argument.
|
mapping argument.
|
||||||
"""
|
"""
|
||||||
return RunnableAssign(RunnableParallel(kwargs))
|
return RunnableAssign(RunnableParallel(kwargs))
|
||||||
@ -336,6 +358,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
these with the original data, introducing new key-value pairs based
|
these with the original data, introducing new key-value pairs based
|
||||||
on the mapper's logic.
|
on the mapper's logic.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
mapper (RunnableParallel[Dict[str, Any]]): A `RunnableParallel` instance
|
||||||
|
that will be used to transform the input dictionary.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -627,11 +653,15 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||||
"""Runnable that picks keys from Dict[str, Any] inputs.
|
"""Runnable that picks keys from Dict[str, Any] inputs.
|
||||||
|
|
||||||
RunnablePick class represents a runnable that selectively picks keys from a
|
RunnablePick class represents a Runnable that selectively picks keys from a
|
||||||
dictionary input. It allows you to specify one or more keys to extract
|
dictionary input. It allows you to specify one or more keys to extract
|
||||||
from the input dictionary. It returns a new dictionary containing only
|
from the input dictionary. It returns a new dictionary containing only
|
||||||
the selected keys.
|
the selected keys.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
keys (Union[str, List[str]]): A single key or a list of keys to pick from
|
||||||
|
the input dictionary.
|
||||||
|
|
||||||
Example :
|
Example :
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
"""Whether to add jitter to the exponential backoff."""
|
"""Whether to add jitter to the exponential backoff."""
|
||||||
|
|
||||||
max_attempt_number: int = 3
|
max_attempt_number: int = 3
|
||||||
"""The maximum number of attempts to retry the runnable."""
|
"""The maximum number of attempts to retry the Runnable."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
@ -38,7 +38,7 @@ class RouterInput(TypedDict):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
key: The key to route on.
|
key: The key to route on.
|
||||||
input: The input to pass to the selected runnable.
|
input: The input to pass to the selected Runnable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
@ -50,6 +50,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
|||||||
Runnable that routes to a set of Runnables based on Input['key'].
|
Runnable that routes to a set of Runnables based on Input['key'].
|
||||||
Returns the output of the selected Runnable.
|
Returns the output of the selected Runnable.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
runnables: A mapping of keys to Runnables.
|
||||||
|
|
||||||
For example,
|
For example,
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Module contains typedefs that are used with runnables."""
|
"""Module contains typedefs that are used with Runnables."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ class EventData(TypedDict, total=False):
|
|||||||
"""Data associated with a streaming event."""
|
"""Data associated with a streaming event."""
|
||||||
|
|
||||||
input: Any
|
input: Any
|
||||||
"""The input passed to the runnable that generated the event.
|
"""The input passed to the Runnable that generated the event.
|
||||||
|
|
||||||
Inputs will sometimes be available at the *START* of the Runnable, and
|
Inputs will sometimes be available at the *START* of the Runnable, and
|
||||||
sometimes at the *END* of the Runnable.
|
sometimes at the *END* of the Runnable.
|
||||||
@ -86,39 +86,42 @@ class BaseStreamEvent(TypedDict):
|
|||||||
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
||||||
|
|
||||||
Runnable types are one of:
|
Runnable types are one of:
|
||||||
* llm - used by non chat models
|
|
||||||
* chat_model - used by chat models
|
- **llm** - used by non chat models
|
||||||
* prompt -- e.g., ChatPromptTemplate
|
- **chat_model** - used by chat models
|
||||||
* tool -- from tools defined via @tool decorator or inheriting from Tool/BaseTool
|
- **prompt** -- e.g., ChatPromptTemplate
|
||||||
* chain - most Runnables are of this type
|
- **tool** -- from tools defined via @tool decorator or inheriting
|
||||||
|
from Tool/BaseTool
|
||||||
|
- **chain** - most Runnables are of this type
|
||||||
|
|
||||||
Further, the events are categorized as one of:
|
Further, the events are categorized as one of:
|
||||||
* start - when the runnable starts
|
|
||||||
* stream - when the runnable is streaming
|
- **start** - when the Runnable starts
|
||||||
* end - when the runnable ends
|
- **stream** - when the Runnable is streaming
|
||||||
|
- **end* - when the Runnable ends
|
||||||
|
|
||||||
start, stream and end are associated with slightly different `data` payload.
|
start, stream and end are associated with slightly different `data` payload.
|
||||||
|
|
||||||
Please see the documentation for `EventData` for more details.
|
Please see the documentation for `EventData` for more details.
|
||||||
"""
|
"""
|
||||||
run_id: str
|
run_id: str
|
||||||
"""An randomly generated ID to keep track of the execution of the given runnable.
|
"""An randomly generated ID to keep track of the execution of the given Runnable.
|
||||||
|
|
||||||
Each child runnable that gets invoked as part of the execution of a parent runnable
|
Each child Runnable that gets invoked as part of the execution of a parent Runnable
|
||||||
is assigned its own unique ID.
|
is assigned its own unique ID.
|
||||||
"""
|
"""
|
||||||
tags: NotRequired[List[str]]
|
tags: NotRequired[List[str]]
|
||||||
"""Tags associated with the runnable that generated this event.
|
"""Tags associated with the Runnable that generated this event.
|
||||||
|
|
||||||
Tags are always inherited from parent runnables.
|
Tags are always inherited from parent Runnables.
|
||||||
|
|
||||||
Tags can either be bound to a runnable using `.with_config({"tags": ["hello"]})`
|
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
|
||||||
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||||
"""
|
"""
|
||||||
metadata: NotRequired[Dict[str, Any]]
|
metadata: NotRequired[Dict[str, Any]]
|
||||||
"""Metadata associated with the runnable that generated this event.
|
"""Metadata associated with the Runnable that generated this event.
|
||||||
|
|
||||||
Metadata can either be bound to a runnable using
|
Metadata can either be bound to a Runnable using
|
||||||
|
|
||||||
`.with_config({"metadata": { "foo": "bar" }})`
|
`.with_config({"metadata": { "foo": "bar" }})`
|
||||||
|
|
||||||
@ -150,21 +153,20 @@ class StandardStreamEvent(BaseStreamEvent):
|
|||||||
The contents of the event data depend on the event type.
|
The contents of the event data depend on the event type.
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
"""The name of the runnable that generated the event."""
|
"""The name of the Runnable that generated the event."""
|
||||||
|
|
||||||
|
|
||||||
class CustomStreamEvent(BaseStreamEvent):
|
class CustomStreamEvent(BaseStreamEvent):
|
||||||
"""A custom stream event created by the user.
|
"""Custom stream event created by the user.
|
||||||
|
|
||||||
.. versionadded:: 0.2.14
|
.. versionadded:: 0.2.14
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Overwrite the event field to be more specific.
|
# Overwrite the event field to be more specific.
|
||||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||||
|
|
||||||
"""The event type."""
|
"""The event type."""
|
||||||
name: str
|
name: str
|
||||||
"""A user defined name for the event."""
|
"""User defined name for the event."""
|
||||||
data: Any
|
data: Any
|
||||||
"""The data associated with the event. Free form and can be anything."""
|
"""The data associated with the event. Free form and can be anything."""
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ Output = TypeVar("Output", covariant=True)
|
|||||||
|
|
||||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
"""Run a coroutine with a semaphore.
|
"""Run a coroutine with a semaphore.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
semaphore: The semaphore to use.
|
semaphore: The semaphore to use.
|
||||||
coro: The coroutine to run.
|
coro: The coroutine to run.
|
||||||
@ -59,7 +60,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: The number of coroutines to run concurrently.
|
n: The number of coroutines to run concurrently.
|
||||||
coros: The coroutines to run.
|
*coros: The coroutines to run.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The results of the coroutines.
|
The results of the coroutines.
|
||||||
@ -73,7 +74,14 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
|||||||
|
|
||||||
|
|
||||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||||
"""Check if a callable accepts a run_manager argument."""
|
"""Check if a callable accepts a run_manager argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: The callable to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the callable accepts a run_manager argument, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return signature(callable).parameters.get("run_manager") is not None
|
return signature(callable).parameters.get("run_manager") is not None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -81,7 +89,14 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||||
"""Check if a callable accepts a config argument."""
|
"""Check if a callable accepts a config argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: The callable to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the callable accepts a config argument, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return signature(callable).parameters.get("config") is not None
|
return signature(callable).parameters.get("config") is not None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -89,7 +104,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def accepts_context(callable: Callable[..., Any]) -> bool:
|
def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||||
"""Check if a callable accepts a context argument."""
|
"""Check if a callable accepts a context argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: The callable to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the callable accepts a context argument, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return signature(callable).parameters.get("context") is not None
|
return signature(callable).parameters.get("context") is not None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -100,10 +122,24 @@ class IsLocalDict(ast.NodeVisitor):
|
|||||||
"""Check if a name is a local dict."""
|
"""Check if a name is a local dict."""
|
||||||
|
|
||||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||||
|
"""Initialize the visitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to check.
|
||||||
|
keys: The keys to populate.
|
||||||
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
|
||||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||||
|
"""Visit a subscript node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
isinstance(node.ctx, ast.Load)
|
isinstance(node.ctx, ast.Load)
|
||||||
and isinstance(node.value, ast.Name)
|
and isinstance(node.value, ast.Name)
|
||||||
@ -115,6 +151,14 @@ class IsLocalDict(ast.NodeVisitor):
|
|||||||
self.keys.add(node.slice.value)
|
self.keys.add(node.slice.value)
|
||||||
|
|
||||||
def visit_Call(self, node: ast.Call) -> Any:
|
def visit_Call(self, node: ast.Call) -> Any:
|
||||||
|
"""Visit a call node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
isinstance(node.func, ast.Attribute)
|
isinstance(node.func, ast.Attribute)
|
||||||
and isinstance(node.func.value, ast.Name)
|
and isinstance(node.func.value, ast.Name)
|
||||||
@ -135,18 +179,42 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
self.keys: Set[str] = set()
|
self.keys: Set[str] = set()
|
||||||
|
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
|
"""Visit a lambda function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if not node.args.args:
|
if not node.args.args:
|
||||||
return
|
return
|
||||||
input_arg_name = node.args.args[0].arg
|
input_arg_name = node.args.args[0].arg
|
||||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||||
|
|
||||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
|
"""Visit a function definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if not node.args.args:
|
if not node.args.args:
|
||||||
return
|
return
|
||||||
input_arg_name = node.args.args[0].arg
|
input_arg_name = node.args.args[0].arg
|
||||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
|
"""Visit an async function definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if not node.args.args:
|
if not node.args.args:
|
||||||
return
|
return
|
||||||
input_arg_name = node.args.args[0].arg
|
input_arg_name = node.args.args[0].arg
|
||||||
@ -161,12 +229,28 @@ class NonLocals(ast.NodeVisitor):
|
|||||||
self.stores: Set[str] = set()
|
self.stores: Set[str] = set()
|
||||||
|
|
||||||
def visit_Name(self, node: ast.Name) -> Any:
|
def visit_Name(self, node: ast.Name) -> Any:
|
||||||
|
"""Visit a name node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if isinstance(node.ctx, ast.Load):
|
if isinstance(node.ctx, ast.Load):
|
||||||
self.loads.add(node.id)
|
self.loads.add(node.id)
|
||||||
elif isinstance(node.ctx, ast.Store):
|
elif isinstance(node.ctx, ast.Store):
|
||||||
self.stores.add(node.id)
|
self.stores.add(node.id)
|
||||||
|
|
||||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||||
|
"""Visit an attribute node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
if isinstance(node.ctx, ast.Load):
|
if isinstance(node.ctx, ast.Load):
|
||||||
parent = node.value
|
parent = node.value
|
||||||
attr_expr = node.attr
|
attr_expr = node.attr
|
||||||
@ -185,16 +269,40 @@ class FunctionNonLocals(ast.NodeVisitor):
|
|||||||
self.nonlocals: Set[str] = set()
|
self.nonlocals: Set[str] = set()
|
||||||
|
|
||||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
|
"""Visit a function definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
visitor = NonLocals()
|
visitor = NonLocals()
|
||||||
visitor.visit(node)
|
visitor.visit(node)
|
||||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
|
"""Visit an async function definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
visitor = NonLocals()
|
visitor = NonLocals()
|
||||||
visitor.visit(node)
|
visitor.visit(node)
|
||||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
|
"""Visit a lambda function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
visitor = NonLocals()
|
visitor = NonLocals()
|
||||||
visitor.visit(node)
|
visitor.visit(node)
|
||||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
@ -209,14 +317,29 @@ class GetLambdaSource(ast.NodeVisitor):
|
|||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
"""Visit a lambda function."""
|
"""Visit a lambda function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node to visit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of the visit.
|
||||||
|
"""
|
||||||
self.count += 1
|
self.count += 1
|
||||||
if hasattr(ast, "unparse"):
|
if hasattr(ast, "unparse"):
|
||||||
self.source = ast.unparse(node)
|
self.source = ast.unparse(node)
|
||||||
|
|
||||||
|
|
||||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||||
"""Get the keys of the first argument of a function if it is a dict."""
|
"""Get the keys of the first argument of a function if it is a dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[str]]: The keys of the first argument if it is a dict,
|
||||||
|
None otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
code = inspect.getsource(func)
|
code = inspect.getsource(func)
|
||||||
tree = ast.parse(textwrap.dedent(code))
|
tree = ast.parse(textwrap.dedent(code))
|
||||||
@ -231,10 +354,10 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
|||||||
"""Get the source code of a lambda function.
|
"""Get the source code of a lambda function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: a callable that can be a lambda function
|
func: a Callable that can be a lambda function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: the source code of the lambda function
|
str: the source code of the lambda function.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
name = func.__name__ if func.__name__ != "<lambda>" else None
|
name = func.__name__ if func.__name__ != "<lambda>" else None
|
||||||
@ -251,7 +374,14 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def get_function_nonlocals(func: Callable) -> List[Any]:
|
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||||
"""Get the nonlocal variables accessed by a function."""
|
"""Get the nonlocal variables accessed by a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any]: The nonlocal variables accessed by the function.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
code = inspect.getsource(func)
|
code = inspect.getsource(func)
|
||||||
tree = ast.parse(textwrap.dedent(code))
|
tree = ast.parse(textwrap.dedent(code))
|
||||||
@ -283,11 +413,11 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
|||||||
"""Indent all lines of text after the first line.
|
"""Indent all lines of text after the first line.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to indent
|
text: The text to indent.
|
||||||
prefix: Used to determine the number of spaces to indent
|
prefix: Used to determine the number of spaces to indent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The indented text
|
str: The indented text.
|
||||||
"""
|
"""
|
||||||
n_spaces = len(prefix)
|
n_spaces = len(prefix)
|
||||||
spaces = " " * n_spaces
|
spaces = " " * n_spaces
|
||||||
@ -341,7 +471,14 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
|||||||
|
|
||||||
|
|
||||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||||
"""Add a sequence of addable objects together."""
|
"""Add a sequence of addable objects together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
addables: The addable objects to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Addable]: The result of adding the addable objects.
|
||||||
|
"""
|
||||||
final = None
|
final = None
|
||||||
for chunk in addables:
|
for chunk in addables:
|
||||||
if final is None:
|
if final is None:
|
||||||
@ -352,7 +489,14 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
|||||||
|
|
||||||
|
|
||||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||||
"""Asynchronously add a sequence of addable objects together."""
|
"""Asynchronously add a sequence of addable objects together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
addables: The addable objects to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Addable]: The result of adding the addable objects.
|
||||||
|
"""
|
||||||
final = None
|
final = None
|
||||||
async for chunk in addables:
|
async for chunk in addables:
|
||||||
if final is None:
|
if final is None:
|
||||||
@ -363,7 +507,15 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
|||||||
|
|
||||||
|
|
||||||
class ConfigurableField(NamedTuple):
|
class ConfigurableField(NamedTuple):
|
||||||
"""Field that can be configured by the user."""
|
"""Field that can be configured by the user.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
id: The unique identifier of the field.
|
||||||
|
name: The name of the field. Defaults to None.
|
||||||
|
description: The description of the field. Defaults to None.
|
||||||
|
annotation: The annotation of the field. Defaults to None.
|
||||||
|
is_shared: Whether the field is shared. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
@ -377,7 +529,16 @@ class ConfigurableField(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class ConfigurableFieldSingleOption(NamedTuple):
|
class ConfigurableFieldSingleOption(NamedTuple):
|
||||||
"""Field that can be configured by the user with a default value."""
|
"""Field that can be configured by the user with a default value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
id: The unique identifier of the field.
|
||||||
|
options: The options for the field.
|
||||||
|
default: The default value for the field.
|
||||||
|
name: The name of the field. Defaults to None.
|
||||||
|
description: The description of the field. Defaults to None.
|
||||||
|
is_shared: Whether the field is shared. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
options: Mapping[str, Any]
|
options: Mapping[str, Any]
|
||||||
@ -392,7 +553,16 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class ConfigurableFieldMultiOption(NamedTuple):
|
class ConfigurableFieldMultiOption(NamedTuple):
|
||||||
"""Field that can be configured by the user with multiple default values."""
|
"""Field that can be configured by the user with multiple default values.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
id: The unique identifier of the field.
|
||||||
|
options: The options for the field.
|
||||||
|
default: The default values for the field.
|
||||||
|
name: The name of the field. Defaults to None.
|
||||||
|
description: The description of the field. Defaults to None.
|
||||||
|
is_shared: Whether the field is shared. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
options: Mapping[str, Any]
|
options: Mapping[str, Any]
|
||||||
@ -412,7 +582,17 @@ AnyConfigurableField = Union[
|
|||||||
|
|
||||||
|
|
||||||
class ConfigurableFieldSpec(NamedTuple):
|
class ConfigurableFieldSpec(NamedTuple):
|
||||||
"""Field that can be configured by the user. It is a specification of a field."""
|
"""Field that can be configured by the user. It is a specification of a field.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
id: The unique identifier of the field.
|
||||||
|
annotation: The annotation of the field.
|
||||||
|
name: The name of the field. Defaults to None.
|
||||||
|
description: The description of the field. Defaults to None.
|
||||||
|
default: The default value for the field. Defaults to None.
|
||||||
|
is_shared: Whether the field is shared. Defaults to False.
|
||||||
|
dependencies: The dependencies of the field. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
annotation: Any
|
annotation: Any
|
||||||
@ -427,7 +607,17 @@ class ConfigurableFieldSpec(NamedTuple):
|
|||||||
def get_unique_config_specs(
|
def get_unique_config_specs(
|
||||||
specs: Iterable[ConfigurableFieldSpec],
|
specs: Iterable[ConfigurableFieldSpec],
|
||||||
) -> List[ConfigurableFieldSpec]:
|
) -> List[ConfigurableFieldSpec]:
|
||||||
"""Get the unique config specs from a sequence of config specs."""
|
"""Get the unique config specs from a sequence of config specs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
specs: The config specs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ConfigurableFieldSpec]: The unique config specs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the runnable sequence contains conflicting config specs.
|
||||||
|
"""
|
||||||
grouped = groupby(
|
grouped = groupby(
|
||||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||||
)
|
)
|
||||||
@ -542,7 +732,15 @@ def _create_model_cached(
|
|||||||
def is_async_generator(
|
def is_async_generator(
|
||||||
func: Any,
|
func: Any,
|
||||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||||
"""Check if a function is an async generator."""
|
"""Check if a function is an async generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
||||||
|
an async generator, False otherwise.
|
||||||
|
"""
|
||||||
return (
|
return (
|
||||||
inspect.isasyncgenfunction(func)
|
inspect.isasyncgenfunction(func)
|
||||||
or hasattr(func, "__call__")
|
or hasattr(func, "__call__")
|
||||||
@ -553,7 +751,15 @@ def is_async_generator(
|
|||||||
def is_async_callable(
|
def is_async_callable(
|
||||||
func: Any,
|
func: Any,
|
||||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||||
"""Check if a function is async."""
|
"""Check if a function is async.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
return (
|
return (
|
||||||
asyncio.iscoroutinefunction(func)
|
asyncio.iscoroutinefunction(func)
|
||||||
or hasattr(func, "__call__")
|
or hasattr(func, "__call__")
|
||||||
|
Loading…
Reference in New Issue
Block a user