From 57e8bab48256ea9807482b9191aeef1dbda8c5ba Mon Sep 17 00:00:00 2001 From: yaoyifan-yyf Date: Thu, 9 Oct 2025 16:37:20 +0800 Subject: [PATCH] feat: query benchmark dataset api --- .../initialization/db_model_initialization.py | 3 + .../src/dbgpt_serve/evaluate/api/endpoints.py | 144 ++++++++++-- .../src/dbgpt_serve/evaluate/api/schemas.py | 16 ++ .../src/dbgpt_serve/evaluate/db/__init__.py | 11 + .../dbgpt_serve/evaluate/db/benchmark_db.py | 210 ++++++++++++++++++ .../output_execute_model.round1.compare.jsonl | 4 - .../output_execute_model.round1.summary.json | 6 - .../output_round1_modelB.round1.compare.jsonl | 4 - .../output_round1_modelB.round1.summary.json | 6 - .../service/benchmark/data_compare_service.py | 2 +- .../service/benchmark/file_parse_service.py | 90 ++++---- .../benchmark/user_input_execute_service.py | 6 +- 12 files changed, 407 insertions(+), 95 deletions(-) create mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py create mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py delete mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.compare.jsonl delete mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.summary.json delete mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.compare.jsonl delete mode 100644 packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.summary.json diff --git a/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py b/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py index 539cb9327..befb0ff80 100644 --- a/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py +++ b/packages/dbgpt-app/src/dbgpt_app/initialization/db_model_initialization.py @@ -19,6 +19,7 @@ from dbgpt_serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt_serve.rag.models.chunk_db import DocumentChunkEntity from dbgpt_serve.rag.models.document_db import KnowledgeDocumentEntity from dbgpt_serve.rag.models.models import KnowledgeSpaceEntity +from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkCompareEntity, BenchmarkSummaryEntity _MODELS = [ PluginHubEntity, @@ -36,4 +37,6 @@ _MODELS = [ FlowServeEntity, RecommendQuestionEntity, FlowVariableEntity, + BenchmarkCompareEntity, + BenchmarkSummaryEntity, ] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py index 94c0b1338..ffa0b9c05 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py @@ -8,9 +8,16 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import ComponentType, SystemApp from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from dbgpt_serve.core import Result -from dbgpt_serve.evaluate.api.schemas import EvaluateServeRequest +from dbgpt_serve.evaluate.api.schemas import EvaluateServeRequest, BuildDemoRequest, ExecuteDemoRequest from dbgpt_serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig from dbgpt_serve.evaluate.service.service import Service +from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao +import json +from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService +from dbgpt_serve.evaluate.service.benchmark.data_compare_service import DataCompareService +from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import UserInputExecuteService +from dbgpt_serve.evaluate.service.benchmark.models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum +from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager from ...prompt.service.service import Service as PromptService @@ -122,32 +129,127 @@ async def get_scenes(): return Result.succ(scene_list) -@router.post("/evaluation") -async def evaluation( - request: EvaluateServeRequest, - service: Service = Depends(get_service), -) -> Result: - """Evaluate results by the scene +@router.get("/benchmark/compare", dependencies=[Depends(check_api_key)]) +async def list_benchmark_compare( + round_id: int, + limit: int = 50, + offset: int = 0, +): + dao = BenchmarkResultDao() + rows = dao.list_compare_by_round(round_id, limit=limit, offset=offset) + result = [] + for r in rows: + result.append({ + "id": r.id, + "round_id": r.round_id, + "mode": r.mode, + "serialNo": r.serial_no, + "analysisModelId": r.analysis_model_id, + "question": r.question, + "selfDefineTags": r.self_define_tags, + "prompt": r.prompt, + "standardAnswerSql": r.standard_answer_sql, + "llmOutput": r.llm_output, + "executeResult": json.loads(r.execute_result) if r.execute_result else None, + "errorMsg": r.error_msg, + "compareResult": r.compare_result, + "isExecute": r.is_execute, + "llmCount": r.llm_count, + "outputPath": r.output_path, + "gmtCreated": r.gmt_created.isoformat() if r.gmt_created else None, + }) + return Result.succ(result) - Args: - request (EvaluateServeRequest): The request - service (Service): The service - Returns: - ServerResponse: The response - """ - return Result.succ( - await service.run_evaluation( - request.scene_key, - request.scene_value, - request.datasets, - request.context, - request.evaluate_metrics, - ) + +@router.post("/benchmark/run_build", dependencies=[Depends(check_api_key)]) +async def benchmark_run_build(req: BuildDemoRequest): + fps = FileParseService() + dcs = DataCompareService() + svc = UserInputExecuteService(fps, dcs) + + inputs = fps.parse_input_sets(req.input_file_path) + left = fps.parse_llm_outputs(req.left_output_file_path) + right = fps.parse_llm_outputs(req.right_output_file_path) + + config = BenchmarkExecuteConfig( + benchmarkModeType=BenchmarkModeTypeEnum.BUILD, + compareResultEnable=True, + standardFilePath=None, + compareConfig=req.compare_config or {"check": "FULL_TEXT"}, ) + svc.post_dispatch( + round_id=req.round_id, + config=config, + inputs=inputs, + left_outputs=left, + right_outputs=right, + input_file_path=req.input_file_path, + output_file_path=req.right_output_file_path, + ) + + summary = fps.summary_and_write_multi_round_benchmark_result( + req.right_output_file_path, req.round_id + ) + return Result.succ({"summary": json.loads(summary)}) + def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None: """Initialize the endpoints""" global global_system_app system_app.register(Service, config=config) global_system_app = system_app + + +@router.get("/benchmark/datasets", dependencies=[Depends(check_api_key)]) +async def list_benchmark_datasets(): + manager = get_benchmark_manager(global_system_app) + info = await manager.get_table_info() + result = [ + {"name": name, "rowCount": meta.get("row_count", 0), "columns": meta.get("columns", [])} + for name, meta in info.items() + ] + return Result.succ(result) + + +@router.get("/benchmark/datasets/{table}/rows", dependencies=[Depends(check_api_key)]) +async def get_benchmark_table_rows(table: str, limit: int = 10): + manager = get_benchmark_manager(global_system_app) + info = await manager.get_table_info() + if table not in info: + raise HTTPException(status_code=404, detail=f"table '{table}' not found") + sql = f'SELECT * FROM "{table}" LIMIT :limit' + rows = await manager.query(sql, {"limit": limit}) + return Result.succ({"table": table, "limit": limit, "rows": rows}) + + +@router.post("/benchmark/run_execute", dependencies=[Depends(check_api_key)]) +async def benchmark_run_execute(req: ExecuteDemoRequest): + fps = FileParseService() + dcs = DataCompareService() + svc = UserInputExecuteService(fps, dcs) + + inputs = fps.parse_input_sets(req.input_file_path) + right = fps.parse_llm_outputs(req.right_output_file_path) + + config = BenchmarkExecuteConfig( + benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE, + compareResultEnable=True, + standardFilePath=req.standard_file_path, + compareConfig=req.compare_config, + ) + + svc.post_dispatch( + round_id=req.round_id, + config=config, + inputs=inputs, + left_outputs=[], + right_outputs=right, + input_file_path=req.input_file_path, + output_file_path=req.right_output_file_path, + ) + + summary = fps.summary_and_write_multi_round_benchmark_result( + req.right_output_file_path, req.round_id + ) + return Result.succ({"summary": json.loads(summary)}) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py index 9f03e91d7..e738e8491 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/schemas.py @@ -37,3 +37,19 @@ class EvaluateServeRequest(BaseModel): class EvaluateServeResponse(EvaluateServeRequest): class Config: title = f"EvaluateServeResponse for {SERVE_APP_NAME_HUMP}" + + +class BuildDemoRequest(BaseModel): + round_id: int = Field(..., description="benchmark round id") + input_file_path: str = Field(..., description="path to input jsonl") + left_output_file_path: str = Field(..., description="path to left llm outputs jsonl") + right_output_file_path: str = Field(..., description="path to right llm outputs jsonl") + compare_config: Optional[dict] = Field(None, description="compare config, e.g. {'check': 'FULL_TEXT'}") + + +class ExecuteDemoRequest(BaseModel): + round_id: int = Field(..., description="benchmark round id") + input_file_path: str = Field(..., description="path to input jsonl") + right_output_file_path: str = Field(..., description="path to right llm outputs jsonl") + standard_file_path: str = Field(..., description="path to standard answers excel") + compare_config: Optional[dict] = Field(None, description="optional compare config for json compare") diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py new file mode 100644 index 000000000..8462bb0bd --- /dev/null +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py @@ -0,0 +1,11 @@ +from .benchmark_db import ( + BenchmarkCompareEntity, + BenchmarkSummaryEntity, + BenchmarkResultDao, +) + +__all__ = [ + "BenchmarkCompareEntity", + "BenchmarkSummaryEntity", + "BenchmarkResultDao", +] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py new file mode 100644 index 000000000..b2c9fbdcf --- /dev/null +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py @@ -0,0 +1,210 @@ +import json +import logging +from datetime import datetime +from typing import List, Optional + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Index, + Integer, + String, + Text, + UniqueConstraint, + desc, +) + +from dbgpt.storage.metadata import BaseDao, Model + +logger = logging.getLogger(__name__) + + +class BenchmarkCompareEntity(Model): + """Single compare record for one input serialNo in one round. + + Fields match the JSON lines produced by FileParseService.write_data_compare_result. + """ + + __tablename__ = "benchmark_compare" + __table_args__ = ( + UniqueConstraint( + "round_id", "serial_no", "output_path", name="uk_round_serial_output" + ), + ) + + id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + # Round and mode + round_id = Column(Integer, nullable=False, comment="Benchmark round id") + mode = Column(String(16), nullable=False, comment="BUILD or EXECUTE") + + # Input & outputs + serial_no = Column(Integer, nullable=False, comment="Input serial number") + analysis_model_id = Column(String(255), nullable=False, comment="Analysis model id") + question = Column(Text, nullable=False, comment="User question") + self_define_tags = Column(String(255), nullable=True, comment="Self define tags") + prompt = Column(Text, nullable=True, comment="Prompt text") + + standard_answer_sql = Column(Text, nullable=True, comment="Standard answer SQL") + llm_output = Column(Text, nullable=True, comment="LLM output text or JSON") + execute_result = Column(Text, nullable=True, comment="Execution result JSON (serialized)") + error_msg = Column(Text, nullable=True, comment="Error message") + + compare_result = Column(String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION") + is_execute = Column(Boolean, default=False, comment="Whether this is EXECUTE mode") + llm_count = Column(Integer, default=0, comment="Number of LLM outputs compared") + + # Source path for traceability (original output jsonl file path) + output_path = Column(String(512), nullable=False, comment="Original output file path") + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time") + + Index("idx_bm_comp_round", "round_id") + Index("idx_bm_comp_mode", "mode") + Index("idx_bm_comp_serial", "serial_no") + + +class BenchmarkSummaryEntity(Model): + """Summary result for one round and one output path. + + Counts of RIGHT/WRONG/FAILED/EXCEPTION. + """ + + __tablename__ = "benchmark_summary" + __table_args__ = ( + UniqueConstraint("round_id", "output_path", name="uk_round_output"), + ) + + id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + round_id = Column(Integer, nullable=False, comment="Benchmark round id") + output_path = Column(String(512), nullable=False, comment="Original output file path") + + right = Column(Integer, default=0, comment="RIGHT count") + wrong = Column(Integer, default=0, comment="WRONG count") + failed = Column(Integer, default=0, comment="FAILED count") + exception = Column(Integer, default=0, comment="EXCEPTION count") + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time") + + Index("idx_bm_sum_round", "round_id") + + +class BenchmarkResultDao(BaseDao): + """DAO for benchmark compare and summary results.""" + + def write_compare_results( + self, + round_id: int, + mode: str, # "BUILD" or "EXECUTE" + output_path: str, + records: List[dict], + is_execute: bool, + llm_count: int, + ) -> int: + """Write multiple compare records to DB. + + records: each dict contains keys like in FileParseService.write_data_compare_result rows. + Returns number of records inserted. + """ + inserted = 0 + with self.session() as session: + for r in records: + try: + entity = BenchmarkCompareEntity( + round_id=round_id, + mode=mode, + serial_no=r.get("serialNo"), + analysis_model_id=r.get("analysisModelId"), + question=r.get("question"), + self_define_tags=r.get("selfDefineTags"), + prompt=r.get("prompt"), + standard_answer_sql=r.get("standardAnswerSql"), + llm_output=r.get("llmOutput"), + execute_result=json.dumps(r.get("executeResult")) if r.get("executeResult") is not None else None, + error_msg=r.get("errorMsg"), + compare_result=r.get("compareResult"), + is_execute=is_execute, + llm_count=llm_count, + output_path=output_path, + ) + session.add(entity) + inserted += 1 + except Exception as e: + logger.error(f"Insert compare record failed: {e}") + session.commit() + return inserted + + def compute_and_save_summary(self, round_id: int, output_path: str) -> Optional[int]: + """Compute summary from compare table and save to summary table. + Returns summary id if saved, else None. + """ + with self.session() as session: + # compute counts + q = ( + session.query(BenchmarkCompareEntity.compare_result) + .filter( + BenchmarkCompareEntity.round_id == round_id, + BenchmarkCompareEntity.output_path == output_path, + ) + .all() + ) + right = sum(1 for x in q if x[0] == "RIGHT") + wrong = sum(1 for x in q if x[0] == "WRONG") + failed = sum(1 for x in q if x[0] == "FAILED") + exception = sum(1 for x in q if x[0] == "EXCEPTION") + + # upsert summary + existing = ( + session.query(BenchmarkSummaryEntity) + .filter( + BenchmarkSummaryEntity.round_id == round_id, + BenchmarkSummaryEntity.output_path == output_path, + ) + .first() + ) + if existing: + existing.right = right + existing.wrong = wrong + existing.failed = failed + existing.exception = exception + existing.gmt_modified = datetime.now() + session.commit() + return existing.id + else: + summary = BenchmarkSummaryEntity( + round_id=round_id, + output_path=output_path, + right=right, + wrong=wrong, + failed=failed, + exception=exception, + ) + session.add(summary) + session.commit() + return summary.id + + # Basic query helpers + def list_compare_by_round(self, round_id: int, limit: int = 100, offset: int = 0): + with self.session(commit=False) as session: + return ( + session.query(BenchmarkCompareEntity) + .filter(BenchmarkCompareEntity.round_id == round_id) + .order_by(desc(BenchmarkCompareEntity.id)) + .limit(limit) + .offset(offset) + .all() + ) + + def get_summary(self, round_id: int, output_path: str) -> Optional[BenchmarkSummaryEntity]: + with self.session(commit=False) as session: + return ( + session.query(BenchmarkSummaryEntity) + .filter( + BenchmarkSummaryEntity.round_id == round_id, + BenchmarkSummaryEntity.output_path == output_path, + ) + .first() + ) + diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.compare.jsonl b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.compare.jsonl deleted file mode 100644 index 8041ece4e..000000000 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.compare.jsonl +++ /dev/null @@ -1,4 +0,0 @@ -{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (\n select \n gender,\n cast(age as int) as age\n from \n ant_icube_dev.di_finance_data\n where \n age rlike '^[0-9]+$'\n)\nselect\n gender as `性别`,\n avg(age) as `平均年龄`\nfrom \n converted_data\ngroup by \n gender\norder by \n `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2} -{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (\n select\n objective,\n cast(government_bonds as bigint) as gov_bond_value\n from\n ant_icube_dev.di_finance_data\n where\n government_bonds is not null\n and government_bonds rlike '^[0-9]+$'\n)\nselect\n objective as `objective`,\n sum(gov_bond_value) as `政府债券总量`\nfrom\n gov_bonds_data\ngroup by\n `objective`\norder by\n `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2} -{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": null, "llmOutput": "select 1", "executeResult": {"colA": ["x", "y"]}, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2} -{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": null, "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2} diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.summary.json b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.summary.json deleted file mode 100644 index 05ac3319b..000000000 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_execute_model.round1.summary.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "right": 2, - "wrong": 0, - "failed": 2, - "exception": 0 -} \ No newline at end of file diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.compare.jsonl b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.compare.jsonl deleted file mode 100644 index 5482d825c..000000000 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.compare.jsonl +++ /dev/null @@ -1,4 +0,0 @@ -{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2} -{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2} -{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": "select 1", "llmOutput": "select 1", "executeResult": {"colB": ["x", "z", "w"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2} -{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": "{\"check\":\"ok\"}", "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "RIGHT", "isExecute": false, "llmCount": 2} diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.summary.json b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.summary.json deleted file mode 100644 index 03a2ef66d..000000000 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data/output_round1_modelB.round1.summary.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "right": 1, - "wrong": 0, - "failed": 0, - "exception": 3 -} \ No newline at end of file diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py index 1cb6f5144..53efa1bb5 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py @@ -4,7 +4,7 @@ from copy import deepcopy from decimal import ROUND_HALF_UP, Decimal from typing import Dict, List, Optional -from models import ( +from .models import ( AnswerExecuteModel, DataCompareResult, DataCompareResultEnum, diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index 59ecc4b19..66e1ee711 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -3,7 +3,7 @@ import os from typing import List import pandas as pd -from models import ( +from .models import ( AnswerExecuteModel, BaseInputModel, DataCompareResultEnum, @@ -11,8 +11,13 @@ from models import ( RoundAnswerConfirmModel, ) +from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao + class FileParseService: + def __init__(self): + self._benchmark_dao = BenchmarkResultDao() + def parse_input_sets(self, path: str) -> List[BaseInputModel]: data = [] with open(path, "r", encoding="utf-8") as f: @@ -50,59 +55,44 @@ class FileParseService: is_execute: bool, llm_count: int, ): - if not path.endswith(".jsonl"): - raise ValueError(f"output_file_path must end with .jsonl, got {path}") - out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl") - with open(out_path, "w", encoding="utf-8") as f: - for cm in confirm_models: - row = dict( - serialNo=cm.serialNo, - analysisModelId=cm.analysisModelId, - question=cm.question, - selfDefineTags=cm.selfDefineTags, - prompt=cm.prompt, - standardAnswerSql=cm.standardAnswerSql, - llmOutput=cm.llmOutput, - executeResult=cm.executeResult, - errorMsg=cm.errorMsg, - compareResult=cm.compareResult.value if cm.compareResult else None, - isExecute=is_execute, - llmCount=llm_count, - ) - f.write(json.dumps(row, ensure_ascii=False) + "\n") - print(f"[write_data_compare_result] compare written to: {out_path}") + mode = "EXECUTE" if is_execute else "BUILD" + records = [] + for cm in confirm_models: + row = dict( + serialNo=cm.serialNo, + analysisModelId=cm.analysisModelId, + question=cm.question, + selfDefineTags=cm.selfDefineTags, + prompt=cm.prompt, + standardAnswerSql=cm.standardAnswerSql, + llmOutput=cm.llmOutput, + executeResult=cm.executeResult, + errorMsg=cm.errorMsg, + compareResult=cm.compareResult.value if cm.compareResult else None, + ) + records.append(row) + self._benchmark_dao.write_compare_results( + round_id=round_id, + mode=mode, + output_path=path, + records=records, + is_execute=is_execute, + llm_count=llm_count, + ) + print(f"[write_data_compare_result] compare written to DB for: {path}") def summary_and_write_multi_round_benchmark_result( self, output_path: str, round_id: int ) -> str: - if not output_path.endswith(".jsonl"): - raise ValueError( - f"output_file_path must end with .jsonl, got {output_path}" - ) - compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl") - right, wrong, failed, exception = 0, 0, 0, 0 - if os.path.exists(compare_path): - with open(compare_path, "r", encoding="utf-8") as f: - for line in f: - if not line.strip(): - continue - obj = json.loads(line) - cr = obj.get("compareResult") - if cr == DataCompareResultEnum.RIGHT.value: - right += 1 - elif cr == DataCompareResultEnum.WRONG.value: - wrong += 1 - elif cr == DataCompareResultEnum.FAILED.value: - failed += 1 - elif cr == DataCompareResultEnum.EXCEPTION.value: - exception += 1 - else: - print(f"[summary] compare file not found: {compare_path}") - summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json") - result = dict(right=right, wrong=wrong, failed=failed, exception=exception) - with open(summary_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"[summary] summary written to: {summary_path} -> {result}") + summary_id = self._benchmark_dao.compute_and_save_summary(round_id, output_path) + summary = self._benchmark_dao.get_summary(round_id, output_path) + result = dict( + right=summary.right if summary else 0, + wrong=summary.wrong if summary else 0, + failed=summary.failed if summary else 0, + exception=summary.exception if summary else 0, + ) + print(f"[summary] summary saved to DB for round={round_id}, output_path={output_path} -> {result}") return json.dumps(result, ensure_ascii=False) def parse_standard_benchmark_sets( diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index 68fefb5f9..9d46ea2a7 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -1,9 +1,9 @@ # app/services/user_input_execute_service.py from typing import List -from data_compare_service import DataCompareService -from file_parse_service import FileParseService -from models import ( +from .data_compare_service import DataCompareService +from .file_parse_service import FileParseService +from .models import ( AnswerExecuteModel, BaseInputModel, BenchmarkExecuteConfig,