mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core[patch]: Deduplicate of callback handlers in merge_configs (#22478)
This PR adds deduplication of callback handlers in merge_configs. Fix for this issue: https://github.com/langchain-ai/langchain/issues/22227 The issue appears when the code is: 1) running python >=3.11 2) invokes a runnable from within a runnable 3) binds the callbacks to the child runnable from the parent runnable using with_config In this case, the same callbacks end up appearing twice: (1) the first time from with_config, (2) the second time with langchain automatically propagating them on behalf of the user. Prior to this PR this will emit duplicate events: ```python @tool async def get_items(question: str, callbacks: Callbacks): # <--- Accept callbacks """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model.with_config( { "callbacks": callbacks, # <-- Propagate callbacks } ) return await chain.ainvoke({"question": question}) ``` Prior to this PR this will work work correctly (no duplicate events): ```python @tool async def get_items(question: str, callbacks: Callbacks): # <--- Accept callbacks """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model return await chain.ainvoke({"question": question}, {"callbacks": callbacks}) ``` This will also work (as long as the user is using python >= 3.11) -- as langchain will automatically propagate callbacks ```python @tool async def get_items(question: str,): """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model return await chain.ainvoke({"question": question}) ```
This commit is contained in:
parent
64dbc52cae
commit
9120cf5df2
@ -305,12 +305,12 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
base["callbacks"] = base_callbacks.__class__(
|
||||
|
||||
manager = base_callbacks.__class__(
|
||||
parent_run_id=base_callbacks.parent_run_id
|
||||
or these_callbacks.parent_run_id,
|
||||
handlers=base_callbacks.handlers + these_callbacks.handlers,
|
||||
inheritable_handlers=base_callbacks.inheritable_handlers
|
||||
+ these_callbacks.inheritable_handlers,
|
||||
handlers=[],
|
||||
inheritable_handlers=[],
|
||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
||||
inheritable_tags=list(
|
||||
set(
|
||||
@ -323,6 +323,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
**these_callbacks.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
handlers = base_callbacks.handlers + these_callbacks.handlers
|
||||
inheritable_handlers = (
|
||||
base_callbacks.inheritable_handlers
|
||||
+ these_callbacks.inheritable_handlers
|
||||
)
|
||||
|
||||
for handler in handlers:
|
||||
manager.add_handler(handler)
|
||||
|
||||
for handler in inheritable_handlers:
|
||||
manager.add_handler(handler, inherit=True)
|
||||
|
||||
base["callbacks"] = manager
|
||||
else:
|
||||
base[key] = config[key] or base.get(key) # type: ignore
|
||||
return base
|
||||
|
@ -1876,3 +1876,34 @@ async def test_runnable_generator() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_with_explicit_config() -> None:
|
||||
"""Test astream events with explicit callbacks being passed."""
|
||||
infinite_cycle = cycle([AIMessage(content="hello world", id="ai3")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
@tool
|
||||
async def say_hello(query: str, callbacks: Callbacks) -> BaseMessage:
|
||||
"""Use this tool to look up which items are in the given place."""
|
||||
|
||||
@RunnableLambda
|
||||
def passthrough_to_trigger_issue(x: str) -> str:
|
||||
"""Add passthrough to trigger issue."""
|
||||
return x
|
||||
|
||||
chain = passthrough_to_trigger_issue | model.with_config(
|
||||
{"tags": ["hello"], "callbacks": callbacks}
|
||||
)
|
||||
|
||||
return await chain.ainvoke(query)
|
||||
|
||||
events = await _collect_events(
|
||||
say_hello.astream_events("meow", version="v2") # type: ignore
|
||||
)
|
||||
|
||||
assert [
|
||||
event["data"]["chunk"].content
|
||||
for event in events
|
||||
if event["event"] == "on_chat_model_stream"
|
||||
] == ["hello", " ", "world"]
|
||||
|
Loading…
Reference in New Issue
Block a user