mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Wait for all futures (#6554)
- Expose method to wait for all futures - Wait for submissions in the run_on_dataset functions to ensure runs are fully submitted before cleaning up
This commit is contained in:
parent
e0605b464b
commit
5322bac5fc
@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchainplus_sdk import LangChainPlusClient
|
||||
@ -21,6 +21,7 @@ from langchain.schema import BaseMessage, messages_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_LOGGED = set()
|
||||
_TRACERS: List[LangChainTracer] = []
|
||||
|
||||
|
||||
def log_error_once(method: str, exception: Exception) -> None:
|
||||
@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None:
|
||||
logger.error(exception)
|
||||
|
||||
|
||||
def wait_for_all_tracers() -> None:
|
||||
global _TRACERS
|
||||
for tracer in _TRACERS:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
@ -52,6 +59,9 @@ class LangChainTracer(BaseTracer):
|
||||
# set max_workers to 1 to process tasks in order
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.client = client or LangChainPlusClient()
|
||||
self._futures: Set[Future] = set()
|
||||
global _TRACERS
|
||||
_TRACERS.append(self)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
@ -93,7 +103,7 @@ class LangChainTracer(BaseTracer):
|
||||
extra["runtime"] = get_runtime_environment()
|
||||
run_dict["extra"] = extra
|
||||
try:
|
||||
run = self.client.create_run(**run_dict, session_name=self.session_name)
|
||||
self.client.create_run(**run_dict, session_name=self.session_name)
|
||||
except Exception as e:
|
||||
# Errors are swallowed by the thread executor so we need to log them here
|
||||
log_error_once("post", e)
|
||||
@ -110,40 +120,67 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
"""Process the Chain Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
"""Process the Tool Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
self._futures.add(
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for the given futures to complete."""
|
||||
futures = list(self._futures)
|
||||
wait(futures)
|
||||
for future in futures:
|
||||
self._futures.remove(future)
|
||||
|
@ -224,9 +224,17 @@ async def _gather_with_concurrency(
|
||||
tracer_queue.put_nowait(tracer)
|
||||
return result
|
||||
|
||||
return await asyncio.gather(
|
||||
results = await asyncio.gather(
|
||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||
)
|
||||
while tracer_queue:
|
||||
try:
|
||||
tracer = tracer_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if tracer:
|
||||
tracer.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
|
||||
@ -411,7 +419,9 @@ def run_on_examples(
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
results[str(example.id)] = result
|
||||
if tracer:
|
||||
tracer.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user