mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
feat: query benchmark dataset api
This commit is contained in:
@@ -19,6 +19,7 @@ from dbgpt_serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt_serve.rag.models.chunk_db import DocumentChunkEntity
|
||||
from dbgpt_serve.rag.models.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceEntity
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkCompareEntity, BenchmarkSummaryEntity
|
||||
|
||||
_MODELS = [
|
||||
PluginHubEntity,
|
||||
@@ -36,4 +37,6 @@ _MODELS = [
|
||||
FlowServeEntity,
|
||||
RecommendQuestionEntity,
|
||||
FlowVariableEntity,
|
||||
BenchmarkCompareEntity,
|
||||
BenchmarkSummaryEntity,
|
||||
]
|
||||
|
||||
@@ -8,9 +8,16 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
from dbgpt_serve.core import Result
|
||||
from dbgpt_serve.evaluate.api.schemas import EvaluateServeRequest
|
||||
from dbgpt_serve.evaluate.api.schemas import EvaluateServeRequest, BuildDemoRequest, ExecuteDemoRequest
|
||||
from dbgpt_serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from dbgpt_serve.evaluate.service.service import Service
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
import json
|
||||
from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService
|
||||
from dbgpt_serve.evaluate.service.benchmark.data_compare_service import DataCompareService
|
||||
from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import UserInputExecuteService
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
|
||||
from ...prompt.service.service import Service as PromptService
|
||||
|
||||
@@ -122,32 +129,127 @@ async def get_scenes():
|
||||
return Result.succ(scene_list)
|
||||
|
||||
|
||||
@router.post("/evaluation")
|
||||
async def evaluation(
|
||||
request: EvaluateServeRequest,
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result:
|
||||
"""Evaluate results by the scene
|
||||
@router.get("/benchmark/compare", dependencies=[Depends(check_api_key)])
|
||||
async def list_benchmark_compare(
|
||||
round_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
):
|
||||
dao = BenchmarkResultDao()
|
||||
rows = dao.list_compare_by_round(round_id, limit=limit, offset=offset)
|
||||
result = []
|
||||
for r in rows:
|
||||
result.append({
|
||||
"id": r.id,
|
||||
"round_id": r.round_id,
|
||||
"mode": r.mode,
|
||||
"serialNo": r.serial_no,
|
||||
"analysisModelId": r.analysis_model_id,
|
||||
"question": r.question,
|
||||
"selfDefineTags": r.self_define_tags,
|
||||
"prompt": r.prompt,
|
||||
"standardAnswerSql": r.standard_answer_sql,
|
||||
"llmOutput": r.llm_output,
|
||||
"executeResult": json.loads(r.execute_result) if r.execute_result else None,
|
||||
"errorMsg": r.error_msg,
|
||||
"compareResult": r.compare_result,
|
||||
"isExecute": r.is_execute,
|
||||
"llmCount": r.llm_count,
|
||||
"outputPath": r.output_path,
|
||||
"gmtCreated": r.gmt_created.isoformat() if r.gmt_created else None,
|
||||
})
|
||||
return Result.succ(result)
|
||||
|
||||
Args:
|
||||
request (EvaluateServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(
|
||||
await service.run_evaluation(
|
||||
request.scene_key,
|
||||
request.scene_value,
|
||||
request.datasets,
|
||||
request.context,
|
||||
request.evaluate_metrics,
|
||||
)
|
||||
|
||||
@router.post("/benchmark/run_build", dependencies=[Depends(check_api_key)])
|
||||
async def benchmark_run_build(req: BuildDemoRequest):
|
||||
fps = FileParseService()
|
||||
dcs = DataCompareService()
|
||||
svc = UserInputExecuteService(fps, dcs)
|
||||
|
||||
inputs = fps.parse_input_sets(req.input_file_path)
|
||||
left = fps.parse_llm_outputs(req.left_output_file_path)
|
||||
right = fps.parse_llm_outputs(req.right_output_file_path)
|
||||
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmarkModeType=BenchmarkModeTypeEnum.BUILD,
|
||||
compareResultEnable=True,
|
||||
standardFilePath=None,
|
||||
compareConfig=req.compare_config or {"check": "FULL_TEXT"},
|
||||
)
|
||||
|
||||
svc.post_dispatch(
|
||||
round_id=req.round_id,
|
||||
config=config,
|
||||
inputs=inputs,
|
||||
left_outputs=left,
|
||||
right_outputs=right,
|
||||
input_file_path=req.input_file_path,
|
||||
output_file_path=req.right_output_file_path,
|
||||
)
|
||||
|
||||
summary = fps.summary_and_write_multi_round_benchmark_result(
|
||||
req.right_output_file_path, req.round_id
|
||||
)
|
||||
return Result.succ({"summary": json.loads(summary)})
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
system_app.register(Service, config=config)
|
||||
global_system_app = system_app
|
||||
|
||||
|
||||
@router.get("/benchmark/datasets", dependencies=[Depends(check_api_key)])
|
||||
async def list_benchmark_datasets():
|
||||
manager = get_benchmark_manager(global_system_app)
|
||||
info = await manager.get_table_info()
|
||||
result = [
|
||||
{"name": name, "rowCount": meta.get("row_count", 0), "columns": meta.get("columns", [])}
|
||||
for name, meta in info.items()
|
||||
]
|
||||
return Result.succ(result)
|
||||
|
||||
|
||||
@router.get("/benchmark/datasets/{table}/rows", dependencies=[Depends(check_api_key)])
|
||||
async def get_benchmark_table_rows(table: str, limit: int = 10):
|
||||
manager = get_benchmark_manager(global_system_app)
|
||||
info = await manager.get_table_info()
|
||||
if table not in info:
|
||||
raise HTTPException(status_code=404, detail=f"table '{table}' not found")
|
||||
sql = f'SELECT * FROM "{table}" LIMIT :limit'
|
||||
rows = await manager.query(sql, {"limit": limit})
|
||||
return Result.succ({"table": table, "limit": limit, "rows": rows})
|
||||
|
||||
|
||||
@router.post("/benchmark/run_execute", dependencies=[Depends(check_api_key)])
|
||||
async def benchmark_run_execute(req: ExecuteDemoRequest):
|
||||
fps = FileParseService()
|
||||
dcs = DataCompareService()
|
||||
svc = UserInputExecuteService(fps, dcs)
|
||||
|
||||
inputs = fps.parse_input_sets(req.input_file_path)
|
||||
right = fps.parse_llm_outputs(req.right_output_file_path)
|
||||
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE,
|
||||
compareResultEnable=True,
|
||||
standardFilePath=req.standard_file_path,
|
||||
compareConfig=req.compare_config,
|
||||
)
|
||||
|
||||
svc.post_dispatch(
|
||||
round_id=req.round_id,
|
||||
config=config,
|
||||
inputs=inputs,
|
||||
left_outputs=[],
|
||||
right_outputs=right,
|
||||
input_file_path=req.input_file_path,
|
||||
output_file_path=req.right_output_file_path,
|
||||
)
|
||||
|
||||
summary = fps.summary_and_write_multi_round_benchmark_result(
|
||||
req.right_output_file_path, req.round_id
|
||||
)
|
||||
return Result.succ({"summary": json.loads(summary)})
|
||||
|
||||
@@ -37,3 +37,19 @@ class EvaluateServeRequest(BaseModel):
|
||||
class EvaluateServeResponse(EvaluateServeRequest):
|
||||
class Config:
|
||||
title = f"EvaluateServeResponse for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
|
||||
class BuildDemoRequest(BaseModel):
|
||||
round_id: int = Field(..., description="benchmark round id")
|
||||
input_file_path: str = Field(..., description="path to input jsonl")
|
||||
left_output_file_path: str = Field(..., description="path to left llm outputs jsonl")
|
||||
right_output_file_path: str = Field(..., description="path to right llm outputs jsonl")
|
||||
compare_config: Optional[dict] = Field(None, description="compare config, e.g. {'check': 'FULL_TEXT'}")
|
||||
|
||||
|
||||
class ExecuteDemoRequest(BaseModel):
|
||||
round_id: int = Field(..., description="benchmark round id")
|
||||
input_file_path: str = Field(..., description="path to input jsonl")
|
||||
right_output_file_path: str = Field(..., description="path to right llm outputs jsonl")
|
||||
standard_file_path: str = Field(..., description="path to standard answers excel")
|
||||
compare_config: Optional[dict] = Field(None, description="optional compare config for json compare")
|
||||
|
||||
11
packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py
Normal file
11
packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .benchmark_db import (
|
||||
BenchmarkCompareEntity,
|
||||
BenchmarkSummaryEntity,
|
||||
BenchmarkResultDao,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BenchmarkCompareEntity",
|
||||
"BenchmarkSummaryEntity",
|
||||
"BenchmarkResultDao",
|
||||
]
|
||||
210
packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py
Normal file
210
packages/dbgpt-serve/src/dbgpt_serve/evaluate/db/benchmark_db.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
desc,
|
||||
)
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BenchmarkCompareEntity(Model):
|
||||
"""Single compare record for one input serialNo in one round.
|
||||
|
||||
Fields match the JSON lines produced by FileParseService.write_data_compare_result.
|
||||
"""
|
||||
|
||||
__tablename__ = "benchmark_compare"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"round_id", "serial_no", "output_path", name="uk_round_serial_output"
|
||||
),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
# Round and mode
|
||||
round_id = Column(Integer, nullable=False, comment="Benchmark round id")
|
||||
mode = Column(String(16), nullable=False, comment="BUILD or EXECUTE")
|
||||
|
||||
# Input & outputs
|
||||
serial_no = Column(Integer, nullable=False, comment="Input serial number")
|
||||
analysis_model_id = Column(String(255), nullable=False, comment="Analysis model id")
|
||||
question = Column(Text, nullable=False, comment="User question")
|
||||
self_define_tags = Column(String(255), nullable=True, comment="Self define tags")
|
||||
prompt = Column(Text, nullable=True, comment="Prompt text")
|
||||
|
||||
standard_answer_sql = Column(Text, nullable=True, comment="Standard answer SQL")
|
||||
llm_output = Column(Text, nullable=True, comment="LLM output text or JSON")
|
||||
execute_result = Column(Text, nullable=True, comment="Execution result JSON (serialized)")
|
||||
error_msg = Column(Text, nullable=True, comment="Error message")
|
||||
|
||||
compare_result = Column(String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION")
|
||||
is_execute = Column(Boolean, default=False, comment="Whether this is EXECUTE mode")
|
||||
llm_count = Column(Integer, default=0, comment="Number of LLM outputs compared")
|
||||
|
||||
# Source path for traceability (original output jsonl file path)
|
||||
output_path = Column(String(512), nullable=False, comment="Original output file path")
|
||||
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time")
|
||||
|
||||
Index("idx_bm_comp_round", "round_id")
|
||||
Index("idx_bm_comp_mode", "mode")
|
||||
Index("idx_bm_comp_serial", "serial_no")
|
||||
|
||||
|
||||
class BenchmarkSummaryEntity(Model):
|
||||
"""Summary result for one round and one output path.
|
||||
|
||||
Counts of RIGHT/WRONG/FAILED/EXCEPTION.
|
||||
"""
|
||||
|
||||
__tablename__ = "benchmark_summary"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("round_id", "output_path", name="uk_round_output"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
round_id = Column(Integer, nullable=False, comment="Benchmark round id")
|
||||
output_path = Column(String(512), nullable=False, comment="Original output file path")
|
||||
|
||||
right = Column(Integer, default=0, comment="RIGHT count")
|
||||
wrong = Column(Integer, default=0, comment="WRONG count")
|
||||
failed = Column(Integer, default=0, comment="FAILED count")
|
||||
exception = Column(Integer, default=0, comment="EXCEPTION count")
|
||||
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time")
|
||||
|
||||
Index("idx_bm_sum_round", "round_id")
|
||||
|
||||
|
||||
class BenchmarkResultDao(BaseDao):
|
||||
"""DAO for benchmark compare and summary results."""
|
||||
|
||||
def write_compare_results(
|
||||
self,
|
||||
round_id: int,
|
||||
mode: str, # "BUILD" or "EXECUTE"
|
||||
output_path: str,
|
||||
records: List[dict],
|
||||
is_execute: bool,
|
||||
llm_count: int,
|
||||
) -> int:
|
||||
"""Write multiple compare records to DB.
|
||||
|
||||
records: each dict contains keys like in FileParseService.write_data_compare_result rows.
|
||||
Returns number of records inserted.
|
||||
"""
|
||||
inserted = 0
|
||||
with self.session() as session:
|
||||
for r in records:
|
||||
try:
|
||||
entity = BenchmarkCompareEntity(
|
||||
round_id=round_id,
|
||||
mode=mode,
|
||||
serial_no=r.get("serialNo"),
|
||||
analysis_model_id=r.get("analysisModelId"),
|
||||
question=r.get("question"),
|
||||
self_define_tags=r.get("selfDefineTags"),
|
||||
prompt=r.get("prompt"),
|
||||
standard_answer_sql=r.get("standardAnswerSql"),
|
||||
llm_output=r.get("llmOutput"),
|
||||
execute_result=json.dumps(r.get("executeResult")) if r.get("executeResult") is not None else None,
|
||||
error_msg=r.get("errorMsg"),
|
||||
compare_result=r.get("compareResult"),
|
||||
is_execute=is_execute,
|
||||
llm_count=llm_count,
|
||||
output_path=output_path,
|
||||
)
|
||||
session.add(entity)
|
||||
inserted += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Insert compare record failed: {e}")
|
||||
session.commit()
|
||||
return inserted
|
||||
|
||||
def compute_and_save_summary(self, round_id: int, output_path: str) -> Optional[int]:
|
||||
"""Compute summary from compare table and save to summary table.
|
||||
Returns summary id if saved, else None.
|
||||
"""
|
||||
with self.session() as session:
|
||||
# compute counts
|
||||
q = (
|
||||
session.query(BenchmarkCompareEntity.compare_result)
|
||||
.filter(
|
||||
BenchmarkCompareEntity.round_id == round_id,
|
||||
BenchmarkCompareEntity.output_path == output_path,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
right = sum(1 for x in q if x[0] == "RIGHT")
|
||||
wrong = sum(1 for x in q if x[0] == "WRONG")
|
||||
failed = sum(1 for x in q if x[0] == "FAILED")
|
||||
exception = sum(1 for x in q if x[0] == "EXCEPTION")
|
||||
|
||||
# upsert summary
|
||||
existing = (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
.filter(
|
||||
BenchmarkSummaryEntity.round_id == round_id,
|
||||
BenchmarkSummaryEntity.output_path == output_path,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.right = right
|
||||
existing.wrong = wrong
|
||||
existing.failed = failed
|
||||
existing.exception = exception
|
||||
existing.gmt_modified = datetime.now()
|
||||
session.commit()
|
||||
return existing.id
|
||||
else:
|
||||
summary = BenchmarkSummaryEntity(
|
||||
round_id=round_id,
|
||||
output_path=output_path,
|
||||
right=right,
|
||||
wrong=wrong,
|
||||
failed=failed,
|
||||
exception=exception,
|
||||
)
|
||||
session.add(summary)
|
||||
session.commit()
|
||||
return summary.id
|
||||
|
||||
# Basic query helpers
|
||||
def list_compare_by_round(self, round_id: int, limit: int = 100, offset: int = 0):
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkCompareEntity)
|
||||
.filter(BenchmarkCompareEntity.round_id == round_id)
|
||||
.order_by(desc(BenchmarkCompareEntity.id))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_summary(self, round_id: int, output_path: str) -> Optional[BenchmarkSummaryEntity]:
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
.filter(
|
||||
BenchmarkSummaryEntity.round_id == round_id,
|
||||
BenchmarkSummaryEntity.output_path == output_path,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (\n select \n gender,\n cast(age as int) as age\n from \n ant_icube_dev.di_finance_data\n where \n age rlike '^[0-9]+$'\n)\nselect\n gender as `性别`,\n avg(age) as `平均年龄`\nfrom \n converted_data\ngroup by \n gender\norder by \n `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (\n select\n objective,\n cast(government_bonds as bigint) as gov_bond_value\n from\n ant_icube_dev.di_finance_data\n where\n government_bonds is not null\n and government_bonds rlike '^[0-9]+$'\n)\nselect\n objective as `objective`,\n sum(gov_bond_value) as `政府债券总量`\nfrom\n gov_bonds_data\ngroup by\n `objective`\norder by\n `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": null, "llmOutput": "select 1", "executeResult": {"colA": ["x", "y"]}, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": null, "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"right": 2,
|
||||
"wrong": 0,
|
||||
"failed": 2,
|
||||
"exception": 0
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": "select 1", "llmOutput": "select 1", "executeResult": {"colB": ["x", "z", "w"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": "{\"check\":\"ok\"}", "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "RIGHT", "isExecute": false, "llmCount": 2}
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"right": 1,
|
||||
"wrong": 0,
|
||||
"failed": 0,
|
||||
"exception": 3
|
||||
}
|
||||
@@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
from decimal import ROUND_HALF_UP, Decimal
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from models import (
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
DataCompareResult,
|
||||
DataCompareResultEnum,
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from models import (
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
DataCompareResultEnum,
|
||||
@@ -11,8 +11,13 @@ from models import (
|
||||
RoundAnswerConfirmModel,
|
||||
)
|
||||
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
|
||||
|
||||
class FileParseService:
|
||||
def __init__(self):
|
||||
self._benchmark_dao = BenchmarkResultDao()
|
||||
|
||||
def parse_input_sets(self, path: str) -> List[BaseInputModel]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
@@ -50,59 +55,44 @@ class FileParseService:
|
||||
is_execute: bool,
|
||||
llm_count: int,
|
||||
):
|
||||
if not path.endswith(".jsonl"):
|
||||
raise ValueError(f"output_file_path must end with .jsonl, got {path}")
|
||||
out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
for cm in confirm_models:
|
||||
row = dict(
|
||||
serialNo=cm.serialNo,
|
||||
analysisModelId=cm.analysisModelId,
|
||||
question=cm.question,
|
||||
selfDefineTags=cm.selfDefineTags,
|
||||
prompt=cm.prompt,
|
||||
standardAnswerSql=cm.standardAnswerSql,
|
||||
llmOutput=cm.llmOutput,
|
||||
executeResult=cm.executeResult,
|
||||
errorMsg=cm.errorMsg,
|
||||
compareResult=cm.compareResult.value if cm.compareResult else None,
|
||||
isExecute=is_execute,
|
||||
llmCount=llm_count,
|
||||
)
|
||||
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
print(f"[write_data_compare_result] compare written to: {out_path}")
|
||||
mode = "EXECUTE" if is_execute else "BUILD"
|
||||
records = []
|
||||
for cm in confirm_models:
|
||||
row = dict(
|
||||
serialNo=cm.serialNo,
|
||||
analysisModelId=cm.analysisModelId,
|
||||
question=cm.question,
|
||||
selfDefineTags=cm.selfDefineTags,
|
||||
prompt=cm.prompt,
|
||||
standardAnswerSql=cm.standardAnswerSql,
|
||||
llmOutput=cm.llmOutput,
|
||||
executeResult=cm.executeResult,
|
||||
errorMsg=cm.errorMsg,
|
||||
compareResult=cm.compareResult.value if cm.compareResult else None,
|
||||
)
|
||||
records.append(row)
|
||||
self._benchmark_dao.write_compare_results(
|
||||
round_id=round_id,
|
||||
mode=mode,
|
||||
output_path=path,
|
||||
records=records,
|
||||
is_execute=is_execute,
|
||||
llm_count=llm_count,
|
||||
)
|
||||
print(f"[write_data_compare_result] compare written to DB for: {path}")
|
||||
|
||||
def summary_and_write_multi_round_benchmark_result(
|
||||
self, output_path: str, round_id: int
|
||||
) -> str:
|
||||
if not output_path.endswith(".jsonl"):
|
||||
raise ValueError(
|
||||
f"output_file_path must end with .jsonl, got {output_path}"
|
||||
)
|
||||
compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||
right, wrong, failed, exception = 0, 0, 0, 0
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
cr = obj.get("compareResult")
|
||||
if cr == DataCompareResultEnum.RIGHT.value:
|
||||
right += 1
|
||||
elif cr == DataCompareResultEnum.WRONG.value:
|
||||
wrong += 1
|
||||
elif cr == DataCompareResultEnum.FAILED.value:
|
||||
failed += 1
|
||||
elif cr == DataCompareResultEnum.EXCEPTION.value:
|
||||
exception += 1
|
||||
else:
|
||||
print(f"[summary] compare file not found: {compare_path}")
|
||||
summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json")
|
||||
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"[summary] summary written to: {summary_path} -> {result}")
|
||||
summary_id = self._benchmark_dao.compute_and_save_summary(round_id, output_path)
|
||||
summary = self._benchmark_dao.get_summary(round_id, output_path)
|
||||
result = dict(
|
||||
right=summary.right if summary else 0,
|
||||
wrong=summary.wrong if summary else 0,
|
||||
failed=summary.failed if summary else 0,
|
||||
exception=summary.exception if summary else 0,
|
||||
)
|
||||
print(f"[summary] summary saved to DB for round={round_id}, output_path={output_path} -> {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
def parse_standard_benchmark_sets(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# app/services/user_input_execute_service.py
|
||||
from typing import List
|
||||
|
||||
from data_compare_service import DataCompareService
|
||||
from file_parse_service import FileParseService
|
||||
from models import (
|
||||
from .data_compare_service import DataCompareService
|
||||
from .file_parse_service import FileParseService
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkExecuteConfig,
|
||||
|
||||
Reference in New Issue
Block a user