mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Wfh/eval max concurrency (#11368)
This commit is contained in:
parent
1165767df2
commit
06f39be1c2
@ -2,13 +2,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
from concurrent.futures import Future, wait
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import langsmith
|
import langsmith
|
||||||
from langsmith import schemas as langsmith_schemas
|
|
||||||
from langsmith.evaluation.evaluator import EvaluationResult
|
from langsmith.evaluation.evaluator import EvaluationResult
|
||||||
|
|
||||||
from langchain.callbacks import manager
|
from langchain.callbacks import manager
|
||||||
@ -73,6 +73,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
example_id: Optional[Union[UUID, str]] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
skip_unfinished: bool = True,
|
skip_unfinished: bool = True,
|
||||||
project_name: Optional[str] = "evaluators",
|
project_name: Optional[str] = "evaluators",
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -81,12 +82,21 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
)
|
)
|
||||||
self.client = client or langchain_tracer.get_client()
|
self.client = client or langchain_tracer.get_client()
|
||||||
self.evaluators = evaluators
|
self.evaluators = evaluators
|
||||||
self.executor = _get_executor()
|
if max_concurrency is None:
|
||||||
|
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
|
||||||
|
elif max_concurrency > 0:
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
|
||||||
|
weakref.finalize(
|
||||||
|
self,
|
||||||
|
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.executor = None
|
||||||
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||||
self.skip_unfinished = skip_unfinished
|
self.skip_unfinished = skip_unfinished
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
|
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
|
||||||
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
|
self.lock = threading.Lock()
|
||||||
global _TRACERS
|
global _TRACERS
|
||||||
_TRACERS.add(self)
|
_TRACERS.add(self)
|
||||||
|
|
||||||
@ -111,12 +121,15 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error evaluating run {run.id} with "
|
f"Error evaluating run {run.id} with "
|
||||||
f"{evaluator.__class__.__name__}: {e}",
|
f"{evaluator.__class__.__name__}: {repr(e)}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
example_id = str(run.reference_example_id)
|
example_id = str(run.reference_example_id)
|
||||||
self.logged_eval_results.setdefault(example_id, []).append(eval_result)
|
with self.lock:
|
||||||
|
self.logged_eval_results.setdefault((str(run.id), example_id), []).append(
|
||||||
|
eval_result
|
||||||
|
)
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Run the evaluator on the run.
|
"""Run the evaluator on the run.
|
||||||
@ -133,9 +146,12 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
run_ = run.copy()
|
run_ = run.copy()
|
||||||
run_.reference_example_id = self.example_id
|
run_.reference_example_id = self.example_id
|
||||||
for evaluator in self.evaluators:
|
for evaluator in self.evaluators:
|
||||||
self.futures.add(
|
if self.executor is None:
|
||||||
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
self._evaluate_in_project(run_, evaluator)
|
||||||
)
|
else:
|
||||||
|
self.futures.add(
|
||||||
|
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
||||||
|
)
|
||||||
|
|
||||||
def wait_for_futures(self) -> None:
|
def wait_for_futures(self) -> None:
|
||||||
"""Wait for all futures to complete."""
|
"""Wait for all futures to complete."""
|
||||||
|
@ -948,7 +948,10 @@ def _collect_test_results(
|
|||||||
for c in configs:
|
for c in configs:
|
||||||
for callback in cast(list, c["callbacks"]):
|
for callback in cast(list, c["callbacks"]):
|
||||||
if isinstance(callback, EvaluatorCallbackHandler):
|
if isinstance(callback, EvaluatorCallbackHandler):
|
||||||
all_eval_results.update(callback.logged_eval_results)
|
eval_results = callback.logged_eval_results
|
||||||
|
all_eval_results.update(
|
||||||
|
{example_id: v for (_, example_id), v in eval_results.items()}
|
||||||
|
)
|
||||||
results = {}
|
results = {}
|
||||||
for example, output in zip(examples, batch_results):
|
for example, output in zip(examples, batch_results):
|
||||||
feedback = all_eval_results.get(str(example.id), [])
|
feedback = all_eval_results.get(str(example.id), [])
|
||||||
|
Loading…
Reference in New Issue
Block a user