Wfh/rm num repetitions (#9425)

Makes it hard to do test run comparison views and we'd probably want to
just run multiple runs right now
This commit is contained in:
William FH 2023-08-18 10:08:39 -07:00 committed by GitHub
parent eee0d1d0dd
commit c29fbede59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 89 deletions

View File

@ -8,6 +8,7 @@ import inspect
import itertools import itertools
import logging import logging
import uuid import uuid
import warnings
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
@ -662,7 +663,6 @@ async def _arun_chain(
async def _arun_llm_or_chain( async def _arun_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MCF, llm_or_chain_factory: MCF,
n_repetitions: int,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None, callbacks: Optional[List[BaseCallbackHandler]] = None,
@ -673,7 +673,6 @@ async def _arun_llm_or_chain(
Args: Args:
example: The example to run. example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run. llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run. tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run. callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map the input to the expected format. input_mapper: Optional function to map the input to the expected format.
@ -694,31 +693,28 @@ async def _arun_llm_or_chain(
chain_or_llm = ( chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain" "LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
) )
for _ in range(n_repetitions): try:
try: if isinstance(llm_or_chain_factory, BaseLanguageModel):
if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = await _arun_llm(
output: Any = await _arun_llm( llm_or_chain_factory,
llm_or_chain_factory, example.inputs,
example.inputs, tags=tags,
tags=tags, callbacks=callbacks,
callbacks=callbacks, input_mapper=input_mapper,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
output = await _arun_chain(
chain,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id}. Error: {e}"
) )
outputs.append({"Error": str(e)}) else:
chain = llm_or_chain_factory()
output = await _arun_chain(
chain,
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
if callbacks and previous_example_ids: if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks): for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"): if hasattr(tracer, "example_id"):
@ -822,7 +818,6 @@ async def _arun_on_examples(
*, *,
evaluation: Optional[RunEvalConfig] = None, evaluation: Optional[RunEvalConfig] = None,
concurrency_level: int = 5, concurrency_level: int = 5,
num_repetitions: int = 1,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
@ -841,9 +836,6 @@ async def _arun_on_examples(
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently. concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Project name to use when tracing runs. project_name: Project name to use when tracing runs.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
@ -873,7 +865,6 @@ async def _arun_on_examples(
result = await _arun_llm_or_chain( result = await _arun_llm_or_chain(
example, example,
wrapped_model, wrapped_model,
num_repetitions,
tags=tags, tags=tags,
callbacks=callbacks, callbacks=callbacks,
input_mapper=input_mapper, input_mapper=input_mapper,
@ -983,7 +974,6 @@ def _run_chain(
def _run_llm_or_chain( def _run_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MCF, llm_or_chain_factory: MCF,
n_repetitions: int,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None, callbacks: Optional[List[BaseCallbackHandler]] = None,
@ -995,7 +985,6 @@ def _run_llm_or_chain(
Args: Args:
example: The example to run. example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run. llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run. tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run. callbacks: Optional callbacks to use during the run.
@ -1016,32 +1005,31 @@ def _run_llm_or_chain(
chain_or_llm = ( chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain" "LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
) )
for _ in range(n_repetitions): try:
try: if isinstance(llm_or_chain_factory, BaseLanguageModel):
if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = _run_llm(
output: Any = _run_llm( llm_or_chain_factory,
llm_or_chain_factory, example.inputs,
example.inputs, callbacks,
callbacks, tags=tags,
tags=tags, input_mapper=input_mapper,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
output = _run_chain(
chain,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
) )
outputs.append({"Error": str(e)}) else:
chain = llm_or_chain_factory()
output = _run_chain(
chain,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
)
outputs.append({"Error": str(e)})
if callbacks and previous_example_ids: if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks): for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"): if hasattr(tracer, "example_id"):
@ -1055,7 +1043,6 @@ def _run_on_examples(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
evaluation: Optional[RunEvalConfig] = None, evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
@ -1073,9 +1060,6 @@ def _run_on_examples(
over the dataset. The Chain constructor is used to permit over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating evaluation: Optional evaluation configuration to use when evaluating
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
@ -1110,7 +1094,6 @@ def _run_on_examples(
result = _run_llm_or_chain( result = _run_llm_or_chain(
example, example,
wrapped_model, wrapped_model,
num_repetitions,
tags=tags, tags=tags,
callbacks=callbacks, callbacks=callbacks,
input_mapper=input_mapper, input_mapper=input_mapper,
@ -1158,11 +1141,11 @@ async def arun_on_dataset(
*, *,
evaluation: Optional[RunEvalConfig] = None, evaluation: Optional[RunEvalConfig] = None,
concurrency_level: int = 5, concurrency_level: int = 5,
num_repetitions: int = 1,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Asynchronously run the Chain or language model on a dataset Asynchronously run the Chain or language model on a dataset
@ -1177,9 +1160,6 @@ async def arun_on_dataset(
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently. concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
@ -1274,6 +1254,13 @@ async def arun_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
if kwargs:
warnings.warn(
"The following arguments are deprecated and will "
"be removed in a future release: "
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
@ -1282,7 +1269,6 @@ async def arun_on_dataset(
examples, examples,
wrapped_model, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
tags=tags, tags=tags,
@ -1323,12 +1309,12 @@ def run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
evaluation: Optional[RunEvalConfig] = None, evaluation: Optional[RunEvalConfig] = None,
num_repetitions: int = 1,
concurrency_level: int = 5, concurrency_level: int = 5,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the Chain or language model on a dataset and store traces Run the Chain or language model on a dataset and store traces
@ -1344,9 +1330,6 @@ def run_on_dataset(
evaluation: Configuration for evaluators to run on the evaluation: Configuration for evaluators to run on the
results of the chain results of the chain
concurrency_level: The number of async tasks to run concurrently. concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Name of the project to store the traces in. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}. Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
@ -1441,6 +1424,13 @@ def run_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
if kwargs:
warnings.warn(
"The following arguments are deprecated and "
"will be removed in a future release: "
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
@ -1449,7 +1439,6 @@ def run_on_dataset(
client, client,
examples, examples,
wrapped_model, wrapped_model,
num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
tags=tags, tags=tags,
@ -1464,7 +1453,6 @@ def run_on_dataset(
examples, examples,
wrapped_model, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
tags=tags, tags=tags,

View File

@ -181,15 +181,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
assert "the wrong input" in inputs assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]} return {"the right input": inputs["the wrong input"]}
result = _run_llm_or_chain( result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
)
assert len(result) == 1 assert len(result) == 1
assert result[0] == {"output": "2", "the right input": "1"} assert result[0] == {"output": "2", "the right input": "1"}
bad_result = _run_llm_or_chain( bad_result = _run_llm_or_chain(
example, example,
lambda: mock_chain, lambda: mock_chain,
n_repetitions=1,
) )
assert len(bad_result) == 1 assert len(bad_result) == 1
assert "Error" in bad_result[0] assert "Error" in bad_result[0]
@ -200,9 +197,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
return "the right input" return "the right input"
mock_llm = FakeLLM(queries={"the right input": "somenumber"}) mock_llm = FakeLLM(queries={"the right input": "somenumber"})
result = _run_llm_or_chain( result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
)
assert len(result) == 1 assert len(result) == 1
llm_result = result[0] llm_result = result[0]
assert isinstance(llm_result, str) assert isinstance(llm_result, str)
@ -302,14 +297,11 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
async def mock_arun_chain( async def mock_arun_chain(
example: Example, example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain], llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None, callbacks: Optional[Any] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
return [ return [{"result": f"Result for example {example.id}"}]
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_create_project(*args: Any, **kwargs: Any) -> Any: def mock_create_project(*args: Any, **kwargs: Any) -> Any:
proj = mock.MagicMock() proj = mock.MagicMock()
@ -327,20 +319,17 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
client = Client(api_url="http://localhost:1984", api_key="123") client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock() chain = mock.MagicMock()
chain.input_keys = ["foothing"] chain.input_keys = ["foothing"]
num_repetitions = 3
results = await arun_on_dataset( results = await arun_on_dataset(
dataset_name="test", dataset_name="test",
llm_or_chain_factory=lambda: chain, llm_or_chain_factory=lambda: chain,
concurrency_level=2, concurrency_level=2,
project_name="test_project", project_name="test_project",
num_repetitions=num_repetitions,
client=client, client=client,
) )
expected = { expected = {
uuid_: [ uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"} {"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(1)
for _ in range(num_repetitions)
] ]
for uuid_ in uuids for uuid_ in uuids
} }