mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-14 04:07:28 +00:00
fix: benchmark compare summary write to db
This commit is contained in:
@@ -204,68 +204,47 @@ async def evaluation(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/benchmark/list_results", dependencies=[Depends(check_api_key)])
|
||||
async def list_compare_runs(limit: int = 50, offset: int = 0):
|
||||
@router.get("/benchmark/result/{serial_no}", dependencies=[Depends(check_api_key)])
|
||||
async def get_compare_run_detail(serial_no: str, limit: int = 200, offset: int = 0):
|
||||
dao = BenchmarkResultDao()
|
||||
rows = dao.list_summaries(limit=limit, offset=offset)
|
||||
result = []
|
||||
for s in rows:
|
||||
result.append(
|
||||
summaries = dao.list_summaries_by_task(serial_no, limit=10000, offset=0)
|
||||
if not summaries:
|
||||
return Result.succ(
|
||||
{"serialNo": serial_no, "summaries": [], "metrics": {}, "cotTokens": {"total": 0, "byModel": {}}})
|
||||
|
||||
detail_list = []
|
||||
total_counts = {"right": 0, "wrong": 0, "failed": 0, "exception": 0}
|
||||
round_ids = set()
|
||||
for s in summaries:
|
||||
r, w, f, e = s.right, s.wrong, s.failed, s.exception
|
||||
denom_exec = max(r + w + f + e, 1)
|
||||
accuracy = r / denom_exec
|
||||
exec_rate = (r + w) / denom_exec
|
||||
total_counts["right"] += r
|
||||
total_counts["wrong"] += w
|
||||
total_counts["failed"] += f
|
||||
total_counts["exception"] += e
|
||||
round_ids.add(s.round_id)
|
||||
detail_list.append(
|
||||
{
|
||||
"id": s.id,
|
||||
"roundId": s.round_id,
|
||||
"llmCode": getattr(s, "llm_code", None),
|
||||
"right": r,
|
||||
"wrong": w,
|
||||
"failed": f,
|
||||
"exception": e,
|
||||
"accuracy": accuracy,
|
||||
"execRate": exec_rate,
|
||||
"outputPath": s.output_path,
|
||||
"right": s.right,
|
||||
"wrong": s.wrong,
|
||||
"failed": s.failed,
|
||||
"exception": s.exception,
|
||||
"gmtCreated": s.gmt_created.isoformat() if s.gmt_created else None,
|
||||
}
|
||||
)
|
||||
return Result.succ(result)
|
||||
|
||||
|
||||
@router.get("/benchmark/result/{summary_id}", dependencies=[Depends(check_api_key)])
|
||||
async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int = 0):
|
||||
dao = BenchmarkResultDao()
|
||||
s = dao.get_summary_by_id(summary_id)
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="compare run not found")
|
||||
compares = dao.list_compare_by_round_and_path(
|
||||
s.round_id, s.output_path, limit=limit, offset=offset
|
||||
return Result.succ(
|
||||
{
|
||||
"serialNo": serial_no,
|
||||
"summaries": detail_list,
|
||||
}
|
||||
)
|
||||
detail = {
|
||||
"id": s.id,
|
||||
"roundId": s.round_id,
|
||||
"outputPath": s.output_path,
|
||||
"summary": {
|
||||
"right": s.right,
|
||||
"wrong": s.wrong,
|
||||
"failed": s.failed,
|
||||
"exception": s.exception,
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"id": r.id,
|
||||
"serialNo": r.serial_no,
|
||||
"analysisModelId": r.analysis_model_id,
|
||||
"question": r.question,
|
||||
"prompt": r.prompt,
|
||||
"standardAnswerSql": r.standard_answer_sql,
|
||||
"llmOutput": r.llm_output,
|
||||
"executeResult": json.loads(r.execute_result)
|
||||
if r.execute_result
|
||||
else None,
|
||||
"errorMsg": r.error_msg,
|
||||
"compareResult": r.compare_result,
|
||||
"isExecute": r.is_execute,
|
||||
"llmCount": r.llm_count,
|
||||
"gmtCreated": r.gmt_created.isoformat() if r.gmt_created else None,
|
||||
}
|
||||
for r in compares
|
||||
],
|
||||
}
|
||||
return Result.succ(detail)
|
||||
|
||||
|
||||
@router.post("/execute_benchmark_task", dependencies=[Depends(check_api_key)])
|
||||
@@ -396,6 +375,32 @@ async def download_benchmark_result(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/benchmark/list_compare_tasks", dependencies=[Depends(check_api_key)])
|
||||
async def list_benchmark_tasks(limit: int = 50, offset: int = 0):
|
||||
dao = BenchmarkResultDao()
|
||||
tasks = dao.list_tasks(limit=limit, offset=offset)
|
||||
result = []
|
||||
for task_id in tasks:
|
||||
summaries = dao.list_summaries_by_task(task_id, limit=10000, offset=0)
|
||||
result.append(
|
||||
{
|
||||
"serialNo": task_id,
|
||||
"summaries": [
|
||||
{
|
||||
"roundId": s.round_id,
|
||||
"llmCode": getattr(s, "llm_code", None),
|
||||
"right": s.right,
|
||||
"wrong": s.wrong,
|
||||
"failed": s.failed,
|
||||
"exception": s.exception,
|
||||
"outputPath": s.output_path,
|
||||
}
|
||||
for s in summaries
|
||||
],
|
||||
}
|
||||
)
|
||||
return Result.succ(result)
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import (
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
desc,
|
||||
func,
|
||||
)
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
@@ -81,12 +82,12 @@ class BenchmarkCompareEntity(Model):
|
||||
class BenchmarkSummaryEntity(Model):
|
||||
"""Summary result for one round and one output path.
|
||||
|
||||
Counts of RIGHT/WRONG/FAILED/EXCEPTION.
|
||||
Counts of RIGHT/WRONG/FAILED/EXCEPTION, per llm_code.
|
||||
"""
|
||||
|
||||
__tablename__ = "benchmark_summary"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("round_id", "output_path", name="uk_round_output"),
|
||||
UniqueConstraint("round_id", "output_path", "llm_code", name="uk_round_output_llm"),
|
||||
)
|
||||
|
||||
id = Column(
|
||||
@@ -96,6 +97,8 @@ class BenchmarkSummaryEntity(Model):
|
||||
output_path = Column(
|
||||
String(512), nullable=False, comment="Original output file path"
|
||||
)
|
||||
task_serial_no = Column(String(255), nullable=True, comment="Task serial number (unique id per submitted task)")
|
||||
llm_code = Column(String(255), nullable=True, comment="LLM code for this summary")
|
||||
|
||||
right = Column(Integer, default=0, comment="RIGHT count")
|
||||
wrong = Column(Integer, default=0, comment="WRONG count")
|
||||
@@ -111,6 +114,7 @@ class BenchmarkSummaryEntity(Model):
|
||||
)
|
||||
|
||||
Index("idx_bm_sum_round", "round_id")
|
||||
Index("idx_bm_sum_task", "task_serial_no")
|
||||
|
||||
|
||||
class BenchmarkResultDao(BaseDao):
|
||||
@@ -212,6 +216,53 @@ class BenchmarkResultDao(BaseDao):
|
||||
session.commit()
|
||||
return summary.id
|
||||
|
||||
def upsert_summary(
|
||||
self,
|
||||
round_id: int,
|
||||
output_path: str,
|
||||
llm_code: Optional[str],
|
||||
right: int,
|
||||
wrong: int,
|
||||
failed: int,
|
||||
exception: int,
|
||||
task_serial_no: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Upsert summary counts directly into DB (per llm_code), with task serial no."""
|
||||
with self.session() as session:
|
||||
existing = (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
.filter(
|
||||
BenchmarkSummaryEntity.round_id == round_id,
|
||||
BenchmarkSummaryEntity.output_path == output_path,
|
||||
BenchmarkSummaryEntity.llm_code == llm_code,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.right = right
|
||||
existing.wrong = wrong
|
||||
existing.failed = failed
|
||||
existing.exception = exception
|
||||
if task_serial_no is not None:
|
||||
existing.task_serial_no = task_serial_no
|
||||
existing.gmt_modified = datetime.now()
|
||||
session.commit()
|
||||
return existing.id
|
||||
else:
|
||||
summary = BenchmarkSummaryEntity(
|
||||
round_id=round_id,
|
||||
output_path=output_path,
|
||||
llm_code=llm_code,
|
||||
right=right,
|
||||
wrong=wrong,
|
||||
failed=failed,
|
||||
exception=exception,
|
||||
task_serial_no=task_serial_no,
|
||||
)
|
||||
session.add(summary)
|
||||
session.commit()
|
||||
return summary.id
|
||||
|
||||
# Basic query helpers
|
||||
def list_compare_by_round(self, round_id: int, limit: int = 100, offset: int = 0):
|
||||
with self.session(commit=False) as session:
|
||||
@@ -237,7 +288,33 @@ class BenchmarkResultDao(BaseDao):
|
||||
.first()
|
||||
)
|
||||
|
||||
# New helpers for listing summaries and detail by id
|
||||
def list_summaries_by_round(self, round_id: int, limit: int = 100, offset: int = 0):
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
.filter(BenchmarkSummaryEntity.round_id == round_id)
|
||||
.order_by(desc(BenchmarkSummaryEntity.id))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
def list_rounds(self, limit: int = 100, offset: int = 0):
|
||||
with self.session(commit=False) as session:
|
||||
rows = (
|
||||
session.query(
|
||||
BenchmarkSummaryEntity.round_id,
|
||||
func.max(BenchmarkSummaryEntity.gmt_created).label("last_time"),
|
||||
)
|
||||
.group_by(BenchmarkSummaryEntity.round_id)
|
||||
.order_by(desc("last_time"))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
# return only round ids in order
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def list_summaries(self, limit: int = 100, offset: int = 0):
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
@@ -256,17 +333,32 @@ class BenchmarkResultDao(BaseDao):
|
||||
.first()
|
||||
)
|
||||
|
||||
def list_compare_by_round_and_path(
|
||||
self, round_id: int, output_path: str, limit: int = 200, offset: int = 0
|
||||
):
|
||||
def list_tasks(self, limit: int = 100, offset: int = 0) -> List[str]:
|
||||
"""List submitted task ids (task_serial_no), ordered by latest summary time."""
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkCompareEntity)
|
||||
.filter(
|
||||
BenchmarkCompareEntity.round_id == round_id,
|
||||
BenchmarkCompareEntity.output_path == output_path,
|
||||
rows = (
|
||||
session.query(
|
||||
BenchmarkSummaryEntity.task_serial_no,
|
||||
func.max(BenchmarkSummaryEntity.gmt_created).label("last_time"),
|
||||
)
|
||||
.order_by(desc(BenchmarkCompareEntity.id))
|
||||
.filter(BenchmarkSummaryEntity.task_serial_no.isnot(None))
|
||||
.group_by(BenchmarkSummaryEntity.task_serial_no)
|
||||
.order_by(desc("last_time"))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def list_summaries_by_task(
|
||||
self, task_serial_no: str, limit: int = 1000, offset: int = 0
|
||||
):
|
||||
"""List summaries for a given task (may include multiple rounds)."""
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
.filter(BenchmarkSummaryEntity.task_serial_no == task_serial_no)
|
||||
.order_by(desc(BenchmarkSummaryEntity.id))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
|
||||
@@ -78,57 +78,54 @@ class FileParseService(ABC):
|
||||
def summary_and_write_multi_round_benchmark_result(
|
||||
self, output_path: str, round_id: int
|
||||
) -> str:
|
||||
"""Compute summary from the Excel file and return JSON string.
|
||||
"""Compute summary from the Excel file grouped by llmCode and return JSON list.
|
||||
|
||||
It will read the '<base>_round{round_id}.xlsx' file and sheet
|
||||
'benchmark_compare_result', then count the compareResult column
|
||||
(RIGHT/WRONG/FAILED/EXCEPTION) to build summary.
|
||||
It reads the '<base>_round{round_id}.xlsx' file and sheet
|
||||
'benchmark_compare_result', then for each llmCode counts the compareResult column
|
||||
(RIGHT/WRONG/FAILED/EXCEPTION) to build summary list.
|
||||
"""
|
||||
try:
|
||||
base_name = Path(output_path).stem
|
||||
extension = Path(output_path).suffix
|
||||
if extension.lower() not in [".xlsx", ".xls"]:
|
||||
extension = ".xlsx"
|
||||
excel_file = (
|
||||
Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
|
||||
)
|
||||
excel_file = Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
|
||||
if not excel_file.exists():
|
||||
logger.warning(f"summary excel not found: {excel_file}")
|
||||
result = dict(right=0, wrong=0, failed=0, exception=0)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return json.dumps([], ensure_ascii=False)
|
||||
|
||||
df = pd.read_excel(str(excel_file), sheet_name="benchmark_compare_result")
|
||||
right = (
|
||||
int((df["compareResult"] == "RIGHT").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
wrong = (
|
||||
int((df["compareResult"] == "WRONG").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
failed = (
|
||||
int((df["compareResult"] == "FAILED").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
exception = (
|
||||
int((df["compareResult"] == "EXCEPTION").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
if "compareResult" not in df.columns:
|
||||
logger.warning("compareResult column missing in excel")
|
||||
return json.dumps([], ensure_ascii=False)
|
||||
|
||||
# ensure llmCode column exists
|
||||
if "llmCode" not in df.columns:
|
||||
df["llmCode"] = None
|
||||
|
||||
summaries = []
|
||||
for llm_code, group in df.groupby("llmCode"):
|
||||
right = int((group["compareResult"] == "RIGHT").sum())
|
||||
wrong = int((group["compareResult"] == "WRONG").sum())
|
||||
failed = int((group["compareResult"] == "FAILED").sum())
|
||||
exception = int((group["compareResult"] == "EXCEPTION").sum())
|
||||
summaries.append(
|
||||
{
|
||||
"llmCode": None if pd.isna(llm_code) else str(llm_code),
|
||||
"right": right,
|
||||
"wrong": wrong,
|
||||
"failed": failed,
|
||||
"exception": exception,
|
||||
}
|
||||
)
|
||||
|
||||
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
|
||||
logger.info(
|
||||
f"[summary] summary computed from Excel for round={round_id},"
|
||||
f" output_path={output_path} -> {result}"
|
||||
f"[summary] computed per llmCode for round={round_id}, output_path={output_path} -> {summaries}"
|
||||
)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return json.dumps(summaries, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"summary compute error from excel: {e}", exc_info=True)
|
||||
result = dict(right=0, wrong=0, failed=0, exception=0)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return json.dumps([], ensure_ascii=False)
|
||||
|
||||
def get_input_stream(self, location: str):
|
||||
"""Get input stream from location
|
||||
|
||||
@@ -6,6 +6,7 @@ from dbgpt.util.benchmarks import StorageUtil
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
|
||||
from .data_compare_service import DataCompareService
|
||||
from .file_parse_service import FileParseService
|
||||
@@ -197,6 +198,26 @@ class UserInputExecuteService:
|
||||
config.benchmark_mode_type == BenchmarkModeTypeEnum.EXECUTE,
|
||||
llm_count,
|
||||
)
|
||||
try:
|
||||
summary_json = self.file_service.summary_and_write_multi_round_benchmark_result(
|
||||
location, round_id
|
||||
)
|
||||
import json as _json
|
||||
|
||||
results = _json.loads(summary_json) if summary_json else []
|
||||
dao = BenchmarkResultDao()
|
||||
for item in results:
|
||||
llm_code = item.get("llmCode")
|
||||
right = int(item.get("right", 0))
|
||||
wrong = int(item.get("wrong", 0))
|
||||
failed = int(item.get("failed", 0))
|
||||
exception = int(item.get("exception", 0))
|
||||
dao.upsert_summary(round_id, location, llm_code, right, wrong, failed, exception, task_serial_no=config.evaluate_code)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[execute_llm_compare_result] summary from excel or write db failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _convert_query_result_to_column_format(
|
||||
self, result: List[Dict]
|
||||
|
||||
Reference in New Issue
Block a user