This commit is contained in:
Bagatur 2023-08-17 17:28:02 -07:00
parent 6f69b19ff5
commit ab21af71be
2 changed files with 44 additions and 10 deletions

View File

@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
config = config or {}
config = config or _empty_config()
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
config = config or _empty_config()
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.invoke,
input,
# mark each step as a child run
patch_config(config, run_manager.get_child()),
patch_config(deepcopy(config), run_manager.get_child()),
)
for step in steps.values()
]
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(

View File

@ -1,6 +1,13 @@
import pytest
from langchain.schema.runnable import GetLocalVar, PutLocalVar
from langchain import PromptTemplate
from langchain.llms import FakeListLLM
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
RunnablePassthrough,
RunnableSequence,
)
@pytest.mark.asyncio
@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None:
with pytest.raises(KeyError):
runnable.invoke("foo")
def test_get_in_map() -> None:
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}
def test_cant_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")
def test_get_passthrough_key() -> None:
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
def test_multi_step_sequence() -> None:
prompt = PromptTemplate.from_template("say {foo}")
runnable = (
PutLocalVar("foo")
| {"foo": RunnablePassthrough()}
| prompt
| FakeListLLM(responses=["bar"])
| GetLocalVar("foo", passthrough_key="output")
)
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}