diff --git a/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py b/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py index befb0ff80..578cbb0f0 100644 --- a/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py +++ b/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py @@ -12,6 +12,10 @@ from dbgpt_serve.agent.app.recommend_question.recommend_question import ( from dbgpt_serve.agent.hub.db.my_plugin_db import MyPluginEntity from dbgpt_serve.agent.hub.db.plugin_hub_db import PluginHubEntity from dbgpt_serve.datasource.manages.connect_config_db import ConnectConfigEntity +from dbgpt_serve.evaluate.db.benchmark_db import ( + BenchmarkCompareEntity, + BenchmarkSummaryEntity, +) from dbgpt_serve.file.models.models import ServeEntity as FileServeEntity from dbgpt_serve.flow.models.models import ServeEntity as FlowServeEntity from dbgpt_serve.flow.models.models import VariablesEntity as FlowVariableEntity @@ -19,7 +23,6 @@ from dbgpt_serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt_serve.rag.models.chunk_db import DocumentChunkEntity from dbgpt_serve.rag.models.document_db import KnowledgeDocumentEntity from dbgpt_serve.rag.models.models import KnowledgeSpaceEntity -from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkCompareEntity, BenchmarkSummaryEntity _MODELS = [ PluginHubEntity, diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py index 83aca363d..8a3b36800 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py @@ -1,3 +1,4 @@ +import json import logging from functools import cache from typing import List, Optional @@ -8,16 +9,29 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import ComponentType, SystemApp from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from dbgpt_serve.core import Result -from dbgpt_serve.evaluate.api.schemas import BenchmarkServeRequest, EvaluateServeRequest +from dbgpt_serve.evaluate.api.schemas import ( + BenchmarkServeRequest, + BuildDemoRequest, + EvaluateServeRequest, + ExecuteDemoRequest, +) from dbgpt_serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig -from dbgpt_serve.evaluate.service.service import Service from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao -import json +from dbgpt_serve.evaluate.service.benchmark.data_compare_service import ( + DataCompareService, +) from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService -from dbgpt_serve.evaluate.service.benchmark.data_compare_service import DataCompareService -from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import UserInputExecuteService -from dbgpt_serve.evaluate.service.benchmark.models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum -from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager +from dbgpt_serve.evaluate.service.benchmark.models import ( + BenchmarkExecuteConfig, + BenchmarkModeTypeEnum, +) +from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import ( + UserInputExecuteService, +) +from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( + get_benchmark_manager, +) +from dbgpt_serve.evaluate.service.service import Service from ...prompt.service.service import Service as PromptService from ..service.benchmark.benchmark_service import ( @@ -139,6 +153,7 @@ async def get_scenes(): return Result.succ(scene_list) + @router.post("/evaluation") async def evaluation( request: EvaluateServeRequest, @@ -162,22 +177,25 @@ async def evaluation( ) ) + @router.get("/benchmark/list_results", dependencies=[Depends(check_api_key)]) async def list_compare_runs(limit: int = 50, offset: int = 0): dao = BenchmarkResultDao() rows = dao.list_summaries(limit=limit, offset=offset) result = [] for s in rows: - result.append({ - "id": s.id, - "roundId": s.round_id, - "outputPath": s.output_path, - "right": s.right, - "wrong": s.wrong, - "failed": s.failed, - "exception": s.exception, - "gmtCreated": s.gmt_created.isoformat() if s.gmt_created else None, - }) + result.append( + { + "id": s.id, + "roundId": s.round_id, + "outputPath": s.output_path, + "right": s.right, + "wrong": s.wrong, + "failed": s.failed, + "exception": s.exception, + "gmtCreated": s.gmt_created.isoformat() if s.gmt_created else None, + } + ) return Result.succ(result) @@ -187,7 +205,9 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int s = dao.get_summary_by_id(summary_id) if not s: raise HTTPException(status_code=404, detail="compare run not found") - compares = dao.list_compare_by_round_and_path(s.round_id, s.output_path, limit=limit, offset=offset) + compares = dao.list_compare_by_round_and_path( + s.round_id, s.output_path, limit=limit, offset=offset + ) detail = { "id": s.id, "roundId": s.round_id, @@ -207,7 +227,9 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int "prompt": r.prompt, "standardAnswerSql": r.standard_answer_sql, "llmOutput": r.llm_output, - "executeResult": json.loads(r.execute_result) if r.execute_result else None, + "executeResult": json.loads(r.execute_result) + if r.execute_result + else None, "errorMsg": r.error_msg, "compareResult": r.compare_result, "isExecute": r.is_execute, @@ -219,6 +241,7 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int } return Result.succ(detail) + @router.post("/benchmark/run_build", dependencies=[Depends(check_api_key)]) async def benchmark_run_build(req: BuildDemoRequest): fps = FileParseService() @@ -244,6 +267,7 @@ async def benchmark_run_build(req: BuildDemoRequest): right_outputs=right, input_file_path=req.input_file_path, output_file_path=req.right_output_file_path, + ) dao = BenchmarkResultDao() summary_id = dao.compute_and_save_summary(req.round_id, req.right_output_file_path) @@ -290,7 +314,11 @@ async def list_benchmark_datasets(): manager = get_benchmark_manager(global_system_app) info = await manager.get_table_info() result = [ - {"name": name, "rowCount": meta.get("row_count", 0), "columns": meta.get("columns", [])} + { + "name": name, + "rowCount": meta.get("row_count", 0), + "columns": meta.get("columns", []), + } for name, meta in info.items() ] return Result.succ(result) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py index a24805aa7..e7d658ca3 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py @@ -42,17 +42,27 @@ class EvaluateServeResponse(EvaluateServeRequest): class BuildDemoRequest(BaseModel): round_id: int = Field(..., description="benchmark round id") input_file_path: str = Field(..., description="path to input jsonl") - left_output_file_path: str = Field(..., description="path to left llm outputs jsonl") - right_output_file_path: str = Field(..., description="path to right llm outputs jsonl") - compare_config: Optional[dict] = Field(None, description="compare config, e.g. {'check': 'FULL_TEXT'}") + left_output_file_path: str = Field( + ..., description="path to left llm outputs jsonl" + ) + right_output_file_path: str = Field( + ..., description="path to right llm outputs jsonl" + ) + compare_config: Optional[dict] = Field( + None, description="compare config, e.g. {'check': 'FULL_TEXT'}" + ) class ExecuteDemoRequest(BaseModel): round_id: int = Field(..., description="benchmark round id") input_file_path: str = Field(..., description="path to input jsonl") - right_output_file_path: str = Field(..., description="path to right llm outputs jsonl") + right_output_file_path: str = Field( + ..., description="path to right llm outputs jsonl" + ) standard_file_path: str = Field(..., description="path to standard answers excel") - compare_config: Optional[dict] = Field(None, description="optional compare config for json compare") + compare_config: Optional[dict] = Field( + None, description="optional compare config for json compare" + ) class BenchmarkServeRequest(BaseModel): diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py index 8462bb0bd..7f1db962d 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py @@ -1,7 +1,7 @@ from .benchmark_db import ( BenchmarkCompareEntity, - BenchmarkSummaryEntity, BenchmarkResultDao, + BenchmarkSummaryEntity, ) __all__ = [ diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py index 68f7f07c6..6d2b50c68 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py @@ -33,7 +33,9 @@ class BenchmarkCompareEntity(Model): ), ) - id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) # Round and mode round_id = Column(Integer, nullable=False, comment="Benchmark round id") mode = Column(String(16), nullable=False, comment="BUILD or EXECUTE") @@ -47,18 +49,29 @@ class BenchmarkCompareEntity(Model): standard_answer_sql = Column(Text, nullable=True, comment="Standard answer SQL") llm_output = Column(Text, nullable=True, comment="LLM output text or JSON") - execute_result = Column(Text, nullable=True, comment="Execution result JSON (serialized)") + execute_result = Column( + Text, nullable=True, comment="Execution result JSON (serialized)" + ) error_msg = Column(Text, nullable=True, comment="Error message") - compare_result = Column(String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION") + compare_result = Column( + String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION" + ) is_execute = Column(Boolean, default=False, comment="Whether this is EXECUTE mode") llm_count = Column(Integer, default=0, comment="Number of LLM outputs compared") # Source path for traceability (original output jsonl file path) - output_path = Column(String(512), nullable=False, comment="Original output file path") + output_path = Column( + String(512), nullable=False, comment="Original output file path" + ) gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") - gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time") + gmt_modified = Column( + DateTime, + default=datetime.now, + onupdate=datetime.now, + comment="Record update time", + ) Index("idx_bm_comp_round", "round_id") Index("idx_bm_comp_mode", "mode") @@ -76,9 +89,13 @@ class BenchmarkSummaryEntity(Model): UniqueConstraint("round_id", "output_path", name="uk_round_output"), ) - id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + id = Column( + Integer, primary_key=True, autoincrement=True, comment="autoincrement id" + ) round_id = Column(Integer, nullable=False, comment="Benchmark round id") - output_path = Column(String(512), nullable=False, comment="Original output file path") + output_path = Column( + String(512), nullable=False, comment="Original output file path" + ) right = Column(Integer, default=0, comment="RIGHT count") wrong = Column(Integer, default=0, comment="WRONG count") @@ -86,7 +103,12 @@ class BenchmarkSummaryEntity(Model): exception = Column(Integer, default=0, comment="EXCEPTION count") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") - gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time") + gmt_modified = Column( + DateTime, + default=datetime.now, + onupdate=datetime.now, + comment="Record update time", + ) Index("idx_bm_sum_round", "round_id") @@ -105,7 +127,8 @@ class BenchmarkResultDao(BaseDao): ) -> int: """Write multiple compare records to DB. - records: each dict contains keys like in FileParseService.write_data_compare_result rows. + records: each dict contains keys like in + FileParseService.write_data_compare_result rows. Returns number of records inserted. """ inserted = 0 @@ -122,7 +145,9 @@ class BenchmarkResultDao(BaseDao): prompt=r.get("prompt"), standard_answer_sql=r.get("standardAnswerSql"), llm_output=r.get("llmOutput"), - execute_result=json.dumps(r.get("executeResult")) if r.get("executeResult") is not None else None, + execute_result=json.dumps(r.get("executeResult")) + if r.get("executeResult") is not None + else None, error_msg=r.get("errorMsg"), compare_result=r.get("compareResult"), is_execute=is_execute, @@ -136,7 +161,9 @@ class BenchmarkResultDao(BaseDao): session.commit() return inserted - def compute_and_save_summary(self, round_id: int, output_path: str) -> Optional[int]: + def compute_and_save_summary( + self, round_id: int, output_path: str + ) -> Optional[int]: """Compute summary from compare table and save to summary table. Returns summary id if saved, else None. """ @@ -197,7 +224,9 @@ class BenchmarkResultDao(BaseDao): .all() ) - def get_summary(self, round_id: int, output_path: str) -> Optional[BenchmarkSummaryEntity]: + def get_summary( + self, round_id: int, output_path: str + ) -> Optional[BenchmarkSummaryEntity]: with self.session(commit=False) as session: return ( session.query(BenchmarkSummaryEntity) @@ -242,4 +271,3 @@ class BenchmarkResultDao(BaseDao): .offset(offset) .all() ) - diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index 82d9a3707..5f5c37edb 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -1,27 +1,24 @@ import io import json import logging -import os from typing import List import pandas as pd - from openpyxl.reader.excel import load_workbook + from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils +from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao from .models import ( AnswerExecuteModel, BaseInputModel, BenchmarkDataSets, - DataCompareResultEnum, DataCompareStrategyConfig, RoundAnswerConfirmModel, ) logger = logging.getLogger(__name__) -from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao - class FileParseService: def __init__(self): @@ -101,7 +98,10 @@ class FileParseService: failed=summary.failed if summary else 0, exception=summary.exception if summary else 0, ) - print(f"[summary] summary saved to DB for round={round_id}, output_path={output_path} -> {result}") + logger.info( + f"[summary] summary saved to DB for round={round_id}," + f" output_path={output_path} -> {result}" + ) return json.dumps(result, ensure_ascii=False) def parse_standard_benchmark_sets( @@ -153,7 +153,6 @@ class FileParseService: return outputs - class ExcelFileParseService(FileParseService): def parse_input_sets(self, location: str) -> BenchmarkDataSets: """ diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py index c9398bde1..e605a8746 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py @@ -13,6 +13,7 @@ class BenchmarkModeTypeEnum(str, Enum): BUILD = "BUILD" EXECUTE = "EXECUTE" + @dataclass class DataCompareStrategyConfig: strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH" @@ -130,7 +131,6 @@ class RoundAnswerConfirmModel: compareResult: Optional[DataCompareResultEnum] = None - class FileParseTypeEnum(Enum): """文件解析类型枚举""" diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index bf17eae94..82f669737 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -2,8 +2,6 @@ import logging from typing import Dict, List -from .data_compare_service import DataCompareService -from .file_parse_service import FileParseService from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( get_benchmark_manager, )