Nc/dec22/runnable graph lambda (#15078)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-22 14:36:46 -08:00 committed by GitHub
parent 59d4b80a92
commit 0d0901ea18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 214 additions and 48 deletions

View File

@ -35,7 +35,7 @@ from typing_extensions import Literal, get_args
from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
@ -48,6 +48,7 @@ from langchain_core.runnables.config import (
merge_configs,
patch_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
AddableDict,
AnyConfigurableField,
@ -59,6 +60,7 @@ from langchain_core.runnables.utils import (
accepts_run_manager,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_function_nonlocals,
get_lambda_source,
get_unique_config_specs,
indent_lines_after_first,
@ -74,7 +76,6 @@ if TYPE_CHECKING:
from langchain_core.runnables.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain_core.runnables.graph import Graph
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.tracers.root_listeners import Listener
@ -82,6 +83,10 @@ if TYPE_CHECKING:
Other = TypeVar("Other")
class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True
class Runnable(Generic[Input, Output], ABC):
"""A unit of work that can be invoked, batched, streamed, transformed and composed.
@ -266,7 +271,9 @@ class Runnable(Generic[Input, Output], ABC):
return root_type
return create_model(
self.__class__.__name__ + "Input", __root__=(root_type, None)
self.__class__.__name__ + "Input",
__root__=(root_type, None),
__config__=_SchemaConfig,
)
@property
@ -297,7 +304,9 @@ class Runnable(Generic[Input, Output], ABC):
return root_type
return create_model(
self.__class__.__name__ + "Output", __root__=(root_type, None)
self.__class__.__name__ + "Output",
__root__=(root_type, None),
__config__=_SchemaConfig,
)
@property
@ -320,9 +329,6 @@ class Runnable(Generic[Input, Output], ABC):
A pydantic model that can be used to validate config.
"""
class _Config:
arbitrary_types_allowed = True
include = include or []
config_specs = self.config_specs
configurable = (
@ -337,6 +343,7 @@ class Runnable(Generic[Input, Output], ABC):
)
for spec in config_specs
},
__config__=_SchemaConfig,
)
if config_specs
else None
@ -344,7 +351,7 @@ class Runnable(Generic[Input, Output], ABC):
return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config",
__config__=_Config,
__config__=_SchemaConfig,
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
@ -1405,6 +1412,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
},
__config__=_SchemaConfig,
)
return self.first.get_input_schema(config)
@ -2006,6 +2014,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
for k, v in step.get_input_schema(config).__fields__.items()
if k != "__root__"
},
__config__=_SchemaConfig,
)
return super().get_input_schema(config)
@ -2017,6 +2026,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return create_model( # type: ignore[call-overload]
"RunnableParallelOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig,
)
@property
@ -2548,9 +2558,14 @@ class RunnableLambda(Runnable[Input, Output]):
return create_model(
"RunnableLambdaInput",
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
__config__=_SchemaConfig,
)
else:
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
return create_model(
"RunnableLambdaInput",
__root__=(List[Any], None),
__config__=_SchemaConfig,
)
if self.InputType != Any:
return super().get_input_schema(config)
@ -2559,6 +2574,7 @@ class RunnableLambda(Runnable[Input, Output]):
return create_model(
"RunnableLambdaInput",
**{key: (Any, None) for key in dict_keys}, # type: ignore
__config__=_SchemaConfig,
)
return super().get_input_schema(config)
@ -2577,6 +2593,50 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError:
return Any
@property
def deps(self) -> List[Runnable]:
"""The dependencies of this runnable."""
if hasattr(self, "func"):
objects = get_function_nonlocals(self.func)
elif hasattr(self, "afunc"):
objects = get_function_nonlocals(self.afunc)
else:
objects = []
return [obj for obj in objects if isinstance(obj, Runnable)]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for dep in self.deps for spec in dep.config_specs
)
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
if deps := self.deps:
graph = Graph()
input_node = graph.add_node(self.get_input_schema(config))
output_node = graph.add_node(self.get_output_schema(config))
for dep in deps:
dep_graph = dep.get_graph()
dep_graph.trim_first_node()
dep_graph.trim_last_node()
if not dep_graph:
graph.add_edge(input_node, output_node)
else:
graph.extend(dep_graph)
dep_first_node = dep_graph.first_node()
if not dep_first_node:
raise ValueError(f"Runnable {dep} has no first node")
dep_last_node = dep_graph.last_node()
if not dep_last_node:
raise ValueError(f"Runnable {dep} has no last node")
graph.add_edge(input_node, dep_first_node)
graph.add_edge(dep_last_node, output_node)
else:
graph = super().get_graph(config)
return graph
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
@ -2740,6 +2800,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
__config__=_SchemaConfig,
)
@property
@ -2756,12 +2817,16 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
List[schema], # type: ignore
None,
),
__config__=_SchemaConfig,
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
return self.bound.get_graph(config)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2973,6 +3038,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
return self.bound.get_graph(config)
@classmethod
def is_lc_serializable(cls) -> bool:
return True

