mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Nc/dec22/runnable graph lambda (#15078)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
59d4b80a92
commit
0d0901ea18
@ -35,7 +35,7 @@ from typing_extensions import Literal, get_args
|
||||
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
@ -48,6 +48,7 @@ from langchain_core.runnables.config import (
|
||||
merge_configs,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.utils import (
|
||||
AddableDict,
|
||||
AnyConfigurableField,
|
||||
@ -59,6 +60,7 @@ from langchain_core.runnables.utils import (
|
||||
accepts_run_manager,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_function_nonlocals,
|
||||
get_lambda_source,
|
||||
get_unique_config_specs,
|
||||
indent_lines_after_first,
|
||||
@ -74,7 +76,6 @@ if TYPE_CHECKING:
|
||||
from langchain_core.runnables.fallbacks import (
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain_core.tracers.root_listeners import Listener
|
||||
|
||||
@ -82,6 +83,10 @@ if TYPE_CHECKING:
|
||||
Other = TypeVar("Other")
|
||||
|
||||
|
||||
class _SchemaConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Runnable(Generic[Input, Output], ABC):
|
||||
"""A unit of work that can be invoked, batched, streamed, transformed and composed.
|
||||
|
||||
@ -266,7 +271,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Input", __root__=(root_type, None)
|
||||
self.__class__.__name__ + "Input",
|
||||
__root__=(root_type, None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -297,7 +304,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||
self.__class__.__name__ + "Output",
|
||||
__root__=(root_type, None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -320,9 +329,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
A pydantic model that can be used to validate config.
|
||||
"""
|
||||
|
||||
class _Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
include = include or []
|
||||
config_specs = self.config_specs
|
||||
configurable = (
|
||||
@ -337,6 +343,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
for spec in config_specs
|
||||
},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
if config_specs
|
||||
else None
|
||||
@ -344,7 +351,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.__class__.__name__ + "Config",
|
||||
__config__=_Config,
|
||||
__config__=_SchemaConfig,
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
@ -1405,6 +1412,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
if k not in first.mapper.steps
|
||||
},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
return self.first.get_input_schema(config)
|
||||
@ -2006,6 +2014,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
for k, v in step.get_input_schema(config).__fields__.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@ -2017,6 +2026,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableParallelOutput",
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -2548,9 +2558,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
else:
|
||||
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
__root__=(List[Any], None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
return super().get_input_schema(config)
|
||||
@ -2559,6 +2574,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@ -2577,6 +2593,50 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
@property
|
||||
def deps(self) -> List[Runnable]:
|
||||
"""The dependencies of this runnable."""
|
||||
if hasattr(self, "func"):
|
||||
objects = get_function_nonlocals(self.func)
|
||||
elif hasattr(self, "afunc"):
|
||||
objects = get_function_nonlocals(self.afunc)
|
||||
else:
|
||||
objects = []
|
||||
|
||||
return [obj for obj in objects if isinstance(obj, Runnable)]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for dep in self.deps for spec in dep.config_specs
|
||||
)
|
||||
|
||||
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||
if deps := self.deps:
|
||||
graph = Graph()
|
||||
input_node = graph.add_node(self.get_input_schema(config))
|
||||
output_node = graph.add_node(self.get_output_schema(config))
|
||||
for dep in deps:
|
||||
dep_graph = dep.get_graph()
|
||||
dep_graph.trim_first_node()
|
||||
dep_graph.trim_last_node()
|
||||
if not dep_graph:
|
||||
graph.add_edge(input_node, output_node)
|
||||
else:
|
||||
graph.extend(dep_graph)
|
||||
dep_first_node = dep_graph.first_node()
|
||||
if not dep_first_node:
|
||||
raise ValueError(f"Runnable {dep} has no first node")
|
||||
dep_last_node = dep_graph.last_node()
|
||||
if not dep_last_node:
|
||||
raise ValueError(f"Runnable {dep} has no last node")
|
||||
graph.add_edge(input_node, dep_first_node)
|
||||
graph.add_edge(dep_last_node, output_node)
|
||||
else:
|
||||
graph = super().get_graph(config)
|
||||
|
||||
return graph
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
if hasattr(self, "func") and hasattr(other, "func"):
|
||||
@ -2740,6 +2800,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
List[self.bound.get_input_schema(config)], # type: ignore
|
||||
None,
|
||||
),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -2756,12 +2817,16 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
return self.bound.get_graph(config)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -2973,6 +3038,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
return self.bound.get_graph(config)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
@ -26,6 +26,7 @@ from langchain_core.runnables.config import (
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.utils import (
|
||||
AnyConfigurableField,
|
||||
ConfigurableField,
|
||||
@ -76,6 +77,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.get_output_schema(config)
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.get_graph(config)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
|
@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, NamedTuple, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.graph_draw import draw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
source: str
|
||||
@ -16,7 +18,7 @@ class Edge(NamedTuple):
|
||||
|
||||
class Node(NamedTuple):
|
||||
id: str
|
||||
data: Union[Type[BaseModel], Runnable]
|
||||
data: Union[Type[BaseModel], RunnableType]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -30,7 +32,7 @@ class Graph:
|
||||
def next_id(self) -> str:
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node:
|
||||
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
node = Node(id=self.next_id(), data=data)
|
||||
self.nodes[node.id] = node
|
||||
@ -109,6 +111,8 @@ class Graph:
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
def node_data(node: Node) -> str:
|
||||
if isinstance(node.data, Runnable):
|
||||
try:
|
||||
|
@ -119,6 +119,53 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
|
||||
class NonLocals(ast.NodeVisitor):
|
||||
"""Get nonlocal variables accessed."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.loads: Set[str] = set()
|
||||
self.stores: Set[str] = set()
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.loads.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.stores.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
parent = node.value
|
||||
attr_expr = node.attr
|
||||
while isinstance(parent, ast.Attribute):
|
||||
attr_expr = parent.attr + "." + attr_expr
|
||||
parent = parent.value
|
||||
if isinstance(parent, ast.Name):
|
||||
self.loads.add(parent.id + "." + attr_expr)
|
||||
self.loads.discard(parent.id)
|
||||
|
||||
|
||||
class FunctionNonLocals(ast.NodeVisitor):
|
||||
"""Get the nonlocal variables accessed of a function."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nonlocals: Set[str] = set()
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
|
||||
class GetLambdaSource(ast.NodeVisitor):
|
||||
"""Get the source code of a lambda function."""
|
||||
|
||||
@ -169,6 +216,28 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
return name
|
||||
|
||||
|
||||
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
"""Get the nonlocal variables accessed by a function."""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = FunctionNonLocals()
|
||||
visitor.visit(tree)
|
||||
values: List[Any] = []
|
||||
for k, v in inspect.getclosurevars(func).nonlocals.items():
|
||||
if k in visitor.nonlocals:
|
||||
values.append(v)
|
||||
for kk in visitor.nonlocals:
|
||||
if "." in kk and kk.startswith(k):
|
||||
vv = v
|
||||
for part in kk.split(".")[1:]:
|
||||
vv = getattr(vv, part)
|
||||
values.append(vv)
|
||||
return values
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return []
|
||||
|
||||
|
||||
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
"""Indent all lines of text after the first line.
|
||||
|
||||
|
@ -32,39 +32,51 @@
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| RunnableParallelInput |
|
||||
+-----------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+---------------------+ +--------------------------------+
|
||||
| RunnablePassthrough | | CommaSeparatedListOutputParser |
|
||||
+---------------------+ +--------------------------------+
|
||||
**** ***
|
||||
**** ****
|
||||
** **
|
||||
+------------------------+
|
||||
| RunnableParallelOutput |
|
||||
+------------------------+
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| RunnableParallelInput |
|
||||
+-----------------------+**
|
||||
**** *******
|
||||
**** *****
|
||||
** *******
|
||||
+---------------------+ ***
|
||||
| RunnableLambdaInput | *
|
||||
+---------------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+----------------------+ +--------------------------------+
|
||||
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
|
||||
+----------------------+ +--------------------------------+
|
||||
**** *******
|
||||
**** *****
|
||||
** ****
|
||||
+------------------------+
|
||||
| RunnableParallelOutput |
|
||||
+------------------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
|
@ -2,9 +2,9 @@ from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
@ -38,13 +38,21 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["a"])
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
str_parser = StrOutputParser()
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
def conditional_str_parser(input: str) -> Runnable:
|
||||
if input == "a":
|
||||
return str_parser
|
||||
else:
|
||||
return xml_parser
|
||||
|
||||
sequence: Runnable = (
|
||||
prompt
|
||||
| fake_llm
|
||||
| {
|
||||
"original": RunnablePassthrough(input_type=str),
|
||||
"as_list": list_parser,
|
||||
"as_str": conditional_str_parser,
|
||||
}
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
|
Loading…
Reference in New Issue
Block a user