diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py index bf0c59b746e..58cb6aab3b8 100644 --- a/libs/community/langchain_community/callbacks/openai_info.py +++ b/libs/community/langchain_community/callbacks/openai_info.py @@ -1,4 +1,5 @@ """Callback Handler that prints to std out.""" +import threading from typing import Any, Dict, List from langchain_core.callbacks import BaseCallbackHandler @@ -154,6 +155,10 @@ class OpenAICallbackHandler(BaseCallbackHandler): successful_requests: int = 0 total_cost: float = 0.0 + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + def __repr__(self) -> str: return ( f"Tokens Used: {self.total_tokens}\n" @@ -182,9 +187,13 @@ class OpenAICallbackHandler(BaseCallbackHandler): """Collect token usage.""" if response.llm_output is None: return None - self.successful_requests += 1 + if "token_usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 return None + + # compute tokens and cost for this request token_usage = response.llm_output["token_usage"] completion_tokens = token_usage.get("completion_tokens", 0) prompt_tokens = token_usage.get("prompt_tokens", 0) @@ -194,10 +203,17 @@ class OpenAICallbackHandler(BaseCallbackHandler): model_name, completion_tokens, is_completion=True ) prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens) + else: + completion_cost = 0 + prompt_cost = 0 + + # update shared state behind lock + with self._lock: self.total_cost += prompt_cost + completion_cost - self.total_tokens += token_usage.get("total_tokens", 0) - self.prompt_tokens += prompt_tokens - self.completion_tokens += completion_tokens + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + self.successful_requests += 1 def __copy__(self) -> "OpenAICallbackHandler": """Return a copy of the callback handler.""" diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index b1bb0119279..8a799a3f76d 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -6,6 +6,7 @@ import logging import uuid from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager +from contextvars import Context, copy_context from typing import ( TYPE_CHECKING, Any, @@ -271,12 +272,25 @@ def handle_event( # we end up in a deadlock, as we'd have gotten here from a # running coroutine, which we cannot interrupt to run this one. # The solution is to create a new loop in a new thread. - with ThreadPoolExecutor(1) as executor: + with _executor_w_context(1) as executor: executor.submit(_run_coros, coros).result() else: _run_coros(coros) +def _set_context(context: Context) -> None: + for var, value in context.items(): + var.set(value) + + +def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor: + return ThreadPoolExecutor( + max_workers=max_workers, + initializer=_set_context, + initargs=(copy_context(),), + ) + + def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: if hasattr(asyncio, "Runner"): # Python 3.11+ @@ -301,6 +315,7 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: async def _ahandle_event_for_handler( + executor: ThreadPoolExecutor, handler: BaseCallbackHandler, event_name: str, ignore_condition_name: Optional[str], @@ -317,12 +332,13 @@ async def _ahandle_event_for_handler( event(*args, **kwargs) else: await asyncio.get_event_loop().run_in_executor( - None, functools.partial(event, *args, **kwargs) + executor, functools.partial(event, *args, **kwargs) ) except NotImplementedError as e: if event_name == "on_chat_model_start": message_strings = [get_buffer_string(m) for m in args[1]] await _ahandle_event_for_handler( + executor, handler, "on_llm_start", "ignore_llm", @@ -364,19 +380,25 @@ async def ahandle_event( *args: The arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler """ - for handler in [h for h in handlers if h.run_inline]: - await _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs - ) - await asyncio.gather( - *( - _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, *args, **kwargs + with _executor_w_context() as executor: + for handler in [h for h in handlers if h.run_inline]: + await _ahandle_event_for_handler( + executor, handler, event_name, ignore_condition_name, *args, **kwargs + ) + await asyncio.gather( + *( + _ahandle_event_for_handler( + executor, + handler, + event_name, + ignore_condition_name, + *args, + **kwargs, + ) + for handler in handlers + if not handler.run_inline ) - for handler in handlers - if not handler.run_inline ) - ) BRM = TypeVar("BRM", bound="BaseRunManager") diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index b913782b897..dba21ba71c1 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -260,7 +260,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if type(self)._astream == BaseChatModel._astream: # model doesn't implement streaming, so use default implementation yield cast( - BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) + BaseMessageChunk, + await self.ainvoke(input, config=config, stop=stop, **kwargs), ) else: config = config or {} diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index d6cf846413e..b42a17ec4ec 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -472,9 +472,10 @@ class Runnable(Generic[Input, Output], ABC): Subclasses should override this method if they can run asynchronously. """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, **kwargs), input, config - ) + with get_executor_for_config(config) as executor: + return await asyncio.get_running_loop().run_in_executor( + executor, partial(self.invoke, **kwargs), input, config + ) def batch( self, @@ -2882,9 +2883,10 @@ class RunnableLambda(Runnable[Input, Output]): @wraps(self.func) async def f(*args, **kwargs): # type: ignore[no-untyped-def] - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.func, **kwargs), *args - ) + with get_executor_for_config(config) as executor: + return await asyncio.get_running_loop().run_in_executor( + executor, partial(self.func, **kwargs), *args + ) afunc = f diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index e68b7080f1d..5672a60fa2d 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -2,6 +2,7 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager +from contextvars import Context, copy_context from typing import ( TYPE_CHECKING, Any, @@ -387,8 +388,15 @@ def get_async_callback_manager_for_config( ) +def _set_context(context: Context) -> None: + for var, value in context.items(): + var.set(value) + + @contextmanager -def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]: +def get_executor_for_config( + config: Optional[RunnableConfig] +) -> Generator[Executor, None, None]: """Get an executor for a config. Args: @@ -397,5 +405,10 @@ def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, Yields: Generator[Executor, None, None]: The executor. """ - with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: + config = config or {} + with ThreadPoolExecutor( + max_workers=config.get("max_concurrency"), + initializer=_set_context, + initargs=(copy_context(),), + ) as executor: yield executor