Support evaluating runnables and arbitrary functions (#8698)

Added a couple of "integration tests" for these that I ran.

Main design point of feedback: at this point, would it just be better to
have separate arguments for each type? Little confusing what is or isn't
supported and what is the intended usage at this point since I try to
wrap the function as runnable or pack or unpack chains/llms.

```
run_on_dataset(
...
llm_or_chain_factory = None,
llm = None,
chain = NOne,
runnable=None,
function=None
):
# raise error if none set
```

Downside with runnables and arbitrary function support is that you get
much less helpful validation and error messages, but I don't think we
should block you from this, at least.
This commit is contained in:
William FH 2023-08-04 16:39:04 -07:00 committed by GitHub
parent d00a247da7
commit c8f3615aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 44 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import functools import functools
import inspect
import itertools import itertools
import logging import logging
import uuid import uuid
@ -19,6 +20,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Union, Union,
cast,
) )
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator
from langchain.schema import ChatResult, LLMResult from langchain.schema import ChatResult, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, messages_from_dict from langchain.schema.messages import BaseMessage, messages_from_dict
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] MODEL_OR_CHAIN_FACTORY = Union[
Callable[[], Union[Chain, Runnable]],
BaseLanguageModel,
Callable[[dict], Any],
Runnable,
Chain,
]
MCF = Union[Callable[[], Union[Chain, Runnable]], BaseLanguageModel]
class InputFormatError(Exception): class InputFormatError(Exception):
@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str:
def _wrap_in_chain_factory( def _wrap_in_chain_factory(
llm_or_chain_factory: Union[Chain, MODEL_OR_CHAIN_FACTORY], llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str = "<my_dataset>", dataset_name: str = "<my_dataset>",
) -> MODEL_OR_CHAIN_FACTORY: ) -> MCF:
"""Forgive the user if they pass in a chain without memory instead of a chain """Forgive the user if they pass in a chain without memory instead of a chain
factory. It's a common mistake. Raise a more helpful error message as well.""" factory. It's a common mistake. Raise a more helpful error message as well."""
if isinstance(llm_or_chain_factory, Chain): if isinstance(llm_or_chain_factory, Chain):
@ -105,11 +115,31 @@ def _wrap_in_chain_factory(
return lambda: chain return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel): elif isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory return llm_or_chain_factory
elif isinstance(llm_or_chain_factory, Runnable):
# Memory may exist here, but it's not elegant to check all those cases.
lcf = llm_or_chain_factory
return lambda: lcf
elif callable(llm_or_chain_factory): elif callable(llm_or_chain_factory):
_model = llm_or_chain_factory() try:
_model = llm_or_chain_factory() # type: ignore[call-arg]
except TypeError:
# It's an arbitrary function, wrap it in a RunnableLambda
user_func = cast(Callable, llm_or_chain_factory)
sig = inspect.signature(user_func)
logger.info(f"Wrapping function {sig} as RunnableLambda.")
wrapped = RunnableLambda(user_func)
return lambda: wrapped
constructor = cast(Callable, llm_or_chain_factory)
if isinstance(_model, BaseLanguageModel): if isinstance(_model, BaseLanguageModel):
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model return _model
return llm_or_chain_factory elif not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function
return lambda: RunnableLambda(constructor)
else:
# Typical correct case
return constructor # noqa
return llm_or_chain_factory return llm_or_chain_factory
@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
def _get_project_name( def _get_project_name(
project_name: Optional[str], project_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MCF,
) -> str: ) -> str:
""" """
Get the project name. Get the project name.
@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs( def _validate_example_inputs(
examples: Iterator[Example], examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]], input_mapper: Optional[Callable[[Dict], Any]],
) -> Iterator[Example]: ) -> Iterator[Example]:
"""Validate that the example inputs are valid for the model.""" """Validate that the example inputs are valid for the model."""
@ -324,7 +354,11 @@ def _validate_example_inputs(
_validate_example_inputs_for_language_model(first_example, input_mapper) _validate_example_inputs_for_language_model(first_example, input_mapper)
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
_validate_example_inputs_for_chain(first_example, chain, input_mapper) if isinstance(chain, Chain):
# Otherwise it's a runnable
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}")
return examples return examples
@ -332,7 +366,7 @@ def _validate_example_inputs(
def _setup_evaluation( def _setup_evaluation(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MCF,
examples: Iterator[Example], examples: Iterator[Example],
evaluation: Optional[RunEvalConfig], evaluation: Optional[RunEvalConfig],
data_type: DataType, data_type: DataType,
@ -353,8 +387,8 @@ def _setup_evaluation(
"Please specify a dataset with the default 'kv' data type." "Please specify a dataset with the default 'kv' data type."
) )
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
run_inputs = chain.input_keys run_inputs = chain.input_keys if isinstance(chain, Chain) else None
run_outputs = chain.output_keys run_outputs = chain.output_keys if isinstance(chain, Chain) else None
run_evaluators = _load_run_evaluators( run_evaluators = _load_run_evaluators(
evaluation, evaluation,
run_type, run_type,
@ -372,17 +406,15 @@ def _setup_evaluation(
def _determine_input_key( def _determine_input_key(
config: RunEvalConfig, config: RunEvalConfig,
run_inputs: Optional[List[str]], run_inputs: Optional[List[str]],
run_type: str,
) -> Optional[str]: ) -> Optional[str]:
input_key = None
if config.input_key: if config.input_key:
input_key = config.input_key input_key = config.input_key
if run_inputs and input_key not in run_inputs: if run_inputs and input_key not in run_inputs:
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}") raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
elif run_type == "llm":
input_key = None
elif run_inputs and len(run_inputs) == 1: elif run_inputs and len(run_inputs) == 1:
input_key = run_inputs[0] input_key = run_inputs[0]
else: elif run_inputs is not None and len(run_inputs) > 1:
raise ValueError( raise ValueError(
f"Must specify input key for model with multiple inputs: {run_inputs}" f"Must specify input key for model with multiple inputs: {run_inputs}"
) )
@ -393,19 +425,17 @@ def _determine_input_key(
def _determine_prediction_key( def _determine_prediction_key(
config: RunEvalConfig, config: RunEvalConfig,
run_outputs: Optional[List[str]], run_outputs: Optional[List[str]],
run_type: str,
) -> Optional[str]: ) -> Optional[str]:
prediction_key = None
if config.prediction_key: if config.prediction_key:
prediction_key = config.prediction_key prediction_key = config.prediction_key
if run_outputs and prediction_key not in run_outputs: if run_outputs and prediction_key not in run_outputs:
raise ValueError( raise ValueError(
f"Prediction key {prediction_key} not in run outputs {run_outputs}" f"Prediction key {prediction_key} not in run outputs {run_outputs}"
) )
elif run_type == "llm":
prediction_key = None
elif run_outputs and len(run_outputs) == 1: elif run_outputs and len(run_outputs) == 1:
prediction_key = run_outputs[0] prediction_key = run_outputs[0]
else: elif run_outputs is not None and len(run_outputs) > 1:
raise ValueError( raise ValueError(
f"Must specify prediction key for model" f"Must specify prediction key for model"
f" with multiple outputs: {run_outputs}" f" with multiple outputs: {run_outputs}"
@ -491,8 +521,8 @@ def _load_run_evaluators(
""" """
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = [] run_evaluators = []
input_key = _determine_input_key(config, run_inputs, run_type) input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs, run_type) prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs) reference_key = _determine_reference_key(config, example_outputs)
for eval_config in config.evaluators: for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator( run_evaluator = _construct_run_evaluator(
@ -590,7 +620,7 @@ async def _arun_llm(
async def _arun_chain( async def _arun_chain(
chain: Chain, chain: Union[Chain, Runnable],
inputs: Dict[str, Any], inputs: Dict[str, Any],
callbacks: Callbacks, callbacks: Callbacks,
*, *,
@ -598,20 +628,22 @@ async def _arun_chain(
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str]: ) -> Union[dict, str]:
"""Run a chain asynchronously on inputs.""" """Run a chain asynchronously on inputs."""
if input_mapper is not None: inputs_ = inputs if input_mapper is None else input_mapper(inputs)
inputs_ = input_mapper(inputs) if isinstance(chain, Chain):
output: Union[dict, str] = await chain.acall( if isinstance(inputs_, dict) and len(inputs_) == 1:
inputs_, callbacks=callbacks, tags=tags val = next(iter(inputs_.values()))
) output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else: else:
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) output = await chain.ainvoke(inputs_, config=runnable_config)
return output return output
async def _arun_llm_or_chain( async def _arun_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MCF,
n_repetitions: int, n_repetitions: int,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
@ -810,12 +842,12 @@ async def _arun_on_examples(
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, wrapped_model)
run_evaluators, examples = _setup_evaluation( run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type wrapped_model, examples, evaluation, data_type
) )
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
results: Dict[str, List[Any]] = {} results: Dict[str, List[Any]] = {}
async def process_example( async def process_example(
@ -824,7 +856,7 @@ async def _arun_on_examples(
"""Process a single example.""" """Process a single example."""
result = await _arun_llm_or_chain( result = await _arun_llm_or_chain(
example, example,
llm_or_chain_factory, wrapped_model,
num_repetitions, num_repetitions,
tags=tags, tags=tags,
callbacks=callbacks, callbacks=callbacks,
@ -911,7 +943,7 @@ def _run_llm(
def _run_chain( def _run_chain(
chain: Chain, chain: Union[Chain, Runnable],
inputs: Dict[str, Any], inputs: Dict[str, Any],
callbacks: Callbacks, callbacks: Callbacks,
*, *,
@ -919,18 +951,22 @@ def _run_chain(
input_mapper: Optional[Callable[[Dict], Any]] = None, input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[Dict, str]: ) -> Union[Dict, str]:
"""Run a chain on inputs.""" """Run a chain on inputs."""
if input_mapper is not None: inputs_ = inputs if input_mapper is None else input_mapper(inputs)
inputs_ = input_mapper(inputs) if isinstance(chain, Chain):
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags) if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
else: else:
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain(inputs_, callbacks=callbacks, tags=tags) output = chain.invoke(inputs_, config=runnable_config)
return output return output
def _run_llm_or_chain( def _run_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MCF,
n_repetitions: int, n_repetitions: int,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
@ -986,7 +1022,8 @@ def _run_llm_or_chain(
outputs.append(output) outputs.append(output)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"{chain_or_llm} failed for example {example.id}. Error: {e}" f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
) )
outputs.append({"Error": str(e)}) outputs.append({"Error": str(e)})
if callbacks and previous_example_ids: if callbacks and previous_example_ids:
@ -1080,7 +1117,7 @@ def _prepare_eval_run(
dataset_name: str, dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str], project_name: Optional[str],
) -> Tuple[MODEL_OR_CHAIN_FACTORY, str, Dataset, Iterator[Example]]: ) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, llm_or_chain_factory)
try: try:

View File

@ -10,9 +10,11 @@ from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.evaluation import EvaluatorType from langchain.evaluation import EvaluatorType
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.smith import RunEvalConfig, run_on_dataset from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith.evaluation import InputFormatError from langchain.smith.evaluation import InputFormatError
from langchain.smith.evaluation.runner_utils import arun_on_dataset
def _check_all_feedback_passed(_project_name: str, client: Client) -> None: def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
@ -427,3 +429,47 @@ def test_chain_on_kv_singleio_dataset(
tags=["shouldpass"], tags=["shouldpass"],
) )
_check_all_feedback_passed(eval_project_name, client) _check_all_feedback_passed(eval_project_name, client)
@pytest.mark.asyncio
async def test_runnable_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
await arun_on_dataset(
client,
kv_singleio_dataset_name,
runnable,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
@pytest.mark.asyncio
async def test_arb_func_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
def my_func(x: dict) -> str:
return runnable.invoke(x).content
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
await arun_on_dataset(
client,
kv_singleio_dataset_name,
my_func,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)