mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Shared Executor (#11028)
This commit is contained in:
parent
926e4b6bad
commit
e9b51513e9
@ -2,20 +2,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
import weakref
|
||||
from concurrent.futures import Future, wait
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
import langsmith
|
||||
from langsmith import schemas as langsmith_schemas
|
||||
from langsmith.evaluation.evaluator import EvaluationResult
|
||||
|
||||
from langchain.callbacks import manager
|
||||
from langchain.callbacks.tracers import langchain as langchain_tracer
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.langchain import _get_executor
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
|
||||
|
||||
|
||||
def wait_for_all_evaluators() -> None:
|
||||
"""Wait for all tracers to finish."""
|
||||
global _TRACERS
|
||||
for tracer in list(_TRACERS):
|
||||
if tracer is not None:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||
@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
----------
|
||||
evaluators : Sequence[RunEvaluator]
|
||||
The run evaluators to apply to all top level runs.
|
||||
max_workers : int, optional
|
||||
The maximum number of worker threads to use for running the evaluators.
|
||||
If not specified, it will default to the number of evaluators.
|
||||
client : LangSmith Client, optional
|
||||
The LangSmith client instance to use for evaluating the runs.
|
||||
If not specified, a new instance will be created.
|
||||
@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
evaluators: Sequence[langsmith.RunEvaluator],
|
||||
max_workers: Optional[int] = None,
|
||||
client: Optional[langsmith.Client] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
skip_unfinished: bool = True,
|
||||
@ -72,11 +81,14 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
)
|
||||
self.client = client or langchain_tracer.get_client()
|
||||
self.evaluators = evaluators
|
||||
self.max_workers = max_workers or len(evaluators)
|
||||
self.futures: Set[Future] = set()
|
||||
self.executor = _get_executor()
|
||||
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
|
||||
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
|
||||
global _TRACERS
|
||||
_TRACERS.add(self)
|
||||
|
||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||
"""Evaluate the run in the project.
|
||||
@ -120,15 +132,11 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
return
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
if self.max_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
list(
|
||||
executor.map(
|
||||
self._evaluate_in_project,
|
||||
[run_ for _ in range(len(self.evaluators))],
|
||||
self.evaluators,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for evaluator in self.evaluators:
|
||||
self._evaluate_in_project(run_, evaluator)
|
||||
self.futures.add(
|
||||
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for all futures to complete."""
|
||||
wait(self.futures)
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
@ -21,8 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
_LOGGED = set()
|
||||
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
|
||||
_CLIENT: Optional[Client] = None
|
||||
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
|
||||
_EXECUTORS: List[ThreadPoolExecutor] = []
|
||||
_EXECUTOR: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
|
||||
def log_error_once(method: str, exception: Exception) -> None:
|
||||
@ -50,6 +49,14 @@ def get_client() -> Client:
|
||||
return _CLIENT
|
||||
|
||||
|
||||
def _get_executor() -> ThreadPoolExecutor:
|
||||
"""Get the executor."""
|
||||
global _EXECUTOR
|
||||
if _EXECUTOR is None:
|
||||
_EXECUTOR = ThreadPoolExecutor()
|
||||
return _EXECUTOR
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
@ -71,21 +78,10 @@ class LangChainTracer(BaseTracer):
|
||||
self.project_name = project_name or os.getenv(
|
||||
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
|
||||
)
|
||||
if use_threading:
|
||||
global _MAX_EXECUTORS
|
||||
if len(_EXECUTORS) < _MAX_EXECUTORS:
|
||||
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
|
||||
max_workers=1
|
||||
)
|
||||
_EXECUTORS.append(self.executor)
|
||||
else:
|
||||
self.executor = _EXECUTORS.pop(0)
|
||||
_EXECUTORS.append(self.executor)
|
||||
else:
|
||||
self.executor = None
|
||||
self.client = client or get_client()
|
||||
self._futures: Set[Future] = set()
|
||||
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.tags = tags or []
|
||||
self.executor = _get_executor() if use_threading else None
|
||||
global _TRACERS
|
||||
_TRACERS.add(self)
|
||||
|
||||
@ -229,7 +225,4 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for the given futures to complete."""
|
||||
futures = list(self._futures)
|
||||
wait(futures)
|
||||
for future in futures:
|
||||
self._futures.remove(future)
|
||||
wait(self._futures)
|
||||
|
@ -24,8 +24,11 @@ from langsmith import Client, RunEvaluator
|
||||
from langsmith.schemas import Dataset, DataType, Example
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||
from langchain.callbacks.tracers.evaluation import (
|
||||
EvaluatorCallbackHandler,
|
||||
wait_for_all_evaluators,
|
||||
)
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.loading import load_evaluator
|
||||
from langchain.evaluation.schema import (
|
||||
@ -915,7 +918,6 @@ def _prepare_run_on_dataset(
|
||||
EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
max_workers=0,
|
||||
example_id=example.id,
|
||||
),
|
||||
progress_bar,
|
||||
@ -934,7 +936,7 @@ def _collect_test_results(
|
||||
configs: List[RunnableConfig],
|
||||
project_name: str,
|
||||
) -> TestResult:
|
||||
wait_for_all_tracers()
|
||||
wait_for_all_evaluators()
|
||||
all_eval_results = {}
|
||||
for c in configs:
|
||||
for callback in cast(list, c["callbacks"]):
|
||||
|
Loading…
Reference in New Issue
Block a user