add interface for evaluating messages

This commit is contained in:
Ankush Gola
2023-07-06 22:37:31 -07:00
parent 1c650f98a8
commit 0d8df345f5
3 changed files with 338 additions and 57 deletions

View File

@@ -0,0 +1,138 @@
"""Run evaluator mapper for message evaluators."""
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union, TypedDict
from langchain.schema import BaseMessage
from langchainplus_sdk.schemas import Example, Run
from langchain.load.serializable import Serializable
from langchain.schema import messages_from_dict
class RunMapping(TypedDict):
prediction: BaseMessage
input: List[BaseMessage]
class ExampleMapping(TypedDict):
reference: BaseMessage
class MessageRunMapper(Serializable):
"""Extract items to evaluate from run object"""
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["prediction", "input"]
@abstractmethod
def map(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
def __call__(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
return self.map(run)
class MessageExampleMapper(Serializable):
"""Map an example, or row in the dataset, to the inputs of an evaluation."""
reference_key: Optional[str] = None
@property
def output_keys(self) -> List[str]:
"""The keys to extract from the run."""
return ["reference"]
def map(self, example: Example) -> ExampleMapping:
"""Maps the Example, or dataset row to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as a reference."
)
if self.reference_key is None:
if len(example.outputs) > 1:
raise ValueError(
f"Example {example.id} has multiple outputs, so you must"
" specify a reference_key."
)
else:
output = list(example.outputs.values())[0]
return {
"reference": output if isinstance(output, BaseMessage) else messages_from_dict([output])[0]
}
elif self.reference_key not in example.outputs:
raise ValueError(
f"Example {example.id} does not have reference key"
f" {self.reference_key}."
)
output = example.outputs[self.reference_key]
return {"reference": output if isinstance(output, BaseMessage) else messages_from_dict([output])[0]}
def __call__(self, example: Example) -> ExampleMapping:
"""Maps the Run and Example to a dictionary."""
if not example.outputs:
raise ValueError(
f"Example {example.id} has no outputs to use as areference label."
)
return self.map(example)
class ChatModelMessageRunMapper(MessageRunMapper):
"""Extract items to evaluate from run object."""
@staticmethod
def extract_inputs(inputs: Dict) -> List[BaseMessage]:
if not inputs.get("messages"):
raise ValueError("Run must have messages as inputs.")
if "messages" in inputs:
if isinstance(inputs["messages"], list) and inputs["messages"]:
if isinstance(inputs["messages"][0], BaseMessage):
return messages_from_dict(inputs["messages"])
elif isinstance(inputs["messages"][0], list):
# Runs from Tracer have messages as a list of lists of dicts
return messages_from_dict(inputs["messages"][0])
raise ValueError(f"Could not extract messages from inputs: {inputs}")
@staticmethod
def extract_outputs(outputs: Dict) -> BaseMessage:
if not outputs.get("generations"):
raise ValueError("LLM Run must have generations as outputs.")
first_generation: Dict = outputs["generations"][0]
if isinstance(first_generation, list):
# Runs from Tracer have generations as a list of lists of dicts
# Whereas Runs from the API have a list of dicts
first_generation = first_generation[0]
if "message" in first_generation:
return messages_from_dict([first_generation["message"]])[0]
def map(self, run: Run) -> RunMapping:
"""Maps the Run to a dictionary."""
if run.run_type != "llm":
raise ValueError("ChatModel RunMapper only supports LangSmith runs of type llm.")
elif not run.outputs:
if run.error:
raise ValueError(
f"Cannot evaluate errored LLM run {run.id}: {run.error}"
)
else:
raise ValueError(
f"Run {run.id} has no outputs. Cannot evaluate this run."
)
else:
try:
inputs = self.extract_inputs(run.inputs)
except Exception as e:
raise ValueError(
f"Could not parse LM input from run inputs {run.inputs}"
) from e
try:
output_ = self.extract_outputs(run.outputs)
except Exception as e:
raise ValueError(
f"Could not parse LM prediction from run outputs {run.outputs}"
) from e
return {"input": inputs, "prediction": output_}

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Protocol
from langchainplus_sdk import EvaluationResult, RunEvaluator
from langchainplus_sdk.schemas import Example, Run
@@ -13,7 +13,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.evaluation.schema import StringEvaluator
from langchain.evaluation.schema import StringEvaluator, MessageEvaluator
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, get_buffer_string, messages_from_dict
from langchain.tools.base import Tool
@@ -28,10 +28,10 @@ class StringRunMapper(Serializable):
return ["prediction", "input"]
@abstractmethod
def map(self, run: Run) -> Dict[str, Any]:
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
def __call__(self, run: Run) -> Dict[str, Any]:
def __call__(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
@@ -71,7 +71,7 @@ class LLMStringRunMapper(StringRunMapper):
output_ = first_generation["text"]
return output_
def map(self, run: Run) -> Dict[str, Any]:
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if run.run_type != "llm":
raise ValueError("LLM RunMapper only supports LLM runs.")
@@ -104,9 +104,9 @@ class ChainStringRunMapper(StringRunMapper):
"""Extract items to evaluate from the run object from a chain."""
input_key: str
"""The key from the model Run's inputs to use as the eval input."""
"""The key from the chain Run's inputs to use as the eval input."""
prediction_key: str
"""The key from the model Run's outputs to use as the eval prediction."""
"""The key from the chain Run's outputs to use as the eval prediction."""
@classmethod
def from_chain(
@@ -150,7 +150,7 @@ class ChainStringRunMapper(StringRunMapper):
raise ValueError(f"Chain {model.lc_namespace} has no input or output keys.")
return cls(input_key=input_key, prediction_key=prediction_key)
def map(self, run: Run) -> Dict[str, Any]:
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
@@ -172,7 +172,7 @@ class ChainStringRunMapper(StringRunMapper):
class ToolStringRunMapper(StringRunMapper):
"""Map an input to the tool."""
def map(self, run: Run) -> Dict[str, Any]:
def map(self, run: Run) -> Dict[str, str]:
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
return {"input": run.inputs["input"], "prediction": run.outputs["output"]}
@@ -193,7 +193,7 @@ class StringExampleMapper(Serializable):
chat_messages = messages_from_dict(messages)
return get_buffer_string(chat_messages)
def map(self, example: Example) -> Dict[str, Any]:
def map(self, example: Example) -> Dict[str, str]:
"""Maps the Example, or dataset row to a dictionary."""
if not example.outputs:
raise ValueError(
@@ -228,17 +228,26 @@ class StringExampleMapper(Serializable):
return self.map(example)
class StringRunEvaluatorChain(Chain, RunEvaluator):
# TODO(agola11) can make these abstract classes
class BaseRunMapper(Protocol):
def map(self, run: Run) -> Dict[str, Any]: ...
class BaseExampleMapper(Protocol):
def map(self, example: Example) -> Dict[str, Any]: ...
class SimpleRunEvaluatorChain(Chain, RunEvaluator):
"""Evaluate Run and optional examples."""
run_mapper: StringRunMapper
run_mapper: BaseRunMapper
"""Maps the Run to a dictionary with 'input' and 'prediction' strings."""
example_mapper: Optional[StringExampleMapper] = None
example_mapper: Optional[BaseExampleMapper] = None
"""Maps the Example (dataset row) to a dictionary
with a 'reference' string."""
name: str
"""The name of the evaluation metric."""
string_evaluator: StringEvaluator
evaluator: Union[StringEvaluator, MessageEvaluator]
"""The evaluation chain."""
@property
@@ -252,7 +261,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
def _prepare_input(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
run: Run = inputs["run"]
example: Optional[Example] = inputs.get("example")
evaluate_strings_inputs = self.run_mapper(run)
evaluate_inputs = self.run_mapper.map(run)
if self.example_mapper:
if not example:
raise ValueError(
@@ -260,8 +269,8 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
" example from the dataset,"
f" but none was provided for run {run.id}."
)
evaluate_strings_inputs.update(self.example_mapper(example))
return evaluate_strings_inputs
evaluate_inputs.update(self.example_mapper.map(example))
return evaluate_inputs
def _prepare_output(self, output: Dict[str, Any]) -> EvaluationResult:
evaluation_result = EvaluationResult(key=self.name, **output)
@@ -275,14 +284,23 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Call the evaluation chain."""
evaluate_strings_inputs = self._prepare_input(inputs)
evaluate_inputs = self._prepare_input(inputs)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
chain_output = self.string_evaluator.evaluate_strings(
**evaluate_strings_inputs,
callbacks=callbacks,
)
if isinstance(self.evaluator, StringEvaluator):
chain_output = self.evaluator.evaluate_strings(
**evaluate_inputs,
callbacks=callbacks,
)
elif isinstance(self.evaluator, MessageEvaluator):
chain_output = self.evaluator.evaluate_messages(
**evaluate_inputs,
callbacks=callbacks,
)
else:
raise ValueError("Unsupported evaluator type")
evaluation_result = self._prepare_output(chain_output)
return {"feedback": evaluation_result}
@@ -291,14 +309,23 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
inputs: Dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> Dict[str, Any]:
"""Call the evaluation chain."""
evaluate_strings_inputs = self._prepare_input(inputs)
evaluate_inputs = self._prepare_input(inputs)
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
chain_output = await self.string_evaluator.aevaluate_strings(
**evaluate_strings_inputs,
callbacks=callbacks,
)
if isinstance(self.evaluator, StringEvaluator):
chain_output = await self.evaluator.aevaluate_strings(
**evaluate_inputs,
callbacks=callbacks,
)
elif isinstance(self.evaluator, MessageEvaluator):
chain_output = await self.evaluator.aevaluate_messages(
**evaluate_inputs,
callbacks=callbacks,
)
else:
raise ValueError("Unsupported evaluator type")
evaluation_result = self._prepare_output(chain_output)
return {"feedback": evaluation_result}
@@ -315,43 +342,51 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
result = await self.acall({"run": run, "example": example})
return result["feedback"]
# TODO: Add ability to load message evaluators
@classmethod
def from_model_and_evaluator(
cls,
model: Union[Chain, BaseLanguageModel, Tool],
evaluator: StringEvaluator,
evaluator: Union[StringEvaluator, MessageEvaluator],
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
reference_key: Optional[str] = None,
) -> StringRunEvaluatorChain:
) -> SimpleRunEvaluatorChain:
"""Create a StringRunEvaluatorChain from a model and evaluator."""
if isinstance(model, BaseLanguageModel):
run_mapper: StringRunMapper = LLMStringRunMapper()
elif isinstance(model, Chain):
run_mapper = ChainStringRunMapper.from_chain(
model, input_key=input_key, prediction_key=prediction_key
)
elif isinstance(model, Tool):
run_mapper = ToolStringRunMapper()
if isinstance(evaluator, StringEvaluator):
if isinstance(model, BaseLanguageModel):
run_mapper: StringRunMapper = LLMStringRunMapper()
elif isinstance(model, Chain):
run_mapper = ChainStringRunMapper.from_chain(
model, input_key=input_key, prediction_key=prediction_key
)
elif isinstance(model, Tool):
run_mapper = ToolStringRunMapper()
else:
raise NotImplementedError(
f"{cls.__name__}.from_model_and_evaluator({type(model)})"
" not yet implemented."
"Expected one of [BaseLanguageModel, Chain, Tool]."
)
if reference_key is not None or isinstance(model, BaseLanguageModel):
example_mapper = StringExampleMapper(reference_key=reference_key)
elif evaluator.requires_reference:
raise ValueError(
f"Evaluator {evaluator.evaluation_name} requires a reference"
" example from the dataset. Please specify the reference key from"
" amongst the dataset outputs keys."
)
else:
example_mapper = None
elif isinstance(evaluator, MessageEvaluator):
raise NotImplementedError()
else:
raise NotImplementedError(
f"{cls.__name__}.from_model_and_evaluator({type(model)})"
" not yet implemented."
"Expected one of [BaseLanguageModel, Chain, Tool]."
)
if reference_key is not None or isinstance(model, BaseLanguageModel):
example_mapper = StringExampleMapper(reference_key=reference_key)
elif evaluator.requires_reference:
raise ValueError(
f"Evaluator {evaluator.evaluation_name} requires a reference"
" example from the dataset. Please specify the reference key from"
" amongst the dataset outputs keys."
)
else:
example_mapper = None
raise NotImplementedError()
return cls(
name=evaluator.evaluation_name,
run_mapper=run_mapper,
example_mapper=example_mapper,
string_evaluator=evaluator,
)
evaluator=evaluator,
)

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, List
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.schema import BaseMessage, get_buffer_string
class EvaluatorType(str, Enum):
@@ -105,6 +106,113 @@ class StringEvaluator(ABC):
)
class MessageEvaluator(ABC):
"""Protocol for evaluating messages."""
@property
def evaluation_name(self) -> str:
raise NotImplementedError()
@property
def requires_reference(self) -> bool:
return False
@abstractmethod
def evaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (BaseMessage): the prediction to evaluate.
reference (Optional[BaseMessage], optional): the reference label
to evaluate against.
input (Optional[List[BaseMessage]], optional): the input to consider during
evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
async def aevaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional
input and label.
Args:
prediction (BaseMessage): the prediction to evaluate.
reference (Optional[BaseMessage], optional): the reference label
to evaluate against.
input (Optional[List[BaseMessage]], optional): the input to consider during
evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
raise NotImplementedError(
f"{self.__class__.__name__} hasn't implemented an "
"async aevaluate_messages method."
)
# TODO(agola11): move this out of schema
class SimpleMessageEvaluator(MessageEvaluator):
"""Simple implementation of MessageEvaluator that delegates to a StringEvaluator."""
def __init__(self, string_evaluator: StringEvaluator):
self.string_evaluator = string_evaluator
def evaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
return self.string_evaluator.evaluate_strings(
prediction=get_buffer_string([prediction]),
reference=get_buffer_string([reference]) if reference else None,
input=get_buffer_string(input) if input else None,
**kwargs,
)
async def aevaluate_messages(
self,
*,
prediction: BaseMessage,
reference: Optional[BaseMessage] = None,
input: Optional[List[BaseMessage]] = None,
**kwargs: Any,
) -> dict:
return await self.string_evaluator.aevaluate_strings(
prediction=get_buffer_string([prediction]),
reference=get_buffer_string([reference]) if reference else None,
input=get_buffer_string(input) if input else None,
**kwargs,
)
class PairwiseStringEvaluator(ABC):
"""A protocol for comparing the output of two models."""