mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
Wait for all futures (#6554)
- Expose method to wait for all futures - Wait for submissions in the run_on_dataset functions to ensure runs are fully submitted before cleaning up
This commit is contained in:
parent
e0605b464b
commit
5322bac5fc
@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchainplus_sdk import LangChainPlusClient
|
from langchainplus_sdk import LangChainPlusClient
|
||||||
@ -21,6 +21,7 @@ from langchain.schema import BaseMessage, messages_to_dict
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_LOGGED = set()
|
_LOGGED = set()
|
||||||
|
_TRACERS: List[LangChainTracer] = []
|
||||||
|
|
||||||
|
|
||||||
def log_error_once(method: str, exception: Exception) -> None:
|
def log_error_once(method: str, exception: Exception) -> None:
|
||||||
@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None:
|
|||||||
logger.error(exception)
|
logger.error(exception)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_all_tracers() -> None:
|
||||||
|
global _TRACERS
|
||||||
|
for tracer in _TRACERS:
|
||||||
|
tracer.wait_for_futures()
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
@ -52,6 +59,9 @@ class LangChainTracer(BaseTracer):
|
|||||||
# set max_workers to 1 to process tasks in order
|
# set max_workers to 1 to process tasks in order
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
self.client = client or LangChainPlusClient()
|
self.client = client or LangChainPlusClient()
|
||||||
|
self._futures: Set[Future] = set()
|
||||||
|
global _TRACERS
|
||||||
|
_TRACERS.append(self)
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
@ -93,7 +103,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
extra["runtime"] = get_runtime_environment()
|
extra["runtime"] = get_runtime_environment()
|
||||||
run_dict["extra"] = extra
|
run_dict["extra"] = extra
|
||||||
try:
|
try:
|
||||||
run = self.client.create_run(**run_dict, session_name=self.session_name)
|
self.client.create_run(**run_dict, session_name=self.session_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Errors are swallowed by the thread executor so we need to log them here
|
# Errors are swallowed by the thread executor so we need to log them here
|
||||||
log_error_once("post", e)
|
log_error_once("post", e)
|
||||||
@ -110,40 +120,67 @@ class LangChainTracer(BaseTracer):
|
|||||||
|
|
||||||
def _on_llm_start(self, run: Run) -> None:
|
def _on_llm_start(self, run: Run) -> None:
|
||||||
"""Persist an LLM run."""
|
"""Persist an LLM run."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_chat_model_start(self, run: Run) -> None:
|
def _on_chat_model_start(self, run: Run) -> None:
|
||||||
"""Persist an LLM run."""
|
"""Persist an LLM run."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_llm_end(self, run: Run) -> None:
|
def _on_llm_end(self, run: Run) -> None:
|
||||||
"""Process the LLM Run."""
|
"""Process the LLM Run."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_llm_error(self, run: Run) -> None:
|
def _on_llm_error(self, run: Run) -> None:
|
||||||
"""Process the LLM Run upon error."""
|
"""Process the LLM Run upon error."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_chain_start(self, run: Run) -> None:
|
def _on_chain_start(self, run: Run) -> None:
|
||||||
"""Process the Chain Run upon start."""
|
"""Process the Chain Run upon start."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_chain_end(self, run: Run) -> None:
|
def _on_chain_end(self, run: Run) -> None:
|
||||||
"""Process the Chain Run."""
|
"""Process the Chain Run."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_chain_error(self, run: Run) -> None:
|
def _on_chain_error(self, run: Run) -> None:
|
||||||
"""Process the Chain Run upon error."""
|
"""Process the Chain Run upon error."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_tool_start(self, run: Run) -> None:
|
def _on_tool_start(self, run: Run) -> None:
|
||||||
"""Process the Tool Run upon start."""
|
"""Process the Tool Run upon start."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_tool_end(self, run: Run) -> None:
|
def _on_tool_end(self, run: Run) -> None:
|
||||||
"""Process the Tool Run."""
|
"""Process the Tool Run."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
def _on_tool_error(self, run: Run) -> None:
|
def _on_tool_error(self, run: Run) -> None:
|
||||||
"""Process the Tool Run upon error."""
|
"""Process the Tool Run upon error."""
|
||||||
|
self._futures.add(
|
||||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for_futures(self) -> None:
|
||||||
|
"""Wait for the given futures to complete."""
|
||||||
|
futures = list(self._futures)
|
||||||
|
wait(futures)
|
||||||
|
for future in futures:
|
||||||
|
self._futures.remove(future)
|
||||||
|
@ -224,9 +224,17 @@ async def _gather_with_concurrency(
|
|||||||
tracer_queue.put_nowait(tracer)
|
tracer_queue.put_nowait(tracer)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||||
)
|
)
|
||||||
|
while tracer_queue:
|
||||||
|
try:
|
||||||
|
tracer = tracer_queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
if tracer:
|
||||||
|
tracer.wait_for_futures()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
|
async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
|
||||||
@ -412,6 +420,8 @@ def run_on_examples(
|
|||||||
if verbose:
|
if verbose:
|
||||||
print(f"{i+1} processed", flush=True, end="\r")
|
print(f"{i+1} processed", flush=True, end="\r")
|
||||||
results[str(example.id)] = result
|
results[str(example.id)] = result
|
||||||
|
if tracer:
|
||||||
|
tracer.wait_for_futures()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user