mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
Add single run eval loader (#7390)
Plus - add evaluation name to make string and embedding validators work with the run evaluator loader. - Rm unused root validator
This commit is contained in:
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, root_validator
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@@ -61,19 +61,6 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
"""
|
||||
return ["score"]
|
||||
|
||||
@root_validator
|
||||
def _validate_distance_metric(cls, values: dict) -> dict:
|
||||
"""Validate the distance metric.
|
||||
|
||||
Args:
|
||||
values (dict): The values to validate.
|
||||
|
||||
Returns:
|
||||
dict: The validated values.
|
||||
"""
|
||||
values["distance_metric"] = values["distance_metric"].lower()
|
||||
return values
|
||||
|
||||
def _get_metric(self, metric: EmbeddingDistance) -> Any:
|
||||
"""Get the metric function for the given metric name.
|
||||
|
||||
@@ -194,6 +181,10 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return f"embedding_{self.distance_metric.value}_distance"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys of the chain.
|
||||
@@ -219,9 +210,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
vectors = np.array(
|
||||
self.embeddings.embed_documents(
|
||||
[inputs["prediction"], inputs["prediction_b"]]
|
||||
)
|
||||
self.embeddings.embed_documents([inputs["prediction"], inputs["reference"]])
|
||||
)
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
@@ -242,7 +231,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
embedded = await self.embeddings.aembed_documents(
|
||||
[inputs["prediction"], inputs["prediction_b"]]
|
||||
[inputs["prediction"], inputs["reference"]]
|
||||
)
|
||||
vectors = np.array(embedded)
|
||||
score = self._compute_score(vectors)
|
||||
@@ -324,6 +313,10 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
"""
|
||||
return ["prediction", "prediction_b"]
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return f"pairwise_embedding_{self.distance_metric.value}_distance"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
|
@@ -11,6 +11,10 @@ from langchain.evaluation.run_evaluators.implementations import (
|
||||
get_qa_evaluator,
|
||||
get_trajectory_evaluator,
|
||||
)
|
||||
from langchain.evaluation.run_evaluators.loading import (
|
||||
load_run_evaluator_for_model,
|
||||
load_run_evaluators_for_model,
|
||||
)
|
||||
from langchain.evaluation.run_evaluators.string_run_evaluator import (
|
||||
StringRunEvaluatorChain,
|
||||
)
|
||||
@@ -25,4 +29,6 @@ __all__ = [
|
||||
"StringRunEvaluatorInputMapper",
|
||||
"ChoicesOutputParser",
|
||||
"StringRunEvaluatorChain",
|
||||
"load_run_evaluators_for_model",
|
||||
"load_run_evaluator_for_model",
|
||||
]
|
||||
|
@@ -1,13 +1,11 @@
|
||||
""""Loading helpers for run evaluators."""
|
||||
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from langchainplus_sdk import RunEvaluator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.loading import load_evaluators
|
||||
from langchain.evaluation.loading import load_evaluator
|
||||
from langchain.evaluation.run_evaluators.string_run_evaluator import (
|
||||
StringRunEvaluatorChain,
|
||||
)
|
||||
@@ -15,6 +13,55 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator
|
||||
from langchain.tools.base import Tool
|
||||
|
||||
|
||||
def load_run_evaluator_for_model(
|
||||
evaluator: EvaluatorType,
|
||||
model: Union[Chain, BaseLanguageModel, Tool],
|
||||
*,
|
||||
input_key: Optional[str] = None,
|
||||
prediction_key: Optional[str] = None,
|
||||
reference_key: Optional[str] = None,
|
||||
eval_llm: Optional[BaseLanguageModel] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[RunEvaluator]:
|
||||
"""Load evaluators specified by a list of evaluator types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evaluator: EvaluatorType
|
||||
The evaluator type to load.
|
||||
model : Union[Chain, BaseLanguageModel, Tool]
|
||||
The model to evaluate. Used to infer how to parse the run.
|
||||
input_key : Optional[str], a chain run's input key to map
|
||||
to the evaluator's input
|
||||
prediction_key : Optional[str], the key in the run's outputs to
|
||||
represent the Chain prediction
|
||||
reference_key : Optional[str], the key in the dataset example (row)
|
||||
outputs to represent the reference, or ground-truth label
|
||||
eval_llm : BaseLanguageModel, optional
|
||||
The language model to use for evaluation, if none is provided, a default
|
||||
ChatOpenAI gpt-4 model will be used.
|
||||
**kwargs : Any
|
||||
Additional keyword arguments to pass to all evaluators.
|
||||
|
||||
Returns
|
||||
-------
|
||||
RunEvaluator
|
||||
The loaded Run evaluator.
|
||||
"""
|
||||
evaluator_ = load_evaluator(evaluator, llm=eval_llm, **kwargs)
|
||||
if isinstance(evaluator_, StringEvaluator):
|
||||
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
|
||||
model,
|
||||
evaluator_,
|
||||
input_key=input_key,
|
||||
prediction_key=prediction_key,
|
||||
reference_key=reference_key,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Run evaluator for {evaluator} is not implemented")
|
||||
return run_evaluator
|
||||
|
||||
|
||||
def load_run_evaluators_for_model(
|
||||
evaluators: Sequence[EvaluatorType],
|
||||
model: Union[Chain, BaseLanguageModel, Tool],
|
||||
@@ -23,6 +70,7 @@ def load_run_evaluators_for_model(
|
||||
prediction_key: Optional[str] = None,
|
||||
reference_key: Optional[str] = None,
|
||||
eval_llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[RunEvaluator]:
|
||||
"""Load evaluators specified by a list of evaluator types.
|
||||
@@ -50,20 +98,18 @@ def load_run_evaluators_for_model(
|
||||
List[RunEvaluator]
|
||||
The loaded Run evaluators.
|
||||
"""
|
||||
evaluators_ = load_evaluators(evaluators, llm=eval_llm, **kwargs)
|
||||
run_evaluators = []
|
||||
for evaluator in evaluators_:
|
||||
if isinstance(evaluator, StringEvaluator):
|
||||
run_evaluator = StringRunEvaluatorChain.from_model_and_evaluator(
|
||||
model,
|
||||
for evaluator in evaluators:
|
||||
_kwargs = config.get(evaluator, {}) if config else {}
|
||||
run_evaluators.append(
|
||||
load_run_evaluator_for_model(
|
||||
evaluator,
|
||||
model,
|
||||
input_key=input_key,
|
||||
prediction_key=prediction_key,
|
||||
reference_key=reference_key,
|
||||
eval_llm=eval_llm,
|
||||
**{**kwargs, **_kwargs},
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Run evaluator for {evaluator} is not implemented"
|
||||
)
|
||||
run_evaluators.append(run_evaluator)
|
||||
)
|
||||
return run_evaluators
|
||||
|
@@ -141,6 +141,10 @@ class StringDistanceEvalChain(_RapidFuzzChainMixin, StringEvaluator):
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return f"{self.distance.value}_distance"
|
||||
|
||||
@staticmethod
|
||||
def _get_metric(distance: str) -> Callable:
|
||||
"""
|
||||
@@ -275,6 +279,10 @@ class PairwiseStringDistanceEvalChain(_RapidFuzzChainMixin, PairwiseStringEvalua
|
||||
"""
|
||||
return ["prediction", "prediction_b"]
|
||||
|
||||
@property
|
||||
def evaluation_name(self) -> str:
|
||||
return f"pairwise_{self.distance.value}_distance"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
|
Reference in New Issue
Block a user