mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
@@ -2,14 +2,13 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from ..core.action.base import ActionOutput
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from ..resource.database import DBResource
|
||||
from .actions.chart_action import ChartAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,18 +73,21 @@ class DataScientistAgent(ConversableAgent):
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
client = self.not_null_resource_loader.get_resource_api(
|
||||
self.actions[0].resource_need, ResourceDbClient
|
||||
)
|
||||
if not client:
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": self.database.dialect,
|
||||
}
|
||||
return reply_message
|
||||
|
||||
@property
|
||||
def database(self) -> DBResource:
|
||||
"""Get the database resource."""
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not dbs:
|
||||
raise ValueError(
|
||||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||||
)
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": client.get_data_type(self.resources[0]),
|
||||
}
|
||||
return reply_message
|
||||
return dbs[0]
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
@@ -112,17 +114,6 @@ class DataScientistAgent(ConversableAgent):
|
||||
"generated is not found.",
|
||||
)
|
||||
try:
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.not_null_resource_loader.get_resource_api(
|
||||
ResourceType(action_out.resource_type), ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
return (
|
||||
False,
|
||||
"Please check your answer, the data resource type is not "
|
||||
"supported.",
|
||||
)
|
||||
if not action_out.resource_value:
|
||||
return (
|
||||
False,
|
||||
@@ -130,8 +121,9 @@ class DataScientistAgent(ConversableAgent):
|
||||
"found.",
|
||||
)
|
||||
|
||||
columns, values = await resource_db_client.query(
|
||||
db=action_out.resource_value, sql=sql
|
||||
columns, values = await self.database.query(
|
||||
sql=sql,
|
||||
db=action_out.resource_value,
|
||||
)
|
||||
if not values or len(values) <= 0:
|
||||
return (
|
||||
|
Reference in New Issue
Block a user