From ab21af71be3c5a2fbe548061228df525c635ba86 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 17:28:02 -0700 Subject: [PATCH] wip --- .../langchain/schema/runnable/base.py | 18 +++++----- .../unit_tests/schema/runnable/test_locals.py | 36 ++++++++++++++++++- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 704a518cde8..c91456394e2 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): step.invoke, input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(deepcopy(config), run_manager.get_child()), ) for step in steps.values() ] @@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index d0a3fb38d9d..dce548fc69f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -1,6 +1,13 @@ import pytest -from langchain.schema.runnable import GetLocalVar, PutLocalVar +from langchain import PromptTemplate +from langchain.llms import FakeListLLM +from langchain.schema.runnable import ( + GetLocalVar, + PutLocalVar, + RunnablePassthrough, + RunnableSequence, +) @pytest.mark.asyncio @@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None: with pytest.raises(KeyError): runnable.invoke("foo") + +def test_get_in_map() -> None: + runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")} + assert runnable.invoke("foo") == {"bar": "foo"} + + +def test_cant_put_in_map() -> None: + runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") + with pytest.raises(KeyError): + runnable.invoke("foo") + + +def test_get_passthrough_key() -> None: + runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output") + assert runnable.invoke("foo") == {"input": "foo", "output": "foo"} + + +def test_multi_step_sequence() -> None: + prompt = PromptTemplate.from_template("say {foo}") + runnable = ( + PutLocalVar("foo") + | {"foo": RunnablePassthrough()} + | prompt + | FakeListLLM(responses=["bar"]) + | GetLocalVar("foo", passthrough_key="output") + ) + assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}