mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
Evaluate on Version (#18471)
This commit is contained in:
parent
55b69d5ad1
commit
1eec67e8fe
@ -968,9 +968,6 @@ def _run_llm_or_chain(
|
||||
return result
|
||||
|
||||
|
||||
## Public API
|
||||
|
||||
|
||||
def _prepare_eval_run(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
@ -978,10 +975,17 @@ def _prepare_eval_run(
|
||||
project_name: str,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
dataset_version: Optional[Union[str, datetime]] = None,
|
||||
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||
examples = list(client.list_examples(dataset_id=dataset.id))
|
||||
as_of = dataset_version if isinstance(dataset_version, datetime) else None
|
||||
if isinstance(dataset_version, str):
|
||||
raise NotImplementedError(
|
||||
"Selecting dataset_version by tag is not yet supported."
|
||||
" Please use a datetime object."
|
||||
)
|
||||
examples = list(client.list_examples(dataset_id=dataset.id, as_of=as_of))
|
||||
if not examples:
|
||||
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
||||
modified_at = [ex.modified_at for ex in examples if ex.modified_at]
|
||||
@ -1173,6 +1177,7 @@ class _DatasetRunContainer:
|
||||
concurrency_level: int = 5,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
dataset_version: Optional[Union[datetime, str]] = None,
|
||||
) -> _DatasetRunContainer:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
if revision_id:
|
||||
@ -1186,6 +1191,7 @@ class _DatasetRunContainer:
|
||||
project_name,
|
||||
project_metadata=project_metadata,
|
||||
tags=tags,
|
||||
dataset_version=dataset_version,
|
||||
)
|
||||
tags = tags or []
|
||||
for k, v in (project.metadata.get("git") or {}).items():
|
||||
@ -1269,6 +1275,8 @@ _INPUT_MAPPER_DEP_WARNING = (
|
||||
"langchain.schema.runnable.base.RunnableLambda.html)"
|
||||
)
|
||||
|
||||
## Public API
|
||||
|
||||
|
||||
async def arun_on_dataset(
|
||||
client: Optional[Client],
|
||||
@ -1276,11 +1284,11 @@ async def arun_on_dataset(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
dataset_version: Optional[Union[datetime, str]] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
@ -1289,6 +1297,13 @@ async def arun_on_dataset(
|
||||
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
||||
if revision_id is None:
|
||||
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
||||
tags = kwargs.pop("tags", None)
|
||||
if tags:
|
||||
warn_deprecated(
|
||||
"0.1.9",
|
||||
message="The tags argument is deprecated and will be"
|
||||
" removed in a future release. Please specify project_metadata instead.",
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
warn_deprecated(
|
||||
@ -1310,6 +1325,7 @@ async def arun_on_dataset(
|
||||
concurrency_level,
|
||||
project_metadata=project_metadata,
|
||||
revision_id=revision_id,
|
||||
dataset_version=dataset_version,
|
||||
)
|
||||
batch_results = await runnable_utils.gather_with_concurrency(
|
||||
container.configs[0].get("max_concurrency"),
|
||||
@ -1332,17 +1348,24 @@ def run_on_dataset(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
dataset_version: Optional[Union[datetime, str]] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
input_mapper = kwargs.pop("input_mapper", None)
|
||||
if input_mapper:
|
||||
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
||||
tags = kwargs.pop("tags", None)
|
||||
if tags:
|
||||
warn_deprecated(
|
||||
"0.1.9",
|
||||
message="The tags argument is deprecated and will be"
|
||||
" removed in a future release. Please specify project_metadata instead.",
|
||||
)
|
||||
if revision_id is None:
|
||||
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
||||
|
||||
@ -1366,6 +1389,7 @@ def run_on_dataset(
|
||||
concurrency_level,
|
||||
project_metadata=project_metadata,
|
||||
revision_id=revision_id,
|
||||
dataset_version=dataset_version,
|
||||
)
|
||||
if concurrency_level == 0:
|
||||
batch_results = [
|
||||
@ -1458,8 +1482,8 @@ Examples
|
||||
client = Client()
|
||||
run_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
dataset_name="<my_dataset_name>",
|
||||
llm_or_chain_factory=construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
|
||||
@ -1496,8 +1520,8 @@ or LangSmith's `RunEvaluator` classes.
|
||||
|
||||
run_on_dataset(
|
||||
client,
|
||||
"<my_dataset_name>",
|
||||
construct_chain,
|
||||
dataset_name="<my_dataset_name>",
|
||||
llm_or_chain_factory=construct_chain,
|
||||
evaluation=evaluation_config,
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
Loading…
Reference in New Issue
Block a user