diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 34f6d1c8af4..07b2e935225 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -35,7 +35,7 @@ from typing_extensions import Literal, get_args from langchain_core.load.dump import dumpd, dumps from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model from langchain_core.runnables.config import ( RunnableConfig, acall_func_with_variable_args, @@ -48,6 +48,7 @@ from langchain_core.runnables.config import ( merge_configs, patch_config, ) +from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, AnyConfigurableField, @@ -59,6 +60,7 @@ from langchain_core.runnables.utils import ( accepts_run_manager, gather_with_concurrency, get_function_first_arg_dict_keys, + get_function_nonlocals, get_lambda_source, get_unique_config_specs, indent_lines_after_first, @@ -74,7 +76,6 @@ if TYPE_CHECKING: from langchain_core.runnables.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) - from langchain_core.runnables.graph import Graph from langchain_core.tracers.log_stream import RunLog, RunLogPatch from langchain_core.tracers.root_listeners import Listener @@ -82,6 +83,10 @@ if TYPE_CHECKING: Other = TypeVar("Other") +class _SchemaConfig(BaseConfig): + arbitrary_types_allowed = True + + class Runnable(Generic[Input, Output], ABC): """A unit of work that can be invoked, batched, streamed, transformed and composed. @@ -266,7 +271,9 @@ class Runnable(Generic[Input, Output], ABC): return root_type return create_model( - self.__class__.__name__ + "Input", __root__=(root_type, None) + self.__class__.__name__ + "Input", + __root__=(root_type, None), + __config__=_SchemaConfig, ) @property @@ -297,7 +304,9 @@ class Runnable(Generic[Input, Output], ABC): return root_type return create_model( - self.__class__.__name__ + "Output", __root__=(root_type, None) + self.__class__.__name__ + "Output", + __root__=(root_type, None), + __config__=_SchemaConfig, ) @property @@ -320,9 +329,6 @@ class Runnable(Generic[Input, Output], ABC): A pydantic model that can be used to validate config. """ - class _Config: - arbitrary_types_allowed = True - include = include or [] config_specs = self.config_specs configurable = ( @@ -337,6 +343,7 @@ class Runnable(Generic[Input, Output], ABC): ) for spec in config_specs }, + __config__=_SchemaConfig, ) if config_specs else None @@ -344,7 +351,7 @@ class Runnable(Generic[Input, Output], ABC): return create_model( # type: ignore[call-overload] self.__class__.__name__ + "Config", - __config__=_Config, + __config__=_SchemaConfig, **({"configurable": (configurable, None)} if configurable else {}), **{ field_name: (field_type, None) @@ -1405,6 +1412,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): for k, v in next_input_schema.__fields__.items() if k not in first.mapper.steps }, + __config__=_SchemaConfig, ) return self.first.get_input_schema(config) @@ -2006,6 +2014,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): for k, v in step.get_input_schema(config).__fields__.items() if k != "__root__" }, + __config__=_SchemaConfig, ) return super().get_input_schema(config) @@ -2017,6 +2026,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): return create_model( # type: ignore[call-overload] "RunnableParallelOutput", **{k: (v.OutputType, None) for k, v in self.steps.items()}, + __config__=_SchemaConfig, ) @property @@ -2548,9 +2558,14 @@ class RunnableLambda(Runnable[Input, Output]): return create_model( "RunnableLambdaInput", **{item[1:-1]: (Any, None) for item in items}, # type: ignore + __config__=_SchemaConfig, ) else: - return create_model("RunnableLambdaInput", __root__=(List[Any], None)) + return create_model( + "RunnableLambdaInput", + __root__=(List[Any], None), + __config__=_SchemaConfig, + ) if self.InputType != Any: return super().get_input_schema(config) @@ -2559,6 +2574,7 @@ class RunnableLambda(Runnable[Input, Output]): return create_model( "RunnableLambdaInput", **{key: (Any, None) for key in dict_keys}, # type: ignore + __config__=_SchemaConfig, ) return super().get_input_schema(config) @@ -2577,6 +2593,50 @@ class RunnableLambda(Runnable[Input, Output]): except ValueError: return Any + @property + def deps(self) -> List[Runnable]: + """The dependencies of this runnable.""" + if hasattr(self, "func"): + objects = get_function_nonlocals(self.func) + elif hasattr(self, "afunc"): + objects = get_function_nonlocals(self.afunc) + else: + objects = [] + + return [obj for obj in objects if isinstance(obj, Runnable)] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return get_unique_config_specs( + spec for dep in self.deps for spec in dep.config_specs + ) + + def get_graph(self, config: RunnableConfig | None = None) -> Graph: + if deps := self.deps: + graph = Graph() + input_node = graph.add_node(self.get_input_schema(config)) + output_node = graph.add_node(self.get_output_schema(config)) + for dep in deps: + dep_graph = dep.get_graph() + dep_graph.trim_first_node() + dep_graph.trim_last_node() + if not dep_graph: + graph.add_edge(input_node, output_node) + else: + graph.extend(dep_graph) + dep_first_node = dep_graph.first_node() + if not dep_first_node: + raise ValueError(f"Runnable {dep} has no first node") + dep_last_node = dep_graph.last_node() + if not dep_last_node: + raise ValueError(f"Runnable {dep} has no last node") + graph.add_edge(input_node, dep_first_node) + graph.add_edge(dep_last_node, output_node) + else: + graph = super().get_graph(config) + + return graph + def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableLambda): if hasattr(self, "func") and hasattr(other, "func"): @@ -2740,6 +2800,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): List[self.bound.get_input_schema(config)], # type: ignore None, ), + __config__=_SchemaConfig, ) @property @@ -2756,12 +2817,16 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): List[schema], # type: ignore None, ), + __config__=_SchemaConfig, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return self.bound.config_specs + def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + return self.bound.get_graph(config) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -2973,6 +3038,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): def config_specs(self) -> List[ConfigurableFieldSpec]: return self.bound.config_specs + def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + return self.bound.get_graph(config) + @classmethod def is_lc_serializable(cls) -> bool: return True diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index d4a55946fba..f7ad523ca55 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -26,6 +26,7 @@ from langchain_core.runnables.config import ( get_config_list, get_executor_for_config, ) +from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AnyConfigurableField, ConfigurableField, @@ -76,6 +77,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): runnable, config = self._prepare(config) return runnable.get_output_schema(config) + def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + runnable, config = self._prepare(config) + return runnable.get_graph(config) + @abstractmethod def _prepare( self, config: Optional[RunnableConfig] = None diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 5de27104cb6..211b44b1f1e 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -1,13 +1,15 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Dict, List, NamedTuple, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union from uuid import uuid4 from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables.base import Runnable from langchain_core.runnables.graph_draw import draw +if TYPE_CHECKING: + from langchain_core.runnables.base import Runnable as RunnableType + class Edge(NamedTuple): source: str @@ -16,7 +18,7 @@ class Edge(NamedTuple): class Node(NamedTuple): id: str - data: Union[Type[BaseModel], Runnable] + data: Union[Type[BaseModel], RunnableType] @dataclass @@ -30,7 +32,7 @@ class Graph: def next_id(self) -> str: return uuid4().hex - def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node: + def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node: """Add a node to the graph and return it.""" node = Node(id=self.next_id(), data=data) self.nodes[node.id] = node @@ -109,6 +111,8 @@ class Graph: self.remove_node(last_node) def draw_ascii(self) -> str: + from langchain_core.runnables.base import Runnable + def node_data(node: Node) -> str: if isinstance(node.data, Runnable): try: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 90401df9d7f..7b523e4e0f8 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -119,6 +119,53 @@ class IsFunctionArgDict(ast.NodeVisitor): IsLocalDict(input_arg_name, self.keys).visit(node) +class NonLocals(ast.NodeVisitor): + """Get nonlocal variables accessed.""" + + def __init__(self) -> None: + self.loads: Set[str] = set() + self.stores: Set[str] = set() + + def visit_Name(self, node: ast.Name) -> Any: + if isinstance(node.ctx, ast.Load): + self.loads.add(node.id) + elif isinstance(node.ctx, ast.Store): + self.stores.add(node.id) + + def visit_Attribute(self, node: ast.Attribute) -> Any: + if isinstance(node.ctx, ast.Load): + parent = node.value + attr_expr = node.attr + while isinstance(parent, ast.Attribute): + attr_expr = parent.attr + "." + attr_expr + parent = parent.value + if isinstance(parent, ast.Name): + self.loads.add(parent.id + "." + attr_expr) + self.loads.discard(parent.id) + + +class FunctionNonLocals(ast.NodeVisitor): + """Get the nonlocal variables accessed of a function.""" + + def __init__(self) -> None: + self.nonlocals: Set[str] = set() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + visitor = NonLocals() + visitor.visit(node) + self.nonlocals.update(visitor.loads - visitor.stores) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: + visitor = NonLocals() + visitor.visit(node) + self.nonlocals.update(visitor.loads - visitor.stores) + + def visit_Lambda(self, node: ast.Lambda) -> Any: + visitor = NonLocals() + visitor.visit(node) + self.nonlocals.update(visitor.loads - visitor.stores) + + class GetLambdaSource(ast.NodeVisitor): """Get the source code of a lambda function.""" @@ -169,6 +216,28 @@ def get_lambda_source(func: Callable) -> Optional[str]: return name +def get_function_nonlocals(func: Callable) -> List[Any]: + """Get the nonlocal variables accessed by a function.""" + try: + code = inspect.getsource(func) + tree = ast.parse(textwrap.dedent(code)) + visitor = FunctionNonLocals() + visitor.visit(tree) + values: List[Any] = [] + for k, v in inspect.getclosurevars(func).nonlocals.items(): + if k in visitor.nonlocals: + values.append(v) + for kk in visitor.nonlocals: + if "." in kk and kk.startswith(k): + vv = v + for part in kk.split(".")[1:]: + vv = getattr(vv, part) + values.append(vv) + return values + except (SyntaxError, TypeError, OSError): + return [] + + def indent_lines_after_first(text: str, prefix: str) -> str: """Indent all lines of text after the first line. diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index c3674d22a84..f241ea429fd 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -32,39 +32,51 @@ # --- # name: test_graph_sequence_map ''' - +-------------+ - | PromptInput | - +-------------+ - * - * - * - +----------------+ - | PromptTemplate | - +----------------+ - * - * - * - +-------------+ - | FakeListLLM | - +-------------+ - * - * - * - +-----------------------+ - | RunnableParallelInput | - +-----------------------+ - **** *** - **** **** - ** ** - +---------------------+ +--------------------------------+ - | RunnablePassthrough | | CommaSeparatedListOutputParser | - +---------------------+ +--------------------------------+ - **** *** - **** **** - ** ** - +------------------------+ - | RunnableParallelOutput | - +------------------------+ + +-------------+ + | PromptInput | + +-------------+ + * + * + * + +----------------+ + | PromptTemplate | + +----------------+ + * + * + * + +-------------+ + | FakeListLLM | + +-------------+ + * + * + * + +-----------------------+ + | RunnableParallelInput | + +-----------------------+** + **** ******* + **** ***** + ** ******* + +---------------------+ *** + | RunnableLambdaInput | * + +---------------------+ * + *** *** * + *** *** * + ** ** * + +-----------------+ +-----------------+ * + | StrOutputParser | | XMLOutputParser | * + +-----------------+ +-----------------+ * + *** *** * + *** *** * + ** ** * + +----------------------+ +--------------------------------+ + | RunnableLambdaOutput | | CommaSeparatedListOutputParser | + +----------------------+ +--------------------------------+ + **** ******* + **** ***** + ** **** + +------------------------+ + | RunnableParallelOutput | + +------------------------+ ''' # --- # name: test_graph_single_runnable diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index c0d763ad140..7d84bc2634e 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -2,9 +2,9 @@ from syrupy import SnapshotAssertion from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.string import StrOutputParser +from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables.base import Runnable -from langchain_core.runnables.passthrough import RunnablePassthrough from tests.unit_tests.fake.llm import FakeListLLM @@ -38,13 +38,21 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: fake_llm = FakeListLLM(responses=["a"]) prompt = PromptTemplate.from_template("Hello, {name}!") list_parser = CommaSeparatedListOutputParser() + str_parser = StrOutputParser() + xml_parser = XMLOutputParser() + + def conditional_str_parser(input: str) -> Runnable: + if input == "a": + return str_parser + else: + return xml_parser sequence: Runnable = ( prompt | fake_llm | { - "original": RunnablePassthrough(input_type=str), "as_list": list_parser, + "as_str": conditional_str_parser, } ) graph = sequence.get_graph()