mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
chore: resolve confict
This commit is contained in:
@@ -12,6 +12,10 @@ from dbgpt_serve.agent.app.recommend_question.recommend_question import (
|
||||
from dbgpt_serve.agent.hub.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt_serve.agent.hub.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt_serve.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import (
|
||||
BenchmarkCompareEntity,
|
||||
BenchmarkSummaryEntity,
|
||||
)
|
||||
from dbgpt_serve.file.models.models import ServeEntity as FileServeEntity
|
||||
from dbgpt_serve.flow.models.models import ServeEntity as FlowServeEntity
|
||||
from dbgpt_serve.flow.models.models import VariablesEntity as FlowVariableEntity
|
||||
@@ -19,7 +23,6 @@ from dbgpt_serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt_serve.rag.models.chunk_db import DocumentChunkEntity
|
||||
from dbgpt_serve.rag.models.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceEntity
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkCompareEntity, BenchmarkSummaryEntity
|
||||
|
||||
_MODELS = [
|
||||
PluginHubEntity,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
@@ -8,16 +9,29 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
from dbgpt_serve.core import Result
|
||||
from dbgpt_serve.evaluate.api.schemas import BenchmarkServeRequest, EvaluateServeRequest
|
||||
from dbgpt_serve.evaluate.api.schemas import (
|
||||
BenchmarkServeRequest,
|
||||
BuildDemoRequest,
|
||||
EvaluateServeRequest,
|
||||
ExecuteDemoRequest,
|
||||
)
|
||||
from dbgpt_serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from dbgpt_serve.evaluate.service.service import Service
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
import json
|
||||
from dbgpt_serve.evaluate.service.benchmark.data_compare_service import (
|
||||
DataCompareService,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService
|
||||
from dbgpt_serve.evaluate.service.benchmark.data_compare_service import DataCompareService
|
||||
from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import UserInputExecuteService
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import (
|
||||
BenchmarkExecuteConfig,
|
||||
BenchmarkModeTypeEnum,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.benchmark.user_input_execute_service import (
|
||||
UserInputExecuteService,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.service import Service
|
||||
|
||||
from ...prompt.service.service import Service as PromptService
|
||||
from ..service.benchmark.benchmark_service import (
|
||||
@@ -139,6 +153,7 @@ async def get_scenes():
|
||||
|
||||
return Result.succ(scene_list)
|
||||
|
||||
|
||||
@router.post("/evaluation")
|
||||
async def evaluation(
|
||||
request: EvaluateServeRequest,
|
||||
@@ -162,22 +177,25 @@ async def evaluation(
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/benchmark/list_results", dependencies=[Depends(check_api_key)])
|
||||
async def list_compare_runs(limit: int = 50, offset: int = 0):
|
||||
dao = BenchmarkResultDao()
|
||||
rows = dao.list_summaries(limit=limit, offset=offset)
|
||||
result = []
|
||||
for s in rows:
|
||||
result.append({
|
||||
"id": s.id,
|
||||
"roundId": s.round_id,
|
||||
"outputPath": s.output_path,
|
||||
"right": s.right,
|
||||
"wrong": s.wrong,
|
||||
"failed": s.failed,
|
||||
"exception": s.exception,
|
||||
"gmtCreated": s.gmt_created.isoformat() if s.gmt_created else None,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"id": s.id,
|
||||
"roundId": s.round_id,
|
||||
"outputPath": s.output_path,
|
||||
"right": s.right,
|
||||
"wrong": s.wrong,
|
||||
"failed": s.failed,
|
||||
"exception": s.exception,
|
||||
"gmtCreated": s.gmt_created.isoformat() if s.gmt_created else None,
|
||||
}
|
||||
)
|
||||
return Result.succ(result)
|
||||
|
||||
|
||||
@@ -187,7 +205,9 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
|
||||
s = dao.get_summary_by_id(summary_id)
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="compare run not found")
|
||||
compares = dao.list_compare_by_round_and_path(s.round_id, s.output_path, limit=limit, offset=offset)
|
||||
compares = dao.list_compare_by_round_and_path(
|
||||
s.round_id, s.output_path, limit=limit, offset=offset
|
||||
)
|
||||
detail = {
|
||||
"id": s.id,
|
||||
"roundId": s.round_id,
|
||||
@@ -207,7 +227,9 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
|
||||
"prompt": r.prompt,
|
||||
"standardAnswerSql": r.standard_answer_sql,
|
||||
"llmOutput": r.llm_output,
|
||||
"executeResult": json.loads(r.execute_result) if r.execute_result else None,
|
||||
"executeResult": json.loads(r.execute_result)
|
||||
if r.execute_result
|
||||
else None,
|
||||
"errorMsg": r.error_msg,
|
||||
"compareResult": r.compare_result,
|
||||
"isExecute": r.is_execute,
|
||||
@@ -219,6 +241,7 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
|
||||
}
|
||||
return Result.succ(detail)
|
||||
|
||||
|
||||
@router.post("/benchmark/run_build", dependencies=[Depends(check_api_key)])
|
||||
async def benchmark_run_build(req: BuildDemoRequest):
|
||||
fps = FileParseService()
|
||||
@@ -244,6 +267,7 @@ async def benchmark_run_build(req: BuildDemoRequest):
|
||||
right_outputs=right,
|
||||
input_file_path=req.input_file_path,
|
||||
output_file_path=req.right_output_file_path,
|
||||
)
|
||||
|
||||
dao = BenchmarkResultDao()
|
||||
summary_id = dao.compute_and_save_summary(req.round_id, req.right_output_file_path)
|
||||
@@ -290,7 +314,11 @@ async def list_benchmark_datasets():
|
||||
manager = get_benchmark_manager(global_system_app)
|
||||
info = await manager.get_table_info()
|
||||
result = [
|
||||
{"name": name, "rowCount": meta.get("row_count", 0), "columns": meta.get("columns", [])}
|
||||
{
|
||||
"name": name,
|
||||
"rowCount": meta.get("row_count", 0),
|
||||
"columns": meta.get("columns", []),
|
||||
}
|
||||
for name, meta in info.items()
|
||||
]
|
||||
return Result.succ(result)
|
||||
|
||||
@@ -42,17 +42,27 @@ class EvaluateServeResponse(EvaluateServeRequest):
|
||||
class BuildDemoRequest(BaseModel):
|
||||
round_id: int = Field(..., description="benchmark round id")
|
||||
input_file_path: str = Field(..., description="path to input jsonl")
|
||||
left_output_file_path: str = Field(..., description="path to left llm outputs jsonl")
|
||||
right_output_file_path: str = Field(..., description="path to right llm outputs jsonl")
|
||||
compare_config: Optional[dict] = Field(None, description="compare config, e.g. {'check': 'FULL_TEXT'}")
|
||||
left_output_file_path: str = Field(
|
||||
..., description="path to left llm outputs jsonl"
|
||||
)
|
||||
right_output_file_path: str = Field(
|
||||
..., description="path to right llm outputs jsonl"
|
||||
)
|
||||
compare_config: Optional[dict] = Field(
|
||||
None, description="compare config, e.g. {'check': 'FULL_TEXT'}"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteDemoRequest(BaseModel):
|
||||
round_id: int = Field(..., description="benchmark round id")
|
||||
input_file_path: str = Field(..., description="path to input jsonl")
|
||||
right_output_file_path: str = Field(..., description="path to right llm outputs jsonl")
|
||||
right_output_file_path: str = Field(
|
||||
..., description="path to right llm outputs jsonl"
|
||||
)
|
||||
standard_file_path: str = Field(..., description="path to standard answers excel")
|
||||
compare_config: Optional[dict] = Field(None, description="optional compare config for json compare")
|
||||
compare_config: Optional[dict] = Field(
|
||||
None, description="optional compare config for json compare"
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkServeRequest(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .benchmark_db import (
|
||||
BenchmarkCompareEntity,
|
||||
BenchmarkSummaryEntity,
|
||||
BenchmarkResultDao,
|
||||
BenchmarkSummaryEntity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -33,7 +33,9 @@ class BenchmarkCompareEntity(Model):
|
||||
),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
# Round and mode
|
||||
round_id = Column(Integer, nullable=False, comment="Benchmark round id")
|
||||
mode = Column(String(16), nullable=False, comment="BUILD or EXECUTE")
|
||||
@@ -47,18 +49,29 @@ class BenchmarkCompareEntity(Model):
|
||||
|
||||
standard_answer_sql = Column(Text, nullable=True, comment="Standard answer SQL")
|
||||
llm_output = Column(Text, nullable=True, comment="LLM output text or JSON")
|
||||
execute_result = Column(Text, nullable=True, comment="Execution result JSON (serialized)")
|
||||
execute_result = Column(
|
||||
Text, nullable=True, comment="Execution result JSON (serialized)"
|
||||
)
|
||||
error_msg = Column(Text, nullable=True, comment="Error message")
|
||||
|
||||
compare_result = Column(String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION")
|
||||
compare_result = Column(
|
||||
String(16), nullable=True, comment="RIGHT/WRONG/FAILED/EXCEPTION"
|
||||
)
|
||||
is_execute = Column(Boolean, default=False, comment="Whether this is EXECUTE mode")
|
||||
llm_count = Column(Integer, default=0, comment="Number of LLM outputs compared")
|
||||
|
||||
# Source path for traceability (original output jsonl file path)
|
||||
output_path = Column(String(512), nullable=False, comment="Original output file path")
|
||||
output_path = Column(
|
||||
String(512), nullable=False, comment="Original output file path"
|
||||
)
|
||||
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time")
|
||||
gmt_modified = Column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
comment="Record update time",
|
||||
)
|
||||
|
||||
Index("idx_bm_comp_round", "round_id")
|
||||
Index("idx_bm_comp_mode", "mode")
|
||||
@@ -76,9 +89,13 @@ class BenchmarkSummaryEntity(Model):
|
||||
UniqueConstraint("round_id", "output_path", name="uk_round_output"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
round_id = Column(Integer, nullable=False, comment="Benchmark round id")
|
||||
output_path = Column(String(512), nullable=False, comment="Original output file path")
|
||||
output_path = Column(
|
||||
String(512), nullable=False, comment="Original output file path"
|
||||
)
|
||||
|
||||
right = Column(Integer, default=0, comment="RIGHT count")
|
||||
wrong = Column(Integer, default=0, comment="WRONG count")
|
||||
@@ -86,7 +103,12 @@ class BenchmarkSummaryEntity(Model):
|
||||
exception = Column(Integer, default=0, comment="EXCEPTION count")
|
||||
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="Record update time")
|
||||
gmt_modified = Column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.now,
|
||||
comment="Record update time",
|
||||
)
|
||||
|
||||
Index("idx_bm_sum_round", "round_id")
|
||||
|
||||
@@ -105,7 +127,8 @@ class BenchmarkResultDao(BaseDao):
|
||||
) -> int:
|
||||
"""Write multiple compare records to DB.
|
||||
|
||||
records: each dict contains keys like in FileParseService.write_data_compare_result rows.
|
||||
records: each dict contains keys like in
|
||||
FileParseService.write_data_compare_result rows.
|
||||
Returns number of records inserted.
|
||||
"""
|
||||
inserted = 0
|
||||
@@ -122,7 +145,9 @@ class BenchmarkResultDao(BaseDao):
|
||||
prompt=r.get("prompt"),
|
||||
standard_answer_sql=r.get("standardAnswerSql"),
|
||||
llm_output=r.get("llmOutput"),
|
||||
execute_result=json.dumps(r.get("executeResult")) if r.get("executeResult") is not None else None,
|
||||
execute_result=json.dumps(r.get("executeResult"))
|
||||
if r.get("executeResult") is not None
|
||||
else None,
|
||||
error_msg=r.get("errorMsg"),
|
||||
compare_result=r.get("compareResult"),
|
||||
is_execute=is_execute,
|
||||
@@ -136,7 +161,9 @@ class BenchmarkResultDao(BaseDao):
|
||||
session.commit()
|
||||
return inserted
|
||||
|
||||
def compute_and_save_summary(self, round_id: int, output_path: str) -> Optional[int]:
|
||||
def compute_and_save_summary(
|
||||
self, round_id: int, output_path: str
|
||||
) -> Optional[int]:
|
||||
"""Compute summary from compare table and save to summary table.
|
||||
Returns summary id if saved, else None.
|
||||
"""
|
||||
@@ -197,7 +224,9 @@ class BenchmarkResultDao(BaseDao):
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_summary(self, round_id: int, output_path: str) -> Optional[BenchmarkSummaryEntity]:
|
||||
def get_summary(
|
||||
self, round_id: int, output_path: str
|
||||
) -> Optional[BenchmarkSummaryEntity]:
|
||||
with self.session(commit=False) as session:
|
||||
return (
|
||||
session.query(BenchmarkSummaryEntity)
|
||||
@@ -242,4 +271,3 @@ class BenchmarkResultDao(BaseDao):
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,27 +1,24 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
DataCompareResultEnum,
|
||||
DataCompareStrategyConfig,
|
||||
RoundAnswerConfirmModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
|
||||
|
||||
class FileParseService:
|
||||
def __init__(self):
|
||||
@@ -101,7 +98,10 @@ class FileParseService:
|
||||
failed=summary.failed if summary else 0,
|
||||
exception=summary.exception if summary else 0,
|
||||
)
|
||||
print(f"[summary] summary saved to DB for round={round_id}, output_path={output_path} -> {result}")
|
||||
logger.info(
|
||||
f"[summary] summary saved to DB for round={round_id},"
|
||||
f" output_path={output_path} -> {result}"
|
||||
)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
def parse_standard_benchmark_sets(
|
||||
@@ -153,7 +153,6 @@ class FileParseService:
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
class ExcelFileParseService(FileParseService):
|
||||
def parse_input_sets(self, location: str) -> BenchmarkDataSets:
|
||||
"""
|
||||
|
||||
@@ -13,6 +13,7 @@ class BenchmarkModeTypeEnum(str, Enum):
|
||||
BUILD = "BUILD"
|
||||
EXECUTE = "EXECUTE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCompareStrategyConfig:
|
||||
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
|
||||
@@ -130,7 +131,6 @@ class RoundAnswerConfirmModel:
|
||||
compareResult: Optional[DataCompareResultEnum] = None
|
||||
|
||||
|
||||
|
||||
class FileParseTypeEnum(Enum):
|
||||
"""文件解析类型枚举"""
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
from .data_compare_service import DataCompareService
|
||||
from .file_parse_service import FileParseService
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user