mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
core, community: propagate context between threads (#15171)
While using `chain.batch`, the default implementation uses a `ThreadPoolExecutor` and run the chains in separate threads. An issue with this approach is that that [the token counting callback](https://python.langchain.com/docs/modules/callbacks/token_counting) fails to work as a consequence of the context not being propagated between threads. This PR adds context propagation to the new threads and adds some thread synchronization in the OpenAI callback. With this change, the token counting callback works as intended. Having the context propagation change would be highly beneficial for those implementing custom callbacks for similar functionalities as well. --------- Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -154,6 +155,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
@@ -182,9 +187,13 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
self.successful_requests += 1
|
||||
|
||||
if "token_usage" not in response.llm_output:
|
||||
with self._lock:
|
||||
self.successful_requests += 1
|
||||
return None
|
||||
|
||||
# compute tokens and cost for this request
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
@@ -194,10 +203,17 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
model_name, completion_tokens, is_completion=True
|
||||
)
|
||||
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
||||
else:
|
||||
completion_cost = 0
|
||||
prompt_cost = 0
|
||||
|
||||
# update shared state behind lock
|
||||
with self._lock:
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
|
Reference in New Issue
Block a user