feat(agent):Fix agent bug (#1953)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
明天
2024-09-04 10:59:03 +08:00
committed by GitHub
parent d72bfb2f5f
commit b951b50689
18 changed files with 67 additions and 46 deletions

View File

@@ -82,10 +82,10 @@ class ActionOutput(BaseModel):
class Action(ABC, Generic[T]):
"""Base Action class for defining agent actions."""
def __init__(self):
def __init__(self, language: str = "en"):
"""Create an action."""
self.resource: Optional[Resource] = None
self.language: str = "en"
self.language: str = language
def init_resource(self, resource: Optional[Resource]):
"""Initialize the resource."""

View File

@@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
class BlankAction(Action):
"""Blank action class."""
def __init__(self):
"""Create a blank action."""
super().__init__()
def __init__(self, **kwargs):
"""Blank action init."""
super().__init__(**kwargs)
@property
def ai_out_schema(self) -> Optional[str]:

View File

@@ -127,9 +127,6 @@ class AgentManager(BaseComponent):
def list_agents(self):
"""Return a list of all registered agents and their descriptions."""
result = []
from datetime import datetime
logger.info(f"List Agent Begin:{datetime.now()}")
for name, value in self._agents.items():
result.append(
{
@@ -137,7 +134,6 @@ class AgentManager(BaseComponent):
"desc": value[1].goal,
}
)
logger.info(f"List Agent End:{datetime.now()}")
return result

View File

@@ -673,7 +673,7 @@ class ConversableAgent(Role, Agent):
self.actions = []
for idx, action in enumerate(actions):
if issubclass(action, Action):
self.actions.append(action())
self.actions.append(action(language=self.language))
async def _a_append_message(
self, message: AgentMessage, role, sender: Agent

View File

@@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):
def __init__(self, **kwargs):
"""Create a plan action."""
super().__init__()
super().__init__(**kwargs)
self._render_protocol = VisAgentPlans()
@property

View File

@@ -31,9 +31,9 @@ class SqlInput(BaseModel):
class ChartAction(Action[SqlInput]):
"""Chart action class."""
def __init__(self):
"""Create a chart action."""
super().__init__()
def __init__(self, **kwargs):
"""Chart action init."""
super().__init__(**kwargs)
self._render_protocol = VisChart()
@property

View File

@@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
class CodeAction(Action[None]):
"""Code Action Module."""
def __init__(self):
"""Create a code action."""
super().__init__()
def __init__(self, **kwargs):
"""Code action init."""
super().__init__(**kwargs)
self._render_protocol = VisCode()
self._code_execution_config = {}

View File

@@ -39,9 +39,9 @@ class ChartItem(BaseModel):
class DashboardAction(Action[List[ChartItem]]):
"""Dashboard action class."""
def __init__(self):
"""Create a dashboard action."""
super().__init__()
def __init__(self, **kwargs):
"""Dashboard action init."""
super().__init__(**kwargs)
self._render_protocol = VisDashboard()
@property

View File

@@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
class IndicatorAction(Action[IndicatorInput]):
"""Indicator Action."""
def __init__(self):
"""Init Indicator Action."""
super().__init__()
def __init__(self, **kwargs):
"""Init indicator action."""
super().__init__(**kwargs)
self._render_protocol = VisApiResponse()
@property

View File

@@ -34,9 +34,9 @@ class ToolInput(BaseModel):
class ToolAction(Action[ToolInput]):
"""Tool action class."""
def __init__(self):
"""Create a plugin action."""
super().__init__()
def __init__(self, **kwargs):
"""Tool action init."""
super().__init__(**kwargs)
self._render_protocol = VisPlugin()
@property

View File

@@ -73,8 +73,8 @@ class RetrieverResource(Resource[ResourceParameters]):
prompt_template = f"\nResources-{self.name}:\n {content}"
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
if lang == "en":
return prompt_template.format(content=content), self._get_references(chunks)
return prompt_template_zh.format(content=content), self._get_references(chunks)
return prompt_template, self._get_references(chunks)
return prompt_template_zh, self._get_references(chunks)
async def get_resources(
self,