This commit is contained in:
Bagatur 2023-08-17 17:28:02 -07:00
parent 6f69b19ff5
commit ab21af71be
2 changed files with 44 additions and 10 deletions

View File

@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
config = config or {} config = config or _empty_config()
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = config or {} config = config or _empty_config()
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
local_callbacks=None, local_callbacks=None,
@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.invoke, step.invoke,
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, run_manager.get_child()), patch_config(deepcopy(config), run_manager.get_child()),
) )
for step in steps.values() for step in steps.values()
] ]
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(

View File

@ -1,6 +1,13 @@
import pytest import pytest
from langchain.schema.runnable import GetLocalVar, PutLocalVar from langchain import PromptTemplate
from langchain.llms import FakeListLLM
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
RunnablePassthrough,
RunnableSequence,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None:
with pytest.raises(KeyError): with pytest.raises(KeyError):
runnable.invoke("foo") runnable.invoke("foo")
def test_get_in_map() -> None:
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}
def test_cant_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")
def test_get_passthrough_key() -> None:
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
def test_multi_step_sequence() -> None:
prompt = PromptTemplate.from_template("say {foo}")
runnable = (
PutLocalVar("foo")
| {"foo": RunnablePassthrough()}
| prompt
| FakeListLLM(responses=["bar"])
| GetLocalVar("foo", passthrough_key="output")
)
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}