mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
core, community: propagate context between threads (#15171)
While using `chain.batch`, the default implementation uses a `ThreadPoolExecutor` and run the chains in separate threads. An issue with this approach is that that [the token counting callback](https://python.langchain.com/docs/modules/callbacks/token_counting) fails to work as a consequence of the context not being propagated between threads. This PR adds context propagation to the new threads and adds some thread synchronization in the OpenAI callback. With this change, the token counting callback works as intended. Having the context propagation change would be highly beneficial for those implementing custom callbacks for similar functionalities as well. --------- Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
parent
f74151b4e4
commit
bf5385592e
@ -1,4 +1,5 @@
|
|||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
|
import threading
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
@ -154,6 +155,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
successful_requests: int = 0
|
successful_requests: int = 0
|
||||||
total_cost: float = 0.0
|
total_cost: float = 0.0
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"Tokens Used: {self.total_tokens}\n"
|
f"Tokens Used: {self.total_tokens}\n"
|
||||||
@ -182,9 +187,13 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
"""Collect token usage."""
|
"""Collect token usage."""
|
||||||
if response.llm_output is None:
|
if response.llm_output is None:
|
||||||
return None
|
return None
|
||||||
self.successful_requests += 1
|
|
||||||
if "token_usage" not in response.llm_output:
|
if "token_usage" not in response.llm_output:
|
||||||
|
with self._lock:
|
||||||
|
self.successful_requests += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# compute tokens and cost for this request
|
||||||
token_usage = response.llm_output["token_usage"]
|
token_usage = response.llm_output["token_usage"]
|
||||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
@ -194,10 +203,17 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
|||||||
model_name, completion_tokens, is_completion=True
|
model_name, completion_tokens, is_completion=True
|
||||||
)
|
)
|
||||||
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
||||||
|
else:
|
||||||
|
completion_cost = 0
|
||||||
|
prompt_cost = 0
|
||||||
|
|
||||||
|
# update shared state behind lock
|
||||||
|
with self._lock:
|
||||||
self.total_cost += prompt_cost + completion_cost
|
self.total_cost += prompt_cost + completion_cost
|
||||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||||
self.prompt_tokens += prompt_tokens
|
self.prompt_tokens += prompt_tokens
|
||||||
self.completion_tokens += completion_tokens
|
self.completion_tokens += completion_tokens
|
||||||
|
self.successful_requests += 1
|
||||||
|
|
||||||
def __copy__(self) -> "OpenAICallbackHandler":
|
def __copy__(self) -> "OpenAICallbackHandler":
|
||||||
"""Return a copy of the callback handler."""
|
"""Return a copy of the callback handler."""
|
||||||
|
@ -6,6 +6,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
from contextvars import Context, copy_context
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -271,12 +272,25 @@ def handle_event(
|
|||||||
# we end up in a deadlock, as we'd have gotten here from a
|
# we end up in a deadlock, as we'd have gotten here from a
|
||||||
# running coroutine, which we cannot interrupt to run this one.
|
# running coroutine, which we cannot interrupt to run this one.
|
||||||
# The solution is to create a new loop in a new thread.
|
# The solution is to create a new loop in a new thread.
|
||||||
with ThreadPoolExecutor(1) as executor:
|
with _executor_w_context(1) as executor:
|
||||||
executor.submit(_run_coros, coros).result()
|
executor.submit(_run_coros, coros).result()
|
||||||
else:
|
else:
|
||||||
_run_coros(coros)
|
_run_coros(coros)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_context(context: Context) -> None:
|
||||||
|
for var, value in context.items():
|
||||||
|
var.set(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
||||||
|
return ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers,
|
||||||
|
initializer=_set_context,
|
||||||
|
initargs=(copy_context(),),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||||
if hasattr(asyncio, "Runner"):
|
if hasattr(asyncio, "Runner"):
|
||||||
# Python 3.11+
|
# Python 3.11+
|
||||||
@ -301,6 +315,7 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _ahandle_event_for_handler(
|
async def _ahandle_event_for_handler(
|
||||||
|
executor: ThreadPoolExecutor,
|
||||||
handler: BaseCallbackHandler,
|
handler: BaseCallbackHandler,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
ignore_condition_name: Optional[str],
|
ignore_condition_name: Optional[str],
|
||||||
@ -317,12 +332,13 @@ async def _ahandle_event_for_handler(
|
|||||||
event(*args, **kwargs)
|
event(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
None, functools.partial(event, *args, **kwargs)
|
executor, functools.partial(event, *args, **kwargs)
|
||||||
)
|
)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
if event_name == "on_chat_model_start":
|
if event_name == "on_chat_model_start":
|
||||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||||
await _ahandle_event_for_handler(
|
await _ahandle_event_for_handler(
|
||||||
|
executor,
|
||||||
handler,
|
handler,
|
||||||
"on_llm_start",
|
"on_llm_start",
|
||||||
"ignore_llm",
|
"ignore_llm",
|
||||||
@ -364,19 +380,25 @@ async def ahandle_event(
|
|||||||
*args: The arguments to pass to the event handler
|
*args: The arguments to pass to the event handler
|
||||||
**kwargs: The keyword arguments to pass to the event handler
|
**kwargs: The keyword arguments to pass to the event handler
|
||||||
"""
|
"""
|
||||||
for handler in [h for h in handlers if h.run_inline]:
|
with _executor_w_context() as executor:
|
||||||
await _ahandle_event_for_handler(
|
for handler in [h for h in handlers if h.run_inline]:
|
||||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
await _ahandle_event_for_handler(
|
||||||
)
|
executor, handler, event_name, ignore_condition_name, *args, **kwargs
|
||||||
await asyncio.gather(
|
)
|
||||||
*(
|
await asyncio.gather(
|
||||||
_ahandle_event_for_handler(
|
*(
|
||||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
_ahandle_event_for_handler(
|
||||||
|
executor,
|
||||||
|
handler,
|
||||||
|
event_name,
|
||||||
|
ignore_condition_name,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for handler in handlers
|
||||||
|
if not handler.run_inline
|
||||||
)
|
)
|
||||||
for handler in handlers
|
|
||||||
if not handler.run_inline
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||||
|
@ -260,7 +260,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if type(self)._astream == BaseChatModel._astream:
|
if type(self)._astream == BaseChatModel._astream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield cast(
|
yield cast(
|
||||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
BaseMessageChunk,
|
||||||
|
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
@ -472,9 +472,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
Subclasses should override this method if they can run asynchronously.
|
Subclasses should override this method if they can run asynchronously.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
with get_executor_for_config(config) as executor:
|
||||||
None, partial(self.invoke, **kwargs), input, config
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
)
|
executor, partial(self.invoke, **kwargs), input, config
|
||||||
|
)
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
@ -2882,9 +2883,10 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
@wraps(self.func)
|
@wraps(self.func)
|
||||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
with get_executor_for_config(config) as executor:
|
||||||
None, partial(self.func, **kwargs), *args
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
)
|
executor, partial(self.func, **kwargs), *args
|
||||||
|
)
|
||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from contextvars import Context, copy_context
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -387,8 +388,15 @@ def get_async_callback_manager_for_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_context(context: Context) -> None:
|
||||||
|
for var, value in context.items():
|
||||||
|
var.set(value)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
def get_executor_for_config(
|
||||||
|
config: Optional[RunnableConfig]
|
||||||
|
) -> Generator[Executor, None, None]:
|
||||||
"""Get an executor for a config.
|
"""Get an executor for a config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -397,5 +405,10 @@ def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None,
|
|||||||
Yields:
|
Yields:
|
||||||
Generator[Executor, None, None]: The executor.
|
Generator[Executor, None, None]: The executor.
|
||||||
"""
|
"""
|
||||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
config = config or {}
|
||||||
|
with ThreadPoolExecutor(
|
||||||
|
max_workers=config.get("max_concurrency"),
|
||||||
|
initializer=_set_context,
|
||||||
|
initargs=(copy_context(),),
|
||||||
|
) as executor:
|
||||||
yield executor
|
yield executor
|
||||||
|
Loading…
Reference in New Issue
Block a user