community[patch]: Update UpTrain Callback Handler to support the new UpTrain evaluation schema (#21656)

UpTrain has a new dashboard now that makes it easier to view projects
and evaluations. Using this requires specifying both project_name and
evaluation_name when performing evaluations. I have updated the code to
support it.
This commit is contained in:
Dhruv Chawla
2024-05-21 05:36:00 +05:30
committed by GitHub
parent c0e3c3a350
commit d4359d3de6
2 changed files with 96 additions and 68 deletions

View File

@@ -83,10 +83,10 @@ class UpTrainDataSchema:
"""The UpTrain data schema for tracking evaluation results.
Args:
project_name_prefix (str): Prefix for the project name.
project_name (str): The project name to be shown in UpTrain dashboard.
Attributes:
project_name_prefix (str): Prefix for the project name.
project_name (str): The project name to be shown in UpTrain dashboard.
uptrain_results (DefaultDict[str, Any]): Dictionary to store evaluation results.
eval_types (Set[str]): Set to store the types of evaluations.
query (str): Query for the RAG evaluation.
@@ -101,10 +101,10 @@ class UpTrainDataSchema:
"""
def __init__(self, project_name_prefix: str) -> None:
def __init__(self, project_name: str) -> None:
"""Initialize the UpTrain data schema."""
# For tracking project name and results
self.project_name_prefix: str = project_name_prefix
self.project_name: str = project_name
self.uptrain_results: DefaultDict[str, Any] = defaultdict(list)
# For tracking event types
@@ -130,7 +130,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs evaluation results to uptrain and the console.
Args:
project_name_prefix (str): Prefix for the project name.
project_name (str): The project name to be shown in UpTrain dashboard.
key_type (str): Type of key to use. Must be 'uptrain' or 'openai'.
api_key (str): API key for the UpTrain or OpenAI API.
(This key is required to perform evaluations using GPT.)
@@ -144,7 +144,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
def __init__(
self,
*,
project_name_prefix: str = "langchain",
project_name: str = "langchain",
key_type: str = "openai",
api_key: str = "sk-****************", # The API key to use for evaluation
model: str = "gpt-3.5-turbo", # The model to use for evaluation
@@ -158,7 +158,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
self.log_results = log_results
# Set uptrain variables
self.schema = UpTrainDataSchema(project_name_prefix=project_name_prefix)
self.schema = UpTrainDataSchema(project_name=project_name)
self.first_score_printed_flag = False
if key_type == "uptrain":
@@ -166,7 +166,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
self.uptrain_client = uptrain.APIClient(settings=settings)
elif key_type == "openai":
settings = uptrain.Settings(
openai_api_key=api_key, evaluate_locally=False, model=model
openai_api_key=api_key, evaluate_locally=True, model=model
)
self.uptrain_client = uptrain.EvalLLM(settings=settings)
else:
@@ -174,23 +174,26 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
def uptrain_evaluate(
self,
project_name: str,
evaluation_name: str,
data: List[Dict[str, Any]],
checks: List[str],
) -> None:
"""Run an evaluation on the UpTrain server using UpTrain client."""
if self.uptrain_client.__class__.__name__ == "APIClient":
uptrain_result = self.uptrain_client.log_and_evaluate(
project_name=project_name,
project_name=self.schema.project_name,
evaluation_name=evaluation_name,
data=data,
checks=checks,
)
else:
uptrain_result = self.uptrain_client.evaluate(
project_name=self.schema.project_name,
evaluation_name=evaluation_name,
data=data,
checks=checks,
)
self.schema.uptrain_results[project_name].append(uptrain_result)
self.schema.uptrain_results[self.schema.project_name].append(uptrain_result)
score_name_map = {
"score_context_relevance": "Context Relevance Score",
@@ -258,7 +261,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
]
self.uptrain_evaluate(
project_name=f"{self.schema.project_name_prefix}_rag",
evaluation_name="rag",
data=data,
checks=[
uptrain.Evals.CONTEXT_RELEVANCE,
@@ -340,7 +343,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
]
self.uptrain_evaluate(
project_name=f"{self.schema.project_name_prefix}_multi_query",
evaluation_name="multi_query",
data=data,
checks=[uptrain.Evals.MULTI_QUERY_ACCURACY],
)
@@ -372,7 +375,7 @@ class UpTrainCallbackHandler(BaseCallbackHandler):
}
]
self.uptrain_evaluate(
project_name=f"{self.schema.project_name_prefix}_context_reranking",
evaluation_name="context_reranking",
data=data,
checks=[
uptrain.Evals.CONTEXT_CONCISENESS,