mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 13:06:53 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
157 lines
6.2 KiB
Python
157 lines
6.2 KiB
Python
"""Data Scientist Agent."""
|
||
|
||
import json
|
||
import logging
|
||
from typing import List, Optional, Tuple
|
||
|
||
from ..core.agent import AgentMessage
|
||
from ..core.base_agent import ConversableAgent
|
||
from ..core.profile import DynConfig, ProfileConfig
|
||
from ..resource.database import DBResource
|
||
from .actions.chart_action import ChartAction
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DataScientistAgent(ConversableAgent):
|
||
"""Data Scientist Agent."""
|
||
|
||
profile: ProfileConfig = ProfileConfig(
|
||
name=DynConfig(
|
||
"Edgar",
|
||
category="agent",
|
||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_name",
|
||
),
|
||
role=DynConfig(
|
||
"DataScientist",
|
||
category="agent",
|
||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_role",
|
||
),
|
||
goal=DynConfig(
|
||
"Use correct {{dialect}} SQL to analyze and resolve user "
|
||
"input targets based on the data structure information of the "
|
||
"database given in the resource.",
|
||
category="agent",
|
||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_goal",
|
||
),
|
||
constraints=DynConfig(
|
||
[
|
||
"Please ensure that the output is in the required format. "
|
||
"Please ensure that each analysis only outputs one analysis "
|
||
"result SQL, including as much analysis target content as possible.",
|
||
"If there is a recent message record, pay attention to refer to "
|
||
"the answers and execution results inside when analyzing, "
|
||
"and do not generate the same wrong answer.Please check carefully "
|
||
"to make sure the correct SQL is generated. Please strictly adhere "
|
||
"to the data structure definition given. The use of non-existing "
|
||
"fields is prohibited. Be careful not to confuse fields from "
|
||
"different tables, and you can perform multi-table related queries.",
|
||
"If the data and fields that need to be analyzed in the target are in "
|
||
"different tables, it is recommended to use multi-table correlation "
|
||
"queries first, and pay attention to the correlation between multiple "
|
||
"table structures.",
|
||
"It is prohibited to construct data yourself as query conditions. "
|
||
"Only the data values given by the famous songs in the input can "
|
||
"be used as query conditions.",
|
||
"Please select an appropriate one from the supported display methods "
|
||
"for data display. If no suitable display type is found, "
|
||
"use 'response_table' as default value. Supported display types: \n"
|
||
"{{ display_type }}",
|
||
],
|
||
category="agent",
|
||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_constraints",
|
||
),
|
||
desc=DynConfig(
|
||
"Use database resources to conduct data analysis, analyze SQL, and provide "
|
||
"recommended rendering methods.",
|
||
category="agent",
|
||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_desc",
|
||
),
|
||
)
|
||
|
||
max_retry_count: int = 5
|
||
language: str = "zh"
|
||
|
||
def __init__(self, **kwargs):
|
||
"""Create a new DataScientistAgent instance."""
|
||
super().__init__(**kwargs)
|
||
self._init_actions([ChartAction])
|
||
|
||
def _init_reply_message(
|
||
self,
|
||
received_message: AgentMessage,
|
||
rely_messages: Optional[List[AgentMessage]] = None,
|
||
) -> AgentMessage:
|
||
reply_message = super()._init_reply_message(received_message, rely_messages)
|
||
reply_message.context = {
|
||
"display_type": self.actions[0].render_prompt(),
|
||
"dialect": self.database.dialect,
|
||
}
|
||
return reply_message
|
||
|
||
@property
|
||
def database(self) -> DBResource:
|
||
"""Get the database resource."""
|
||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||
if not dbs:
|
||
raise ValueError(
|
||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||
)
|
||
return dbs[0]
|
||
|
||
async def correctness_check(
|
||
self, message: AgentMessage
|
||
) -> Tuple[bool, Optional[str]]:
|
||
"""Verify whether the current execution results meet the target expectations."""
|
||
action_out = message.action_report
|
||
if action_out is None:
|
||
return (
|
||
False,
|
||
f"No executable analysis SQL is generated,{message.content}.",
|
||
)
|
||
|
||
if not action_out.is_exe_success:
|
||
return (
|
||
False,
|
||
f"Please check your answer, {action_out.content}.",
|
||
)
|
||
action_reply_obj = json.loads(action_out.content)
|
||
sql = action_reply_obj.get("sql", None)
|
||
if not sql:
|
||
return (
|
||
False,
|
||
"Please check your answer, the sql information that needs to be "
|
||
"generated is not found.",
|
||
)
|
||
try:
|
||
if not action_out.resource_value:
|
||
return (
|
||
False,
|
||
"Please check your answer, the data resource information is not "
|
||
"found.",
|
||
)
|
||
|
||
columns, values = await self.database.query(
|
||
sql=sql,
|
||
db=action_out.resource_value,
|
||
)
|
||
if not values or len(values) <= 0:
|
||
return (
|
||
False,
|
||
"Please check your answer, the current SQL cannot find the data to "
|
||
"determine whether filtered field values or inappropriate filter "
|
||
"conditions are used.",
|
||
)
|
||
else:
|
||
logger.info(
|
||
f"reply check success! There are {len(values)} rows of data"
|
||
)
|
||
return True, None
|
||
except Exception as e:
|
||
logger.exception(f"DataScientist check exception!{str(e)}")
|
||
return (
|
||
False,
|
||
f"SQL execution error, please re-read the historical information to "
|
||
f"fix this SQL. The error message is as follows:{str(e)}",
|
||
)
|