mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Shared Executor (#11028)
This commit is contained in:
parent
926e4b6bad
commit
e9b51513e9
@ -2,20 +2,33 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
import weakref
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
from concurrent.futures import Future, wait
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import langsmith
|
import langsmith
|
||||||
|
from langsmith import schemas as langsmith_schemas
|
||||||
from langsmith.evaluation.evaluator import EvaluationResult
|
from langsmith.evaluation.evaluator import EvaluationResult
|
||||||
|
|
||||||
from langchain.callbacks import manager
|
from langchain.callbacks import manager
|
||||||
from langchain.callbacks.tracers import langchain as langchain_tracer
|
from langchain.callbacks.tracers import langchain as langchain_tracer
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.langchain import _get_executor
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_all_evaluators() -> None:
|
||||||
|
"""Wait for all tracers to finish."""
|
||||||
|
global _TRACERS
|
||||||
|
for tracer in list(_TRACERS):
|
||||||
|
if tracer is not None:
|
||||||
|
tracer.wait_for_futures()
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorCallbackHandler(BaseTracer):
|
class EvaluatorCallbackHandler(BaseTracer):
|
||||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||||
@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
----------
|
----------
|
||||||
evaluators : Sequence[RunEvaluator]
|
evaluators : Sequence[RunEvaluator]
|
||||||
The run evaluators to apply to all top level runs.
|
The run evaluators to apply to all top level runs.
|
||||||
max_workers : int, optional
|
|
||||||
The maximum number of worker threads to use for running the evaluators.
|
|
||||||
If not specified, it will default to the number of evaluators.
|
|
||||||
client : LangSmith Client, optional
|
client : LangSmith Client, optional
|
||||||
The LangSmith client instance to use for evaluating the runs.
|
The LangSmith client instance to use for evaluating the runs.
|
||||||
If not specified, a new instance will be created.
|
If not specified, a new instance will be created.
|
||||||
@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
evaluators: Sequence[langsmith.RunEvaluator],
|
evaluators: Sequence[langsmith.RunEvaluator],
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
client: Optional[langsmith.Client] = None,
|
client: Optional[langsmith.Client] = None,
|
||||||
example_id: Optional[Union[UUID, str]] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
skip_unfinished: bool = True,
|
skip_unfinished: bool = True,
|
||||||
@ -72,11 +81,14 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
)
|
)
|
||||||
self.client = client or langchain_tracer.get_client()
|
self.client = client or langchain_tracer.get_client()
|
||||||
self.evaluators = evaluators
|
self.evaluators = evaluators
|
||||||
self.max_workers = max_workers or len(evaluators)
|
self.executor = _get_executor()
|
||||||
self.futures: Set[Future] = set()
|
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||||
self.skip_unfinished = skip_unfinished
|
self.skip_unfinished = skip_unfinished
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
|
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
|
||||||
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
|
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
|
||||||
|
global _TRACERS
|
||||||
|
_TRACERS.add(self)
|
||||||
|
|
||||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||||
"""Evaluate the run in the project.
|
"""Evaluate the run in the project.
|
||||||
@ -120,15 +132,11 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
return
|
return
|
||||||
run_ = run.copy()
|
run_ = run.copy()
|
||||||
run_.reference_example_id = self.example_id
|
run_.reference_example_id = self.example_id
|
||||||
if self.max_workers > 0:
|
for evaluator in self.evaluators:
|
||||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
self.futures.add(
|
||||||
list(
|
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
||||||
executor.map(
|
)
|
||||||
self._evaluate_in_project,
|
|
||||||
[run_ for _ in range(len(self.evaluators))],
|
def wait_for_futures(self) -> None:
|
||||||
self.evaluators,
|
"""Wait for all futures to complete."""
|
||||||
)
|
wait(self.futures)
|
||||||
)
|
|
||||||
else:
|
|
||||||
for evaluator in self.evaluators:
|
|
||||||
self._evaluate_in_project(run_, evaluator)
|
|
||||||
|
@ -6,7 +6,7 @@ import os
|
|||||||
import weakref
|
import weakref
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
@ -21,8 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
_LOGGED = set()
|
_LOGGED = set()
|
||||||
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
|
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
|
||||||
_CLIENT: Optional[Client] = None
|
_CLIENT: Optional[Client] = None
|
||||||
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
|
_EXECUTOR: Optional[ThreadPoolExecutor] = None
|
||||||
_EXECUTORS: List[ThreadPoolExecutor] = []
|
|
||||||
|
|
||||||
|
|
||||||
def log_error_once(method: str, exception: Exception) -> None:
|
def log_error_once(method: str, exception: Exception) -> None:
|
||||||
@ -50,6 +49,14 @@ def get_client() -> Client:
|
|||||||
return _CLIENT
|
return _CLIENT
|
||||||
|
|
||||||
|
|
||||||
|
def _get_executor() -> ThreadPoolExecutor:
|
||||||
|
"""Get the executor."""
|
||||||
|
global _EXECUTOR
|
||||||
|
if _EXECUTOR is None:
|
||||||
|
_EXECUTOR = ThreadPoolExecutor()
|
||||||
|
return _EXECUTOR
|
||||||
|
|
||||||
|
|
||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
@ -71,21 +78,10 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.project_name = project_name or os.getenv(
|
self.project_name = project_name or os.getenv(
|
||||||
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
|
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
|
||||||
)
|
)
|
||||||
if use_threading:
|
|
||||||
global _MAX_EXECUTORS
|
|
||||||
if len(_EXECUTORS) < _MAX_EXECUTORS:
|
|
||||||
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
|
|
||||||
max_workers=1
|
|
||||||
)
|
|
||||||
_EXECUTORS.append(self.executor)
|
|
||||||
else:
|
|
||||||
self.executor = _EXECUTORS.pop(0)
|
|
||||||
_EXECUTORS.append(self.executor)
|
|
||||||
else:
|
|
||||||
self.executor = None
|
|
||||||
self.client = client or get_client()
|
self.client = client or get_client()
|
||||||
self._futures: Set[Future] = set()
|
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||||
self.tags = tags or []
|
self.tags = tags or []
|
||||||
|
self.executor = _get_executor() if use_threading else None
|
||||||
global _TRACERS
|
global _TRACERS
|
||||||
_TRACERS.add(self)
|
_TRACERS.add(self)
|
||||||
|
|
||||||
@ -229,7 +225,4 @@ class LangChainTracer(BaseTracer):
|
|||||||
|
|
||||||
def wait_for_futures(self) -> None:
|
def wait_for_futures(self) -> None:
|
||||||
"""Wait for the given futures to complete."""
|
"""Wait for the given futures to complete."""
|
||||||
futures = list(self._futures)
|
wait(self._futures)
|
||||||
wait(futures)
|
|
||||||
for future in futures:
|
|
||||||
self._futures.remove(future)
|
|
||||||
|
@ -24,8 +24,11 @@ from langsmith import Client, RunEvaluator
|
|||||||
from langsmith.schemas import Dataset, DataType, Example
|
from langsmith.schemas import Dataset, DataType, Example
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
from langchain.callbacks.tracers.evaluation import (
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
EvaluatorCallbackHandler,
|
||||||
|
wait_for_all_evaluators,
|
||||||
|
)
|
||||||
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.evaluation.loading import load_evaluator
|
from langchain.evaluation.loading import load_evaluator
|
||||||
from langchain.evaluation.schema import (
|
from langchain.evaluation.schema import (
|
||||||
@ -915,7 +918,6 @@ def _prepare_run_on_dataset(
|
|||||||
EvaluatorCallbackHandler(
|
EvaluatorCallbackHandler(
|
||||||
evaluators=run_evaluators or [],
|
evaluators=run_evaluators or [],
|
||||||
client=client,
|
client=client,
|
||||||
max_workers=0,
|
|
||||||
example_id=example.id,
|
example_id=example.id,
|
||||||
),
|
),
|
||||||
progress_bar,
|
progress_bar,
|
||||||
@ -934,7 +936,7 @@ def _collect_test_results(
|
|||||||
configs: List[RunnableConfig],
|
configs: List[RunnableConfig],
|
||||||
project_name: str,
|
project_name: str,
|
||||||
) -> TestResult:
|
) -> TestResult:
|
||||||
wait_for_all_tracers()
|
wait_for_all_evaluators()
|
||||||
all_eval_results = {}
|
all_eval_results = {}
|
||||||
for c in configs:
|
for c in configs:
|
||||||
for callback in cast(list, c["callbacks"]):
|
for callback in cast(list, c["callbacks"]):
|
||||||
|
Loading…
Reference in New Issue
Block a user