mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
core[patch]: runnable config ensure_config deep copy from var_child_runnable… (#24862)
**issue**: #24660 RunnableWithMessageHistory.stream result in error because the [evaluation](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/runnables/branch.py#L220) of the branch [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) unexpectedly trigger the "[on_end](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332)
)" (exit_history) callback of the default branch **descriptions** After a lot of investigation I'm convinced that the root cause is that 1. during the execution of the runnable, the [var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L122)
) is shared between the branch [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) runnable and the [default branch runnable](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332)
) within the same context 2. when the default branch runnable runs, it gets the [var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L163)
) and may unintentionally [add more handlers ](99eb31ec41/libs/core/langchain_core/runnables/config.py (L325)
)to the callback manager of this config 3. when it is again the turn for the [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) to run, it gets the `var_child_runnable_config` whose callback manager has the handlers added by the default branch. When it runs that handler (`exit_history`) it leads to the error with the assumption that, the `ensure_config` function actually does want to create a immutable copy from `var_child_runnable_config` because it starts with an [`empty` variable ](99eb31ec41/libs/core/langchain_core/runnables/config.py (L156)
), i go ahead to do a deepcopy to ensure that future modification to the returned value won't affect the `var_child_runnable_config` variable Having said that I actually 1. don't know if this is a proper fix 2. don't know whether it will lead to other unintended consequence 3. don't know why only "stream" runs into this issue while "invoke" runs without problem so @nfcampos @hwchase17 please help review, thanks! --------- Co-authored-by: Lifu Wu <lifu@nextbillion.ai> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3ab09d87d6
commit
ad16eed119
@ -116,6 +116,13 @@ CONFIG_KEYS = [
|
|||||||
"run_id",
|
"run_id",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
COPIABLE_KEYS = [
|
||||||
|
"tags",
|
||||||
|
"metadata",
|
||||||
|
"callbacks",
|
||||||
|
"configurable",
|
||||||
|
]
|
||||||
|
|
||||||
DEFAULT_RECURSION_LIMIT = 25
|
DEFAULT_RECURSION_LIMIT = 25
|
||||||
|
|
||||||
|
|
||||||
@ -162,14 +169,29 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
|||||||
)
|
)
|
||||||
if var_config := var_child_runnable_config.get():
|
if var_config := var_child_runnable_config.get():
|
||||||
empty.update(
|
empty.update(
|
||||||
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
cast(
|
||||||
|
RunnableConfig,
|
||||||
|
{
|
||||||
|
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||||
|
for k, v in var_config.items()
|
||||||
|
if v is not None
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config is not None:
|
||||||
|
empty.update(
|
||||||
|
cast(
|
||||||
|
RunnableConfig,
|
||||||
|
{
|
||||||
|
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||||
|
for k, v in config.items()
|
||||||
|
if v is not None and k in CONFIG_KEYS
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if v is not None:
|
if k not in CONFIG_KEYS and v is not None:
|
||||||
if k in CONFIG_KEYS:
|
|
||||||
empty[k] = v # type: ignore[literal-required]
|
|
||||||
else:
|
|
||||||
empty["configurable"][k] = v
|
empty["configurable"][k] = v
|
||||||
for key, value in empty.get("configurable", {}).items():
|
for key, value in empty.get("configurable", {}).items():
|
||||||
if (
|
if (
|
||||||
@ -291,7 +313,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
**(config.get(key) or {}), # type: ignore
|
**(config.get(key) or {}), # type: ignore
|
||||||
}
|
}
|
||||||
elif key == "tags":
|
elif key == "tags":
|
||||||
base[key] = list( # type: ignore
|
base[key] = sorted( # type: ignore
|
||||||
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
|
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
|
||||||
)
|
)
|
||||||
elif key == "configurable":
|
elif key == "configurable":
|
||||||
@ -306,7 +328,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
# so merging two callbacks values has 6 cases
|
# so merging two callbacks values has 6 cases
|
||||||
if isinstance(these_callbacks, list):
|
if isinstance(these_callbacks, list):
|
||||||
if base_callbacks is None:
|
if base_callbacks is None:
|
||||||
base["callbacks"] = these_callbacks
|
base["callbacks"] = these_callbacks.copy()
|
||||||
elif isinstance(base_callbacks, list):
|
elif isinstance(base_callbacks, list):
|
||||||
base["callbacks"] = base_callbacks + these_callbacks
|
base["callbacks"] = base_callbacks + these_callbacks
|
||||||
else:
|
else:
|
||||||
@ -318,7 +340,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
elif these_callbacks is not None:
|
elif these_callbacks is not None:
|
||||||
# these_callbacks is a manager
|
# these_callbacks is a manager
|
||||||
if base_callbacks is None:
|
if base_callbacks is None:
|
||||||
base["callbacks"] = these_callbacks
|
base["callbacks"] = these_callbacks.copy()
|
||||||
elif isinstance(base_callbacks, list):
|
elif isinstance(base_callbacks, list):
|
||||||
mngr = these_callbacks.copy()
|
mngr = these_callbacks.copy()
|
||||||
for callback in base_callbacks:
|
for callback in base_callbacks:
|
||||||
@ -361,6 +383,8 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
elif key == "recursion_limit":
|
elif key == "recursion_limit":
|
||||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||||
base["recursion_limit"] = config["recursion_limit"]
|
base["recursion_limit"] = config["recursion_limit"]
|
||||||
|
elif key in COPIABLE_KEYS and config[key] is not None: # type: ignore[literal-required]
|
||||||
|
base[key] = config[key].copy() # type: ignore[literal-required]
|
||||||
else:
|
else:
|
||||||
base[key] = config[key] or base.get(key) # type: ignore
|
base[key] = config[key] or base.get(key) # type: ignore
|
||||||
return base
|
return base
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from typing import Any, cast
|
import json
|
||||||
|
import uuid
|
||||||
|
from contextvars import copy_context
|
||||||
|
from typing import Any, Dict, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -8,12 +11,61 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
|
|||||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
|
_set_config_context,
|
||||||
|
ensure_config,
|
||||||
merge_configs,
|
merge_configs,
|
||||||
run_in_executor,
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_config() -> None:
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
arg: Dict = {
|
||||||
|
"something": "else",
|
||||||
|
"metadata": {"foo": "bar"},
|
||||||
|
"configurable": {"baz": "qux"},
|
||||||
|
"callbacks": [StdOutCallbackHandler()],
|
||||||
|
"tags": ["tag1", "tag2"],
|
||||||
|
"max_concurrency": 1,
|
||||||
|
"recursion_limit": 100,
|
||||||
|
"run_id": run_id,
|
||||||
|
"run_name": "test",
|
||||||
|
}
|
||||||
|
arg_str = json.dumps({**arg, "callbacks": []})
|
||||||
|
ctx = copy_context()
|
||||||
|
ctx.run(
|
||||||
|
_set_config_context,
|
||||||
|
{
|
||||||
|
"callbacks": [ConsoleCallbackHandler()],
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"configurable": {"c": "d"},
|
||||||
|
"tags": ["tag3", "tag4"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
config = ctx.run(ensure_config, cast(RunnableConfig, arg))
|
||||||
|
assert (
|
||||||
|
len(arg["callbacks"]) == 1
|
||||||
|
), "ensure_config should not modify the original config"
|
||||||
|
assert (
|
||||||
|
json.dumps({**arg, "callbacks": []}) == arg_str
|
||||||
|
), "ensure_config should not modify the original config"
|
||||||
|
assert config is not arg
|
||||||
|
assert config["callbacks"] is not arg["callbacks"]
|
||||||
|
assert config["metadata"] is not arg["metadata"]
|
||||||
|
assert config["configurable"] is not arg["configurable"]
|
||||||
|
assert config == {
|
||||||
|
"tags": ["tag1", "tag2"],
|
||||||
|
"metadata": {"foo": "bar", "baz": "qux", "something": "else"},
|
||||||
|
"callbacks": [arg["callbacks"][0]],
|
||||||
|
"recursion_limit": 100,
|
||||||
|
"configurable": {"baz": "qux", "something": "else"},
|
||||||
|
"max_concurrency": 1,
|
||||||
|
"run_id": run_id,
|
||||||
|
"run_name": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_merge_config_callbacks() -> None:
|
def test_merge_config_callbacks() -> None:
|
||||||
manager: RunnableConfig = {
|
manager: RunnableConfig = {
|
||||||
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
||||||
|
@ -53,6 +53,8 @@ def test_input_messages() -> None:
|
|||||||
assert output == "you said: hello"
|
assert output == "you said: hello"
|
||||||
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
||||||
assert output == "you said: hello\ngood bye"
|
assert output == "you said: hello\ngood bye"
|
||||||
|
output = [*with_history.stream([HumanMessage(content="hi again")], config)]
|
||||||
|
assert output == ["you said: hello\ngood bye\nhi again"]
|
||||||
assert store == {
|
assert store == {
|
||||||
"1": InMemoryChatMessageHistory(
|
"1": InMemoryChatMessageHistory(
|
||||||
messages=[
|
messages=[
|
||||||
@ -60,6 +62,8 @@ def test_input_messages() -> None:
|
|||||||
AIMessage(content="you said: hello"),
|
AIMessage(content="you said: hello"),
|
||||||
HumanMessage(content="good bye"),
|
HumanMessage(content="good bye"),
|
||||||
AIMessage(content="you said: hello\ngood bye"),
|
AIMessage(content="you said: hello\ngood bye"),
|
||||||
|
HumanMessage(content="hi again"),
|
||||||
|
AIMessage(content="you said: hello\ngood bye\nhi again"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -78,6 +82,10 @@ async def test_input_messages_async() -> None:
|
|||||||
assert output == "you said: hello"
|
assert output == "you said: hello"
|
||||||
output = await with_history.ainvoke([HumanMessage(content="good bye")], config) # type: ignore[arg-type]
|
output = await with_history.ainvoke([HumanMessage(content="good bye")], config) # type: ignore[arg-type]
|
||||||
assert output == "you said: hello\ngood bye"
|
assert output == "you said: hello\ngood bye"
|
||||||
|
output = [
|
||||||
|
c
|
||||||
|
async for c in with_history.astream([HumanMessage(content="hi again")], config) # type: ignore[arg-type]
|
||||||
|
] == ["you said: hello\ngood bye\nhi again"]
|
||||||
assert store == {
|
assert store == {
|
||||||
"1_async": InMemoryChatMessageHistory(
|
"1_async": InMemoryChatMessageHistory(
|
||||||
messages=[
|
messages=[
|
||||||
@ -85,6 +93,8 @@ async def test_input_messages_async() -> None:
|
|||||||
AIMessage(content="you said: hello"),
|
AIMessage(content="you said: hello"),
|
||||||
HumanMessage(content="good bye"),
|
HumanMessage(content="good bye"),
|
||||||
AIMessage(content="you said: hello\ngood bye"),
|
AIMessage(content="you said: hello\ngood bye"),
|
||||||
|
HumanMessage(content="hi again"),
|
||||||
|
AIMessage(content="you said: hello\ngood bye\nhi again"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -5402,7 +5402,7 @@ def test_listeners() -> None:
|
|||||||
|
|
||||||
shared_state = {}
|
shared_state = {}
|
||||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
|
||||||
|
|
||||||
def on_start(run: Run) -> None:
|
def on_start(run: Run) -> None:
|
||||||
shared_state[run.id] = {"inputs": run.inputs}
|
shared_state[run.id] = {"inputs": run.inputs}
|
||||||
@ -5432,7 +5432,7 @@ async def test_listeners_async() -> None:
|
|||||||
|
|
||||||
shared_state = {}
|
shared_state = {}
|
||||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
|
||||||
|
|
||||||
def on_start(run: Run) -> None:
|
def on_start(run: Run) -> None:
|
||||||
shared_state[run.id] = {"inputs": run.inputs}
|
shared_state[run.id] = {"inputs": run.inputs}
|
||||||
|
Loading…
Reference in New Issue
Block a user