mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
Evaluate on Version (#18471)
This commit is contained in:
parent
55b69d5ad1
commit
1eec67e8fe
@ -968,9 +968,6 @@ def _run_llm_or_chain(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
## Public API
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_eval_run(
|
def _prepare_eval_run(
|
||||||
client: Client,
|
client: Client,
|
||||||
dataset_name: str,
|
dataset_name: str,
|
||||||
@ -978,10 +975,17 @@ def _prepare_eval_run(
|
|||||||
project_name: str,
|
project_name: str,
|
||||||
project_metadata: Optional[Dict[str, Any]] = None,
|
project_metadata: Optional[Dict[str, Any]] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
dataset_version: Optional[Union[str, datetime]] = None,
|
||||||
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
|
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
|
||||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||||
examples = list(client.list_examples(dataset_id=dataset.id))
|
as_of = dataset_version if isinstance(dataset_version, datetime) else None
|
||||||
|
if isinstance(dataset_version, str):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Selecting dataset_version by tag is not yet supported."
|
||||||
|
" Please use a datetime object."
|
||||||
|
)
|
||||||
|
examples = list(client.list_examples(dataset_id=dataset.id, as_of=as_of))
|
||||||
if not examples:
|
if not examples:
|
||||||
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
||||||
modified_at = [ex.modified_at for ex in examples if ex.modified_at]
|
modified_at = [ex.modified_at for ex in examples if ex.modified_at]
|
||||||
@ -1173,6 +1177,7 @@ class _DatasetRunContainer:
|
|||||||
concurrency_level: int = 5,
|
concurrency_level: int = 5,
|
||||||
project_metadata: Optional[Dict[str, Any]] = None,
|
project_metadata: Optional[Dict[str, Any]] = None,
|
||||||
revision_id: Optional[str] = None,
|
revision_id: Optional[str] = None,
|
||||||
|
dataset_version: Optional[Union[datetime, str]] = None,
|
||||||
) -> _DatasetRunContainer:
|
) -> _DatasetRunContainer:
|
||||||
project_name = project_name or name_generation.random_name()
|
project_name = project_name or name_generation.random_name()
|
||||||
if revision_id:
|
if revision_id:
|
||||||
@ -1186,6 +1191,7 @@ class _DatasetRunContainer:
|
|||||||
project_name,
|
project_name,
|
||||||
project_metadata=project_metadata,
|
project_metadata=project_metadata,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
dataset_version=dataset_version,
|
||||||
)
|
)
|
||||||
tags = tags or []
|
tags = tags or []
|
||||||
for k, v in (project.metadata.get("git") or {}).items():
|
for k, v in (project.metadata.get("git") or {}).items():
|
||||||
@ -1269,6 +1275,8 @@ _INPUT_MAPPER_DEP_WARNING = (
|
|||||||
"langchain.schema.runnable.base.RunnableLambda.html)"
|
"langchain.schema.runnable.base.RunnableLambda.html)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Public API
|
||||||
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
async def arun_on_dataset(
|
||||||
client: Optional[Client],
|
client: Optional[Client],
|
||||||
@ -1276,11 +1284,11 @@ async def arun_on_dataset(
|
|||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
*,
|
*,
|
||||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||||
|
dataset_version: Optional[Union[datetime, str]] = None,
|
||||||
concurrency_level: int = 5,
|
concurrency_level: int = 5,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
project_metadata: Optional[Dict[str, Any]] = None,
|
project_metadata: Optional[Dict[str, Any]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
revision_id: Optional[str] = None,
|
revision_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -1289,6 +1297,13 @@ async def arun_on_dataset(
|
|||||||
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
||||||
if revision_id is None:
|
if revision_id is None:
|
||||||
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
||||||
|
tags = kwargs.pop("tags", None)
|
||||||
|
if tags:
|
||||||
|
warn_deprecated(
|
||||||
|
"0.1.9",
|
||||||
|
message="The tags argument is deprecated and will be"
|
||||||
|
" removed in a future release. Please specify project_metadata instead.",
|
||||||
|
)
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
warn_deprecated(
|
warn_deprecated(
|
||||||
@ -1310,6 +1325,7 @@ async def arun_on_dataset(
|
|||||||
concurrency_level,
|
concurrency_level,
|
||||||
project_metadata=project_metadata,
|
project_metadata=project_metadata,
|
||||||
revision_id=revision_id,
|
revision_id=revision_id,
|
||||||
|
dataset_version=dataset_version,
|
||||||
)
|
)
|
||||||
batch_results = await runnable_utils.gather_with_concurrency(
|
batch_results = await runnable_utils.gather_with_concurrency(
|
||||||
container.configs[0].get("max_concurrency"),
|
container.configs[0].get("max_concurrency"),
|
||||||
@ -1332,17 +1348,24 @@ def run_on_dataset(
|
|||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
*,
|
*,
|
||||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||||
|
dataset_version: Optional[Union[datetime, str]] = None,
|
||||||
concurrency_level: int = 5,
|
concurrency_level: int = 5,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
project_metadata: Optional[Dict[str, Any]] = None,
|
project_metadata: Optional[Dict[str, Any]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
revision_id: Optional[str] = None,
|
revision_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
input_mapper = kwargs.pop("input_mapper", None)
|
input_mapper = kwargs.pop("input_mapper", None)
|
||||||
if input_mapper:
|
if input_mapper:
|
||||||
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
|
||||||
|
tags = kwargs.pop("tags", None)
|
||||||
|
if tags:
|
||||||
|
warn_deprecated(
|
||||||
|
"0.1.9",
|
||||||
|
message="The tags argument is deprecated and will be"
|
||||||
|
" removed in a future release. Please specify project_metadata instead.",
|
||||||
|
)
|
||||||
if revision_id is None:
|
if revision_id is None:
|
||||||
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
revision_id = get_langchain_env_var_metadata().get("revision_id")
|
||||||
|
|
||||||
@ -1366,6 +1389,7 @@ def run_on_dataset(
|
|||||||
concurrency_level,
|
concurrency_level,
|
||||||
project_metadata=project_metadata,
|
project_metadata=project_metadata,
|
||||||
revision_id=revision_id,
|
revision_id=revision_id,
|
||||||
|
dataset_version=dataset_version,
|
||||||
)
|
)
|
||||||
if concurrency_level == 0:
|
if concurrency_level == 0:
|
||||||
batch_results = [
|
batch_results = [
|
||||||
@ -1458,8 +1482,8 @@ Examples
|
|||||||
client = Client()
|
client = Client()
|
||||||
run_on_dataset(
|
run_on_dataset(
|
||||||
client,
|
client,
|
||||||
"<my_dataset_name>",
|
dataset_name="<my_dataset_name>",
|
||||||
construct_chain,
|
llm_or_chain_factory=construct_chain,
|
||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1496,8 +1520,8 @@ or LangSmith's `RunEvaluator` classes.
|
|||||||
|
|
||||||
run_on_dataset(
|
run_on_dataset(
|
||||||
client,
|
client,
|
||||||
"<my_dataset_name>",
|
dataset_name="<my_dataset_name>",
|
||||||
construct_chain,
|
llm_or_chain_factory=construct_chain,
|
||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
Loading…
Reference in New Issue
Block a user