mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
RunnableLambda, if func returns a Runnable, run it
This commit is contained in:
parent
677da6a0fd
commit
6d19709b65
@ -13,7 +13,6 @@ from typing import (
|
|||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]],
|
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
|
||||||
afunc: Optional[Coroutine[Input, Any, Output]] = None,
|
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if afunc is not None:
|
if afunc is not None:
|
||||||
self.afunc = afunc
|
self.afunc = afunc
|
||||||
@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
self.afunc = func
|
self.afunc = func
|
||||||
elif callable(func):
|
elif callable(func):
|
||||||
self.func = func
|
self.func = cast(Callable[[Input], Output], func)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected a callable type for `func`."
|
"Expected a callable type for `func`."
|
||||||
@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
output = self.func(input)
|
||||||
|
# If the output is a runnable, invoke it
|
||||||
|
if isinstance(output, Runnable):
|
||||||
|
recursion_limit = config["recursion_limit"]
|
||||||
|
if recursion_limit <= 0:
|
||||||
|
raise RecursionError(
|
||||||
|
f"Recursion limit reached when invoking {self} with input {input}."
|
||||||
|
)
|
||||||
|
output = output.invoke(
|
||||||
|
input,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
recursion_limit=recursion_limit - 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
output = await self.afunc(input)
|
||||||
|
# If the output is a runnable, invoke it
|
||||||
|
if isinstance(output, Runnable):
|
||||||
|
recursion_limit = config["recursion_limit"]
|
||||||
|
if recursion_limit <= 0:
|
||||||
|
raise RecursionError(
|
||||||
|
f"Recursion limit reached when invoking {self} with input {input}."
|
||||||
|
)
|
||||||
|
output = await output.ainvoke(
|
||||||
|
input,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
recursion_limit=recursion_limit - 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
if hasattr(self, "func"):
|
if hasattr(self, "func"):
|
||||||
return self._call_with_config(self.func, input, config)
|
return self._call_with_config(self._invoke, input, config)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Cannot invoke a coroutine function synchronously."
|
"Cannot invoke a coroutine function synchronously."
|
||||||
@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
if hasattr(self, "afunc"):
|
if hasattr(self, "afunc"):
|
||||||
return await self._acall_with_config(self.afunc, input, config)
|
return await self._acall_with_config(self._ainvoke, input, config)
|
||||||
else:
|
else:
|
||||||
return await super().ainvoke(input, config)
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
ThreadPoolExecutor will be created.
|
ThreadPoolExecutor will be created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
recursion_limit: int
|
||||||
|
"""
|
||||||
|
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
empty = RunnableConfig(
|
empty = RunnableConfig(
|
||||||
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
metadata={},
|
metadata={},
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
)
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(config)
|
empty.update(config)
|
||||||
@ -66,6 +72,7 @@ def patch_config(
|
|||||||
deep_copy_locals: bool = False,
|
deep_copy_locals: bool = False,
|
||||||
callbacks: Optional[BaseCallbackManager] = None,
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
executor: Optional[Executor] = None,
|
executor: Optional[Executor] = None,
|
||||||
|
recursion_limit: Optional[int] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if deep_copy_locals:
|
if deep_copy_locals:
|
||||||
@ -74,6 +81,8 @@ def patch_config(
|
|||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
if executor is not None:
|
if executor is not None:
|
||||||
config["executor"] = executor
|
config["executor"] = executor
|
||||||
|
if recursion_limit is not None:
|
||||||
|
config["recursion_limit"] = recursion_limit
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
from ast import Not
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -39,7 +39,6 @@ from langchain.schema.runnable import (
|
|||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
passthrough,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -768,6 +769,105 @@ async def test_router_runnable(
|
|||||||
assert len(router_run.child_runs) == 2
|
assert len(router_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_higher_order_lambda_runnable(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
math_chain = ChatPromptTemplate.from_template(
|
||||||
|
"You are a math genius. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["4"])
|
||||||
|
english_chain = ChatPromptTemplate.from_template(
|
||||||
|
"You are an english major. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["2"])
|
||||||
|
input_map = RunnableMap(
|
||||||
|
{ # type: ignore[arg-type]
|
||||||
|
"key": lambda x: x["key"],
|
||||||
|
"input": {"question": lambda x: x["question"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def router(input: Dict[str, Any]) -> Runnable:
|
||||||
|
if input["key"] == "math":
|
||||||
|
return itemgetter("input") | math_chain
|
||||||
|
elif input["key"] == "english":
|
||||||
|
return itemgetter("input") | english_chain
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown key: {input['key']}")
|
||||||
|
|
||||||
|
chain: Runnable = input_map | router
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = chain.batch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = await chain.abatch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
math_spy = mocker.spy(math_chain.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||||
|
== "4"
|
||||||
|
)
|
||||||
|
assert math_spy.call_args.args[1] == {
|
||||||
|
"key": "math",
|
||||||
|
"input": {"question": "2 + 2"},
|
||||||
|
}
|
||||||
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
|
assert len(parent_run.child_runs) == 2
|
||||||
|
router_run = parent_run.child_runs[1]
|
||||||
|
assert router_run.name == "RunnableLambda"
|
||||||
|
assert len(router_run.child_runs) == 1
|
||||||
|
math_run = router_run.child_runs[0]
|
||||||
|
assert math_run.name == "RunnableSequence"
|
||||||
|
assert len(math_run.child_runs) == 3
|
||||||
|
|
||||||
|
# Test ainvoke
|
||||||
|
async def arouter(input: Dict[str, Any]) -> Runnable:
|
||||||
|
if input["key"] == "math":
|
||||||
|
return itemgetter("input") | math_chain
|
||||||
|
elif input["key"] == "english":
|
||||||
|
return itemgetter("input") | english_chain
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown key: {input['key']}")
|
||||||
|
|
||||||
|
achain: Runnable = input_map | arouter
|
||||||
|
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
await achain.ainvoke(
|
||||||
|
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
== "4"
|
||||||
|
)
|
||||||
|
assert math_spy.call_args.args[1] == {
|
||||||
|
"key": "math",
|
||||||
|
"input": {"question": "2 + 2"},
|
||||||
|
}
|
||||||
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
|
assert len(parent_run.child_runs) == 2
|
||||||
|
router_run = parent_run.child_runs[1]
|
||||||
|
assert router_run.name == "RunnableLambda"
|
||||||
|
assert len(router_run.child_runs) == 1
|
||||||
|
math_run = router_run.child_runs[0]
|
||||||
|
assert math_run.name == "RunnableSequence"
|
||||||
|
assert len(math_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||||
|
Loading…
Reference in New Issue
Block a user