Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
1b67b432b2 assert output keys 2023-06-30 06:12:09 -07:00
2 changed files with 34 additions and 0 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional
from langchainplus_sdk import EvaluationResult, RunEvaluator
from langchainplus_sdk.schemas import Example, Run
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@@ -17,6 +18,13 @@ from langchain.schema import RUN_KEY, BaseOutputParser
class RunEvaluatorInputMapper:
"""Map the inputs of a run to the inputs of an evaluation."""
@property
def output_keys(self) -> List[str]:
"""The keys of the output of the input mapper."""
raise NotImplementedError(
f"{self.__class__.__name__} must implement output_keys"
)
@abstractmethod
def map(self, run: Run, example: Optional[Example] = None) -> Dict[str, Any]:
"""Maps the Run and Optional[Example] to a dictionary"""
@@ -43,12 +51,25 @@ class RunEvaluatorChain(Chain, RunEvaluator):
output_parser: RunEvaluatorOutputParser
"""Parse the output of the eval chain into feedback."""
@root_validator
def validate_mapper_keys(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the input mapper and eval chain have matching keys."""
input_mapper: RunEvaluatorInputMapper = values["input_mapper"]
eval_chain: Chain = values["eval_chain"]
if set(input_mapper.output_keys) != set(eval_chain.input_keys):
raise ValueError(
f"Input mapper output_keys ({input_mapper.output_keys}) "
f"must match eval_chain input_keys ({eval_chain.input_keys})"
)
@property
def input_keys(self) -> List[str]:
"""The keys of the input of the evaluation chain."""
return ["run", "example"]
@property
def output_keys(self) -> List[str]:
"""The keys of the output of the evaluation chain."""
return ["feedback"]
def _call(

View File

@@ -44,6 +44,15 @@ class StringRunEvaluatorInputMapper(RunEvaluatorInputMapper, BaseModel):
answer_map: Optional[Dict[str, str]] = None
"""Map from example outputs to the evaluation inputs."""
@property
def output_keys(self) -> List[str]:
"""The keys of the output of the input mapper."""
return (
list(self.prediction_map.values())
+ list(self.input_map.values())
+ (list(self.answer_map.values()) if self.answer_map else [])
)
def map(self, run: Run, example: Optional[Example] = None) -> Dict[str, Any]:
"""Maps the Run and Optional[Example] to a dictionary"""
if run.outputs is None and self.prediction_map:
@@ -237,6 +246,10 @@ class TrajectoryInputMapper(RunEvaluatorInputMapper, BaseModel):
tool_output_key: str = "output"
"""The key to load from the tool executor's run output dictionary."""
@property
def output_keys(self) -> List[str]:
return ["tool_descriptions", "question", "agent_trajectory", "answer"]
def map(self, run: Run, example: Optional[Example] = None) -> Dict[str, str]:
"""Maps the Run and Optional[Example] to a dictionary"""
if run.child_runs is None: