refactor(agent): Refactor resource of agents (#1518)

This commit is contained in:
Fangyin Cheng
2024-05-15 09:57:19 +08:00
committed by GitHub
parent db4d318a5f
commit 559affe87d
102 changed files with 2633 additions and 2549 deletions

View File

@@ -2,14 +2,13 @@
import json
import logging
from typing import Optional, Tuple, cast
from typing import List, Optional, Tuple, cast
from ..core.action.base import ActionOutput
from ..core.agent import AgentMessage
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..resource.resource_api import ResourceType
from ..resource.resource_db_api import ResourceDbClient
from ..resource.database import DBResource
from .actions.chart_action import ChartAction
logger = logging.getLogger(__name__)
@@ -74,18 +73,21 @@ class DataScientistAgent(ConversableAgent):
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
reply_message = super()._init_reply_message(received_message)
client = self.not_null_resource_loader.get_resource_api(
self.actions[0].resource_need, ResourceDbClient
)
if not client:
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": self.database.dialect,
}
return reply_message
@property
def database(self) -> DBResource:
"""Get the database resource."""
dbs: List[DBResource] = DBResource.from_resource(self.resource)
if not dbs:
raise ValueError(
f"Resource type {self.actions[0].resource_need} is not supported."
)
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": client.get_data_type(self.resources[0]),
}
return reply_message
return dbs[0]
async def correctness_check(
self, message: AgentMessage
@@ -112,17 +114,6 @@ class DataScientistAgent(ConversableAgent):
"generated is not found.",
)
try:
resource_db_client: Optional[
ResourceDbClient
] = self.not_null_resource_loader.get_resource_api(
ResourceType(action_out.resource_type), ResourceDbClient
)
if not resource_db_client:
return (
False,
"Please check your answer, the data resource type is not "
"supported.",
)
if not action_out.resource_value:
return (
False,
@@ -130,8 +121,9 @@ class DataScientistAgent(ConversableAgent):
"found.",
)
columns, values = await resource_db_client.query(
db=action_out.resource_value, sql=sql
columns, values = await self.database.query(
sql=sql,
db=action_out.resource_value,
)
if not values or len(values) <= 0:
return (