Evaluate on Version (#18471)

This commit is contained in:
William FH 2024-03-03 17:47:35 -08:00 committed by GitHub
parent 55b69d5ad1
commit 1eec67e8fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -968,9 +968,6 @@ def _run_llm_or_chain(
return result return result
## Public API
def _prepare_eval_run( def _prepare_eval_run(
client: Client, client: Client,
dataset_name: str, dataset_name: str,
@ -978,10 +975,17 @@ def _prepare_eval_run(
project_name: str, project_name: str,
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
dataset_version: Optional[Union[str, datetime]] = None,
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]: ) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
dataset = client.read_dataset(dataset_name=dataset_name) dataset = client.read_dataset(dataset_name=dataset_name)
examples = list(client.list_examples(dataset_id=dataset.id)) as_of = dataset_version if isinstance(dataset_version, datetime) else None
if isinstance(dataset_version, str):
raise NotImplementedError(
"Selecting dataset_version by tag is not yet supported."
" Please use a datetime object."
)
examples = list(client.list_examples(dataset_id=dataset.id, as_of=as_of))
if not examples: if not examples:
raise ValueError(f"Dataset {dataset_name} has no example rows.") raise ValueError(f"Dataset {dataset_name} has no example rows.")
modified_at = [ex.modified_at for ex in examples if ex.modified_at] modified_at = [ex.modified_at for ex in examples if ex.modified_at]
@ -1173,6 +1177,7 @@ class _DatasetRunContainer:
concurrency_level: int = 5, concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
revision_id: Optional[str] = None, revision_id: Optional[str] = None,
dataset_version: Optional[Union[datetime, str]] = None,
) -> _DatasetRunContainer: ) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name() project_name = project_name or name_generation.random_name()
if revision_id: if revision_id:
@ -1186,6 +1191,7 @@ class _DatasetRunContainer:
project_name, project_name,
project_metadata=project_metadata, project_metadata=project_metadata,
tags=tags, tags=tags,
dataset_version=dataset_version,
) )
tags = tags or [] tags = tags or []
for k, v in (project.metadata.get("git") or {}).items(): for k, v in (project.metadata.get("git") or {}).items():
@ -1269,6 +1275,8 @@ _INPUT_MAPPER_DEP_WARNING = (
"langchain.schema.runnable.base.RunnableLambda.html)" "langchain.schema.runnable.base.RunnableLambda.html)"
) )
## Public API
async def arun_on_dataset( async def arun_on_dataset(
client: Optional[Client], client: Optional[Client],
@ -1276,11 +1284,11 @@ async def arun_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
evaluation: Optional[smith_eval.RunEvalConfig] = None, evaluation: Optional[smith_eval.RunEvalConfig] = None,
dataset_version: Optional[Union[datetime, str]] = None,
concurrency_level: int = 5, concurrency_level: int = 5,
project_name: Optional[str] = None, project_name: Optional[str] = None,
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None, revision_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -1289,6 +1297,13 @@ async def arun_on_dataset(
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True) warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
if revision_id is None: if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id") revision_id = get_langchain_env_var_metadata().get("revision_id")
tags = kwargs.pop("tags", None)
if tags:
warn_deprecated(
"0.1.9",
message="The tags argument is deprecated and will be"
" removed in a future release. Please specify project_metadata instead.",
)
if kwargs: if kwargs:
warn_deprecated( warn_deprecated(
@ -1310,6 +1325,7 @@ async def arun_on_dataset(
concurrency_level, concurrency_level,
project_metadata=project_metadata, project_metadata=project_metadata,
revision_id=revision_id, revision_id=revision_id,
dataset_version=dataset_version,
) )
batch_results = await runnable_utils.gather_with_concurrency( batch_results = await runnable_utils.gather_with_concurrency(
container.configs[0].get("max_concurrency"), container.configs[0].get("max_concurrency"),
@ -1332,17 +1348,24 @@ def run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*, *,
evaluation: Optional[smith_eval.RunEvalConfig] = None, evaluation: Optional[smith_eval.RunEvalConfig] = None,
dataset_version: Optional[Union[datetime, str]] = None,
concurrency_level: int = 5, concurrency_level: int = 5,
project_name: Optional[str] = None, project_name: Optional[str] = None,
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None, revision_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None) input_mapper = kwargs.pop("input_mapper", None)
if input_mapper: if input_mapper:
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True) warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
tags = kwargs.pop("tags", None)
if tags:
warn_deprecated(
"0.1.9",
message="The tags argument is deprecated and will be"
" removed in a future release. Please specify project_metadata instead.",
)
if revision_id is None: if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id") revision_id = get_langchain_env_var_metadata().get("revision_id")
@ -1366,6 +1389,7 @@ def run_on_dataset(
concurrency_level, concurrency_level,
project_metadata=project_metadata, project_metadata=project_metadata,
revision_id=revision_id, revision_id=revision_id,
dataset_version=dataset_version,
) )
if concurrency_level == 0: if concurrency_level == 0:
batch_results = [ batch_results = [
@ -1458,8 +1482,8 @@ Examples
client = Client() client = Client()
run_on_dataset( run_on_dataset(
client, client,
"<my_dataset_name>", dataset_name="<my_dataset_name>",
construct_chain, llm_or_chain_factory=construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
) )
@ -1496,8 +1520,8 @@ or LangSmith's `RunEvaluator` classes.
run_on_dataset( run_on_dataset(
client, client,
"<my_dataset_name>", dataset_name="<my_dataset_name>",
construct_chain, llm_or_chain_factory=construct_chain,
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501