mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
Nc/dec22/runnable graph lambda (#15078)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
59d4b80a92
commit
0d0901ea18
@ -35,7 +35,7 @@ from typing_extensions import Literal, get_args
|
|||||||
|
|
||||||
from langchain_core.load.dump import dumpd, dumps
|
from langchain_core.load.dump import dumpd, dumps
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
acall_func_with_variable_args,
|
acall_func_with_variable_args,
|
||||||
@ -48,6 +48,7 @@ from langchain_core.runnables.config import (
|
|||||||
merge_configs,
|
merge_configs,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
AddableDict,
|
AddableDict,
|
||||||
AnyConfigurableField,
|
AnyConfigurableField,
|
||||||
@ -59,6 +60,7 @@ from langchain_core.runnables.utils import (
|
|||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
|
get_function_nonlocals,
|
||||||
get_lambda_source,
|
get_lambda_source,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
indent_lines_after_first,
|
indent_lines_after_first,
|
||||||
@ -74,7 +76,6 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables.fallbacks import (
|
from langchain_core.runnables.fallbacks import (
|
||||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
|
||||||
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
|
||||||
from langchain_core.tracers.root_listeners import Listener
|
from langchain_core.tracers.root_listeners import Listener
|
||||||
|
|
||||||
@ -82,6 +83,10 @@ if TYPE_CHECKING:
|
|||||||
Other = TypeVar("Other")
|
Other = TypeVar("Other")
|
||||||
|
|
||||||
|
|
||||||
|
class _SchemaConfig(BaseConfig):
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class Runnable(Generic[Input, Output], ABC):
|
class Runnable(Generic[Input, Output], ABC):
|
||||||
"""A unit of work that can be invoked, batched, streamed, transformed and composed.
|
"""A unit of work that can be invoked, batched, streamed, transformed and composed.
|
||||||
|
|
||||||
@ -266,7 +271,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model(
|
return create_model(
|
||||||
self.__class__.__name__ + "Input", __root__=(root_type, None)
|
self.__class__.__name__ + "Input",
|
||||||
|
__root__=(root_type, None),
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -297,7 +304,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model(
|
return create_model(
|
||||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
self.__class__.__name__ + "Output",
|
||||||
|
__root__=(root_type, None),
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -320,9 +329,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
A pydantic model that can be used to validate config.
|
A pydantic model that can be used to validate config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
include = include or []
|
include = include or []
|
||||||
config_specs = self.config_specs
|
config_specs = self.config_specs
|
||||||
configurable = (
|
configurable = (
|
||||||
@ -337,6 +343,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
for spec in config_specs
|
for spec in config_specs
|
||||||
},
|
},
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
if config_specs
|
if config_specs
|
||||||
else None
|
else None
|
||||||
@ -344,7 +351,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
self.__class__.__name__ + "Config",
|
self.__class__.__name__ + "Config",
|
||||||
__config__=_Config,
|
__config__=_SchemaConfig,
|
||||||
**({"configurable": (configurable, None)} if configurable else {}),
|
**({"configurable": (configurable, None)} if configurable else {}),
|
||||||
**{
|
**{
|
||||||
field_name: (field_type, None)
|
field_name: (field_type, None)
|
||||||
@ -1405,6 +1412,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
for k, v in next_input_schema.__fields__.items()
|
for k, v in next_input_schema.__fields__.items()
|
||||||
if k not in first.mapper.steps
|
if k not in first.mapper.steps
|
||||||
},
|
},
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.first.get_input_schema(config)
|
return self.first.get_input_schema(config)
|
||||||
@ -2006,6 +2014,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
for k, v in step.get_input_schema(config).__fields__.items()
|
for k, v in step.get_input_schema(config).__fields__.items()
|
||||||
if k != "__root__"
|
if k != "__root__"
|
||||||
},
|
},
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
@ -2017,6 +2026,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableParallelOutput",
|
"RunnableParallelOutput",
|
||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -2548,9 +2558,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return create_model(
|
return create_model(
|
||||||
"RunnableLambdaInput",
|
"RunnableLambdaInput",
|
||||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
return create_model(
|
||||||
|
"RunnableLambdaInput",
|
||||||
|
__root__=(List[Any], None),
|
||||||
|
__config__=_SchemaConfig,
|
||||||
|
)
|
||||||
|
|
||||||
if self.InputType != Any:
|
if self.InputType != Any:
|
||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
@ -2559,6 +2574,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return create_model(
|
return create_model(
|
||||||
"RunnableLambdaInput",
|
"RunnableLambdaInput",
|
||||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
@ -2577,6 +2593,50 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deps(self) -> List[Runnable]:
|
||||||
|
"""The dependencies of this runnable."""
|
||||||
|
if hasattr(self, "func"):
|
||||||
|
objects = get_function_nonlocals(self.func)
|
||||||
|
elif hasattr(self, "afunc"):
|
||||||
|
objects = get_function_nonlocals(self.afunc)
|
||||||
|
else:
|
||||||
|
objects = []
|
||||||
|
|
||||||
|
return [obj for obj in objects if isinstance(obj, Runnable)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
|
return get_unique_config_specs(
|
||||||
|
spec for dep in self.deps for spec in dep.config_specs
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||||
|
if deps := self.deps:
|
||||||
|
graph = Graph()
|
||||||
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
|
output_node = graph.add_node(self.get_output_schema(config))
|
||||||
|
for dep in deps:
|
||||||
|
dep_graph = dep.get_graph()
|
||||||
|
dep_graph.trim_first_node()
|
||||||
|
dep_graph.trim_last_node()
|
||||||
|
if not dep_graph:
|
||||||
|
graph.add_edge(input_node, output_node)
|
||||||
|
else:
|
||||||
|
graph.extend(dep_graph)
|
||||||
|
dep_first_node = dep_graph.first_node()
|
||||||
|
if not dep_first_node:
|
||||||
|
raise ValueError(f"Runnable {dep} has no first node")
|
||||||
|
dep_last_node = dep_graph.last_node()
|
||||||
|
if not dep_last_node:
|
||||||
|
raise ValueError(f"Runnable {dep} has no last node")
|
||||||
|
graph.add_edge(input_node, dep_first_node)
|
||||||
|
graph.add_edge(dep_last_node, output_node)
|
||||||
|
else:
|
||||||
|
graph = super().get_graph(config)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if isinstance(other, RunnableLambda):
|
if isinstance(other, RunnableLambda):
|
||||||
if hasattr(self, "func") and hasattr(other, "func"):
|
if hasattr(self, "func") and hasattr(other, "func"):
|
||||||
@ -2740,6 +2800,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
List[self.bound.get_input_schema(config)], # type: ignore
|
List[self.bound.get_input_schema(config)], # type: ignore
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -2756,12 +2817,16 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
List[schema], # type: ignore
|
List[schema], # type: ignore
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
|
return self.bound.get_graph(config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -2973,6 +3038,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
|
return self.bound.get_graph(config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -26,6 +26,7 @@ from langchain_core.runnables.config import (
|
|||||||
get_config_list,
|
get_config_list,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
AnyConfigurableField,
|
AnyConfigurableField,
|
||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
@ -76,6 +77,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
runnable, config = self._prepare(config)
|
runnable, config = self._prepare(config)
|
||||||
return runnable.get_output_schema(config)
|
return runnable.get_output_schema(config)
|
||||||
|
|
||||||
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.get_graph(config)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, NamedTuple, Optional, Type, Union
|
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables.base import Runnable
|
|
||||||
from langchain_core.runnables.graph_draw import draw
|
from langchain_core.runnables.graph_draw import draw
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.runnables.base import Runnable as RunnableType
|
||||||
|
|
||||||
|
|
||||||
class Edge(NamedTuple):
|
class Edge(NamedTuple):
|
||||||
source: str
|
source: str
|
||||||
@ -16,7 +18,7 @@ class Edge(NamedTuple):
|
|||||||
|
|
||||||
class Node(NamedTuple):
|
class Node(NamedTuple):
|
||||||
id: str
|
id: str
|
||||||
data: Union[Type[BaseModel], Runnable]
|
data: Union[Type[BaseModel], RunnableType]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,7 +32,7 @@ class Graph:
|
|||||||
def next_id(self) -> str:
|
def next_id(self) -> str:
|
||||||
return uuid4().hex
|
return uuid4().hex
|
||||||
|
|
||||||
def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node:
|
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
|
||||||
"""Add a node to the graph and return it."""
|
"""Add a node to the graph and return it."""
|
||||||
node = Node(id=self.next_id(), data=data)
|
node = Node(id=self.next_id(), data=data)
|
||||||
self.nodes[node.id] = node
|
self.nodes[node.id] = node
|
||||||
@ -109,6 +111,8 @@ class Graph:
|
|||||||
self.remove_node(last_node)
|
self.remove_node(last_node)
|
||||||
|
|
||||||
def draw_ascii(self) -> str:
|
def draw_ascii(self) -> str:
|
||||||
|
from langchain_core.runnables.base import Runnable
|
||||||
|
|
||||||
def node_data(node: Node) -> str:
|
def node_data(node: Node) -> str:
|
||||||
if isinstance(node.data, Runnable):
|
if isinstance(node.data, Runnable):
|
||||||
try:
|
try:
|
||||||
|
@ -119,6 +119,53 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
class NonLocals(ast.NodeVisitor):
|
||||||
|
"""Get nonlocal variables accessed."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.loads: Set[str] = set()
|
||||||
|
self.stores: Set[str] = set()
|
||||||
|
|
||||||
|
def visit_Name(self, node: ast.Name) -> Any:
|
||||||
|
if isinstance(node.ctx, ast.Load):
|
||||||
|
self.loads.add(node.id)
|
||||||
|
elif isinstance(node.ctx, ast.Store):
|
||||||
|
self.stores.add(node.id)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||||
|
if isinstance(node.ctx, ast.Load):
|
||||||
|
parent = node.value
|
||||||
|
attr_expr = node.attr
|
||||||
|
while isinstance(parent, ast.Attribute):
|
||||||
|
attr_expr = parent.attr + "." + attr_expr
|
||||||
|
parent = parent.value
|
||||||
|
if isinstance(parent, ast.Name):
|
||||||
|
self.loads.add(parent.id + "." + attr_expr)
|
||||||
|
self.loads.discard(parent.id)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionNonLocals(ast.NodeVisitor):
|
||||||
|
"""Get the nonlocal variables accessed of a function."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.nonlocals: Set[str] = set()
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
|
visitor = NonLocals()
|
||||||
|
visitor.visit(node)
|
||||||
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
|
visitor = NonLocals()
|
||||||
|
visitor.visit(node)
|
||||||
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
|
visitor = NonLocals()
|
||||||
|
visitor.visit(node)
|
||||||
|
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||||
|
|
||||||
|
|
||||||
class GetLambdaSource(ast.NodeVisitor):
|
class GetLambdaSource(ast.NodeVisitor):
|
||||||
"""Get the source code of a lambda function."""
|
"""Get the source code of a lambda function."""
|
||||||
|
|
||||||
@ -169,6 +216,28 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||||
|
"""Get the nonlocal variables accessed by a function."""
|
||||||
|
try:
|
||||||
|
code = inspect.getsource(func)
|
||||||
|
tree = ast.parse(textwrap.dedent(code))
|
||||||
|
visitor = FunctionNonLocals()
|
||||||
|
visitor.visit(tree)
|
||||||
|
values: List[Any] = []
|
||||||
|
for k, v in inspect.getclosurevars(func).nonlocals.items():
|
||||||
|
if k in visitor.nonlocals:
|
||||||
|
values.append(v)
|
||||||
|
for kk in visitor.nonlocals:
|
||||||
|
if "." in kk and kk.startswith(k):
|
||||||
|
vv = v
|
||||||
|
for part in kk.split(".")[1:]:
|
||||||
|
vv = getattr(vv, part)
|
||||||
|
values.append(vv)
|
||||||
|
return values
|
||||||
|
except (SyntaxError, TypeError, OSError):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def indent_lines_after_first(text: str, prefix: str) -> str:
|
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||||
"""Indent all lines of text after the first line.
|
"""Indent all lines of text after the first line.
|
||||||
|
|
||||||
|
@ -32,39 +32,51 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_sequence_map
|
# name: test_graph_sequence_map
|
||||||
'''
|
'''
|
||||||
+-------------+
|
+-------------+
|
||||||
| PromptInput |
|
| PromptInput |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+----------------+
|
+----------------+
|
||||||
| PromptTemplate |
|
| PromptTemplate |
|
||||||
+----------------+
|
+----------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+-------------+
|
+-------------+
|
||||||
| FakeListLLM |
|
| FakeListLLM |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+-----------------------+
|
+-----------------------+
|
||||||
| RunnableParallelInput |
|
| RunnableParallelInput |
|
||||||
+-----------------------+
|
+-----------------------+**
|
||||||
**** ***
|
**** *******
|
||||||
**** ****
|
**** *****
|
||||||
** **
|
** *******
|
||||||
+---------------------+ +--------------------------------+
|
+---------------------+ ***
|
||||||
| RunnablePassthrough | | CommaSeparatedListOutputParser |
|
| RunnableLambdaInput | *
|
||||||
+---------------------+ +--------------------------------+
|
+---------------------+ *
|
||||||
**** ***
|
*** *** *
|
||||||
**** ****
|
*** *** *
|
||||||
** **
|
** ** *
|
||||||
+------------------------+
|
+-----------------+ +-----------------+ *
|
||||||
| RunnableParallelOutput |
|
| StrOutputParser | | XMLOutputParser | *
|
||||||
+------------------------+
|
+-----------------+ +-----------------+ *
|
||||||
|
*** *** *
|
||||||
|
*** *** *
|
||||||
|
** ** *
|
||||||
|
+----------------------+ +--------------------------------+
|
||||||
|
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
|
||||||
|
+----------------------+ +--------------------------------+
|
||||||
|
**** *******
|
||||||
|
**** *****
|
||||||
|
** ****
|
||||||
|
+------------------------+
|
||||||
|
| RunnableParallelOutput |
|
||||||
|
+------------------------+
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
# name: test_graph_single_runnable
|
# name: test_graph_single_runnable
|
||||||
|
@ -2,9 +2,9 @@ from syrupy import SnapshotAssertion
|
|||||||
|
|
||||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||||
from langchain_core.output_parsers.string import StrOutputParser
|
from langchain_core.output_parsers.string import StrOutputParser
|
||||||
|
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
|
||||||
from tests.unit_tests.fake.llm import FakeListLLM
|
from tests.unit_tests.fake.llm import FakeListLLM
|
||||||
|
|
||||||
|
|
||||||
@ -38,13 +38,21 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
fake_llm = FakeListLLM(responses=["a"])
|
fake_llm = FakeListLLM(responses=["a"])
|
||||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||||
list_parser = CommaSeparatedListOutputParser()
|
list_parser = CommaSeparatedListOutputParser()
|
||||||
|
str_parser = StrOutputParser()
|
||||||
|
xml_parser = XMLOutputParser()
|
||||||
|
|
||||||
|
def conditional_str_parser(input: str) -> Runnable:
|
||||||
|
if input == "a":
|
||||||
|
return str_parser
|
||||||
|
else:
|
||||||
|
return xml_parser
|
||||||
|
|
||||||
sequence: Runnable = (
|
sequence: Runnable = (
|
||||||
prompt
|
prompt
|
||||||
| fake_llm
|
| fake_llm
|
||||||
| {
|
| {
|
||||||
"original": RunnablePassthrough(input_type=str),
|
|
||||||
"as_list": list_parser,
|
"as_list": list_parser,
|
||||||
|
"as_str": conditional_str_parser,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
graph = sequence.get_graph()
|
graph = sequence.get_graph()
|
||||||
|
Loading…
Reference in New Issue
Block a user