mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
Add concurrency support for run_on_dataset (#8841)
Long-term, would be better to use the lower-level batch() method(s) but it may take me a bit longer to clean up. This unblocks in the meantime, though it may fail when the evaluated chain raises a `NotImplementedError` for a corresponding async method
This commit is contained in:
parent
fc2f450f2d
commit
91be7eee66
@ -1278,6 +1278,27 @@ async def arun_on_dataset(
|
||||
}
|
||||
|
||||
|
||||
def _handle_coroutine(coro: Coroutine) -> Any:
|
||||
"""
|
||||
Handles a coroutine from a sync context.
|
||||
|
||||
Args:
|
||||
coro (asyncio.coroutine): The coroutine to be handled.
|
||||
|
||||
Returns:
|
||||
any: The result of the executed coroutine.
|
||||
"""
|
||||
# Check if there's a running event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError: # No event loop
|
||||
return asyncio.run(coro)
|
||||
if loop.is_running():
|
||||
return loop.create_task(coro)
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def run_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
@ -1285,6 +1306,7 @@ def run_on_dataset(
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
num_repetitions: int = 1,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
@ -1303,6 +1325,7 @@ def run_on_dataset(
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Configuration for evaluators to run on the
|
||||
results of the chain
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
@ -1403,18 +1426,35 @@ def run_on_dataset(
|
||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
results = _run_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
if concurrency_level in (0, 1):
|
||||
results = _run_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
else:
|
||||
# TODO: Use runnables and the batch method
|
||||
coro = _arun_on_examples(
|
||||
client,
|
||||
examples,
|
||||
llm_or_chain_factory,
|
||||
concurrency_level=concurrency_level,
|
||||
num_repetitions=num_repetitions,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
results = _handle_coroutine(coro)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
"results": results,
|
||||
|
Loading…
Reference in New Issue
Block a user