core, community: propagate context between threads (#15171)

While using `chain.batch`, the default implementation uses a
`ThreadPoolExecutor` and run the chains in separate threads. An issue
with this approach is that that [the token counting
callback](https://python.langchain.com/docs/modules/callbacks/token_counting)
fails to work as a consequence of the context not being propagated
between threads. This PR adds context propagation to the new threads and
adds some thread synchronization in the OpenAI callback. With this
change, the token counting callback works as intended.

Having the context propagation change would be highly beneficial for
those implementing custom callbacks for similar functionalities as well.

---------

Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
joshy-deshaw
2023-12-29 04:21:22 +05:30
committed by GitHub
parent f74151b4e4
commit bf5385592e
5 changed files with 80 additions and 26 deletions

View File

@@ -1,4 +1,5 @@
"""Callback Handler that prints to std out."""
import threading
from typing import Any, Dict, List
from langchain_core.callbacks import BaseCallbackHandler
@@ -154,6 +155,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
successful_requests: int = 0
total_cost: float = 0.0
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
@@ -182,9 +187,13 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1
if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None
# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
@@ -194,10 +203,17 @@ class OpenAICallbackHandler(BaseCallbackHandler):
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
else:
completion_cost = 0
prompt_cost = 0
# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""