Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
d37842232b Add Tqdm to eval wait 2023-11-22 05:46:44 -08:00
2 changed files with 21 additions and 7 deletions

View File

@@ -22,12 +22,12 @@ logger = logging.getLogger(__name__)
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
def wait_for_all_evaluators() -> None:
def wait_for_all_evaluators(verbose: bool = False) -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
tracer.wait_for_futures(verbose=verbose)
class EvaluatorCallbackHandler(BaseTracer):
@@ -218,6 +218,19 @@ class EvaluatorCallbackHandler(BaseTracer):
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None:
def wait_for_futures(self, verbose: bool = False) -> None:
"""Wait for all futures to complete."""
wait(self.futures)
if verbose:
logger.info(
f"Waiting for {len(self.futures)} evaluators to finish."
)
try:
from tqdm.auto import tqdm
wait(tqdm(self.futures))
except ImportError:
wait(self.futures)
else:
wait(self.futures)

View File

@@ -995,8 +995,9 @@ def _collect_test_results(
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
configs: List[RunnableConfig],
project_name: str,
verbose: bool = False,
) -> TestResult:
wait_for_all_evaluators()
wait_for_all_evaluators(verbose=verbose)
all_eval_results = {}
all_execution_time = {}
for c in configs:
@@ -1101,7 +1102,7 @@ async def arun_on_dataset(
configs,
),
)
results = _collect_test_results(examples, batch_results, configs, project_name)
results = _collect_test_results(examples, batch_results, configs, project_name, verbose=verbose)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
@@ -1173,7 +1174,7 @@ def run_on_dataset(
)
)
results = _collect_test_results(examples, batch_results, configs, project_name)
results = _collect_test_results(examples, batch_results, configs, project_name, verbose=verbose)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()