core[patch]: Update AsyncCallbackManager to honor run_inline attribute and prevent context loss (#26885)

## Description

This PR fixes the context loss issue in `AsyncCallbackManager`,
specifically in `on_llm_start` and `on_chat_model_start` methods. It
properly honors the `run_inline` attribute of callback handlers,
preventing race conditions and ordering issues.

Key changes:
1. Separate handlers into inline and non-inline groups.
2. Execute inline handlers sequentially for each prompt.
3. Execute non-inline handlers concurrently across all prompts.
4. Preserve context for stateful handlers.
5. Maintain performance benefits for non-inline handlers.

**These changes are implemented in `AsyncCallbackManager` rather than
`ahandle_event` because the issue occurs at the prompt and message_list
levels, not within individual events.**

## Testing

- Test case implemented in #26857 now passes, verifying execution order
for inline handlers.

## Related Issues

- Fixes issue discussed in #23909 

## Dependencies

No new dependencies are required.

---

@eyurtsev: This PR implements the discussed changes to respect
`run_inline` in `AsyncCallbackManager`. Please review and advise on any
needed changes.

Twitter handle: @parambharat

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Bharat Ramanathan
2024-10-08 00:29:29 +05:30
committed by GitHub
parent c61b9daef5
commit 931ce8d026
2 changed files with 207 additions and 21 deletions

View File

@@ -1729,7 +1729,12 @@ class AsyncCallbackManager(BaseCallbackManager):
to each prompt.
"""
tasks = []
inline_tasks = []
non_inline_tasks = []
inline_handlers = [handler for handler in self.handlers if handler.run_inline]
non_inline_handlers = [
handler for handler in self.handlers if not handler.run_inline
]
managers = []
for prompt in prompts:
@@ -1739,20 +1744,36 @@ class AsyncCallbackManager(BaseCallbackManager):
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
if inline_handlers:
inline_tasks.append(
ahandle_event(
inline_handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
else:
non_inline_tasks.append(
ahandle_event(
non_inline_handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
@@ -1767,7 +1788,13 @@ class AsyncCallbackManager(BaseCallbackManager):
)
)
await asyncio.gather(*tasks)
# Run inline tasks sequentially
for inline_task in inline_tasks:
await inline_task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
@@ -1791,7 +1818,8 @@ class AsyncCallbackManager(BaseCallbackManager):
async callback managers, one for each LLM Run
corresponding to each inner message list.
"""
tasks = []
inline_tasks = []
non_inline_tasks = []
managers = []
for message_list in messages:
@@ -1801,9 +1829,9 @@ class AsyncCallbackManager(BaseCallbackManager):
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
self.handlers,
for handler in self.handlers:
task = ahandle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
@@ -1814,7 +1842,10 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata=self.metadata,
**kwargs,
)
)
if handler.run_inline:
inline_tasks.append(task)
else:
non_inline_tasks.append(task)
managers.append(
AsyncCallbackManagerForLLMRun(
@@ -1829,7 +1860,14 @@ class AsyncCallbackManager(BaseCallbackManager):
)
)
await asyncio.gather(*tasks)
# Run inline tasks sequentially
for task in inline_tasks:
await task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
async def on_chain_start(