mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(agent):Fix agent bug (#1953)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
d72bfb2f5f
commit
b951b50689
@ -82,10 +82,10 @@ class ActionOutput(BaseModel):
|
||||
class Action(ABC, Generic[T]):
|
||||
"""Base Action class for defining agent actions."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, language: str = "en"):
|
||||
"""Create an action."""
|
||||
self.resource: Optional[Resource] = None
|
||||
self.language: str = "en"
|
||||
self.language: str = language
|
||||
|
||||
def init_resource(self, resource: Optional[Resource]):
|
||||
"""Initialize the resource."""
|
||||
|
@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
class BlankAction(Action):
|
||||
"""Blank action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a blank action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Blank action init."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
|
@ -127,9 +127,6 @@ class AgentManager(BaseComponent):
|
||||
def list_agents(self):
|
||||
"""Return a list of all registered agents and their descriptions."""
|
||||
result = []
|
||||
from datetime import datetime
|
||||
|
||||
logger.info(f"List Agent Begin:{datetime.now()}")
|
||||
for name, value in self._agents.items():
|
||||
result.append(
|
||||
{
|
||||
@ -137,7 +134,6 @@ class AgentManager(BaseComponent):
|
||||
"desc": value[1].goal,
|
||||
}
|
||||
)
|
||||
logger.info(f"List Agent End:{datetime.now()}")
|
||||
return result
|
||||
|
||||
|
||||
|
@ -673,7 +673,7 @@ class ConversableAgent(Role, Agent):
|
||||
self.actions = []
|
||||
for idx, action in enumerate(actions):
|
||||
if issubclass(action, Action):
|
||||
self.actions.append(action())
|
||||
self.actions.append(action(language=self.language))
|
||||
|
||||
async def _a_append_message(
|
||||
self, message: AgentMessage, role, sender: Agent
|
||||
|
@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a plan action."""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisAgentPlans()
|
||||
|
||||
@property
|
||||
|
@ -31,9 +31,9 @@ class SqlInput(BaseModel):
|
||||
class ChartAction(Action[SqlInput]):
|
||||
"""Chart action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a chart action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Chart action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisChart()
|
||||
|
||||
@property
|
||||
|
@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
class CodeAction(Action[None]):
|
||||
"""Code Action Module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a code action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Code action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisCode()
|
||||
self._code_execution_config = {}
|
||||
|
||||
|
@ -39,9 +39,9 @@ class ChartItem(BaseModel):
|
||||
class DashboardAction(Action[List[ChartItem]]):
|
||||
"""Dashboard action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a dashboard action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Dashboard action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisDashboard()
|
||||
|
||||
@property
|
||||
|
@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
|
||||
class IndicatorAction(Action[IndicatorInput]):
|
||||
"""Indicator Action."""
|
||||
|
||||
def __init__(self):
|
||||
"""Init Indicator Action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Init indicator action."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisApiResponse()
|
||||
|
||||
@property
|
||||
|
@ -34,9 +34,9 @@ class ToolInput(BaseModel):
|
||||
class ToolAction(Action[ToolInput]):
|
||||
"""Tool action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a plugin action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Tool action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisPlugin()
|
||||
|
||||
@property
|
||||
|
@ -73,8 +73,8 @@ class RetrieverResource(Resource[ResourceParameters]):
|
||||
prompt_template = f"\nResources-{self.name}:\n {content}"
|
||||
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
|
||||
if lang == "en":
|
||||
return prompt_template.format(content=content), self._get_references(chunks)
|
||||
return prompt_template_zh.format(content=content), self._get_references(chunks)
|
||||
return prompt_template, self._get_references(chunks)
|
||||
return prompt_template_zh, self._get_references(chunks)
|
||||
|
||||
async def get_resources(
|
||||
self,
|
||||
|
@ -548,6 +548,17 @@ async def chat_completions(
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Chat Exception!{dialogue}", e)
|
||||
|
||||
async def error_text(err_msg):
|
||||
yield f"data:{err_msg}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
error_text(str(e)),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
finally:
|
||||
# write to recent usage app.
|
||||
if dialogue.user_name is not None and dialogue.app_code is not None:
|
||||
|
@ -34,6 +34,8 @@ class ChatExcel(BaseChat):
|
||||
"""
|
||||
|
||||
self.select_param = chat_param["select_param"]
|
||||
if not self.select_param:
|
||||
raise ValueError("Please upload the Excel document you want to talk to!")
|
||||
self.model_name = chat_param["model_name"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||
self.chat_param = chat_param
|
||||
|
@ -230,7 +230,7 @@ def is_chinese(text):
|
||||
|
||||
|
||||
class ExcelReader:
|
||||
def __init__(self, conv_uid, file_param):
|
||||
def __init__(self, conv_uid: str, file_param: str):
|
||||
self.conv_uid = conv_uid
|
||||
self.file_param = file_param
|
||||
if isinstance(file_param, str) and os.path.isabs(file_param):
|
||||
|
@ -214,10 +214,15 @@ class ConnectConfigDao(BaseDao):
|
||||
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
|
||||
"""Get db list."""
|
||||
session = self.get_raw_session()
|
||||
if db_name:
|
||||
if db_name and user_id:
|
||||
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
|
||||
else:
|
||||
elif user_id:
|
||||
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
|
||||
elif db_name:
|
||||
sql = f"SELECT * FROM connect_config where db_name='{db_name}'" # noqa
|
||||
else:
|
||||
sql = f"SELECT * FROM connect_config" # noqa
|
||||
|
||||
result = session.execute(text(sql))
|
||||
fields = [field[0] for field in result.cursor.description] # type: ignore
|
||||
data = []
|
||||
|
@ -50,6 +50,7 @@ class StartAppAction(Action[LinkAppInput]):
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
conv_id = kwargs.get("conv_id")
|
||||
user_input = kwargs.get("user_input")
|
||||
paren_agent = kwargs.get("paren_agent")
|
||||
init_message_rounds = kwargs.get("init_message_rounds")
|
||||
|
||||
@ -83,7 +84,7 @@ class StartAppAction(Action[LinkAppInput]):
|
||||
from dbgpt.serve.agent.agents.controller import multi_agents
|
||||
|
||||
await multi_agents.agent_team_chat_new(
|
||||
new_user_input,
|
||||
new_user_input if new_user_input else user_input,
|
||||
conv_id,
|
||||
gpts_app,
|
||||
paren_agent.memory,
|
||||
|
@ -54,18 +54,28 @@ class IntentRecognitionAction(Action[IntentRecognitionInput]):
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
out_put_schema = {
|
||||
"intent": "[The recognized intent is placed here]",
|
||||
"app_code": "[App code in selected intent]",
|
||||
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
|
||||
"ask_user": "If you want the user to supplement slot data, ask the user a question",
|
||||
"user_input": "[Complete instructions generated based on intent and slot]",
|
||||
}
|
||||
if self.language == "en":
|
||||
out_put_schema = {
|
||||
"intent": "[The recognized intent is placed here]",
|
||||
"app_code": "[App code in selected intent]",
|
||||
"slots": {
|
||||
"Slot attribute 1 in intent definition": "value",
|
||||
"Slot attribute 2 in intent definition": "value",
|
||||
},
|
||||
"ask_user": "[If you want the user to supplement slot data, ask the user a question]",
|
||||
"user_input": "[Complete instructions generated based on intent and slot]",
|
||||
}
|
||||
return f"""Please reply in the following json format:
|
||||
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||
Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501
|
||||
else:
|
||||
out_put_schema = {
|
||||
"intent": "选择的意图放在这里",
|
||||
"app_code": "选择意图对应的Appcode值",
|
||||
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
|
||||
"ask_user": "如果需要用户补充槽位属性的具体值,请向用户进行提问",
|
||||
"user_input": "根据意图和槽位生成完整指令问题",
|
||||
}
|
||||
return f"""请按如下JSON格式输出:
|
||||
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||
确保输出只有json,且可以被python json.loads加载."""
|
||||
|
@ -32,21 +32,17 @@ class CommunityStore:
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await (self._community_store_adapter.discover_communities())
|
||||
community_ids = await self._community_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await (
|
||||
self._community_store_adapter.get_community(community_id)
|
||||
)
|
||||
community = await self._community_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
|
||||
community.summary = await (
|
||||
self._community_summarizer.summarize(graph=graph)
|
||||
)
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
|
Loading…
Reference in New Issue
Block a user