mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 06:47:30 +00:00
feat(agent):Fix agent bug (#1953)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
d72bfb2f5f
commit
b951b50689
@ -82,10 +82,10 @@ class ActionOutput(BaseModel):
|
|||||||
class Action(ABC, Generic[T]):
|
class Action(ABC, Generic[T]):
|
||||||
"""Base Action class for defining agent actions."""
|
"""Base Action class for defining agent actions."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, language: str = "en"):
|
||||||
"""Create an action."""
|
"""Create an action."""
|
||||||
self.resource: Optional[Resource] = None
|
self.resource: Optional[Resource] = None
|
||||||
self.language: str = "en"
|
self.language: str = language
|
||||||
|
|
||||||
def init_resource(self, resource: Optional[Resource]):
|
def init_resource(self, resource: Optional[Resource]):
|
||||||
"""Initialize the resource."""
|
"""Initialize the resource."""
|
||||||
|
@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class BlankAction(Action):
|
class BlankAction(Action):
|
||||||
"""Blank action class."""
|
"""Blank action class."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Create a blank action."""
|
"""Blank action init."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ai_out_schema(self) -> Optional[str]:
|
def ai_out_schema(self) -> Optional[str]:
|
||||||
|
@ -127,9 +127,6 @@ class AgentManager(BaseComponent):
|
|||||||
def list_agents(self):
|
def list_agents(self):
|
||||||
"""Return a list of all registered agents and their descriptions."""
|
"""Return a list of all registered agents and their descriptions."""
|
||||||
result = []
|
result = []
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
logger.info(f"List Agent Begin:{datetime.now()}")
|
|
||||||
for name, value in self._agents.items():
|
for name, value in self._agents.items():
|
||||||
result.append(
|
result.append(
|
||||||
{
|
{
|
||||||
@ -137,7 +134,6 @@ class AgentManager(BaseComponent):
|
|||||||
"desc": value[1].goal,
|
"desc": value[1].goal,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(f"List Agent End:{datetime.now()}")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -673,7 +673,7 @@ class ConversableAgent(Role, Agent):
|
|||||||
self.actions = []
|
self.actions = []
|
||||||
for idx, action in enumerate(actions):
|
for idx, action in enumerate(actions):
|
||||||
if issubclass(action, Action):
|
if issubclass(action, Action):
|
||||||
self.actions.append(action())
|
self.actions.append(action(language=self.language))
|
||||||
|
|
||||||
async def _a_append_message(
|
async def _a_append_message(
|
||||||
self, message: AgentMessage, role, sender: Agent
|
self, message: AgentMessage, role, sender: Agent
|
||||||
|
@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""Create a plan action."""
|
"""Create a plan action."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisAgentPlans()
|
self._render_protocol = VisAgentPlans()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -31,9 +31,9 @@ class SqlInput(BaseModel):
|
|||||||
class ChartAction(Action[SqlInput]):
|
class ChartAction(Action[SqlInput]):
|
||||||
"""Chart action class."""
|
"""Chart action class."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Create a chart action."""
|
"""Chart action init."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisChart()
|
self._render_protocol = VisChart()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class CodeAction(Action[None]):
|
class CodeAction(Action[None]):
|
||||||
"""Code Action Module."""
|
"""Code Action Module."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Create a code action."""
|
"""Code action init."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisCode()
|
self._render_protocol = VisCode()
|
||||||
self._code_execution_config = {}
|
self._code_execution_config = {}
|
||||||
|
|
||||||
|
@ -39,9 +39,9 @@ class ChartItem(BaseModel):
|
|||||||
class DashboardAction(Action[List[ChartItem]]):
|
class DashboardAction(Action[List[ChartItem]]):
|
||||||
"""Dashboard action class."""
|
"""Dashboard action class."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Create a dashboard action."""
|
"""Dashboard action init."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisDashboard()
|
self._render_protocol = VisDashboard()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
|
|||||||
class IndicatorAction(Action[IndicatorInput]):
|
class IndicatorAction(Action[IndicatorInput]):
|
||||||
"""Indicator Action."""
|
"""Indicator Action."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Init Indicator Action."""
|
"""Init indicator action."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisApiResponse()
|
self._render_protocol = VisApiResponse()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -34,9 +34,9 @@ class ToolInput(BaseModel):
|
|||||||
class ToolAction(Action[ToolInput]):
|
class ToolAction(Action[ToolInput]):
|
||||||
"""Tool action class."""
|
"""Tool action class."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""Create a plugin action."""
|
"""Tool action init."""
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisPlugin()
|
self._render_protocol = VisPlugin()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -73,8 +73,8 @@ class RetrieverResource(Resource[ResourceParameters]):
|
|||||||
prompt_template = f"\nResources-{self.name}:\n {content}"
|
prompt_template = f"\nResources-{self.name}:\n {content}"
|
||||||
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
|
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
return prompt_template.format(content=content), self._get_references(chunks)
|
return prompt_template, self._get_references(chunks)
|
||||||
return prompt_template_zh.format(content=content), self._get_references(chunks)
|
return prompt_template_zh, self._get_references(chunks)
|
||||||
|
|
||||||
async def get_resources(
|
async def get_resources(
|
||||||
self,
|
self,
|
||||||
|
@ -548,6 +548,17 @@ async def chat_completions(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Chat Exception!{dialogue}", e)
|
||||||
|
|
||||||
|
async def error_text(err_msg):
|
||||||
|
yield f"data:{err_msg}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
error_text(str(e)),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/plain",
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# write to recent usage app.
|
# write to recent usage app.
|
||||||
if dialogue.user_name is not None and dialogue.app_code is not None:
|
if dialogue.user_name is not None and dialogue.app_code is not None:
|
||||||
|
@ -34,6 +34,8 @@ class ChatExcel(BaseChat):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.select_param = chat_param["select_param"]
|
self.select_param = chat_param["select_param"]
|
||||||
|
if not self.select_param:
|
||||||
|
raise ValueError("Please upload the Excel document you want to talk to!")
|
||||||
self.model_name = chat_param["model_name"]
|
self.model_name = chat_param["model_name"]
|
||||||
chat_param["chat_mode"] = ChatScene.ChatExcel
|
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||||
self.chat_param = chat_param
|
self.chat_param = chat_param
|
||||||
|
@ -230,7 +230,7 @@ def is_chinese(text):
|
|||||||
|
|
||||||
|
|
||||||
class ExcelReader:
|
class ExcelReader:
|
||||||
def __init__(self, conv_uid, file_param):
|
def __init__(self, conv_uid: str, file_param: str):
|
||||||
self.conv_uid = conv_uid
|
self.conv_uid = conv_uid
|
||||||
self.file_param = file_param
|
self.file_param = file_param
|
||||||
if isinstance(file_param, str) and os.path.isabs(file_param):
|
if isinstance(file_param, str) and os.path.isabs(file_param):
|
||||||
|
@ -214,10 +214,15 @@ class ConnectConfigDao(BaseDao):
|
|||||||
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
|
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
|
||||||
"""Get db list."""
|
"""Get db list."""
|
||||||
session = self.get_raw_session()
|
session = self.get_raw_session()
|
||||||
if db_name:
|
if db_name and user_id:
|
||||||
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
|
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
|
||||||
else:
|
elif user_id:
|
||||||
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
|
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
|
||||||
|
elif db_name:
|
||||||
|
sql = f"SELECT * FROM connect_config where db_name='{db_name}'" # noqa
|
||||||
|
else:
|
||||||
|
sql = f"SELECT * FROM connect_config" # noqa
|
||||||
|
|
||||||
result = session.execute(text(sql))
|
result = session.execute(text(sql))
|
||||||
fields = [field[0] for field in result.cursor.description] # type: ignore
|
fields = [field[0] for field in result.cursor.description] # type: ignore
|
||||||
data = []
|
data = []
|
||||||
|
@ -50,6 +50,7 @@ class StartAppAction(Action[LinkAppInput]):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ActionOutput:
|
) -> ActionOutput:
|
||||||
conv_id = kwargs.get("conv_id")
|
conv_id = kwargs.get("conv_id")
|
||||||
|
user_input = kwargs.get("user_input")
|
||||||
paren_agent = kwargs.get("paren_agent")
|
paren_agent = kwargs.get("paren_agent")
|
||||||
init_message_rounds = kwargs.get("init_message_rounds")
|
init_message_rounds = kwargs.get("init_message_rounds")
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ class StartAppAction(Action[LinkAppInput]):
|
|||||||
from dbgpt.serve.agent.agents.controller import multi_agents
|
from dbgpt.serve.agent.agents.controller import multi_agents
|
||||||
|
|
||||||
await multi_agents.agent_team_chat_new(
|
await multi_agents.agent_team_chat_new(
|
||||||
new_user_input,
|
new_user_input if new_user_input else user_input,
|
||||||
conv_id,
|
conv_id,
|
||||||
gpts_app,
|
gpts_app,
|
||||||
paren_agent.memory,
|
paren_agent.memory,
|
||||||
|
@ -54,18 +54,28 @@ class IntentRecognitionAction(Action[IntentRecognitionInput]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def ai_out_schema(self) -> Optional[str]:
|
def ai_out_schema(self) -> Optional[str]:
|
||||||
out_put_schema = {
|
|
||||||
"intent": "[The recognized intent is placed here]",
|
|
||||||
"app_code": "[App code in selected intent]",
|
|
||||||
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
|
|
||||||
"ask_user": "If you want the user to supplement slot data, ask the user a question",
|
|
||||||
"user_input": "[Complete instructions generated based on intent and slot]",
|
|
||||||
}
|
|
||||||
if self.language == "en":
|
if self.language == "en":
|
||||||
|
out_put_schema = {
|
||||||
|
"intent": "[The recognized intent is placed here]",
|
||||||
|
"app_code": "[App code in selected intent]",
|
||||||
|
"slots": {
|
||||||
|
"Slot attribute 1 in intent definition": "value",
|
||||||
|
"Slot attribute 2 in intent definition": "value",
|
||||||
|
},
|
||||||
|
"ask_user": "[If you want the user to supplement slot data, ask the user a question]",
|
||||||
|
"user_input": "[Complete instructions generated based on intent and slot]",
|
||||||
|
}
|
||||||
return f"""Please reply in the following json format:
|
return f"""Please reply in the following json format:
|
||||||
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||||
Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501
|
Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501
|
||||||
else:
|
else:
|
||||||
|
out_put_schema = {
|
||||||
|
"intent": "选择的意图放在这里",
|
||||||
|
"app_code": "选择意图对应的Appcode值",
|
||||||
|
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
|
||||||
|
"ask_user": "如果需要用户补充槽位属性的具体值,请向用户进行提问",
|
||||||
|
"user_input": "根据意图和槽位生成完整指令问题",
|
||||||
|
}
|
||||||
return f"""请按如下JSON格式输出:
|
return f"""请按如下JSON格式输出:
|
||||||
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||||
确保输出只有json,且可以被python json.loads加载."""
|
确保输出只有json,且可以被python json.loads加载."""
|
||||||
|
@ -32,21 +32,17 @@ class CommunityStore:
|
|||||||
|
|
||||||
async def build_communities(self):
|
async def build_communities(self):
|
||||||
"""Discover communities."""
|
"""Discover communities."""
|
||||||
community_ids = await (self._community_store_adapter.discover_communities())
|
community_ids = await self._community_store_adapter.discover_communities()
|
||||||
|
|
||||||
# summarize communities
|
# summarize communities
|
||||||
communities = []
|
communities = []
|
||||||
for community_id in community_ids:
|
for community_id in community_ids:
|
||||||
community = await (
|
community = await self._community_store_adapter.get_community(community_id)
|
||||||
self._community_store_adapter.get_community(community_id)
|
|
||||||
)
|
|
||||||
graph = community.data.format()
|
graph = community.data.format()
|
||||||
if not graph:
|
if not graph:
|
||||||
break
|
break
|
||||||
|
|
||||||
community.summary = await (
|
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||||
self._community_summarizer.summarize(graph=graph)
|
|
||||||
)
|
|
||||||
communities.append(community)
|
communities.append(community)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||||
|
Loading…
Reference in New Issue
Block a user