Add revision identifier to run_on_dataset (#16167)

Allow specifying revision identifier for better project versioning
This commit is contained in:
SN 2024-01-17 20:27:43 -08:00 committed by GitHub
parent 5d8c147332
commit 7d444724d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -657,6 +657,7 @@ async def _arun_llm(
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]: ) -> Union[str, BaseMessage]:
"""Asynchronously run the language model. """Asynchronously run the language model.
@ -682,7 +683,9 @@ async def _arun_llm(
): ):
return await llm.ainvoke( return await llm.ainvoke(
prompt_or_messages, prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []), config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
else: else:
raise InputFormatError( raise InputFormatError(
@ -695,12 +698,18 @@ async def _arun_llm(
try: try:
prompt = _get_prompt(inputs) prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.ainvoke( llm_output: Union[str, BaseMessage] = await llm.ainvoke(
prompt, config=RunnableConfig(callbacks=callbacks, tags=tags or []) prompt,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
except InputFormatError: except InputFormatError:
messages = _get_messages(inputs) messages = _get_messages(inputs)
llm_output = await llm.ainvoke( llm_output = await llm.ainvoke(
messages, config=RunnableConfig(callbacks=callbacks, tags=tags or []) messages,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
return llm_output return llm_output
@ -712,6 +721,7 @@ async def _arun_chain(
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[dict, str]: ) -> Union[dict, str]:
"""Run a chain asynchronously on inputs.""" """Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs) inputs_ = inputs if input_mapper is None else input_mapper(inputs)
@ -723,10 +733,15 @@ async def _arun_chain(
): ):
val = next(iter(inputs_.values())) val = next(iter(inputs_.values()))
output = await chain.ainvoke( output = await chain.ainvoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or []) val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = await chain.ainvoke(inputs_, config=runnable_config) output = await chain.ainvoke(inputs_, config=runnable_config)
return output return output
@ -762,6 +777,7 @@ async def _arun_llm_or_chain(
tags=config["tags"], tags=config["tags"],
callbacks=config["callbacks"], callbacks=config["callbacks"],
input_mapper=input_mapper, input_mapper=input_mapper,
metadata=config.get("metadata"),
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
@ -771,6 +787,7 @@ async def _arun_llm_or_chain(
tags=config["tags"], tags=config["tags"],
callbacks=config["callbacks"], callbacks=config["callbacks"],
input_mapper=input_mapper, input_mapper=input_mapper,
metadata=config.get("metadata"),
) )
result = output result = output
except Exception as e: except Exception as e:
@ -793,6 +810,7 @@ def _run_llm(
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]: ) -> Union[str, BaseMessage]:
""" """
Run the language model on the example. Run the language model on the example.
@ -819,7 +837,9 @@ def _run_llm(
): ):
llm_output: Union[str, BaseMessage] = llm.invoke( llm_output: Union[str, BaseMessage] = llm.invoke(
prompt_or_messages, prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []), config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
else: else:
raise InputFormatError( raise InputFormatError(
@ -831,12 +851,16 @@ def _run_llm(
try: try:
llm_prompts = _get_prompt(inputs) llm_prompts = _get_prompt(inputs)
llm_output = llm.invoke( llm_output = llm.invoke(
llm_prompts, config=RunnableConfig(callbacks=callbacks, tags=tags or []) llm_prompts,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
except InputFormatError: except InputFormatError:
llm_messages = _get_messages(inputs) llm_messages = _get_messages(inputs)
llm_output = llm.invoke( llm_output = llm.invoke(
llm_messages, config=RunnableConfig(callbacks=callbacks) llm_messages,
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
) )
return llm_output return llm_output
@ -848,6 +872,7 @@ def _run_chain(
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[Dict, str]: ) -> Union[Dict, str]:
"""Run a chain on inputs.""" """Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs) inputs_ = inputs if input_mapper is None else input_mapper(inputs)
@ -859,10 +884,15 @@ def _run_chain(
): ):
val = next(iter(inputs_.values())) val = next(iter(inputs_.values()))
output = chain.invoke( output = chain.invoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or []) val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
) )
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = chain.invoke(inputs_, config=runnable_config) output = chain.invoke(inputs_, config=runnable_config)
return output return output
@ -899,6 +929,7 @@ def _run_llm_or_chain(
config["callbacks"], config["callbacks"],
tags=config["tags"], tags=config["tags"],
input_mapper=input_mapper, input_mapper=input_mapper,
metadata=config.get("metadata"),
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
@ -908,6 +939,7 @@ def _run_llm_or_chain(
config["callbacks"], config["callbacks"],
tags=config["tags"], tags=config["tags"],
input_mapper=input_mapper, input_mapper=input_mapper,
metadata=config.get("metadata"),
) )
result = output result = output
except Exception as e: except Exception as e:
@ -1083,8 +1115,13 @@ class _DatasetRunContainer:
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5, concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
revision_id: Optional[str] = None,
) -> _DatasetRunContainer: ) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name() project_name = project_name or name_generation.random_name()
if revision_id:
if not project_metadata:
project_metadata = {}
project_metadata.update({"revision_id": revision_id})
wrapped_model, project, dataset, examples = _prepare_eval_run( wrapped_model, project, dataset, examples = _prepare_eval_run(
client, client,
dataset_name, dataset_name,
@ -1121,6 +1158,7 @@ class _DatasetRunContainer:
], ],
tags=tags, tags=tags,
max_concurrency=concurrency_level, max_concurrency=concurrency_level,
metadata={"revision_id": revision_id} if revision_id else {},
) )
for example in examples for example in examples
] ]
@ -1183,6 +1221,7 @@ async def arun_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None) input_mapper = kwargs.pop("input_mapper", None)
@ -1208,6 +1247,7 @@ async def arun_on_dataset(
input_mapper, input_mapper,
concurrency_level, concurrency_level,
project_metadata=project_metadata, project_metadata=project_metadata,
revision_id=revision_id,
) )
batch_results = await runnable_utils.gather_with_concurrency( batch_results = await runnable_utils.gather_with_concurrency(
container.configs[0].get("max_concurrency"), container.configs[0].get("max_concurrency"),
@ -1235,6 +1275,7 @@ def run_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None, project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False, verbose: bool = False,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None) input_mapper = kwargs.pop("input_mapper", None)
@ -1260,6 +1301,7 @@ def run_on_dataset(
input_mapper, input_mapper,
concurrency_level, concurrency_level,
project_metadata=project_metadata, project_metadata=project_metadata,
revision_id=revision_id,
) )
if concurrency_level == 0: if concurrency_level == 0:
batch_results = [ batch_results = [
@ -1309,6 +1351,8 @@ Args:
log feedback and run traces. log feedback and run traces.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to each run in the project. tags: Tags to add to each run in the project.
revision_id: Optional revision identifier to assign this test run to
track the performance of different versions of your system.
Returns: Returns:
A dictionary containing the run's project name and the resulting model outputs. A dictionary containing the run's project name and the resulting model outputs.