DB-GPT/dbgpt/agent/expand/data_scientist_agent.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

157 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Data Scientist Agent."""
import json
import logging
from typing import List, Optional, Tuple
from ..core.agent import AgentMessage
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..resource.database import DBResource
from .actions.chart_action import ChartAction
logger = logging.getLogger(__name__)
class DataScientistAgent(ConversableAgent):
"""Data Scientist Agent."""
profile: ProfileConfig = ProfileConfig(
name=DynConfig(
"Edgar",
category="agent",
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_name",
),
role=DynConfig(
"DataScientist",
category="agent",
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_role",
),
goal=DynConfig(
"Use correct {{dialect}} SQL to analyze and resolve user "
"input targets based on the data structure information of the "
"database given in the resource.",
category="agent",
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_goal",
),
constraints=DynConfig(
[
"Please ensure that the output is in the required format. "
"Please ensure that each analysis only outputs one analysis "
"result SQL, including as much analysis target content as possible.",
"If there is a recent message record, pay attention to refer to "
"the answers and execution results inside when analyzing, "
"and do not generate the same wrong answer.Please check carefully "
"to make sure the correct SQL is generated. Please strictly adhere "
"to the data structure definition given. The use of non-existing "
"fields is prohibited. Be careful not to confuse fields from "
"different tables, and you can perform multi-table related queries.",
"If the data and fields that need to be analyzed in the target are in "
"different tables, it is recommended to use multi-table correlation "
"queries first, and pay attention to the correlation between multiple "
"table structures.",
"It is prohibited to construct data yourself as query conditions. "
"Only the data values given by the famous songs in the input can "
"be used as query conditions.",
"Please select an appropriate one from the supported display methods "
"for data display. If no suitable display type is found, "
"use 'response_table' as default value. Supported display types: \n"
"{{ display_type }}",
],
category="agent",
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_constraints",
),
desc=DynConfig(
"Use database resources to conduct data analysis, analyze SQL, and provide "
"recommended rendering methods.",
category="agent",
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_desc",
),
)
max_retry_count: int = 5
language: str = "zh"
def __init__(self, **kwargs):
"""Create a new DataScientistAgent instance."""
super().__init__(**kwargs)
self._init_actions([ChartAction])
def _init_reply_message(
self,
received_message: AgentMessage,
rely_messages: Optional[List[AgentMessage]] = None,
) -> AgentMessage:
reply_message = super()._init_reply_message(received_message, rely_messages)
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": self.database.dialect,
}
return reply_message
@property
def database(self) -> DBResource:
"""Get the database resource."""
dbs: List[DBResource] = DBResource.from_resource(self.resource)
if not dbs:
raise ValueError(
f"Resource type {self.actions[0].resource_need} is not supported."
)
return dbs[0]
async def correctness_check(
self, message: AgentMessage
) -> Tuple[bool, Optional[str]]:
"""Verify whether the current execution results meet the target expectations."""
action_out = message.action_report
if action_out is None:
return (
False,
f"No executable analysis SQL is generated,{message.content}.",
)
if not action_out.is_exe_success:
return (
False,
f"Please check your answer, {action_out.content}.",
)
action_reply_obj = json.loads(action_out.content)
sql = action_reply_obj.get("sql", None)
if not sql:
return (
False,
"Please check your answer, the sql information that needs to be "
"generated is not found.",
)
try:
if not action_out.resource_value:
return (
False,
"Please check your answer, the data resource information is not "
"found.",
)
columns, values = await self.database.query(
sql=sql,
db=action_out.resource_value,
)
if not values or len(values) <= 0:
return (
False,
"Please check your answer, the current SQL cannot find the data to "
"determine whether filtered field values or inappropriate filter "
"conditions are used.",
)
else:
logger.info(
f"reply check success! There are {len(values)} rows of data"
)
return True, None
except Exception as e:
logger.exception(f"DataScientist check exception{str(e)}")
return (
False,
f"SQL execution error, please re-read the historical information to "
f"fix this SQL. The error message is as follows:{str(e)}",
)