feat(agent):Fix agent bug (#1953)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
明天 2024-09-04 10:59:03 +08:00 committed by GitHub
parent d72bfb2f5f
commit b951b50689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 67 additions and 46 deletions

View File

@ -82,10 +82,10 @@ class ActionOutput(BaseModel):
class Action(ABC, Generic[T]): class Action(ABC, Generic[T]):
"""Base Action class for defining agent actions.""" """Base Action class for defining agent actions."""
def __init__(self): def __init__(self, language: str = "en"):
"""Create an action.""" """Create an action."""
self.resource: Optional[Resource] = None self.resource: Optional[Resource] = None
self.language: str = "en" self.language: str = language
def init_resource(self, resource: Optional[Resource]): def init_resource(self, resource: Optional[Resource]):
"""Initialize the resource.""" """Initialize the resource."""

View File

@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BlankAction(Action): class BlankAction(Action):
"""Blank action class.""" """Blank action class."""
def __init__(self): def __init__(self, **kwargs):
"""Create a blank action.""" """Blank action init."""
super().__init__() super().__init__(**kwargs)
@property @property
def ai_out_schema(self) -> Optional[str]: def ai_out_schema(self) -> Optional[str]:

View File

@ -127,9 +127,6 @@ class AgentManager(BaseComponent):
def list_agents(self): def list_agents(self):
"""Return a list of all registered agents and their descriptions.""" """Return a list of all registered agents and their descriptions."""
result = [] result = []
from datetime import datetime
logger.info(f"List Agent Begin:{datetime.now()}")
for name, value in self._agents.items(): for name, value in self._agents.items():
result.append( result.append(
{ {
@ -137,7 +134,6 @@ class AgentManager(BaseComponent):
"desc": value[1].goal, "desc": value[1].goal,
} }
) )
logger.info(f"List Agent End:{datetime.now()}")
return result return result

View File

@ -673,7 +673,7 @@ class ConversableAgent(Role, Agent):
self.actions = [] self.actions = []
for idx, action in enumerate(actions): for idx, action in enumerate(actions):
if issubclass(action, Action): if issubclass(action, Action):
self.actions.append(action()) self.actions.append(action(language=self.language))
async def _a_append_message( async def _a_append_message(
self, message: AgentMessage, role, sender: Agent self, message: AgentMessage, role, sender: Agent

View File

@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Create a plan action.""" """Create a plan action."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisAgentPlans() self._render_protocol = VisAgentPlans()
@property @property

View File

@ -31,9 +31,9 @@ class SqlInput(BaseModel):
class ChartAction(Action[SqlInput]): class ChartAction(Action[SqlInput]):
"""Chart action class.""" """Chart action class."""
def __init__(self): def __init__(self, **kwargs):
"""Create a chart action.""" """Chart action init."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisChart() self._render_protocol = VisChart()
@property @property

View File

@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
class CodeAction(Action[None]): class CodeAction(Action[None]):
"""Code Action Module.""" """Code Action Module."""
def __init__(self): def __init__(self, **kwargs):
"""Create a code action.""" """Code action init."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisCode() self._render_protocol = VisCode()
self._code_execution_config = {} self._code_execution_config = {}

View File

@ -39,9 +39,9 @@ class ChartItem(BaseModel):
class DashboardAction(Action[List[ChartItem]]): class DashboardAction(Action[List[ChartItem]]):
"""Dashboard action class.""" """Dashboard action class."""
def __init__(self): def __init__(self, **kwargs):
"""Create a dashboard action.""" """Dashboard action init."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisDashboard() self._render_protocol = VisDashboard()
@property @property

View File

@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
class IndicatorAction(Action[IndicatorInput]): class IndicatorAction(Action[IndicatorInput]):
"""Indicator Action.""" """Indicator Action."""
def __init__(self): def __init__(self, **kwargs):
"""Init Indicator Action.""" """Init indicator action."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisApiResponse() self._render_protocol = VisApiResponse()
@property @property

View File

@ -34,9 +34,9 @@ class ToolInput(BaseModel):
class ToolAction(Action[ToolInput]): class ToolAction(Action[ToolInput]):
"""Tool action class.""" """Tool action class."""
def __init__(self): def __init__(self, **kwargs):
"""Create a plugin action.""" """Tool action init."""
super().__init__() super().__init__(**kwargs)
self._render_protocol = VisPlugin() self._render_protocol = VisPlugin()
@property @property

View File

@ -73,8 +73,8 @@ class RetrieverResource(Resource[ResourceParameters]):
prompt_template = f"\nResources-{self.name}:\n {content}" prompt_template = f"\nResources-{self.name}:\n {content}"
prompt_template_zh = f"\n资源-{self.name}:\n {content}" prompt_template_zh = f"\n资源-{self.name}:\n {content}"
if lang == "en": if lang == "en":
return prompt_template.format(content=content), self._get_references(chunks) return prompt_template, self._get_references(chunks)
return prompt_template_zh.format(content=content), self._get_references(chunks) return prompt_template_zh, self._get_references(chunks)
async def get_resources( async def get_resources(
self, self,

View File

@ -548,6 +548,17 @@ async def chat_completions(
headers=headers, headers=headers,
media_type="text/plain", media_type="text/plain",
) )
except Exception as e:
logger.exception(f"Chat Exception!{dialogue}", e)
async def error_text(err_msg):
yield f"data:{err_msg}\n\n"
return StreamingResponse(
error_text(str(e)),
headers=headers,
media_type="text/plain",
)
finally: finally:
# write to recent usage app. # write to recent usage app.
if dialogue.user_name is not None and dialogue.app_code is not None: if dialogue.user_name is not None and dialogue.app_code is not None:

View File

@ -34,6 +34,8 @@ class ChatExcel(BaseChat):
""" """
self.select_param = chat_param["select_param"] self.select_param = chat_param["select_param"]
if not self.select_param:
raise ValueError("Please upload the Excel document you want to talk to")
self.model_name = chat_param["model_name"] self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel chat_param["chat_mode"] = ChatScene.ChatExcel
self.chat_param = chat_param self.chat_param = chat_param

View File

@ -230,7 +230,7 @@ def is_chinese(text):
class ExcelReader: class ExcelReader:
def __init__(self, conv_uid, file_param): def __init__(self, conv_uid: str, file_param: str):
self.conv_uid = conv_uid self.conv_uid = conv_uid
self.file_param = file_param self.file_param = file_param
if isinstance(file_param, str) and os.path.isabs(file_param): if isinstance(file_param, str) and os.path.isabs(file_param):

View File

@ -214,10 +214,15 @@ class ConnectConfigDao(BaseDao):
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None): def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
"""Get db list.""" """Get db list."""
session = self.get_raw_session() session = self.get_raw_session()
if db_name: if db_name and user_id:
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
else: elif user_id:
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
elif db_name:
sql = f"SELECT * FROM connect_config where db_name='{db_name}'" # noqa
else:
sql = f"SELECT * FROM connect_config" # noqa
result = session.execute(text(sql)) result = session.execute(text(sql))
fields = [field[0] for field in result.cursor.description] # type: ignore fields = [field[0] for field in result.cursor.description] # type: ignore
data = [] data = []

View File

@ -50,6 +50,7 @@ class StartAppAction(Action[LinkAppInput]):
**kwargs, **kwargs,
) -> ActionOutput: ) -> ActionOutput:
conv_id = kwargs.get("conv_id") conv_id = kwargs.get("conv_id")
user_input = kwargs.get("user_input")
paren_agent = kwargs.get("paren_agent") paren_agent = kwargs.get("paren_agent")
init_message_rounds = kwargs.get("init_message_rounds") init_message_rounds = kwargs.get("init_message_rounds")
@ -83,7 +84,7 @@ class StartAppAction(Action[LinkAppInput]):
from dbgpt.serve.agent.agents.controller import multi_agents from dbgpt.serve.agent.agents.controller import multi_agents
await multi_agents.agent_team_chat_new( await multi_agents.agent_team_chat_new(
new_user_input, new_user_input if new_user_input else user_input,
conv_id, conv_id,
gpts_app, gpts_app,
paren_agent.memory, paren_agent.memory,

View File

@ -54,18 +54,28 @@ class IntentRecognitionAction(Action[IntentRecognitionInput]):
@property @property
def ai_out_schema(self) -> Optional[str]: def ai_out_schema(self) -> Optional[str]:
out_put_schema = {
"intent": "[The recognized intent is placed here]",
"app_code": "[App code in selected intent]",
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
"ask_user": "If you want the user to supplement slot data, ask the user a question",
"user_input": "[Complete instructions generated based on intent and slot]",
}
if self.language == "en": if self.language == "en":
out_put_schema = {
"intent": "[The recognized intent is placed here]",
"app_code": "[App code in selected intent]",
"slots": {
"Slot attribute 1 in intent definition": "value",
"Slot attribute 2 in intent definition": "value",
},
"ask_user": "[If you want the user to supplement slot data, ask the user a question]",
"user_input": "[Complete instructions generated based on intent and slot]",
}
return f"""Please reply in the following json format: return f"""Please reply in the following json format:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)} {json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501 Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501
else: else:
out_put_schema = {
"intent": "选择的意图放在这里",
"app_code": "选择意图对应的Appcode值",
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
"ask_user": "如果需要用户补充槽位属性的具体值,请向用户进行提问",
"user_input": "根据意图和槽位生成完整指令问题",
}
return f"""请按如下JSON格式输出: return f"""请按如下JSON格式输出:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)} {json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
确保输出只有json且可以被python json.loads加载.""" 确保输出只有json且可以被python json.loads加载."""

View File

@ -32,21 +32,17 @@ class CommunityStore:
async def build_communities(self): async def build_communities(self):
"""Discover communities.""" """Discover communities."""
community_ids = await (self._community_store_adapter.discover_communities()) community_ids = await self._community_store_adapter.discover_communities()
# summarize communities # summarize communities
communities = [] communities = []
for community_id in community_ids: for community_id in community_ids:
community = await ( community = await self._community_store_adapter.get_community(community_id)
self._community_store_adapter.get_community(community_id)
)
graph = community.data.format() graph = community.data.format()
if not graph: if not graph:
break break
community.summary = await ( community.summary = await self._community_summarizer.summarize(graph=graph)
self._community_summarizer.summarize(graph=graph)
)
communities.append(community) communities.append(community)
logger.info( logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..." f"Summarize community {community_id}: " f"{community.summary[:50]}..."