mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
chore: make code format
This commit is contained in:
@@ -95,12 +95,18 @@ class BenchmarkServeRequest(BaseModel):
|
||||
log_info: Optional[str] = Field(None, description="evaluation task error message")
|
||||
gmt_create: Optional[str] = Field(None, description="create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="modified time")
|
||||
benchmark_type: Optional[str] = Field(None, description="execute benchmark type, llm or agent")
|
||||
benchmark_type: Optional[str] = Field(
|
||||
None, description="execute benchmark type, llm or agent"
|
||||
)
|
||||
api_url: Optional[str] = Field(None, description="api url")
|
||||
http_method: Optional[str] = Field(None, description="http method")
|
||||
headers: Optional[dict] = Field(None, description="http headers")
|
||||
parse_strategy: Optional[str] = Field(None, description="agent response parse strategy")
|
||||
response_mapping: Optional[dict] = Field(None, description="agent response extract result mapping")
|
||||
parse_strategy: Optional[str] = Field(
|
||||
None, description="agent response parse strategy"
|
||||
)
|
||||
response_mapping: Optional[dict] = Field(
|
||||
None, description="agent response extract result mapping"
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkServeResponse(BenchmarkServeRequest):
|
||||
|
||||
@@ -18,8 +18,12 @@ from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util import PaginationResult, get_or_create_event_loop
|
||||
from .ext.excel_file_parse import ExcelFileParseService
|
||||
from ..fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import (
|
||||
BenchmarkAgentTask,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_llm_task import (
|
||||
BenchmarkLLMTask,
|
||||
)
|
||||
|
||||
from ....core import BaseService
|
||||
from ....prompt.service.service import Service as PromptService
|
||||
@@ -35,21 +39,19 @@ from ...api.schemas import (
|
||||
from ...config import ServeConfig
|
||||
from ...models.models import ServeDao, ServeEntity
|
||||
from ..fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_llm_task import BenchmarkLLMTask
|
||||
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import (
|
||||
BenchmarkAgentTask,
|
||||
)
|
||||
from .data_compare_service import DataCompareService
|
||||
from .ext.excel_file_parse import ExcelFileParseService
|
||||
from .models import (
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
BenchmarkExecuteConfig,
|
||||
BenchmarkInvokeType,
|
||||
BenchmarkModeTypeEnum,
|
||||
BenchmarkTaskResult,
|
||||
ContentTypeEnum,
|
||||
FileParseTypeEnum,
|
||||
InputType,
|
||||
OutputType,
|
||||
BenchmarkInvokeType, FileParseTypeEnum
|
||||
)
|
||||
from .user_input_execute_service import UserInputExecuteService
|
||||
|
||||
@@ -259,7 +261,7 @@ class BenchmarkService(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load benchmark dataset before run: {e}")
|
||||
raise
|
||||
raise e
|
||||
|
||||
output_file_path = self._generate_output_file_full_path(
|
||||
output_file_path, evaluate_code
|
||||
@@ -277,7 +279,7 @@ class BenchmarkService(
|
||||
http_method,
|
||||
headers,
|
||||
parse_strategy,
|
||||
response_mapping
|
||||
response_mapping,
|
||||
)
|
||||
logger.info(f"run benchmark with benchmarkConfig={config}")
|
||||
|
||||
@@ -364,9 +366,7 @@ class BenchmarkService(
|
||||
try:
|
||||
return HttpMethod(http_method.upper())
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid HTTP method: {http_method}, using default POST"
|
||||
)
|
||||
logger.warning(f"Invalid HTTP method: {http_method}, using default POST")
|
||||
return HttpMethod.POST
|
||||
|
||||
def _parse_response_strategy(self, parse_strategy: Optional[str]):
|
||||
@@ -418,7 +418,7 @@ class BenchmarkService(
|
||||
http_method,
|
||||
headers,
|
||||
parse_strategy,
|
||||
response_mapping
|
||||
response_mapping,
|
||||
) -> BenchmarkExecuteConfig:
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmark_mode_type=BenchmarkModeTypeEnum.EXECUTE,
|
||||
@@ -459,9 +459,7 @@ class BenchmarkService(
|
||||
}
|
||||
return prompt.format(**format_params)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to format prompt template: {e}. "
|
||||
)
|
||||
logger.warning(f"Failed to format prompt template: {e}. ")
|
||||
return prompt
|
||||
|
||||
def _get_database_dialect(self) -> str | None:
|
||||
|
||||
@@ -157,6 +157,7 @@ class ContentTypeEnum(Enum):
|
||||
SQL = "SQL"
|
||||
JSON = "JSON"
|
||||
|
||||
|
||||
class BenchmarkInvokeType(str, Enum):
|
||||
LLM = "LLM"
|
||||
AGENT = "AGENT"
|
||||
@@ -178,6 +179,7 @@ class ResponseParseStrategy(str, Enum):
|
||||
JSON_PATH = "JSON_PATH" # Use JSON path to extract content
|
||||
DIRECT = "DIRECT" # Directly use response as content
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentApiConfig:
|
||||
"""Agent API configuration.
|
||||
@@ -254,6 +256,7 @@ class AgentApiConfig:
|
||||
extra_config=config_dict.get("extra_config", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkExecuteConfig:
|
||||
"""
|
||||
@@ -439,12 +442,13 @@ class ReasoningResponse:
|
||||
self.think = think
|
||||
self.content = content
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCompletionRequest:
|
||||
"""benchmark Agent request entity."""
|
||||
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None,
|
||||
messages: Optional[List[dict]] = (None,)
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
|
||||
@@ -8,43 +8,48 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import ReasoningResponse, AgentCompletionRequest, AgentApiConfig, \
|
||||
ResponseParseStrategy, HttpMethod
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import (
|
||||
AgentApiConfig,
|
||||
AgentCompletionRequest,
|
||||
HttpMethod,
|
||||
ReasoningResponse,
|
||||
ResponseParseStrategy,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
class ResponseParser:
|
||||
"""Response parser for extracting content from API responses."""
|
||||
|
||||
@staticmethod
|
||||
def parse_json_path(response_data: Any, json_path: str) -> Any:
|
||||
"""Parse response using JSON path expression.
|
||||
|
||||
|
||||
Args:
|
||||
response_data: The response data (dict or list)
|
||||
json_path: JSON path expression (e.g., "$.data.content")
|
||||
|
||||
|
||||
Returns:
|
||||
Extracted value or None if path not found
|
||||
"""
|
||||
if not json_path:
|
||||
return response_data
|
||||
|
||||
|
||||
# Remove leading $. if present
|
||||
path = json_path.lstrip("$.")
|
||||
|
||||
|
||||
# Split path by dots and brackets
|
||||
parts = path.replace("[", ".").replace("]", "").split(".")
|
||||
|
||||
|
||||
current = response_data
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
@@ -53,21 +58,21 @@ class ResponseParser:
|
||||
current = current[index]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
if current is None:
|
||||
return None
|
||||
except (KeyError, IndexError, ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def parse_direct(response_data: Any) -> str:
|
||||
"""Parse response directly as string.
|
||||
|
||||
|
||||
Args:
|
||||
response_data: The response data
|
||||
|
||||
|
||||
Returns:
|
||||
String representation of the response
|
||||
"""
|
||||
@@ -85,7 +90,7 @@ class ResponseParser:
|
||||
|
||||
class BenchmarkAgentTask:
|
||||
"""Benchmark Agent Task for evaluating remote agent APIs.
|
||||
|
||||
|
||||
This class provides functionality to:
|
||||
1. Call remote agent APIs with configurable parameters
|
||||
2. Parse API responses according to configuration
|
||||
@@ -99,7 +104,7 @@ class BenchmarkAgentTask:
|
||||
agent_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the BenchmarkAgentTask.
|
||||
|
||||
|
||||
Args:
|
||||
api_config: Agent API configuration
|
||||
agent_name: Optional name for the agent (for logging)
|
||||
@@ -121,12 +126,10 @@ class BenchmarkAgentTask:
|
||||
raise ValueError("API URL is required")
|
||||
|
||||
async def invoke_agent(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
**kwargs
|
||||
self, prompt: Optional[str] = None, **kwargs
|
||||
) -> Union[ReasoningResponse, None]:
|
||||
"""Invoke the remote agent API.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent
|
||||
**kwargs: Additional parameters for request body mapping
|
||||
@@ -136,37 +139,38 @@ class BenchmarkAgentTask:
|
||||
return await self._invoke_task(prompt, **kwargs)
|
||||
|
||||
async def _invoke_task(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
**kwargs
|
||||
self, prompt: Optional[str], **kwargs
|
||||
) -> Union[ReasoningResponse, None]:
|
||||
"""Internal method to invoke the agent task.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send
|
||||
**kwargs: Additional parameters
|
||||
|
||||
|
||||
Returns:
|
||||
ReasoningResponse or None
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Build request body
|
||||
request_body_obj = self._build_request_body(prompt, **kwargs)
|
||||
request_body = {k: v for k, v in asdict(request_body_obj).items() if v is not None}
|
||||
|
||||
request_body = {
|
||||
k: v for k, v in asdict(request_body_obj).items() if v is not None
|
||||
}
|
||||
|
||||
# Execute request with retries
|
||||
for attempt in range(self._api_config.max_retries):
|
||||
try:
|
||||
response_data = await self._execute_request(request_body)
|
||||
|
||||
|
||||
# Parse response
|
||||
reasoning_response = self._parse_response(response_data)
|
||||
|
||||
|
||||
if reasoning_response:
|
||||
logger.info(
|
||||
f"[{self._agent_name}] Successfully invoked agent API, "
|
||||
f"reasoning_response={reasoning_response}, duration={(time.time() - start_time):.2f}s"
|
||||
f"reasoning_response={reasoning_response},"
|
||||
f" duration={(time.time() - start_time):.2f}s"
|
||||
)
|
||||
return reasoning_response
|
||||
else:
|
||||
@@ -174,44 +178,37 @@ class BenchmarkAgentTask:
|
||||
f"[{self._agent_name}] Failed to parse response, "
|
||||
f"attempt={attempt + 1}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self._agent_name}] Request failed on attempt {attempt + 1}: {e}"
|
||||
)
|
||||
|
||||
|
||||
if attempt < self._api_config.max_retries - 1:
|
||||
# Wait before retry
|
||||
await self._async_sleep(self._api_config.retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"[{self._agent_name}] All retry attempts exhausted"
|
||||
)
|
||||
logger.error(f"[{self._agent_name}] All retry attempts exhausted")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _build_request_body(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
**kwargs
|
||||
self, prompt: Optional[str], **kwargs
|
||||
) -> AgentCompletionRequest:
|
||||
"""Build request body from template and parameters.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt text
|
||||
**kwargs: Additional parameters including model, temperature, top_p,
|
||||
**kwargs: Additional parameters including model, temperature, top_p,
|
||||
top_k, max_tokens, stream, user, question
|
||||
|
||||
|
||||
Returns:
|
||||
AgentCompletionRequest object
|
||||
"""
|
||||
messages = []
|
||||
if prompt:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return AgentCompletionRequest(
|
||||
messages=messages,
|
||||
@@ -220,68 +217,65 @@ class BenchmarkAgentTask:
|
||||
top_k=kwargs.get("top_k"),
|
||||
max_tokens=kwargs.get("max_tokens"),
|
||||
stream=kwargs.get("stream"),
|
||||
user=kwargs.get("user")
|
||||
user=kwargs.get("user"),
|
||||
)
|
||||
|
||||
async def _execute_request(self, request_body: Dict[str, Any]) -> Any:
|
||||
"""Execute HTTP request to the agent API.
|
||||
|
||||
|
||||
Args:
|
||||
request_body: The request body
|
||||
|
||||
|
||||
Returns:
|
||||
Response data (parsed JSON or text)
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If request fails
|
||||
"""
|
||||
connector = None
|
||||
if not self._api_config.verify_ssl:
|
||||
connector = aiohttp.TCPConnector(ssl=False)
|
||||
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self._api_config.timeout)
|
||||
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout
|
||||
connector=connector, timeout=timeout
|
||||
) as session:
|
||||
request_kwargs = {
|
||||
"url": self._api_config.api_url,
|
||||
"headers": self._api_config.headers,
|
||||
"params": self._api_config.query_params,
|
||||
}
|
||||
|
||||
|
||||
# Add body for methods that support it
|
||||
if self._api_config.http_method in [HttpMethod.POST, HttpMethod.PUT, HttpMethod.PATCH]:
|
||||
if self._api_config.http_method in [
|
||||
HttpMethod.POST,
|
||||
HttpMethod.PUT,
|
||||
HttpMethod.PATCH,
|
||||
]:
|
||||
request_kwargs["json"] = request_body
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"[{self._agent_name}] Sending {self._api_config.http_method.value} "
|
||||
f"request to {self._api_config.api_url}"
|
||||
)
|
||||
|
||||
|
||||
async with session.request(
|
||||
self._api_config.http_method.value,
|
||||
**request_kwargs
|
||||
self._api_config.http_method.value, **request_kwargs
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"HTTP {response.status}: {error_text}"
|
||||
)
|
||||
|
||||
raise Exception(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
# Try to parse as JSON, fallback to text
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
return await response.text()
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
response_data: Any
|
||||
) -> Optional[ReasoningResponse]:
|
||||
def _parse_response(self, response_data: Any) -> Optional[ReasoningResponse]:
|
||||
"""Parse the API response into ReasoningResponse.
|
||||
|
||||
|
||||
Args:
|
||||
response_data: The raw response data
|
||||
|
||||
@@ -291,54 +285,46 @@ class BenchmarkAgentTask:
|
||||
try:
|
||||
if self._api_config.parse_strategy == ResponseParseStrategy.DIRECT:
|
||||
content = self._parser.parse_direct(response_data)
|
||||
return ReasoningResponse(
|
||||
content=content,
|
||||
cot_tokens=0,
|
||||
think=None
|
||||
)
|
||||
|
||||
return ReasoningResponse(content=content, cot_tokens=0, think=None)
|
||||
|
||||
elif self._api_config.parse_strategy == ResponseParseStrategy.JSON_PATH:
|
||||
# Extract fields using JSON path
|
||||
content = None
|
||||
tokens = 0
|
||||
think = None
|
||||
|
||||
|
||||
if "sql" in self._api_config.response_mapping:
|
||||
content = self._parser.parse_json_path(
|
||||
response_data,
|
||||
self._api_config.response_mapping["sql"]
|
||||
response_data, self._api_config.response_mapping["sql"]
|
||||
)
|
||||
|
||||
|
||||
if "tokens" in self._api_config.response_mapping:
|
||||
tokens_value = self._parser.parse_json_path(
|
||||
response_data,
|
||||
self._api_config.response_mapping["tokens"]
|
||||
response_data, self._api_config.response_mapping["tokens"]
|
||||
)
|
||||
if tokens_value is not None:
|
||||
try:
|
||||
tokens = int(tokens_value)
|
||||
except (ValueError, TypeError):
|
||||
tokens = 0
|
||||
|
||||
|
||||
if "think" in self._api_config.response_mapping:
|
||||
think = self._parser.parse_json_path(
|
||||
response_data,
|
||||
self._api_config.response_mapping["think"]
|
||||
response_data, self._api_config.response_mapping["think"]
|
||||
)
|
||||
|
||||
|
||||
# If content is None, try to extract from response directly
|
||||
if content is None:
|
||||
content = self._parser.parse_direct(response_data)
|
||||
|
||||
|
||||
return ReasoningResponse(
|
||||
content=str(content) if content is not None else "",
|
||||
cot_tokens=tokens,
|
||||
think=str(think) if think is not None else None
|
||||
think=str(think) if think is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self._agent_name}] Failed to parse response: {e}",
|
||||
exc_info=True
|
||||
f"[{self._agent_name}] Failed to parse response: {e}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -346,11 +332,12 @@ class BenchmarkAgentTask:
|
||||
async def _async_sleep(seconds: int):
|
||||
"""Async sleep utility."""
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
def get_config(self) -> AgentApiConfig:
|
||||
"""Get the current API configuration.
|
||||
|
||||
|
||||
Returns:
|
||||
AgentApiConfig object
|
||||
"""
|
||||
@@ -358,7 +345,7 @@ class BenchmarkAgentTask:
|
||||
|
||||
def update_config(self, **kwargs):
|
||||
"""Update API configuration.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration fields to update
|
||||
"""
|
||||
@@ -366,6 +353,4 @@ class BenchmarkAgentTask:
|
||||
if hasattr(self._api_config, key):
|
||||
setattr(self._api_config, key, value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self._agent_name}] Unknown config field: {key}"
|
||||
)
|
||||
logger.warning(f"[{self._agent_name}] Unknown config field: {key}")
|
||||
|
||||
Reference in New Issue
Block a user