mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
feat(agent):Fix agent bug (#1953)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
@@ -82,10 +82,10 @@ class ActionOutput(BaseModel):
|
||||
class Action(ABC, Generic[T]):
|
||||
"""Base Action class for defining agent actions."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, language: str = "en"):
|
||||
"""Create an action."""
|
||||
self.resource: Optional[Resource] = None
|
||||
self.language: str = "en"
|
||||
self.language: str = language
|
||||
|
||||
def init_resource(self, resource: Optional[Resource]):
|
||||
"""Initialize the resource."""
|
||||
|
@@ -12,9 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
class BlankAction(Action):
|
||||
"""Blank action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a blank action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Blank action init."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
|
@@ -127,9 +127,6 @@ class AgentManager(BaseComponent):
|
||||
def list_agents(self):
|
||||
"""Return a list of all registered agents and their descriptions."""
|
||||
result = []
|
||||
from datetime import datetime
|
||||
|
||||
logger.info(f"List Agent Begin:{datetime.now()}")
|
||||
for name, value in self._agents.items():
|
||||
result.append(
|
||||
{
|
||||
@@ -137,7 +134,6 @@ class AgentManager(BaseComponent):
|
||||
"desc": value[1].goal,
|
||||
}
|
||||
)
|
||||
logger.info(f"List Agent End:{datetime.now()}")
|
||||
return result
|
||||
|
||||
|
||||
|
@@ -673,7 +673,7 @@ class ConversableAgent(Role, Agent):
|
||||
self.actions = []
|
||||
for idx, action in enumerate(actions):
|
||||
if issubclass(action, Action):
|
||||
self.actions.append(action())
|
||||
self.actions.append(action(language=self.language))
|
||||
|
||||
async def _a_append_message(
|
||||
self, message: AgentMessage, role, sender: Agent
|
||||
|
@@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a plan action."""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisAgentPlans()
|
||||
|
||||
@property
|
||||
|
@@ -31,9 +31,9 @@ class SqlInput(BaseModel):
|
||||
class ChartAction(Action[SqlInput]):
|
||||
"""Chart action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a chart action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Chart action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisChart()
|
||||
|
||||
@property
|
||||
|
@@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
class CodeAction(Action[None]):
|
||||
"""Code Action Module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a code action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Code action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisCode()
|
||||
self._code_execution_config = {}
|
||||
|
||||
|
@@ -39,9 +39,9 @@ class ChartItem(BaseModel):
|
||||
class DashboardAction(Action[List[ChartItem]]):
|
||||
"""Dashboard action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a dashboard action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Dashboard action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisDashboard()
|
||||
|
||||
@property
|
||||
|
@@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
|
||||
class IndicatorAction(Action[IndicatorInput]):
|
||||
"""Indicator Action."""
|
||||
|
||||
def __init__(self):
|
||||
"""Init Indicator Action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Init indicator action."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisApiResponse()
|
||||
|
||||
@property
|
||||
|
@@ -34,9 +34,9 @@ class ToolInput(BaseModel):
|
||||
class ToolAction(Action[ToolInput]):
|
||||
"""Tool action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a plugin action."""
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
"""Tool action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisPlugin()
|
||||
|
||||
@property
|
||||
|
@@ -73,8 +73,8 @@ class RetrieverResource(Resource[ResourceParameters]):
|
||||
prompt_template = f"\nResources-{self.name}:\n {content}"
|
||||
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
|
||||
if lang == "en":
|
||||
return prompt_template.format(content=content), self._get_references(chunks)
|
||||
return prompt_template_zh.format(content=content), self._get_references(chunks)
|
||||
return prompt_template, self._get_references(chunks)
|
||||
return prompt_template_zh, self._get_references(chunks)
|
||||
|
||||
async def get_resources(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user