Compare commits

..

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
0508176982 Update lambda typing 2023-09-08 15:49:00 -07:00
4 changed files with 162 additions and 67 deletions

View File

@@ -1692,8 +1692,18 @@ class RunnableLambda(Runnable[Input, Output]):
def __init__(
self,
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
func: Union[
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
],
afunc: Optional[
Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
]
] = None,
) -> None:
if afunc is not None:
self.afunc = afunc

View File

@@ -143,6 +143,7 @@ def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
@@ -166,6 +167,10 @@ async def acall_func_with_variable_args(
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
Callable[
[Input, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
run_manager: AsyncCallbackManagerForChainRun,

View File

@@ -278,21 +278,35 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
## Shared data validation utilities
def _validate_example_inputs_for_language_model(
first_example: Example,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None:
try:
_get_prompt(first_example.inputs)
except InputFormatError:
try:
_get_messages(first_example.inputs)
except InputFormatError:
if input_mapper:
prompt_input = input_mapper(first_example.inputs)
if not isinstance(prompt_input, str) and not (
isinstance(prompt_input, list)
and all(isinstance(msg, BaseMessage) for msg in prompt_input)
):
raise InputFormatError(
"Example inputs do not match language model input format. "
"Expected a dictionary with messages or a single prompt."
f" Got: {first_example.inputs}"
" Please update your dataset OR provide an input_mapper"
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
"When using an input_mapper to prepare dataset example inputs"
" for an LLM or chat model, the output must a single string or"
" a list of chat messages."
f"\nGot: {prompt_input} of type {type(prompt_input)}."
)
else:
try:
_get_prompt(first_example.inputs)
except InputFormatError:
try:
_get_messages(first_example.inputs)
except InputFormatError:
raise InputFormatError(
"Example inputs do not match language model input format. "
"Expected a dictionary with messages or a single prompt."
f" Got: {first_example.inputs}"
" Please update your dataset OR provide an input_mapper"
" to convert the example.inputs to a compatible format"
" for the llm or chat model you wish to evaluate."
)
def _validate_example_inputs_for_chain(
@@ -335,15 +349,16 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs(
example: Example,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]],
) -> None:
"""Validate that the example inputs are valid for the model."""
if isinstance(llm_or_chain_factory, BaseLanguageModel):
_validate_example_inputs_for_language_model(example)
_validate_example_inputs_for_language_model(example, input_mapper)
else:
chain = llm_or_chain_factory()
if isinstance(chain, Chain):
# Otherwise it's a runnable
_validate_example_inputs_for_chain(example, chain)
_validate_example_inputs_for_chain(example, chain, input_mapper)
elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}")
@@ -576,6 +591,7 @@ async def _arun_llm(
*,
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[str, BaseMessage]:
"""Asynchronously run the language model.
@@ -584,6 +600,7 @@ async def _arun_llm(
inputs: The input dictionary.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map inputs to the expected format.
Returns:
The LLMResult or ChatResult.
@@ -591,16 +608,36 @@ async def _arun_llm(
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.apredict(
prompt, callbacks=callbacks, tags=tags
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.apredict_messages(
messages, callbacks=callbacks, tags=tags
)
if input_mapper is not None:
prompt_or_messages = input_mapper(inputs)
if isinstance(prompt_or_messages, str):
return await llm.apredict(
prompt_or_messages, callbacks=callbacks, tags=tags
)
elif isinstance(prompt_or_messages, list) and all(
isinstance(msg, BaseMessage) for msg in prompt_or_messages
):
return await llm.apredict_messages(
prompt_or_messages, callbacks=callbacks, tags=tags
)
else:
raise InputFormatError(
"Input mapper returned invalid format"
f" {prompt_or_messages}"
"\nExpected a single string or list of chat messages."
)
else:
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.apredict(
prompt, callbacks=callbacks, tags=tags
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.apredict_messages(
messages, callbacks=callbacks, tags=tags
)
return llm_output
@@ -610,17 +647,19 @@ async def _arun_chain(
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs.values()))
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs, config=runnable_config)
output = await chain.ainvoke(inputs_, config=runnable_config)
return output
@@ -629,6 +668,7 @@ async def _arun_llm_or_chain(
config: RunnableConfig,
*,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""Asynchronously run the Chain or language model.
@@ -653,6 +693,7 @@ async def _arun_llm_or_chain(
example.inputs,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
@@ -661,6 +702,7 @@ async def _arun_llm_or_chain(
example.inputs,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
@@ -733,17 +775,19 @@ def _run_chain(
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[Dict, str]:
"""Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs, dict) and len(inputs) == 1:
val = next(iter(inputs.values()))
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs, callbacks=callbacks, tags=tags)
output = chain(inputs_, callbacks=callbacks, tags=tags)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs, config=runnable_config)
output = chain.invoke(inputs_, config=runnable_config)
return output
@@ -752,6 +796,7 @@ def _run_llm_or_chain(
config: RunnableConfig,
*,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""
Run the Chain or language model synchronously.
@@ -777,6 +822,7 @@ def _run_llm_or_chain(
example.inputs,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
@@ -785,6 +831,7 @@ def _run_llm_or_chain(
example.inputs,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
@@ -832,7 +879,9 @@ def _prepare_run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
config: Optional[RunnableConfig] = None,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
@@ -842,10 +891,8 @@ def _prepare_run_on_dataset(
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type
)
_validate_example_inputs(examples[0], wrapped_model)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
progress_bar = progress.ProgressBarCallback(len(examples))
run_config = config or RunnableConfig(max_concurrency=5)
_callbacks = run_config["callbacks"] if "callbacks" in run_config else []
configs = [
RunnableConfig(
callbacks=[
@@ -862,9 +909,9 @@ def _prepare_run_on_dataset(
example_id=example.id,
),
progress_bar,
]
+ _callbacks,
**run_config,
],
tags=tags or [],
max_concurrency=concurrency_level,
)
for example in examples
]
@@ -900,14 +947,16 @@ def _collect_test_results(
async def arun_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
config: Optional[RunnableConfig] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
@@ -915,22 +964,24 @@ async def arun_on_dataset(
and store traces to the specified project name.
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Either a runnable or a constructor for a runnable.
If your runnable contains state (e.g. an agent or
chain with memory), you should pass in a constructor
function that returns a new runnable instance for each
example row in the dataset.
client: LangSmith client to use to read the dataset, and to
log feedback and run traces.
evaluation: Optional evaluation configuration to specify which
evaluators to run on the results of the chain and which
key(s) to consider during evaluation.
config: Optional configuration to specify concurrency, tags,
or other information to apply to the chain runs.
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently.
project_name: Name of the project to store the traces in.
Defaults to a friendly name.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary containing the run's project name and the
resulting model outputs.
@@ -942,6 +993,7 @@ async def arun_on_dataset(
.. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
@@ -968,9 +1020,11 @@ async def arun_on_dataset(
]
)
client = Client()
await arun_on_dataset(
dataset_name="<my_dataset_name>",
llm_or_chain_config=construct_chain,
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
)
@@ -1006,6 +1060,7 @@ async def arun_on_dataset(
)
await arun_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
@@ -1018,14 +1073,15 @@ async def arun_on_dataset(
f"{kwargs.keys()}.",
DeprecationWarning,
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
config=config,
tags,
input_mapper,
concurrency_level,
)
batch_results = await runnable_utils.gather_with_concurrency(
@@ -1034,6 +1090,7 @@ async def arun_on_dataset(
functools.partial(
_arun_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
@@ -1051,14 +1108,16 @@ async def arun_on_dataset(
def run_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
client: Optional[Client] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
config: Optional[RunnableConfig] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
@@ -1076,7 +1135,7 @@ def run_on_dataset(
results of the chain
concurrency_level: The number of async tasks to run concurrently.
project_name: Name of the project to store the traces in.
Defaults to a friendly name
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
@@ -1096,6 +1155,7 @@ def run_on_dataset(
.. code-block:: python
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
@@ -1122,11 +1182,12 @@ def run_on_dataset(
]
)
client = Client()
run_on_dataset(
dataset_name="<my_dataset_name>",
llm_or_chain_factory=construct_chain,
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
verbose=True,
)
You can also create custom evaluators by subclassing the
@@ -1161,6 +1222,7 @@ def run_on_dataset(
)
run_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
evaluation=evaluation_config,
@@ -1173,14 +1235,15 @@ def run_on_dataset(
f"{kwargs.keys()}.",
DeprecationWarning,
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
config=config,
tags,
input_mapper,
concurrency_level,
)
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
@@ -1188,6 +1251,7 @@ def run_on_dataset(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,

View File

@@ -1,3 +1,4 @@
import asyncio
from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
@@ -1785,3 +1786,18 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
@pytest.mark.asyncio
async def test_lambda_accept_config() -> None:
def sync_with_config(x: str, config: RunnableConfig) -> str:
return x
RunnableLambda(sync_with_config).invoke("foo")
async def async_with_config(x: str, config: RunnableConfig) -> str:
asyncio.sleep(0.001)
return x
await RunnableLambda(async_with_config).ainvoke("foo")
await RunnableLambda(sync_with_config, async_with_config).abatch(["foo"])