View File

@ -26,6 +26,7 @@ from langchain_core.runnables.config import (
get_config_list,
get_executor_for_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
AnyConfigurableField,
ConfigurableField,
@ -76,6 +77,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
runnable, config = self._prepare(config)
return runnable.get_output_schema(config)
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
runnable, config = self._prepare(config)
return runnable.get_graph(config)
@abstractmethod
def _prepare(
self, config: Optional[RunnableConfig] = None

View File

@ -1,13 +1,15 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, NamedTuple, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.graph_draw import draw
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
class Edge(NamedTuple):
source: str
@ -16,7 +18,7 @@ class Edge(NamedTuple):
class Node(NamedTuple):
id: str
data: Union[Type[BaseModel], Runnable]
data: Union[Type[BaseModel], RunnableType]
@dataclass
@ -30,7 +32,7 @@ class Graph:
def next_id(self) -> str:
return uuid4().hex
def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node:
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
self.nodes[node.id] = node
@ -109,6 +111,8 @@ class Graph:
self.remove_node(last_node)
def draw_ascii(self) -> str:
from langchain_core.runnables.base import Runnable
def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:

View File

@ -119,6 +119,53 @@ class IsFunctionArgDict(ast.NodeVisitor):
IsLocalDict(input_arg_name, self.keys).visit(node)
class NonLocals(ast.NodeVisitor):
"""Get nonlocal variables accessed."""
def __init__(self) -> None:
self.loads: Set[str] = set()
self.stores: Set[str] = set()
def visit_Name(self, node: ast.Name) -> Any:
if isinstance(node.ctx, ast.Load):
self.loads.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.stores.add(node.id)
def visit_Attribute(self, node: ast.Attribute) -> Any:
if isinstance(node.ctx, ast.Load):
parent = node.value
attr_expr = node.attr
while isinstance(parent, ast.Attribute):
attr_expr = parent.attr + "." + attr_expr
parent = parent.value
if isinstance(parent, ast.Name):
self.loads.add(parent.id + "." + attr_expr)
self.loads.discard(parent.id)
class FunctionNonLocals(ast.NodeVisitor):
"""Get the nonlocal variables accessed of a function."""
def __init__(self) -> None:
self.nonlocals: Set[str] = set()
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
def visit_Lambda(self, node: ast.Lambda) -> Any:
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
class GetLambdaSource(ast.NodeVisitor):
"""Get the source code of a lambda function."""
@ -169,6 +216,28 @@ def get_lambda_source(func: Callable) -> Optional[str]:
return name
def get_function_nonlocals(func: Callable) -> List[Any]:
"""Get the nonlocal variables accessed by a function."""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals()
visitor.visit(tree)
values: List[Any] = []
for k, v in inspect.getclosurevars(func).nonlocals.items():
if k in visitor.nonlocals:
values.append(v)
for kk in visitor.nonlocals:
if "." in kk and kk.startswith(k):
vv = v
for part in kk.split(".")[1:]:
vv = getattr(vv, part)
values.append(vv)
return values
except (SyntaxError, TypeError, OSError):
return []
def indent_lines_after_first(text: str, prefix: str) -> str:
"""Indent all lines of text after the first line.

View File

@ -32,39 +32,51 @@
# ---
# name: test_graph_sequence_map
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+-----------------------+
| RunnableParallelInput |
+-----------------------+
**** ***
**** ****
** **
+---------------------+ +--------------------------------+
| RunnablePassthrough | | CommaSeparatedListOutputParser |
+---------------------+ +--------------------------------+
**** ***
**** ****
** **
+------------------------+
| RunnableParallelOutput |
+------------------------+
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+-----------------------+
| RunnableParallelInput |
+-----------------------+**
**** *******
**** *****
** *******
+---------------------+ ***
| RunnableLambdaInput | *
+---------------------+ *
*** *** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
*** *** *
*** *** *
** ** *
+----------------------+ +--------------------------------+
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
+----------------------+ +--------------------------------+
**** *******
**** *****
** ****
+------------------------+
| RunnableParallelOutput |
+------------------------+
'''
# ---
# name: test_graph_single_runnable

View File

@ -2,9 +2,9 @@ from syrupy import SnapshotAssertion
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.passthrough import RunnablePassthrough
from tests.unit_tests.fake.llm import FakeListLLM
@ -38,13 +38,21 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])
prompt = PromptTemplate.from_template("Hello, {name}!")
list_parser = CommaSeparatedListOutputParser()
str_parser = StrOutputParser()
xml_parser = XMLOutputParser()
def conditional_str_parser(input: str) -> Runnable:
if input == "a":
return str_parser
else:
return xml_parser
sequence: Runnable = (
prompt
| fake_llm
| {
"original": RunnablePassthrough(input_type=str),
"as_list": list_parser,
"as_str": conditional_str_parser,
}
)
graph = sequence.get_graph()