mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 18:48:50 +00:00
RunnableBranch (#10594)
Runnable Branch implementation, no optimization for streaming logic yet
This commit is contained in:
parent
287c81db89
commit
1eefb9052b
@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
||||
from langchain.schema.runnable.base import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnableSequence,
|
||||
@ -12,16 +13,17 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
|
||||
__all__ = [
|
||||
"patch_config",
|
||||
"GetLocalVar",
|
||||
"patch_config",
|
||||
"PutLocalVar",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
"RunnableMap",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnablePassthrough",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
|
@ -658,6 +658,188 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
|
||||
|
||||
class RunnableBranch(Serializable, Runnable[Input, Output]):
|
||||
"""A Runnable that selects which branch to run based on a condition.
|
||||
|
||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||
a default branch.
|
||||
|
||||
When operating on an input, the first condition that evaluates to True is
|
||||
selected, and the corresponding runnable is run on the input.
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema.runnable import RunnableBranch
|
||||
|
||||
branch = RunnableBranch(
|
||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||
lambda x: "goodbye",
|
||||
)
|
||||
|
||||
branch.invoke("hello") # "HELLO"
|
||||
branch.invoke(None) # "goodbye"
|
||||
"""
|
||||
|
||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
default: Runnable[Input, Output]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*branches: Union[
|
||||
Tuple[
|
||||
Union[
|
||||
Runnable[Input, bool],
|
||||
Callable[[Input], bool],
|
||||
Callable[[Input], Awaitable[bool]],
|
||||
],
|
||||
RunnableLike,
|
||||
],
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
default = branches[-1]
|
||||
|
||||
if not isinstance(
|
||||
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||
)
|
||||
|
||||
_branches = []
|
||||
|
||||
for branch in branches[:-1]:
|
||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||
raise TypeError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists, not {type(branch)}"
|
||||
)
|
||||
|
||||
if not len(branch) == 2:
|
||||
raise ValueError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists of length 2, not {len(branch)}"
|
||||
)
|
||||
condition, runnable = branch
|
||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = condition.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
return runnable.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
|
||||
),
|
||||
)
|
||||
|
||||
output = self.default.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Async version of invoke."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = await condition.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
return await runnable.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = await self.default.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A Runnable that can fallback to other Runnables if it fails.
|
||||
@ -2007,14 +2189,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
|
||||
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
|
||||
RunnableLike = Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Mapping[str, Any],
|
||||
]
|
||||
|
||||
def coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Mapping[str, Any],
|
||||
]
|
||||
) -> Runnable[Input, Output]:
|
||||
|
||||
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif callable(thing):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@ -34,6 +34,7 @@ from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
Runnable,
|
||||
RunnableBranch,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
@ -541,7 +542,7 @@ async def test_prompt_with_llm(
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test stream
|
||||
# Test stream#
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
@ -1816,3 +1817,205 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||
assert len(parent_run_qux.child_runs) == 4
|
||||
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||
|
||||
|
||||
def test_runnable_branch_init() -> None:
|
||||
"""Verify that runnable branch gets initialized properly."""
|
||||
add = RunnableLambda(lambda x: x + 1)
|
||||
condition = RunnableLambda(lambda x: x > 0)
|
||||
|
||||
# Test failure with less than 2 branches
|
||||
with pytest.raises(ValueError):
|
||||
RunnableBranch((condition, add))
|
||||
|
||||
# Test failure with less than 2 branches
|
||||
with pytest.raises(ValueError):
|
||||
RunnableBranch(condition)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"branches",
|
||||
[
|
||||
[
|
||||
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
||||
RunnableLambda(lambda x: x - 1),
|
||||
],
|
||||
[
|
||||
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
||||
(RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)),
|
||||
RunnableLambda(lambda x: x - 1),
|
||||
],
|
||||
[
|
||||
(lambda x: x > 0, lambda x: x + 1),
|
||||
(lambda x: x > 5, lambda x: x + 1),
|
||||
lambda x: x - 1,
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
||||
"""Verify that runnable branch gets initialized properly."""
|
||||
runnable = RunnableBranch[int, int](*branches)
|
||||
for branch in runnable.branches:
|
||||
condition, body = branch
|
||||
assert isinstance(condition, Runnable)
|
||||
assert isinstance(body, Runnable)
|
||||
|
||||
assert isinstance(runnable.default, Runnable)
|
||||
|
||||
|
||||
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||
"""Verify that runnables are invoked only when necessary."""
|
||||
# Test with single branch
|
||||
add = RunnableLambda(lambda x: x + 1)
|
||||
sub = RunnableLambda(lambda x: x - 1)
|
||||
condition = RunnableLambda(lambda x: x > 0)
|
||||
spy = mocker.spy(condition, "invoke")
|
||||
add_spy = mocker.spy(add, "invoke")
|
||||
|
||||
branch = RunnableBranch[int, int]((condition, add), (condition, add), sub)
|
||||
assert spy.call_count == 0
|
||||
assert add_spy.call_count == 0
|
||||
|
||||
assert branch.invoke(1) == 2
|
||||
assert add_spy.call_count == 1
|
||||
assert spy.call_count == 1
|
||||
|
||||
assert branch.invoke(2) == 3
|
||||
assert spy.call_count == 2
|
||||
assert add_spy.call_count == 2
|
||||
|
||||
assert branch.invoke(-3) == -4
|
||||
# Should fall through to default branch with condition being evaluated twice!
|
||||
assert spy.call_count == 4
|
||||
# Add should not be invoked
|
||||
assert add_spy.call_count == 2
|
||||
|
||||
|
||||
def test_runnable_branch_invoke() -> None:
|
||||
# Test with single branch
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 100, raise_value_error),
|
||||
# mypy cannot infer types from the lambda
|
||||
(lambda x: x > 0 and x < 5, lambda x: x + 1), # type: ignore[misc]
|
||||
(lambda x: x > 5, lambda x: x * 10),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert branch.invoke(1) == 2
|
||||
assert branch.invoke(10) == 100
|
||||
assert branch.invoke(0) == -1
|
||||
# Should raise an exception
|
||||
with pytest.raises(ValueError):
|
||||
branch.invoke(1000)
|
||||
|
||||
|
||||
def test_runnable_branch_batch() -> None:
|
||||
"""Test batch variant."""
|
||||
# Test with single branch
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||
(lambda x: x > 5, lambda x: x * 10),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert branch.batch([1, 10, 0]) == [2, 100, -1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runnable_branch_ainvoke() -> None:
|
||||
"""Test async variant of invoke."""
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||
(lambda x: x > 5, lambda x: x * 10),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert await branch.ainvoke(1) == 2
|
||||
assert await branch.ainvoke(10) == 100
|
||||
assert await branch.ainvoke(0) == -1
|
||||
|
||||
# Verify that the async variant is used if available
|
||||
async def condition(x: int) -> bool:
|
||||
return x > 0
|
||||
|
||||
async def add(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
async def sub(x: int) -> int:
|
||||
return x - 1
|
||||
|
||||
branch = RunnableBranch[int, int]((condition, add), sub)
|
||||
|
||||
assert await branch.ainvoke(1) == 2
|
||||
assert await branch.ainvoke(-10) == -11
|
||||
|
||||
|
||||
def test_runnable_branch_invoke_callbacks() -> None:
|
||||
"""Verify that callbacks are correctly used in invoke."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 100, raise_value_error),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert branch.invoke(1, config={"callbacks": [tracer]}) == 0
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": 0}
|
||||
|
||||
# Check that the chain on end is invoked
|
||||
with pytest.raises(ValueError):
|
||||
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||
"""Verify that callbacks are invoked correctly in ainvoke."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
async def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 100, raise_value_error),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": 0}
|
||||
|
||||
# Check that the chain on end is invoked
|
||||
with pytest.raises(ValueError):
|
||||
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runnable_branch_abatch() -> None:
|
||||
"""Test async variant of invoke."""
|
||||
branch = RunnableBranch[int, int](
|
||||
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||
(lambda x: x > 5, lambda x: x * 10),
|
||||
lambda x: x - 1,
|
||||
)
|
||||
|
||||
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||
|
Loading…
Reference in New Issue
Block a user