From ad16eed11969c42388d9f2a499eb47ae87701fbf Mon Sep 17 00:00:00 2001 From: WU LIFU Date: Fri, 2 Aug 2024 08:30:32 +0800 Subject: [PATCH] =?UTF-8?q?core[patch]:=20runnable=20config=20ensure=5Fcon?= =?UTF-8?q?fig=20deep=20copy=20from=20var=5Fchild=5Frunnable=E2=80=A6=20(#?= =?UTF-8?q?24862)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **issue**: #24660 RunnableWithMessageHistory.stream result in error because the [evaluation](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/runnables/branch.py#L220) of the branch [condition](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/history.py#L328C1-L329C1) unexpectedly trigger the "[on_end](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/history.py#L332)" (exit_history) callback of the default branch **descriptions** After a lot of investigation I'm convinced that the root cause is that 1. during the execution of the runnable, the [var_child_runnable_config](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/config.py#L122) is shared between the branch [condition](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/history.py#L328C1-L329C1) runnable and the [default branch runnable](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/history.py#L332) within the same context 2. when the default branch runnable runs, it gets the [var_child_runnable_config](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/config.py#L163) and may unintentionally [add more handlers ](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/config.py#L325)to the callback manager of this config 3. when it is again the turn for the [condition](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/history.py#L328C1-L329C1) to run, it gets the `var_child_runnable_config` whose callback manager has the handlers added by the default branch. When it runs that handler (`exit_history`) it leads to the error with the assumption that, the `ensure_config` function actually does want to create a immutable copy from `var_child_runnable_config` because it starts with an [`empty` variable ](https://github.com/langchain-ai/langchain/blob/99eb31ec4121e146e007ea584b2d5f4aff2a4337/libs/core/langchain_core/runnables/config.py#L156), i go ahead to do a deepcopy to ensure that future modification to the returned value won't affect the `var_child_runnable_config` variable Having said that I actually 1. don't know if this is a proper fix 2. don't know whether it will lead to other unintended consequence 3. don't know why only "stream" runs into this issue while "invoke" runs without problem so @nfcampos @hwchase17 please help review, thanks! --------- Co-authored-by: Lifu Wu Co-authored-by: Nuno Campos Co-authored-by: Bagatur --- libs/core/langchain_core/runnables/config.py | 42 +++++++++++---- .../tests/unit_tests/runnables/test_config.py | 54 ++++++++++++++++++- .../unit_tests/runnables/test_history.py | 10 ++++ .../unit_tests/runnables/test_runnable.py | 4 +- 4 files changed, 98 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 33174270a72..0d71b7fa9f7 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -116,6 +116,13 @@ CONFIG_KEYS = [ "run_id", ] +COPIABLE_KEYS = [ + "tags", + "metadata", + "callbacks", + "configurable", +] + DEFAULT_RECURSION_LIMIT = 25 @@ -162,15 +169,30 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: ) if var_config := var_child_runnable_config.get(): empty.update( - cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None}) + cast( + RunnableConfig, + { + k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined] + for k, v in var_config.items() + if v is not None + }, + ) + ) + if config is not None: + empty.update( + cast( + RunnableConfig, + { + k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined] + for k, v in config.items() + if v is not None and k in CONFIG_KEYS + }, + ) ) if config is not None: for k, v in config.items(): - if v is not None: - if k in CONFIG_KEYS: - empty[k] = v # type: ignore[literal-required] - else: - empty["configurable"][k] = v + if k not in CONFIG_KEYS and v is not None: + empty["configurable"][k] = v for key, value in empty.get("configurable", {}).items(): if ( not key.startswith("__") @@ -291,7 +313,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: **(config.get(key) or {}), # type: ignore } elif key == "tags": - base[key] = list( # type: ignore + base[key] = sorted( # type: ignore set(base.get(key, []) + (config.get(key) or [])), # type: ignore ) elif key == "configurable": @@ -306,7 +328,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: # so merging two callbacks values has 6 cases if isinstance(these_callbacks, list): if base_callbacks is None: - base["callbacks"] = these_callbacks + base["callbacks"] = these_callbacks.copy() elif isinstance(base_callbacks, list): base["callbacks"] = base_callbacks + these_callbacks else: @@ -318,7 +340,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: elif these_callbacks is not None: # these_callbacks is a manager if base_callbacks is None: - base["callbacks"] = these_callbacks + base["callbacks"] = these_callbacks.copy() elif isinstance(base_callbacks, list): mngr = these_callbacks.copy() for callback in base_callbacks: @@ -361,6 +383,8 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: elif key == "recursion_limit": if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT: base["recursion_limit"] = config["recursion_limit"] + elif key in COPIABLE_KEYS and config[key] is not None: # type: ignore[literal-required] + base[key] = config[key].copy() # type: ignore[literal-required] else: base[key] = config[key] or base.get(key) # type: ignore return base diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index 6e68b40484d..a8118bbd3bb 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -1,4 +1,7 @@ -from typing import Any, cast +import json +import uuid +from contextvars import copy_context +from typing import Any, Dict, cast import pytest @@ -8,12 +11,61 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan from langchain_core.runnables import RunnableBinding, RunnablePassthrough from langchain_core.runnables.config import ( RunnableConfig, + _set_config_context, + ensure_config, merge_configs, run_in_executor, ) from langchain_core.tracers.stdout import ConsoleCallbackHandler +def test_ensure_config() -> None: + run_id = str(uuid.uuid4()) + arg: Dict = { + "something": "else", + "metadata": {"foo": "bar"}, + "configurable": {"baz": "qux"}, + "callbacks": [StdOutCallbackHandler()], + "tags": ["tag1", "tag2"], + "max_concurrency": 1, + "recursion_limit": 100, + "run_id": run_id, + "run_name": "test", + } + arg_str = json.dumps({**arg, "callbacks": []}) + ctx = copy_context() + ctx.run( + _set_config_context, + { + "callbacks": [ConsoleCallbackHandler()], + "metadata": {"a": "b"}, + "configurable": {"c": "d"}, + "tags": ["tag3", "tag4"], + }, + ) + config = ctx.run(ensure_config, cast(RunnableConfig, arg)) + assert ( + len(arg["callbacks"]) == 1 + ), "ensure_config should not modify the original config" + assert ( + json.dumps({**arg, "callbacks": []}) == arg_str + ), "ensure_config should not modify the original config" + assert config is not arg + assert config["callbacks"] is not arg["callbacks"] + assert config["metadata"] is not arg["metadata"] + assert config["configurable"] is not arg["configurable"] + assert config == { + "tags": ["tag1", "tag2"], + "metadata": {"foo": "bar", "baz": "qux", "something": "else"}, + "callbacks": [arg["callbacks"][0]], + "recursion_limit": 100, + "configurable": {"baz": "qux", "something": "else"}, + "max_concurrency": 1, + "run_id": run_id, + "run_name": "test", + } + + def test_merge_config_callbacks() -> None: manager: RunnableConfig = { "callbacks": CallbackManager(handlers=[StdOutCallbackHandler()]) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index e85789736be..3de59e22973 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -53,6 +53,8 @@ def test_input_messages() -> None: assert output == "you said: hello" output = with_history.invoke([HumanMessage(content="good bye")], config) assert output == "you said: hello\ngood bye" + output = [*with_history.stream([HumanMessage(content="hi again")], config)] + assert output == ["you said: hello\ngood bye\nhi again"] assert store == { "1": InMemoryChatMessageHistory( messages=[ @@ -60,6 +62,8 @@ def test_input_messages() -> None: AIMessage(content="you said: hello"), HumanMessage(content="good bye"), AIMessage(content="you said: hello\ngood bye"), + HumanMessage(content="hi again"), + AIMessage(content="you said: hello\ngood bye\nhi again"), ] ) } @@ -78,6 +82,10 @@ async def test_input_messages_async() -> None: assert output == "you said: hello" output = await with_history.ainvoke([HumanMessage(content="good bye")], config) # type: ignore[arg-type] assert output == "you said: hello\ngood bye" + output = [ + c + async for c in with_history.astream([HumanMessage(content="hi again")], config) # type: ignore[arg-type] + ] == ["you said: hello\ngood bye\nhi again"] assert store == { "1_async": InMemoryChatMessageHistory( messages=[ @@ -85,6 +93,8 @@ async def test_input_messages_async() -> None: AIMessage(content="you said: hello"), HumanMessage(content="good bye"), AIMessage(content="you said: hello\ngood bye"), + HumanMessage(content="hi again"), + AIMessage(content="you said: hello\ngood bye\nhi again"), ] ) } diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index c1044885c5e..25125be8c10 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5402,7 +5402,7 @@ def test_listeners() -> None: shared_state = {} value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} - value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}} def on_start(run: Run) -> None: shared_state[run.id] = {"inputs": run.inputs} @@ -5432,7 +5432,7 @@ async def test_listeners_async() -> None: shared_state = {} value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} - value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}} + value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}} def on_start(run: Run) -> None: shared_state[run.id] = {"inputs": run.inputs}