mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
core[patch]: docstrings runnables
update (#24161)
Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
parent
14ba1d4b45
commit
aa3e3cfa40
File diff suppressed because it is too large
Load Diff
@ -48,6 +48,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Parameters:
|
||||
branches: A list of (condition, Runnable) pairs.
|
||||
default: A Runnable to run if no condition is met.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@ -82,7 +86,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
"""A Runnable that runs one of two branches based on a condition.
|
||||
|
||||
Args:
|
||||
*branches: A list of (condition, Runnable) pairs.
|
||||
Defaults a Runnable to run if no condition is met.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of branches is less than 2.
|
||||
TypeError: If the default branch is not Runnable, Callable or Mapping.
|
||||
TypeError: If a branch is not a tuple or list.
|
||||
ValueError: If a branch is not of length 2.
|
||||
"""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
@ -93,7 +108,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
(Runnable, Callable, Mapping), # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
"RunnableBranch default must be Runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
@ -176,7 +191,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
"""First evaluates the condition, then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Returns:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -277,7 +304,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
"""First evaluates the condition,
|
||||
then delegate to true or false branch."""
|
||||
then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Yields:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
BaseException: If an error occurs during the execution of the Runnable.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -352,7 +391,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
"""First evaluates the condition,
|
||||
then delegate to true or false branch."""
|
||||
then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Yields:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
BaseException: If an error occurs during the execution of the Runnable.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
|
@ -111,7 +111,7 @@ var_child_runnable_config = ContextVar(
|
||||
|
||||
|
||||
def _set_config_context(config: RunnableConfig) -> None:
|
||||
"""Set the child runnable config + tracing context
|
||||
"""Set the child Runnable config + tracing context
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config to set.
|
||||
@ -216,7 +216,6 @@ def patch_config(
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]): The config to patch.
|
||||
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
|
||||
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
||||
Defaults to None.
|
||||
recursion_limit (Optional[int], optional): The recursion limit to set.
|
||||
@ -362,9 +361,9 @@ def call_func_with_variable_args(
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function. Defaults to None.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
@ -395,7 +394,7 @@ def acall_func_with_variable_args(
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[Output]:
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
"""Async call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
||||
@ -403,9 +402,9 @@ def acall_func_with_variable_args(
|
||||
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function. Defaults to None.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
@ -493,6 +492,18 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||
timeout: float | None = None,
|
||||
chunksize: int = 1,
|
||||
) -> Iterator[T]:
|
||||
"""Map a function to multiple iterables.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., T]): The function to map.
|
||||
*iterables (Iterable[Any]): The iterables to map over.
|
||||
timeout (float | None, optional): The timeout for the map.
|
||||
Defaults to None.
|
||||
chunksize (int, optional): The chunksize for the map. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Iterator[T]: The iterator for the mapped function.
|
||||
"""
|
||||
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||
|
||||
def _wrapped_fn(*args: Any) -> T:
|
||||
@ -534,13 +545,16 @@ async def run_in_executor(
|
||||
"""Run a function in an executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor.
|
||||
executor_or_config: The executor or config to run in.
|
||||
func (Callable[P, Output]): The function.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function raises a StopIteration.
|
||||
"""
|
||||
|
||||
def wrapper() -> T:
|
||||
|
@ -44,7 +44,15 @@ from langchain_core.runnables.utils import (
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
"""Serializable Runnable that can be dynamically configured."""
|
||||
"""Serializable Runnable that can be dynamically configured.
|
||||
|
||||
A DynamicRunnable should be initiated using the `configurable_fields` or
|
||||
`configurable_alternatives` method of a Runnable.
|
||||
|
||||
Parameters:
|
||||
default: The default Runnable to use.
|
||||
config: The configuration to use.
|
||||
"""
|
||||
|
||||
default: RunnableSerializable[Input, Output]
|
||||
|
||||
@ -99,6 +107,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
def prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
"""Prepare the Runnable for invocation.
|
||||
|
||||
Args:
|
||||
config: The configuration to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[Runnable[Input, Output], RunnableConfig]: The prepared Runnable and
|
||||
configuration.
|
||||
"""
|
||||
runnable: Runnable[Input, Output] = self
|
||||
while isinstance(runnable, DynamicRunnable):
|
||||
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
|
||||
@ -284,6 +301,9 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
A RunnableConfigurableFields should be initiated using the
|
||||
`configurable_fields` method of a Runnable.
|
||||
|
||||
Parameters:
|
||||
fields: The configurable fields to use.
|
||||
|
||||
Here is an example of using a RunnableConfigurableFields with LLMs:
|
||||
|
||||
.. code-block:: python
|
||||
@ -348,6 +368,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableConfigurableFields.
|
||||
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The configuration specs.
|
||||
"""
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
(
|
||||
@ -374,6 +399,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
"""Get a new RunnableConfigurableFields with the specified
|
||||
configurable fields."""
|
||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
def _prepare(
|
||||
@ -493,11 +520,13 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
""" # noqa: E501
|
||||
|
||||
which: ConfigurableField
|
||||
"""The ConfigurableField to use to choose between alternatives."""
|
||||
|
||||
alternatives: Dict[
|
||||
str,
|
||||
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||
]
|
||||
"""The alternatives to choose from."""
|
||||
|
||||
default_key: str = "default"
|
||||
"""The enum value to use for the default option. Defaults to "default"."""
|
||||
@ -619,7 +648,7 @@ def prefix_config_spec(
|
||||
prefix: The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
ConfigurableFieldSpec: The prefixed ConfigurableFieldSpec.
|
||||
"""
|
||||
return (
|
||||
ConfigurableFieldSpec(
|
||||
@ -641,6 +670,13 @@ def make_options_spec(
|
||||
) -> ConfigurableFieldSpec:
|
||||
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
|
||||
ConfigurableFieldMultiOption.
|
||||
|
||||
Args:
|
||||
spec: The ConfigurableFieldSingleOption or ConfigurableFieldMultiOption.
|
||||
description: The description to use if the spec does not have one.
|
||||
|
||||
Returns:
|
||||
The ConfigurableFieldSpec.
|
||||
"""
|
||||
with _enums_for_spec_lock:
|
||||
if enum := _enums_for_spec.get(spec):
|
||||
|
@ -91,7 +91,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
"""The runnable to run first."""
|
||||
"""The Runnable to run first."""
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
"""A sequence of fallbacks to try."""
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
@ -102,7 +102,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
exception_key: Optional[str] = None
|
||||
"""If string is specified then handled exceptions will be passed to fallbacks as
|
||||
part of the input under the specified key. If None, exceptions
|
||||
will not be passed to fallbacks. If used, the base runnable and its fallbacks
|
||||
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
|
||||
must accept a dictionary as input."""
|
||||
|
||||
class Config:
|
||||
@ -554,7 +554,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
await run_manager.on_chain_end(output)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get an attribute from the wrapped runnable and its fallbacks.
|
||||
"""Get an attribute from the wrapped Runnable and its fallbacks.
|
||||
|
||||
Returns:
|
||||
If the attribute is anything other than a method that outputs a Runnable,
|
||||
|
@ -57,7 +57,14 @@ def is_uuid(value: str) -> bool:
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
"""Edge in a graph."""
|
||||
"""Edge in a graph.
|
||||
|
||||
Parameters:
|
||||
source: The source node id.
|
||||
target: The target node id.
|
||||
data: Optional data associated with the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
@ -67,6 +74,15 @@ class Edge(NamedTuple):
|
||||
def copy(
|
||||
self, *, source: Optional[str] = None, target: Optional[str] = None
|
||||
) -> Edge:
|
||||
"""Return a copy of the edge with optional new source and target nodes.
|
||||
|
||||
Args:
|
||||
source: The new source node id. Defaults to None.
|
||||
target: The new target node id. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A copy of the edge with the new source and target nodes.
|
||||
"""
|
||||
return Edge(
|
||||
source=source or self.source,
|
||||
target=target or self.target,
|
||||
@ -76,7 +92,14 @@ class Edge(NamedTuple):
|
||||
|
||||
|
||||
class Node(NamedTuple):
|
||||
"""Node in a graph."""
|
||||
"""Node in a graph.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the node.
|
||||
name: The name of the node.
|
||||
data: The data of the node.
|
||||
metadata: Optional metadata for the node. Defaults to None.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
@ -84,6 +107,15 @@ class Node(NamedTuple):
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
"""Return a copy of the node with optional new id and name.
|
||||
|
||||
Args:
|
||||
id: The new node id. Defaults to None.
|
||||
name: The new node name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A copy of the node with the new id and name.
|
||||
"""
|
||||
return Node(
|
||||
id=id or self.id,
|
||||
name=name or self.name,
|
||||
@ -93,7 +125,13 @@ class Node(NamedTuple):
|
||||
|
||||
|
||||
class Branch(NamedTuple):
|
||||
"""Branch in a graph."""
|
||||
"""Branch in a graph.
|
||||
|
||||
Parameters:
|
||||
condition: A callable that returns a string representation of the condition.
|
||||
ends: Optional dictionary of end node ids for the branches. Defaults
|
||||
to None.
|
||||
"""
|
||||
|
||||
condition: Callable[..., str]
|
||||
ends: Optional[dict[str, str]]
|
||||
@ -118,7 +156,13 @@ class CurveStyle(Enum):
|
||||
|
||||
@dataclass
|
||||
class NodeStyles:
|
||||
"""Schema for Hexadecimal color codes for different node types"""
|
||||
"""Schema for Hexadecimal color codes for different node types.
|
||||
|
||||
Parameters:
|
||||
default: The default color code. Defaults to "fill:#f2f0ff,line-height:1.2".
|
||||
first: The color code for the first node. Defaults to "fill-opacity:0".
|
||||
last: The color code for the last node. Defaults to "fill:#bfb6fc".
|
||||
"""
|
||||
|
||||
default: str = "fill:#f2f0ff,line-height:1.2"
|
||||
first: str = "fill-opacity:0"
|
||||
@ -161,7 +205,7 @@ def node_data_json(
|
||||
Args:
|
||||
node: The node to convert.
|
||||
with_schemas: Whether to include the schema of the data if
|
||||
it is a Pydantic model.
|
||||
it is a Pydantic model. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A dictionary with the type of the data and the data itself.
|
||||
@ -209,13 +253,26 @@ def node_data_json(
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
"""Graph of nodes and edges."""
|
||||
"""Graph of nodes and edges.
|
||||
|
||||
Parameters:
|
||||
nodes: Dictionary of nodes in the graph. Defaults to an empty dictionary.
|
||||
edges: List of edges in the graph. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Convert the graph to a JSON-serializable format."""
|
||||
"""Convert the graph to a JSON-serializable format.
|
||||
|
||||
Args:
|
||||
with_schemas: Whether to include the schemas of the nodes if they are
|
||||
Pydantic models. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A dictionary with the nodes and edges of the graph.
|
||||
"""
|
||||
stable_node_ids = {
|
||||
node.id: i if is_uuid(node.id) else node.id
|
||||
for i, node in enumerate(self.nodes.values())
|
||||
@ -247,6 +304,8 @@ class Graph:
|
||||
return bool(self.nodes)
|
||||
|
||||
def next_id(self) -> str:
|
||||
"""Return a new unique node
|
||||
identifier that can be used to add a node to the graph."""
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(
|
||||
@ -256,7 +315,19 @@ class Graph:
|
||||
*,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
"""Add a node to the graph and return it.
|
||||
|
||||
Args:
|
||||
data: The data of the node.
|
||||
id: The id of the node. Defaults to None.
|
||||
metadata: Optional metadata for the node. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The node that was added to the graph.
|
||||
|
||||
Raises:
|
||||
ValueError: If a node with the same id already exists.
|
||||
"""
|
||||
if id is not None and id in self.nodes:
|
||||
raise ValueError(f"Node with id {id} already exists")
|
||||
id = id or self.next_id()
|
||||
@ -265,7 +336,11 @@ class Graph:
|
||||
return node
|
||||
|
||||
def remove_node(self, node: Node) -> None:
|
||||
"""Remove a node from the graph and all edges connected to it."""
|
||||
"""Remove a node from the graph and all edges connected to it.
|
||||
|
||||
Args:
|
||||
node: The node to remove.
|
||||
"""
|
||||
self.nodes.pop(node.id)
|
||||
self.edges = [
|
||||
edge
|
||||
@ -280,7 +355,20 @@ class Graph:
|
||||
data: Optional[Stringifiable] = None,
|
||||
conditional: bool = False,
|
||||
) -> Edge:
|
||||
"""Add an edge to the graph and return it."""
|
||||
"""Add an edge to the graph and return it.
|
||||
|
||||
Args:
|
||||
source: The source node of the edge.
|
||||
target: The target node of the edge.
|
||||
data: Optional data associated with the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The edge that was added to the graph.
|
||||
|
||||
Raises:
|
||||
ValueError: If the source or target node is not in the graph.
|
||||
"""
|
||||
if source.id not in self.nodes:
|
||||
raise ValueError(f"Source node {source.id} not in graph")
|
||||
if target.id not in self.nodes:
|
||||
@ -295,7 +383,15 @@ class Graph:
|
||||
self, graph: Graph, *, prefix: str = ""
|
||||
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||
"""Add all nodes and edges from another graph.
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs.
|
||||
|
||||
Args:
|
||||
graph: The graph to add.
|
||||
prefix: The prefix to add to the node ids. Defaults to "".
|
||||
|
||||
Returns:
|
||||
A tuple of the first and last nodes of the subgraph.
|
||||
"""
|
||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||
prefix = ""
|
||||
|
||||
@ -350,7 +446,7 @@ class Graph:
|
||||
def first_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the origin."""
|
||||
When drawing the graph, this node would be the origin."""
|
||||
targets = {edge.target for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
@ -361,7 +457,7 @@ class Graph:
|
||||
def last_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a source of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the destination.
|
||||
When drawing the graph, this node would be the destination.
|
||||
"""
|
||||
sources = {edge.source for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
@ -372,7 +468,7 @@ class Graph:
|
||||
|
||||
def trim_first_node(self) -> None:
|
||||
"""Remove the first node if it exists and has a single outgoing edge,
|
||||
ie. if removing it would not leave the graph without a "first" node."""
|
||||
i.e., if removing it would not leave the graph without a "first" node."""
|
||||
first_node = self.first_node()
|
||||
if first_node:
|
||||
if (
|
||||
@ -384,7 +480,7 @@ class Graph:
|
||||
|
||||
def trim_last_node(self) -> None:
|
||||
"""Remove the last node if it exists and has a single incoming edge,
|
||||
ie. if removing it would not leave the graph without a "last" node."""
|
||||
i.e., if removing it would not leave the graph without a "last" node."""
|
||||
last_node = self.last_node()
|
||||
if last_node:
|
||||
if (
|
||||
@ -395,6 +491,7 @@ class Graph:
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
"""Draw the graph as an ASCII art string."""
|
||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||
|
||||
return draw_ascii(
|
||||
@ -403,6 +500,7 @@ class Graph:
|
||||
)
|
||||
|
||||
def print_ascii(self) -> None:
|
||||
"""Print the graph as an ASCII art string."""
|
||||
print(self.draw_ascii()) # noqa: T201
|
||||
|
||||
@overload
|
||||
@ -427,6 +525,17 @@ class Graph:
|
||||
fontname: Optional[str] = None,
|
||||
labels: Optional[LabelsDict] = None,
|
||||
) -> Union[bytes, None]:
|
||||
"""Draw the graph as a PNG image.
|
||||
|
||||
Args:
|
||||
output_file_path: The path to save the image to. If None, the image
|
||||
is not saved. Defaults to None.
|
||||
fontname: The name of the font to use. Defaults to None.
|
||||
labels: Optional labels for nodes and edges in the graph. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes if output_file_path is None, None otherwise.
|
||||
"""
|
||||
from langchain_core.runnables.graph_png import PngDrawer
|
||||
|
||||
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||
@ -450,6 +559,18 @@ class Graph:
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draw the graph as a Mermaid syntax string.
|
||||
|
||||
Args:
|
||||
with_styles: Whether to include styles in the syntax. Defaults to True.
|
||||
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||
Defaults to 9.
|
||||
|
||||
Returns:
|
||||
The Mermaid syntax string.
|
||||
"""
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||
|
||||
graph = self.reid()
|
||||
@ -478,6 +599,23 @@ class Graph:
|
||||
background_color: str = "white",
|
||||
padding: int = 10,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
Args:
|
||||
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||
Defaults to 9.
|
||||
output_file_path: The path to save the image to. If None, the image
|
||||
is not saved. Defaults to None.
|
||||
draw_method: The method to use to draw the graph.
|
||||
Defaults to MermaidDrawMethod.API.
|
||||
background_color: The color of the background. Defaults to "white".
|
||||
padding: The padding around the graph. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
"""
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
|
||||
|
||||
mermaid_syntax = self.draw_mermaid(
|
||||
|
@ -17,6 +17,7 @@ class VertexViewer:
|
||||
"""
|
||||
|
||||
HEIGHT = 3 # top and bottom box edges + text
|
||||
"""Height of the box."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self._h = self.HEIGHT # top and bottom box edges + text
|
||||
|
@ -23,18 +23,25 @@ def draw_mermaid(
|
||||
node_styles: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data
|
||||
"""Draws a Mermaid graph using the provided graph data.
|
||||
|
||||
Args:
|
||||
nodes (dict[str, str]): List of node ids
|
||||
edges (List[Edge]): List of edges, object with source,
|
||||
target and data.
|
||||
nodes (dict[str, str]): List of node ids.
|
||||
edges (List[Edge]): List of edges, object with a source,
|
||||
target and data.
|
||||
first_node (str, optional): Id of the first node. Defaults to None.
|
||||
last_node (str, optional): Id of the last node. Defaults to None.
|
||||
with_styles (bool, optional): Whether to include styles in the graph.
|
||||
Defaults to True.
|
||||
curve_style (CurveStyle, optional): Curve style for the edges.
|
||||
node_colors (NodeColors, optional): Node colors for different types.
|
||||
Defaults to CurveStyle.LINEAR.
|
||||
node_styles (NodeStyles, optional): Node colors for different types.
|
||||
Defaults to NodeStyles().
|
||||
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
||||
Defaults to 9.
|
||||
|
||||
Returns:
|
||||
str: Mermaid graph syntax
|
||||
str: Mermaid graph syntax.
|
||||
"""
|
||||
# Initialize Mermaid graph configuration
|
||||
mermaid_graph = (
|
||||
@ -139,7 +146,24 @@ def draw_mermaid_png(
|
||||
background_color: Optional[str] = "white",
|
||||
padding: int = 10,
|
||||
) -> bytes:
|
||||
"""Draws a Mermaid graph as PNG using provided syntax."""
|
||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||
|
||||
Args:
|
||||
mermaid_syntax (str): Mermaid graph syntax.
|
||||
output_file_path (str, optional): Path to save the PNG image.
|
||||
Defaults to None.
|
||||
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
|
||||
Defaults to MermaidDrawMethod.API.
|
||||
background_color (str, optional): Background color of the image.
|
||||
Defaults to "white".
|
||||
padding (int, optional): Padding around the image. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
bytes: PNG image bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid draw method is provided.
|
||||
"""
|
||||
if draw_method == MermaidDrawMethod.PYPPETEER:
|
||||
import asyncio
|
||||
|
||||
|
@ -6,7 +6,7 @@ from langchain_core.runnables.graph import Graph, LabelsDict
|
||||
class PngDrawer:
|
||||
"""Helper class to draw a state graph into a PNG file.
|
||||
|
||||
It requires graphviz and pygraphviz to be installed.
|
||||
It requires `graphviz` and `pygraphviz` to be installed.
|
||||
:param fontname: The font to use for the labels
|
||||
:param labels: A dictionary of label overrides. The dictionary
|
||||
should have the following format:
|
||||
@ -33,7 +33,7 @@ class PngDrawer:
|
||||
"""Initializes the PNG drawer.
|
||||
|
||||
Args:
|
||||
fontname: The font to use for the labels
|
||||
fontname: The font to use for the labels. Defaults to "arial".
|
||||
labels: A dictionary of label overrides. The dictionary
|
||||
should have the following format:
|
||||
{
|
||||
@ -48,6 +48,7 @@ class PngDrawer:
|
||||
}
|
||||
}
|
||||
The keys are the original labels, and the values are the new labels.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.fontname = fontname or "arial"
|
||||
self.labels = labels or LabelsDict(nodes={}, edges={})
|
||||
@ -56,7 +57,7 @@ class PngDrawer:
|
||||
"""Returns the label to use for a node.
|
||||
|
||||
Args:
|
||||
label: The original label
|
||||
label: The original label.
|
||||
|
||||
Returns:
|
||||
The new label.
|
||||
@ -68,7 +69,7 @@ class PngDrawer:
|
||||
"""Returns the label to use for an edge.
|
||||
|
||||
Args:
|
||||
label: The original label
|
||||
label: The original label.
|
||||
|
||||
Returns:
|
||||
The new label.
|
||||
@ -80,8 +81,8 @@ class PngDrawer:
|
||||
"""Adds a node to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object
|
||||
node: The node to add
|
||||
viz: The graphviz object.
|
||||
node: The node to add.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@ -106,9 +107,9 @@ class PngDrawer:
|
||||
"""Adds an edge to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object
|
||||
source: The source node
|
||||
target: The target node
|
||||
viz: The graphviz object.
|
||||
source: The source node.
|
||||
target: The target node.
|
||||
label: The label for the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
|
||||
@ -127,7 +128,7 @@ class PngDrawer:
|
||||
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
||||
"""Draw the given state graph into a PNG file.
|
||||
|
||||
Requires graphviz and pygraphviz to be installed.
|
||||
Requires `graphviz` and `pygraphviz` to be installed.
|
||||
:param graph: The graph to draw
|
||||
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
|
||||
"""
|
||||
@ -156,14 +157,32 @@ class PngDrawer:
|
||||
viz.close()
|
||||
|
||||
def add_nodes(self, viz: Any, graph: Graph) -> None:
|
||||
"""Add nodes to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
for node in graph.nodes:
|
||||
self.add_node(viz, node)
|
||||
|
||||
def add_edges(self, viz: Any, graph: Graph) -> None:
|
||||
"""Add edges to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
for start, end, data, cond in graph.edges:
|
||||
self.add_edge(viz, start, end, str(data), cond)
|
||||
|
||||
def update_styles(self, viz: Any, graph: Graph) -> None:
|
||||
"""Update the styles of the entrypoint and END nodes.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
if first := graph.first_node():
|
||||
viz.get_node(first.id).attr.update(fillcolor="lightblue")
|
||||
if last := graph.last_node():
|
||||
|
@ -45,13 +45,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history for it; it is responsible for reading and updating the chat message
|
||||
history.
|
||||
|
||||
The formats supports for the inputs and outputs of the wrapped Runnable
|
||||
The formats supported for the inputs and outputs of the wrapped Runnable
|
||||
are described below.
|
||||
|
||||
RunnableWithMessageHistory must always be called with a config that contains
|
||||
the appropriate parameters for the chat message history factory.
|
||||
|
||||
By default the Runnable is expected to take a single configuration parameter
|
||||
By default, the Runnable is expected to take a single configuration parameter
|
||||
called `session_id` which is a string. This parameter is used to create a new
|
||||
or look up an existing chat message history that matches the given session_id.
|
||||
|
||||
@ -70,6 +70,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
For production use cases, you will want to use a persistent implementation
|
||||
of chat message history, such as ``RedisChatMessageHistory``.
|
||||
|
||||
Parameters:
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory.
|
||||
This function should either take a single positional argument
|
||||
`session_id` of type string and return a corresponding
|
||||
chat message history instance.
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input. The key in the input dict that contains the messages.
|
||||
output_messages_key: Must be specified if the base Runnable returns a dict
|
||||
as output. The key in the output dict that contains the messages.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
history_factory_config: Configure fields that should be passed to the
|
||||
chat history factory. See ``ConfigurableFieldSpec`` for more details.
|
||||
|
||||
Example: Chat message history with an in-memory implementation for testing.
|
||||
|
||||
@ -287,9 +300,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
...
|
||||
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input.
|
||||
as input. Default is None.
|
||||
output_messages_key: Must be specified if the base runnable returns a dict
|
||||
as output.
|
||||
as output. Default is None.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
history_factory_config: Configure fields that should be passed to the
|
||||
@ -347,6 +360,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableWithMessageHistory."""
|
||||
return get_unique_config_specs(
|
||||
super().config_specs + list(self.history_factory_config)
|
||||
)
|
||||
|
@ -53,19 +53,33 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def identity(x: Other) -> Other:
|
||||
"""Identity function"""
|
||||
"""Identity function.
|
||||
|
||||
Args:
|
||||
x (Other): input.
|
||||
|
||||
Returns:
|
||||
Other: output.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
async def aidentity(x: Other) -> Other:
|
||||
"""Async identity function"""
|
||||
"""Async identity function.
|
||||
|
||||
Args:
|
||||
x (Other): input.
|
||||
|
||||
Returns:
|
||||
Other: output.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""Runnable to passthrough inputs unchanged or with additional keys.
|
||||
|
||||
This runnable behaves almost like the identity function, except that it
|
||||
This Runnable behaves almost like the identity function, except that it
|
||||
can be configured to add additional keys to the output, if the input is a
|
||||
dict.
|
||||
|
||||
@ -73,6 +87,13 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
chains. The chains rely on simple lambdas to make the examples easy to execute
|
||||
and experiment with.
|
||||
|
||||
Parameters:
|
||||
func (Callable[[Other], None], optional): Function to be called with the input.
|
||||
afunc (Callable[[Other], Awaitable[None]], optional): Async function to
|
||||
be called with the input.
|
||||
input_type (Optional[Type[Other]], optional): Type of the input.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@ -199,10 +220,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
mapping: A mapping from keys to runnables or callables.
|
||||
**kwargs: Runnable, Callable or a Mapping from keys to Runnables
|
||||
or Callables.
|
||||
|
||||
Returns:
|
||||
A runnable that merges the Dict input with the output produced by the
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
@ -336,6 +358,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
these with the original data, introducing new key-value pairs based
|
||||
on the mapper's logic.
|
||||
|
||||
Parameters:
|
||||
mapper (RunnableParallel[Dict[str, Any]]): A `RunnableParallel` instance
|
||||
that will be used to transform the input dictionary.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@ -627,11 +653,15 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""Runnable that picks keys from Dict[str, Any] inputs.
|
||||
|
||||
RunnablePick class represents a runnable that selectively picks keys from a
|
||||
RunnablePick class represents a Runnable that selectively picks keys from a
|
||||
dictionary input. It allows you to specify one or more keys to extract
|
||||
from the input dictionary. It returns a new dictionary containing only
|
||||
the selected keys.
|
||||
|
||||
Parameters:
|
||||
keys (Union[str, List[str]]): A single key or a list of keys to pick from
|
||||
the input dictionary.
|
||||
|
||||
Example :
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -112,7 +112,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
"""Whether to add jitter to the exponential backoff."""
|
||||
|
||||
max_attempt_number: int = 3
|
||||
"""The maximum number of attempts to retry the runnable."""
|
||||
"""The maximum number of attempts to retry the Runnable."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
|
@ -38,7 +38,7 @@ class RouterInput(TypedDict):
|
||||
|
||||
Attributes:
|
||||
key: The key to route on.
|
||||
input: The input to pass to the selected runnable.
|
||||
input: The input to pass to the selected Runnable.
|
||||
"""
|
||||
|
||||
key: str
|
||||
@ -50,6 +50,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
Runnable that routes to a set of Runnables based on Input['key'].
|
||||
Returns the output of the selected Runnable.
|
||||
|
||||
Parameters:
|
||||
runnables: A mapping of keys to Runnables.
|
||||
|
||||
For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Module contains typedefs that are used with runnables."""
|
||||
"""Module contains typedefs that are used with Runnables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -11,7 +11,7 @@ class EventData(TypedDict, total=False):
|
||||
"""Data associated with a streaming event."""
|
||||
|
||||
input: Any
|
||||
"""The input passed to the runnable that generated the event.
|
||||
"""The input passed to the Runnable that generated the event.
|
||||
|
||||
Inputs will sometimes be available at the *START* of the Runnable, and
|
||||
sometimes at the *END* of the Runnable.
|
||||
@ -85,40 +85,43 @@ class BaseStreamEvent(TypedDict):
|
||||
event: str
|
||||
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
||||
|
||||
Runnable types are one of:
|
||||
* llm - used by non chat models
|
||||
* chat_model - used by chat models
|
||||
* prompt -- e.g., ChatPromptTemplate
|
||||
* tool -- from tools defined via @tool decorator or inheriting from Tool/BaseTool
|
||||
* chain - most Runnables are of this type
|
||||
Runnable types are one of:
|
||||
|
||||
- **llm** - used by non chat models
|
||||
- **chat_model** - used by chat models
|
||||
- **prompt** -- e.g., ChatPromptTemplate
|
||||
- **tool** -- from tools defined via @tool decorator or inheriting
|
||||
from Tool/BaseTool
|
||||
- **chain** - most Runnables are of this type
|
||||
|
||||
Further, the events are categorized as one of:
|
||||
* start - when the runnable starts
|
||||
* stream - when the runnable is streaming
|
||||
* end - when the runnable ends
|
||||
|
||||
- **start** - when the Runnable starts
|
||||
- **stream** - when the Runnable is streaming
|
||||
- **end* - when the Runnable ends
|
||||
|
||||
start, stream and end are associated with slightly different `data` payload.
|
||||
|
||||
Please see the documentation for `EventData` for more details.
|
||||
"""
|
||||
run_id: str
|
||||
"""An randomly generated ID to keep track of the execution of the given runnable.
|
||||
"""An randomly generated ID to keep track of the execution of the given Runnable.
|
||||
|
||||
Each child runnable that gets invoked as part of the execution of a parent runnable
|
||||
Each child Runnable that gets invoked as part of the execution of a parent Runnable
|
||||
is assigned its own unique ID.
|
||||
"""
|
||||
tags: NotRequired[List[str]]
|
||||
"""Tags associated with the runnable that generated this event.
|
||||
"""Tags associated with the Runnable that generated this event.
|
||||
|
||||
Tags are always inherited from parent runnables.
|
||||
Tags are always inherited from parent Runnables.
|
||||
|
||||
Tags can either be bound to a runnable using `.with_config({"tags": ["hello"]})`
|
||||
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
|
||||
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||
"""
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
"""Metadata associated with the runnable that generated this event.
|
||||
"""Metadata associated with the Runnable that generated this event.
|
||||
|
||||
Metadata can either be bound to a runnable using
|
||||
Metadata can either be bound to a Runnable using
|
||||
|
||||
`.with_config({"metadata": { "foo": "bar" }})`
|
||||
|
||||
@ -150,21 +153,20 @@ class StandardStreamEvent(BaseStreamEvent):
|
||||
The contents of the event data depend on the event type.
|
||||
"""
|
||||
name: str
|
||||
"""The name of the runnable that generated the event."""
|
||||
"""The name of the Runnable that generated the event."""
|
||||
|
||||
|
||||
class CustomStreamEvent(BaseStreamEvent):
|
||||
"""A custom stream event created by the user.
|
||||
"""Custom stream event created by the user.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
# Overwrite the event field to be more specific.
|
||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||
|
||||
"""The event type."""
|
||||
name: str
|
||||
"""A user defined name for the event."""
|
||||
"""User defined name for the event."""
|
||||
data: Any
|
||||
"""The data associated with the event. Free form and can be anything."""
|
||||
|
||||
|
@ -43,6 +43,7 @@ Output = TypeVar("Output", covariant=True)
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
"""Run a coroutine with a semaphore.
|
||||
|
||||
Args:
|
||||
semaphore: The semaphore to use.
|
||||
coro: The coroutine to run.
|
||||
@ -59,7 +60,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
|
||||
Args:
|
||||
n: The number of coroutines to run concurrently.
|
||||
coros: The coroutines to run.
|
||||
*coros: The coroutines to run.
|
||||
|
||||
Returns:
|
||||
The results of the coroutines.
|
||||
@ -73,7 +74,14 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a run_manager argument."""
|
||||
"""Check if a callable accepts a run_manager argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a run_manager argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
except ValueError:
|
||||
@ -81,7 +89,14 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a config argument."""
|
||||
"""Check if a callable accepts a config argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a config argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
@ -89,7 +104,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a context argument."""
|
||||
"""Check if a callable accepts a context argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a context argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("context") is not None
|
||||
except ValueError:
|
||||
@ -100,10 +122,24 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
"""Initialize the visitor.
|
||||
|
||||
Args:
|
||||
name: The name to check.
|
||||
keys: The keys to populate.
|
||||
"""
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
"""Visit a subscript node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if (
|
||||
isinstance(node.ctx, ast.Load)
|
||||
and isinstance(node.value, ast.Name)
|
||||
@ -115,6 +151,14 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
self.keys.add(node.slice.value)
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
"""Visit a call node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and isinstance(node.func.value, ast.Name)
|
||||
@ -135,18 +179,42 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
self.keys: Set[str] = set()
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
@ -161,12 +229,28 @@ class NonLocals(ast.NodeVisitor):
|
||||
self.stores: Set[str] = set()
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""Visit a name node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.loads.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.stores.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
"""Visit an attribute node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
parent = node.value
|
||||
attr_expr = node.attr
|
||||
@ -185,16 +269,40 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
self.nonlocals: Set[str] = set()
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
@ -209,14 +317,29 @@ class GetLambdaSource(ast.NodeVisitor):
|
||||
self.count = 0
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function."""
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
self.count += 1
|
||||
if hasattr(ast, "unparse"):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
"""Get the keys of the first argument of a function if it is a dict."""
|
||||
"""Get the keys of the first argument of a function if it is a dict.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]: The keys of the first argument if it is a dict,
|
||||
None otherwise.
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
@ -231,10 +354,10 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
"""Get the source code of a lambda function.
|
||||
|
||||
Args:
|
||||
func: a callable that can be a lambda function
|
||||
func: a Callable that can be a lambda function.
|
||||
|
||||
Returns:
|
||||
str: the source code of the lambda function
|
||||
str: the source code of the lambda function.
|
||||
"""
|
||||
try:
|
||||
name = func.__name__ if func.__name__ != "<lambda>" else None
|
||||
@ -251,7 +374,14 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
|
||||
|
||||
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
"""Get the nonlocal variables accessed by a function."""
|
||||
"""Get the nonlocal variables accessed by a function.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
List[Any]: The nonlocal variables accessed by the function.
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
@ -283,11 +413,11 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
"""Indent all lines of text after the first line.
|
||||
|
||||
Args:
|
||||
text: The text to indent
|
||||
prefix: Used to determine the number of spaces to indent
|
||||
text: The text to indent.
|
||||
prefix: Used to determine the number of spaces to indent.
|
||||
|
||||
Returns:
|
||||
str: The indented text
|
||||
str: The indented text.
|
||||
"""
|
||||
n_spaces = len(prefix)
|
||||
spaces = " " * n_spaces
|
||||
@ -341,7 +471,14 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
||||
|
||||
|
||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
"""Add a sequence of addable objects together."""
|
||||
"""Add a sequence of addable objects together.
|
||||
|
||||
Args:
|
||||
addables: The addable objects to add.
|
||||
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
@ -352,7 +489,14 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
"""Asynchronously add a sequence of addable objects together."""
|
||||
"""Asynchronously add a sequence of addable objects together.
|
||||
|
||||
Args:
|
||||
addables: The addable objects to add.
|
||||
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
@ -363,7 +507,15 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
"""Field that can be configured by the user."""
|
||||
"""Field that can be configured by the user.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
annotation: The annotation of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
@ -377,7 +529,16 @@ class ConfigurableField(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldSingleOption(NamedTuple):
|
||||
"""Field that can be configured by the user with a default value."""
|
||||
"""Field that can be configured by the user with a default value.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
options: The options for the field.
|
||||
default: The default value for the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
@ -392,7 +553,16 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldMultiOption(NamedTuple):
|
||||
"""Field that can be configured by the user with multiple default values."""
|
||||
"""Field that can be configured by the user with multiple default values.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
options: The options for the field.
|
||||
default: The default values for the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
@ -412,7 +582,17 @@ AnyConfigurableField = Union[
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
"""Field that can be configured by the user. It is a specification of a field."""
|
||||
"""Field that can be configured by the user. It is a specification of a field.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
annotation: The annotation of the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
default: The default value for the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
dependencies: The dependencies of the field. Defaults to None.
|
||||
"""
|
||||
|
||||
id: str
|
||||
annotation: Any
|
||||
@ -427,7 +607,17 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
"""Get the unique config specs from a sequence of config specs.
|
||||
|
||||
Args:
|
||||
specs: The config specs.
|
||||
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The unique config specs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the runnable sequence contains conflicting config specs.
|
||||
"""
|
||||
grouped = groupby(
|
||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||
)
|
||||
@ -542,7 +732,15 @@ def _create_model_cached(
|
||||
def is_async_generator(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||
"""Check if a function is an async generator."""
|
||||
"""Check if a function is an async generator.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
||||
an async generator, False otherwise.
|
||||
"""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
@ -553,7 +751,15 @@ def is_async_generator(
|
||||
def is_async_callable(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||
"""Check if a function is async."""
|
||||
"""Check if a function is async.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
||||
False otherwise.
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
|
Loading…
Reference in New Issue
Block a user