Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
981c5e9f75 draft 2023-10-03 21:08:09 -07:00

View File

@@ -16,6 +16,7 @@ from typing import (
Optional,
Sequence,
Tuple,
TypedDict,
Union,
cast,
)
@@ -68,6 +69,11 @@ class InputFormatError(Exception):
## Shared Utilities
class ProjectConfig(TypedDict, total=False):
project_name: Optional[str]
project_metadata: Optional[Dict[str, Any]]
class TestResult(dict):
"""A dictionary of the results of a single test run."""
@@ -859,13 +865,21 @@ def _run_llm_or_chain(
def _prepare_eval_run(
client: Client,
dataset_name: str,
dataset_name: Optional[str],
dataset_share_token: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: str,
project_metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[MCF, str, Dataset, List[Example]]:
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
dataset = client.read_dataset(dataset_name=dataset_name)
wrapped_model = _wrap_in_chain_factory(
llm_or_chain_factory, dataset_name or dataset_share_token
)
if dataset_name is not None:
dataset = client.read_dataset(dataset_name=dataset_name)
elif dataset_share_token is not None:
dataset = client.read_shared_dataset(share_token=dataset_share_token)
else:
raise ValueError("Must specify either dataset_name or dataset_share_token")
try:
project = client.create_project(
project_name,
@@ -890,7 +904,7 @@ def _prepare_eval_run(
def _prepare_run_on_dataset(
client: Client,
dataset_name: str,
dataset_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
@@ -898,11 +912,13 @@ def _prepare_run_on_dataset(
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
dataset_share_token: Optional[str] = None,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client,
dataset_name,
dataset_share_token,
llm_or_chain_factory,
project_name,
project_metadata=project_metadata,
@@ -1046,6 +1062,84 @@ async def arun_on_dataset(
return results
def _run_on_examples(
examples: List[Example],
configs: List[RunnableConfig],
wrapped_model: MODEL_OR_CHAIN_FACTORY,
project_name: str,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
verbose: bool = False,
) -> TestResult:
if concurrency_level == 0:
batch_results = [
_run_llm_or_chain(
example,
config,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
)
for example, config in zip(examples, configs)
]
else:
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
)
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
print("\n Eval quantiles:")
print(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results
def run_on_public_dataset(
component: MODEL_OR_CHAIN_FACTORY,
share_token: str,
*,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
project_config: Optional[ProjectConfig] = None,
config: Optional[RunnableConfig] = None,
verbose: bool = False,
) -> Dict[str, Any]:
project_config = project_config or ProjectConfig()
config = config or RunnableConfig()
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
None,
component,
project_config.get("project_name"),
evaluation,
tags=config.get("tags"),
concurrency_level=config.get("max_concurrency"),
project_metadata=project_config.get("project_metadata"),
dataset_share_token=share_token,
)
return _run_on_examples(
examples,
configs,
wrapped_model,
project_name=project_name,
concurrency_level=config.get("max_concurrency"),
verbose=verbose,
)
def run_on_dataset(
client: Optional[Client],
dataset_name: str,
@@ -1084,39 +1178,15 @@ def run_on_dataset(
concurrency_level,
project_metadata=project_metadata,
)
if concurrency_level == 0:
batch_results = [
_run_llm_or_chain(
example,
config,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
)
for example, config in zip(examples, configs)
]
else:
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
)
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
print("\n Eval quantiles:")
print(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results
return _run_on_examples(
examples,
configs,
wrapped_model,
project_name=project_name,
input_mapper=input_mapper,
concurrency_level=concurrency_level,
verbose=verbose,
)
_RUN_ON_DATASET_DOCSTRING = """