mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Add revision identifier to run_on_dataset (#16167)
Allow specifying revision identifier for better project versioning
This commit is contained in:
parent
5d8c147332
commit
7d444724d7
@ -657,6 +657,7 @@ async def _arun_llm(
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, BaseMessage]:
|
||||
"""Asynchronously run the language model.
|
||||
|
||||
@ -682,7 +683,9 @@ async def _arun_llm(
|
||||
):
|
||||
return await llm.ainvoke(
|
||||
prompt_or_messages,
|
||||
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise InputFormatError(
|
||||
@ -695,12 +698,18 @@ async def _arun_llm(
|
||||
try:
|
||||
prompt = _get_prompt(inputs)
|
||||
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
|
||||
prompt, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
|
||||
prompt,
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
except InputFormatError:
|
||||
messages = _get_messages(inputs)
|
||||
llm_output = await llm.ainvoke(
|
||||
messages, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
|
||||
messages,
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
return llm_output
|
||||
|
||||
@ -712,6 +721,7 @@ async def _arun_chain(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[dict, str]:
|
||||
"""Run a chain asynchronously on inputs."""
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
@ -723,10 +733,15 @@ async def _arun_chain(
|
||||
):
|
||||
val = next(iter(inputs_.values()))
|
||||
output = await chain.ainvoke(
|
||||
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
|
||||
val,
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
runnable_config = RunnableConfig(
|
||||
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
|
||||
)
|
||||
output = await chain.ainvoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
@ -762,6 +777,7 @@ async def _arun_llm_or_chain(
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
@ -771,6 +787,7 @@ async def _arun_llm_or_chain(
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
@ -793,6 +810,7 @@ def _run_llm(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, BaseMessage]:
|
||||
"""
|
||||
Run the language model on the example.
|
||||
@ -819,7 +837,9 @@ def _run_llm(
|
||||
):
|
||||
llm_output: Union[str, BaseMessage] = llm.invoke(
|
||||
prompt_or_messages,
|
||||
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise InputFormatError(
|
||||
@ -831,12 +851,16 @@ def _run_llm(
|
||||
try:
|
||||
llm_prompts = _get_prompt(inputs)
|
||||
llm_output = llm.invoke(
|
||||
llm_prompts, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
|
||||
llm_prompts,
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
except InputFormatError:
|
||||
llm_messages = _get_messages(inputs)
|
||||
llm_output = llm.invoke(
|
||||
llm_messages, config=RunnableConfig(callbacks=callbacks)
|
||||
llm_messages,
|
||||
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
|
||||
)
|
||||
return llm_output
|
||||
|
||||
@ -848,6 +872,7 @@ def _run_chain(
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[Dict, str]:
|
||||
"""Run a chain on inputs."""
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
@ -859,10 +884,15 @@ def _run_chain(
|
||||
):
|
||||
val = next(iter(inputs_.values()))
|
||||
output = chain.invoke(
|
||||
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
|
||||
val,
|
||||
config=RunnableConfig(
|
||||
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
runnable_config = RunnableConfig(
|
||||
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
|
||||
)
|
||||
output = chain.invoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
@ -899,6 +929,7 @@ def _run_llm_or_chain(
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
@ -908,6 +939,7 @@ def _run_llm_or_chain(
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
@ -1083,8 +1115,13 @@ class _DatasetRunContainer:
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
) -> _DatasetRunContainer:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
if revision_id:
|
||||
if not project_metadata:
|
||||
project_metadata = {}
|
||||
project_metadata.update({"revision_id": revision_id})
|
||||
wrapped_model, project, dataset, examples = _prepare_eval_run(
|
||||
client,
|
||||
dataset_name,
|
||||
@ -1121,6 +1158,7 @@ class _DatasetRunContainer:
|
||||
],
|
||||
tags=tags,
|
||||
max_concurrency=concurrency_level,
|
||||
metadata={"revision_id": revision_id} if revision_id else {},
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
@ -1183,6 +1221,7 @@ async def arun_on_dataset(
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
input_mapper = kwargs.pop("input_mapper", None)
|
||||
@ -1208,6 +1247,7 @@ async def arun_on_dataset(
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
project_metadata=project_metadata,
|
||||
revision_id=revision_id,
|
||||
)
|
||||
batch_results = await runnable_utils.gather_with_concurrency(
|
||||
container.configs[0].get("max_concurrency"),
|
||||
@ -1235,6 +1275,7 @@ def run_on_dataset(
|
||||
project_metadata: Optional[Dict[str, Any]] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
revision_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
input_mapper = kwargs.pop("input_mapper", None)
|
||||
@ -1260,6 +1301,7 @@ def run_on_dataset(
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
project_metadata=project_metadata,
|
||||
revision_id=revision_id,
|
||||
)
|
||||
if concurrency_level == 0:
|
||||
batch_results = [
|
||||
@ -1309,6 +1351,8 @@ Args:
|
||||
log feedback and run traces.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to each run in the project.
|
||||
revision_id: Optional revision identifier to assign this test run to
|
||||
track the performance of different versions of your system.
|
||||
Returns:
|
||||
A dictionary containing the run's project name and the resulting model outputs.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user