feat(agent): Multi agent sdk (#976)

Co-authored-by: xtyuns <xtyuns@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: qidanrui <qidanrui@gmail.com>
This commit is contained in:
明天
2023-12-27 16:25:55 +08:00
committed by GitHub
parent 69fb97e508
commit 9aec636b02
79 changed files with 6359 additions and 121 deletions

View File

@@ -1,11 +1 @@
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
from .commands.command import execute_command, get_command
from .commands.generator import PluginPromptGenerator
from .commands.disply_type.show_chart_gen import static_message_img_path
from .common.schema import Status, PluginStorageType
from .commands.command_mange import ApiCall
from .commands.command import execute_command
from .common.schema import PluginStorageType

188
dbgpt/agent/agents/agent.py Normal file
View File

@@ -0,0 +1,188 @@
from __future__ import annotations
import dataclasses
from collections import defaultdict
from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, List, Optional, Union
from ..memory.gpts_memory import GptsMemory
from dbgpt.core import LLMClient
from dbgpt.core.interface.llm import ModelMetadata
class Agent:
"""
An interface for AI agent.
An agent can communicate with other agents and perform actions.
"""
def __init__(
self,
name: str,
memory: GptsMemory,
describe: str,
):
"""
Args:
name (str): name of the agent.
"""
self._name = name
self._describe = describe
# the agent's collective memory
self._memory = memory
@property
def name(self):
"""Get the name of the agent."""
return self._name
@property
def memory(self):
return self._memory
@property
def describe(self):
"""Get the name of the agent."""
return self._describe
async def a_send(
self,
message: Union[Dict, str],
recipient: Agent,
reviewer: Agent,
request_reply: Optional[bool] = True,
is_recovery: Optional[bool] = False,
):
"""(Abstract async method) Send a message to another agent."""
async def a_receive(
self,
message: Optional[Dict],
sender: Agent,
reviewer: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
"""(Abstract async method) Receive a message from another agent."""
async def a_review(self, message: Union[Dict, str], censored: Agent):
"""
Args:
message:
censored:
Returns:
"""
def reset(self):
"""(Abstract method) Reset the agent."""
async def a_generate_reply(
self,
message: Optional[Dict],
sender: Agent,
reviewer: Agent,
silent: Optional[bool] = False,
**kwargs,
) -> Union[str, Dict, None]:
"""(Abstract async method) Generate a reply based on the received messages.
Args:
messages (Optional[Dict]): a dict of messages received from other agents.
sender: sender of an Agent instance.
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
async def a_reasoning_reply(
self, messages: Optional[List[Dict]]
) -> Union[str, Dict, None]:
"""
Based on the requirements of the current agent, reason about the current task goal through LLM
Args:
message:
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
async def a_action_reply(
self,
messages: Optional[str],
sender: Agent,
**kwargs,
) -> Union[str, Dict, None]:
"""
Parse the inference results for the current target and execute the inference results using the current agent's executor
Args:
messages (list[dict]): a list of messages received.
sender: sender of an Agent instance.
**kwargs:
Returns:
str or dict or None: the agent action reply. If None, no reply is generated.
"""
async def a_verify_reply(
self,
message: Optional[Dict],
sender: Agent,
reviewer: Agent,
**kwargs,
) -> Union[str, Dict, None]:
"""
Verify whether the current execution results meet the target expectations
Args:
messages:
sender:
**kwargs:
Returns:
"""
@dataclass
class AgentResource:
type: str
name: str
introduce: str
@staticmethod
def from_dict(d: Dict[str, Any]) -> Optional[AgentResource]:
if d is None:
return None
return AgentResource(
type=d.get("type"),
name=d.get("name"),
introduce=d.get("introduce"),
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclass
class AgentContext:
conv_id: str
llm_provider: LLMClient
gpts_name: Optional[str] = None
resource_db: Optional[AgentResource] = None
resource_knowledge: Optional[AgentResource] = None
resource_internet: Optional[AgentResource] = None
llm_models: Optional[List[Union[ModelMetadata, str]]] = None
model_priority: Optional[dict] = None
agents: Optional[List[str]] = None
max_chat_round: Optional[int] = 100
max_retry_round: Optional[int] = 10
max_new_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.5
allow_format_str_template: Optional[bool] = False
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)

View File

@@ -0,0 +1,48 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from .agent import Agent
from .expand.code_assistant_agent import CodeAssistantAgent
from .expand.dashboard_assistant_agent import DashboardAssistantAgent
from .expand.data_scientist_agent import DataScientistAgent
from .expand.sql_assistant_agent import SQLAssistantAgent
def get_all_subclasses(cls):
all_subclasses = []
direct_subclasses = cls.__subclasses__()
all_subclasses.extend(direct_subclasses)
for subclass in direct_subclasses:
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses
class AgentsMange:
def __init__(self):
self._agents = defaultdict()
def register_agent(self, cls):
self._agents[cls.NAME] = cls
def get_by_name(self, name: str) -> Optional[Type[Agent]]:
if name not in self._agents:
raise ValueError(f"Agent:{name} not register!")
return self._agents[name]
def get_describe_by_name(self, name: str) -> Optional[Type[Agent]]:
return self._agents[name].DEFAULT_DESCRIBE
def all_agents(self):
result = {}
for name, cls in self._agents.items():
result[name] = cls.DEFAULT_DESCRIBE
return result
agent_mange = AgentsMange()
agent_mange.register_agent(CodeAssistantAgent)
agent_mange.register_agent(DashboardAssistantAgent)
agent_mange.register_agent(DataScientistAgent)
agent_mange.register_agent(SQLAssistantAgent)

View File

@@ -0,0 +1,768 @@
import asyncio
import copy
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Type, Union
from dbgpt.agent.agents.llm.llm_client import AIWrapper
from dbgpt.core.awel import BaseOperator
from dbgpt.core.interface.message import ModelMessageRoleType
from dbgpt.util.error_types import LLMChatError
from ..memory.base import GptsMessage
from ..memory.gpts_memory import GptsMemory
from .agent import Agent, AgentContext
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
logger = logging.getLogger(__name__)
class ConversableAgent(Agent):
DEFAULT_SYSTEM_MESSAGE = "You are a helpful AI Assistant."
MAX_CONSECUTIVE_AUTO_REPLY = (
100 # maximum number of consecutive auto replies (subject to future change)
)
def __init__(
self,
name: str,
describe: str = DEFAULT_SYSTEM_MESSAGE,
memory: GptsMemory = GptsMemory(),
agent_context: AgentContext = None,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
default_auto_reply: Optional[Union[str, Dict, None]] = "",
):
super().__init__(name, memory, describe)
# a dictionary of conversations, default value is list
# self._oai_messages = defaultdict(list)
self._oai_system_message = [
{"content": system_message, "role": ModelMessageRoleType.SYSTEM}
]
self._rely_messages = []
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: x.get("content") == "TERMINATE")
)
self.client = AIWrapper(llm_client=agent_context.llm_provider)
self.human_input_mode = human_input_mode
self._max_consecutive_auto_reply = (
max_consecutive_auto_reply
if max_consecutive_auto_reply is not None
else self.MAX_CONSECUTIVE_AUTO_REPLY
)
self.consecutive_auto_reply_counter: int = 0
self._current_retry_counter: int = 0
## By default, the memory of 4 rounds of dialogue is retained.
self.dialogue_memory_rounds = 5
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
self._max_consecutive_auto_reply_dict = {}
self.agent_context = agent_context
def register_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func: Callable,
position: int = 0,
config: Optional[Any] = None,
reset_config: Optional[Callable] = None,
):
if not isinstance(trigger, (type, str, Agent, Callable, list)):
raise ValueError(
"trigger must be a class, a string, an agent, a callable or a list."
)
self._reply_func_list.insert(
position,
{
"trigger": trigger,
"reply_func": reply_func,
"config": copy.copy(config),
"init_config": config,
"reset_config": reset_config,
},
)
@property
def system_message(self):
"""Return the system message."""
return self._oai_system_message[0]["content"]
def update_system_message(self, system_message: str):
"""Update the system message.
Args:
system_message (str): system message for the ChatCompletion inference.
"""
self._oai_system_message[0]["content"] = system_message
def update_max_consecutive_auto_reply(
self, value: int, sender: Optional[Agent] = None
):
"""Update the maximum number of consecutive auto replies.
Args:
value (int): the maximum number of consecutive auto replies.
sender (Agent): when the sender is provided, only update the max_consecutive_auto_reply for that sender.
"""
if sender is None:
self._max_consecutive_auto_reply = value
for k in self._max_consecutive_auto_reply_dict:
self._max_consecutive_auto_reply_dict[k] = value
else:
self._max_consecutive_auto_reply_dict[sender] = value
def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int:
"""The maximum number of consecutive auto replies."""
return (
self._max_consecutive_auto_reply
if sender is None
else self._max_consecutive_auto_reply_dict[sender]
)
@property
# def chat_messages(self) -> Dict[Agent, List[Dict]]:
def chat_messages(self) -> Any:
"""A dictionary of conversations from agent to list of messages."""
all_gpts_messages = self.memory.message_memory.get_by_agent(
self.agent_context.conv_id, self.name
)
return self._gpts_message_to_ai_message(all_gpts_messages)
def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]:
"""The last message exchanged with the agent.
Args:
agent (Agent): The agent in the conversation.
If None and more than one agent's conversations are found, an error will be raised.
If None and only one conversation is found, the last message of the only conversation will be returned.
Returns:
The last message exchanged with the agent.
"""
if agent is None:
all_oai_messages = self.chat_messages()
n_conversations = len(all_oai_messages)
if n_conversations == 0:
return None
if n_conversations == 1:
for conversation in all_oai_messages.values():
return conversation[-1]
raise ValueError(
"More than one conversation is found. Please specify the sender to get the last message."
)
agent_messages = self.memory.message_memory.get_between_agents(
self.agent_context.conv_id, self.name, agent.name
)
if len(agent_messages) <= 0:
raise KeyError(
f"The agent '{agent.name}' is not present in any conversation. No history available for this agent."
)
return self._gpts_message_to_ai_message(agent_messages)[-1]
@staticmethod
def _message_to_dict(message: Union[Dict, str]):
"""Convert a message to a dictionary.
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary.
"""
if isinstance(message, str):
return {"content": message}
elif isinstance(message, dict):
return message
else:
return dict(message)
def append_rely_message(self, message: Union[Dict, str], role) -> None:
message = self._message_to_dict(message)
message["role"] = role
# create oai message to be appended to the oai conversation that can be passed to oai directly.
self._rely_messages.append(message)
def reset_rely_message(self) -> None:
# create oai message to be appended to the oai conversation that can be passed to oai directly.
self._rely_messages = []
def append_message(self, message: Optional[Dict], role, sender: Agent) -> bool:
"""
Put the received message content into the collective message memory
Args:
conv_id:
message:
role:
sender:
Returns:
"""
oai_message = {
k: message[k]
for k in (
"content",
"function_call",
"name",
"context",
"action_report",
"review_info",
"current_gogal",
"model_name",
)
if k in message
}
if "content" not in oai_message:
if "function_call" in oai_message:
oai_message[
"content"
] = None # if only function_call is provided, content will be set to None.
else:
return False
oai_message["role"] = "function" if message.get("role") == "function" else role
if "function_call" in oai_message:
oai_message[
"role"
] = "assistant" # only messages with role 'assistant' can have a function call.
oai_message["function_call"] = dict(oai_message["function_call"])
gpts_message: GptsMessage = GptsMessage(
conv_id=self.agent_context.conv_id,
sender=sender.name,
receiver=self.name,
role=role,
rounds=self.consecutive_auto_reply_counter,
current_gogal=oai_message.get("current_gogal", None),
content=oai_message.get("content", None),
context=json.dumps(oai_message["context"])
if "context" in oai_message
else None,
review_info=json.dumps(oai_message["review_info"])
if "review_info" in oai_message
else None,
action_report=json.dumps(oai_message["action_report"])
if "action_report" in oai_message
else None,
model_name=oai_message.get("model_name", None),
)
self.memory.message_memory.append(gpts_message)
return True
async def a_send(
self,
message: Optional[Dict],
recipient: Agent,
reviewer: "Agent",
request_reply: Optional[bool] = True,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
await recipient.a_receive(
message=message,
sender=self,
reviewer=reviewer,
request_reply=request_reply,
silent=silent,
is_recovery=is_recovery,
)
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(
colored(sender.name, "yellow"),
"(to",
f"{self.name})-[{message.get('model_name', '')}]:\n",
flush=True,
)
message = self._message_to_dict(message)
if message.get("role") == "function":
func_print = (
f"***** Response from calling function \"{message['name']}\" *****"
)
print(colored(func_print, "green"), flush=True)
print(message["content"], flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
else:
content = json.dumps(message.get("content"))
if content is not None:
if "context" in message:
content = AIWrapper.instantiate(
content,
message["context"],
self.agent_context.allow_format_str_template,
)
print(content, flush=True)
if "function_call" in message:
function_call = dict(message["function_call"])
func_print = f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
"Arguments: \n",
function_call.get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
review_info = message.get("review_info", None)
if review_info:
approve_print = f">>>>>>>>{sender.name} Review info: \n {'Pass' if review_info.get('approve') else 'Reject'}.{review_info.get('comments')}"
print(colored(approve_print, "green"), flush=True)
action_report = message.get("action_report", None)
if action_report:
action_print = f">>>>>>>>{sender.name} Action report: \n{'execution succeeded' if action_report['is_exe_success'] else 'execution failed'},\n{action_report['content']}"
print(colored(action_print, "blue"), flush=True)
print("\n", "-" * 80, flush=True, sep="")
def _process_received_message(self, message, sender, silent):
message = self._message_to_dict(message)
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self.append_message(message, None, sender)
if not valid:
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
)
if not silent:
self._print_received_message(message, sender)
async def a_review(self, message: Union[Dict, str], censored: "Agent"):
return True, None
def _process_action_reply(self, action_reply: Optional[Union[str, Dict, None]]):
if isinstance(action_reply, str):
return {"is_exe_success": True, "content": action_reply}
elif isinstance(action_reply, dict):
return action_reply
elif action_reply is None:
return None
else:
return dict(action_reply)
def _gpts_message_to_ai_message(
self, gpts_messages: Optional[List[GptsMessage]]
) -> List[Dict]:
oai_messages: List[Dict] = []
###Based on the current agent, all messages received are user, and all messages sent are assistant.
for item in gpts_messages:
role = ""
if item.role:
role = role
else:
if item.receiver == self.name:
role = ModelMessageRoleType.HUMAN
elif item.sender == self.name:
role = ModelMessageRoleType.AI
else:
continue
oai_messages.append(
{
"content": item.content,
"role": role,
"context": json.loads(item.context)
if item.context is not None
else None,
"review_info": json.loads(item.review_info)
if item.review_info is not None
else None,
"action_report": json.loads(item.action_report)
if item.action_report is not None
else None,
}
)
return oai_messages
def process_now_message(self, sender, current_gogal: Optional[str] = None):
### Convert and tailor the information in collective memory into contextual memory available to the current Agent
current_gogal_messages = self._gpts_message_to_ai_message(
self.memory.message_memory.get_between_agents(
self.agent_context.conv_id, self.name, sender.name, current_gogal
)
)
### relay messages
cut_messages = []
cut_messages.extend(self._rely_messages)
if len(current_gogal_messages) < self.dialogue_memory_rounds:
cut_messages.extend(current_gogal_messages)
else:
### TODO 基于token预算来分配历史信息
cut_messages.extend(current_gogal_messages[:2])
# end_round = self.dialogue_memory_rounds - 2
cut_messages.extend(current_gogal_messages[-3:])
return cut_messages
async def a_system_fill_param(self):
self.update_system_message(self.DEFAULT_SYSTEM_MESSAGE)
async def a_generate_reply(
self,
message: Optional[Dict],
sender: Agent,
reviewer: "Agent",
silent: Optional[bool] = False,
):
## 0.New message build
new_message = {}
new_message["context"] = message.get("context", None)
new_message["current_gogal"] = message.get("current_gogal", None)
## 1.LLM Reasonging
await self.a_system_fill_param()
await asyncio.sleep(5) ##TODO Rate limit reached for gpt-3.5-turbo
current_messages = self.process_now_message(
sender, message.get("current_gogal", None)
)
if current_messages is None or len(current_messages) <= 0:
current_messages = [message]
ai_reply, model = await self.a_reasoning_reply(messages=current_messages)
new_message["content"] = ai_reply
new_message["model_name"] = model
## 2.Review of reasoning results
approve = True
comments = None
if reviewer and ai_reply:
approve, comments = await reviewer.a_review(ai_reply, self)
new_message["review_info"] = {"approve": approve, "comments": comments}
## 3.reasoning result action
if approve:
excute_reply = await self.a_action_reply(
message=ai_reply,
sender=sender,
reviewer=reviewer,
)
new_message["action_report"] = self._process_action_reply(excute_reply)
## 4.verify reply
return await self.a_verify_reply(new_message, sender, reviewer)
async def a_receive(
self,
message: Optional[Dict],
sender: Agent,
reviewer: "Agent",
request_reply: Optional[bool] = True,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
if not is_recovery:
self.consecutive_auto_reply_counter = (
sender.consecutive_auto_reply_counter + 1
)
self._process_received_message(message, sender, silent)
else:
logger.info("Process received retrying")
self.consecutive_auto_reply_counter = sender.consecutive_auto_reply_counter
if request_reply is False or request_reply is None:
logger.info("Messages that do not require a reply")
return
verify_paas, reply = await self.a_generate_reply(
message=message, sender=sender, reviewer=reviewer, silent=silent
)
if verify_paas:
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
self._current_retry_counter += 1
logger.info(
"The generated answer failed to verify, so send it to yourself for optimization."
)
### TODO 自优化最大轮次后,异常退出
await sender.a_send(
message=reply, recipient=self, reviewer=reviewer, silent=silent
)
async def a_verify(self, message: Optional[Dict]):
return True, message
async def _optimization_check(self, message: Optional[Dict]):
need_retry = False
fail_reason = ""
## Check approval results
if "review_info" in message:
review_info = message.get("review_info")
if review_info and not review_info.get("approve"):
fail_reason = review_info.get("comments")
need_retry = True
## Check execution results
if "action_report" in message and not need_retry:
action_report = message["action_report"]
if action_report:
if not action_report["is_exe_success"]:
fail_reason = action_report["content"]
need_retry = True
else:
if (
not action_report["content"]
or len(action_report["content"].strip()) < 1
):
fail_reason = f'The code is executed successfully but the output:{action_report["content"]} is invalid or empty. Please reanalyze the target to generate valid code.'
need_retry = True
## Verify the correctness of the results
if not need_retry:
verify_pass, verfiy_msg = await self.a_verify(message)
if not verify_pass:
need_retry = True
fail_reason = verfiy_msg
return need_retry, fail_reason
async def a_verify_reply(
self, message: Optional[Dict], sender: "Agent", reviewer: "Agent", **kwargs
) -> Union[str, Dict, None]:
need_retry, fail_reason = await self._optimization_check(message)
if need_retry:
## Before optimization, wrong answers are stored in memory
await self.a_send(
message=message,
recipient=sender,
reviewer=reviewer,
request_reply=False,
)
## Send error messages to yourself for retrieval optimization and increase the number of retrievals
retry_message = {}
retry_message["context"] = message.get("context", None)
retry_message["current_gogal"] = message.get("current_gogal", None)
retry_message["model_name"] = message.get("model_name", None)
retry_message["content"] = fail_reason
## Use the original sender to send the retry message to yourself
return False, retry_message
else:
## The verification passes, the message is released, and the number of retries returns to 0.
self._current_retry_counter = 0
return True, message
async def a_retry_chat(
self,
recipient: "ConversableAgent",
agent_map: dict,
reviewer: "Agent" = None,
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
**context,
):
last_message: GptsMessage = self.memory.message_memory.get_last_message(
self.agent_context.conv_id
)
self.consecutive_auto_reply_counter = last_message.rounds
message = {
"content": last_message.content,
"context": json.loads(last_message.context)
if last_message.context
else None,
"current_gogal": last_message.current_gogal,
"review_info": json.loads(last_message.review_info)
if last_message.review_info
else None,
"action_report": json.loads(last_message.action_report)
if last_message.action_report
else None,
"model_name": last_message.model_name,
}
await self.a_send(
message,
recipient,
reviewer,
request_reply=True,
silent=silent,
is_recovery=True,
)
async def a_initiate_chat(
self,
recipient: "ConversableAgent",
reviewer: "Agent" = None,
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
**context,
):
await self.a_send(
{
"content": self.generate_init_message(**context),
"current_gogal": self.generate_init_message(**context),
},
recipient,
reviewer,
request_reply=True,
silent=silent,
)
def reset(self):
"""Reset the agent."""
self.clear_history()
self.reset_consecutive_auto_reply_counter()
for reply_func_tuple in self._reply_func_list:
if reply_func_tuple["reset_config"] is not None:
reply_func_tuple["reset_config"](reply_func_tuple["config"])
else:
reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"])
def reset_consecutive_auto_reply_counter(self):
"""Reset the consecutive_auto_reply_counter of the sender."""
self.consecutive_auto_reply_counter = 0
def clear_history(self, agent: Optional[Agent] = None):
"""Clear the chat history of the agent.
Args:
agent: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
"""
pass
# if agent is None:
# self._oai_messages.clear()
# else:
# self._oai_messages[agent].clear()
def _get_model_priority(self):
llm_models_priority = self.agent_context.model_priority
if llm_models_priority:
if self.name in llm_models_priority:
model_priority = llm_models_priority[self.name]
else:
model_priority = llm_models_priority["default"]
return model_priority
else:
return None
def _filter_health_models(self, need_uses: Optional[list]):
all_modes = self.agent_context.llm_models
can_uses = []
for item in all_modes:
if item.model in need_uses:
can_uses.append(item)
return can_uses
def _select_llm_model(self, old_model: str = None):
"""
LLM model selector, currently only supports manual selection, more strategies will be opened in the future
Returns:
"""
all_modes = self.agent_context.llm_models
model_priority = self._get_model_priority()
if model_priority and len(model_priority) > 0:
can_uses = self._filter_health_models(model_priority)
if len(can_uses) > 0:
return can_uses[0].model
now_model = all_modes[0]
if old_model:
filtered_list = [item for item in all_modes if item.model != old_model]
if filtered_list and len(filtered_list) >= 1:
now_model = filtered_list[0]
return now_model.model
async def a_reasoning_reply(
self, messages: Optional[List[Dict]] = None
) -> Union[str, Dict, None]:
"""(async) Reply based on the conversation history and the sender.
Args:
messages: a list of messages in the conversation history.
default_reply (str or dict): default reply.
sender: sender of an Agent instance.
exclude: a list of functions to exclude.
Returns:
str or dict or None: reply. None if no reply is generated.
"""
last_model = None
last_err = None
retry_count = 0
while retry_count < 3:
llm_model = self._select_llm_model(last_model)
try:
response = await self.client.create(
context=messages[-1].pop("context", None),
messages=self._oai_system_message + messages,
llm_model=llm_model,
max_new_tokens=self.agent_context.max_new_tokens,
temperature=self.agent_context.temperature,
)
return response, llm_model
except LLMChatError as e:
logger.error(f"model:{llm_model} generate Failed!{str(e)}")
retry_count += 1
last_model = llm_model
last_err = str(e)
await asyncio.sleep(10) ## TODORate limit reached for gpt-3.5-turbo
if last_err:
raise ValueError(last_err)
async def a_action_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: "Agent" = None,
exclude: Optional[List[Callable]] = None,
**kwargs,
) -> Union[str, Dict, None]:
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
if asyncio.coroutines.iscoroutinefunction(reply_func):
final, reply = await reply_func(
self,
message=message,
sender=sender,
reviewer=reviewer,
config=reply_func_tuple["config"],
)
else:
final, reply = reply_func(
self,
message=message,
sender=sender,
reviewer=reviewer,
config=reply_func_tuple["config"],
)
if final:
return reply
return self._default_auto_reply
def _match_trigger(self, trigger, sender):
"""Check if the sender matches the trigger."""
if trigger is None:
return sender is None
elif isinstance(trigger, str):
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
elif isinstance(trigger, Agent):
return trigger == sender
elif isinstance(trigger, Callable):
return trigger(sender)
elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger)
else:
raise ValueError(f"Unsupported trigger type: {type(trigger)}")
def generate_init_message(self, **context) -> Union[str, Dict]:
"""Generate the initial message for the agent.
Override this function to customize the initial message based on user's request.
If not overridden, "message" needs to be provided in the context.
"""
return context["message"]

View File

@@ -0,0 +1,261 @@
import json
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt.core.awel import BaseOperator
from dbgpt.util.code_utils import UNKNOWN, execute_code, extract_code, infer_lang
from dbgpt.util.string_utils import str_to_bool
from ...memory.gpts_memory import GptsMemory
from ..agent import Agent, AgentContext
from ..base_agent import ConversableAgent
from dbgpt.core.interface.message import ModelMessageRoleType
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
class CodeAssistantAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
*** IMPORTANT REMINDER ***
- The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.Don't ask users to copy and paste results. Instead, the "Print" function must be used for output when relevant.
- When using code, you must indicate the script type in the code block. Please don't include multiple code blocks in one response.
- If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line.
- If you receive user input that indicates an error in the code execution, fix the error and output the complete code again. It is recommended to use the complete code rather than partial code or code changes. If the error cannot be fixed, or the task is not resolved even after the code executes successfully, analyze the problem, revisit your assumptions, gather additional information you need from historical conversation records, and consider trying a different approach
- Unless necessary, give priority to solving problems with python code. If it involves downloading files or storing data locally, please use "Print" to output the full file path of the stored data and a brief introduction to the data.
- The output content of the "Print" function will be passed to other LLM agents. Please ensure that the information output by the "Print" function has been streamlined as much as possible and only retains key data information.
- The code is executed without user participation. It is forbidden to use methods that will block the process or need to be shut down, such as the plt.show() method of matplotlib.pyplot as plt.
"""
CHECK_RESULT_SYSTEM_MESSAGE = f"""
You are an expert in analyzing the results of task execution.
Your responsibility is to analyze the task goals and execution results provided by the user, and then make a judgment. You need to answer according to the following rules:
Rule 1: Analysis and judgment only focus on whether the execution result is related to the task goal and whether it is answering the target question, but does not pay attention to whether the result content is reasonable or the correctness of the scope boundary of the answer content.
Rule 2: If the target is a calculation type, there is no need to verify the correctness of the calculation of the values in the execution result.
As long as the execution result meets the task goal according to the above rules, True will be returned, otherwise False will be returned. Only returns True or False.
"""
NAME = "CodeEngineer"
DEFAULT_DESCRIBE = """According to the current planning steps, write python/shell code to solve the problem, such as: data crawling, data sorting and conversion, etc. Wrap the code in a code block of the specified script type. Users cannot modify your code. So don't suggest incomplete code that needs to be modified by others.
Don't include multiple code blocks in one response. Don't ask others to copy and paste the results
"""
def __init__(
self,
agent_context: AgentContext,
memory: Optional[GptsMemory] = None,
describe: Optional[str] = DEFAULT_DESCRIBE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
llm_config (dict): llm inference configuration.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](conversable_agent#__init__).
"""
super().__init__(
name=self.NAME,
memory=memory,
describe=describe,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
### register code funtion
self.register_reply(Agent, CodeAssistantAgent.generate_code_execution_reply)
def _vis_code_idea(self, code, exit_success, log, language):
param = {}
param["exit_success"] = exit_success
param["language"] = language
param["code"] = code
param["log"] = log
return f"```vis-code\n{json.dumps(param)}\n```"
async def generate_code_execution_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
code_execution_config = (
config if config is not None else self._code_execution_config
)
if code_execution_config is False:
return False, None
last_n_messages = code_execution_config.pop("last_n_messages", 1)
# iterate through the last n messages reversly
# if code blocks are found, execute the code blocks and return the output
# if no code blocks are found, continue
code_blocks = extract_code(message)
if len(code_blocks) < 1:
self.send(
f"Failed to get valid answer,{message}", sender, None, silent=True
)
elif len(code_blocks) > 1 and code_blocks[0][0] == UNKNOWN:
self.send(
f"Failed to get valid answer,{message}", self, reviewer, silent=True
)
# found code blocks, execute code and push "last_n_messages" back
exitcode, logs = self.execute_code_blocks(code_blocks)
code_execution_config["last_n_messages"] = last_n_messages
exit_success = True if exitcode == 0 else False
if exit_success:
return True, {
"is_exe_success": exit_success,
"content": f"{logs}",
"view": self._vis_code_idea(
code_blocks, exit_success, logs, code_blocks[0][0]
),
}
else:
return True, {
"is_exe_success": exit_success,
"content": f"exitcode: {exitcode} (execution failed)\n {logs}",
"view": self._vis_code_idea(
code_blocks, exit_success, logs, code_blocks[0][0]
),
}
async def a_verify(self, message: Optional[Dict]):
self.update_system_message(self.CHECK_RESULT_SYSTEM_MESSAGE)
task_gogal = message.get("current_gogal", None)
action_report = message.get("action_report", None)
task_result = ""
if action_report:
task_result = action_report.get("content", "")
check_reult, model = await self.a_reasoning_reply(
[
{
"role": ModelMessageRoleType.HUMAN,
"content": f"""Please understand the following task objectives and results and give your judgment:
Task Gogal: {task_gogal}
Execution Result: {task_result}
Only True or False is returned.
""",
}
]
)
sucess = str_to_bool(check_reult)
fail_reason = None
if sucess == False:
fail_reason = "The execution result of the code you wrote is judged as not answering the task question. Please re-understand and complete the task."
return sucess, fail_reason
@property
def use_docker(self) -> Union[bool, str, None]:
"""Bool value of whether to use docker to execute the code,
or str value of the docker image name to use, or None when code execution is disabled.
"""
return (
None
if self._code_execution_config is False
else self._code_execution_config.get("use_docker")
)
def run_code(self, code, **kwargs):
"""Run the code and return the result.
Override this function to modify the way to run the code.
Args:
code (str): the code to be executed.
**kwargs: other keyword arguments.
Returns:
A tuple of (exitcode, logs, image).
exitcode (int): the exit code of the code execution.
logs (str): the logs of the code execution.
image (str or None): the docker image used for the code execution.
"""
return execute_code(code, **kwargs)
def execute_code_blocks(self, code_blocks):
"""Execute the code blocks and return the result."""
logs_all = ""
for i, code_block in enumerate(code_blocks):
lang, code = code_block
if not lang:
lang = infer_lang(code)
print(
colored(
f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...",
"red",
),
flush=True,
)
if lang in ["bash", "shell", "sh"]:
exitcode, logs, image = self.run_code(
code, lang=lang, **self._code_execution_config
)
elif lang in ["python", "Python"]:
if code.startswith("# filename: "):
filename = code[11 : code.find("\n")].strip()
else:
filename = None
exitcode, logs, image = self.run_code(
code,
lang="python",
filename=filename,
**self._code_execution_config,
)
else:
# In case the language is not supported, we return an error message.
exitcode, logs, image = (
1,
f"unknown language {lang}",
None,
)
# raise NotImplementedError
if image is not None:
self._code_execution_config["use_docker"] = image
logs_all += "\n" + logs
if exitcode != 0:
return exitcode, logs_all
return exitcode, logs_all

View File

@@ -0,0 +1,121 @@
import json
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt.agent.plugin.commands.command_mange import ApiCall
from dbgpt.util.json_utils import find_json_objects
from ...memory.gpts_memory import GptsMemory
from ..agent import Agent, AgentContext
from ..base_agent import ConversableAgent
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
from dbgpt._private.config import Config
from dbgpt.core.awel import BaseOperator
CFG = Config()
class DashboardAssistantAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """Please read the historical messages, collect the generated JSON data of the analysis sql results, and integrate it into the following JSON format to return:
[
{{
"display_type":"The chart rendering method selected for the task 1 sql",
"sql": "Analysis sql of the step task 1",
"thought":"thoughts summary to say to user"
}},
{{
"display_type":"The chart rendering method selected for the task 2 sql",
"sql": "Analysis sql of the step task 2",
"thought":"thoughts summary to say to user"
}}
]
Make sure the response is correct json and can be parsed by Python json.loads.
"""
DEFAULT_DESCRIBE = "Integrate analytical data generated by data scientists into a required format for building sales reports."
NAME = "Reporter"
def __init__(
self,
memory: GptsMemory,
agent_context: AgentContext,
describe: Optional[str] = DEFAULT_DESCRIBE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
**kwargs,
):
super().__init__(
name=self.NAME,
memory=memory,
describe=describe,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self.register_reply(Agent, DashboardAssistantAgent.generate_dashboard_reply)
self.agent_context = agent_context
self.db_connect = CFG.LOCAL_DB_MANAGE.get_connect(
self.agent_context.resource_db.get("name", None)
)
async def generate_dashboard_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
json_objects = find_json_objects(message)
plan_objects = []
fail_reason = (
"Please recheck your answerno usable plans generated in correct format"
)
json_count = len(json_objects)
rensponse_succ = True
view = None
content = None
if json_count != 1:
### Answer failed, turn on automatic repair
fail_reason += f"There are currently {json_count} json contents"
rensponse_succ = False
else:
try:
chart_objs = json_objects[0]
content = json.dumps(chart_objs)
vis_client = ApiCall()
view = vis_client.display_dashboard_vis(
chart_objs, self.db_connect.run_to_df
)
except Exception as e:
fail_reason += f"Return json structure error and cannot be converted to a sql-rendered chart{str(e)}"
rensponse_succ = False
if not rensponse_succ:
content = fail_reason
return True, {
"is_exe_success": rensponse_succ,
"content": content,
"view": view,
}

View File

@@ -0,0 +1,162 @@
import json
import logging
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt._private.config import Config
from dbgpt.agent.plugin.commands.command_mange import ApiCall
from dbgpt.core.awel import BaseOperator
from dbgpt.util.json_utils import find_json_objects
from ...memory.gpts_memory import GptsMemory
from ..agent import Agent, AgentContext
from ..base_agent import ConversableAgent
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
CFG = Config()
logger = logging.getLogger(__name__)
class DataScientistAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant who is good at writing SQL for various databases.
Based on the given data structure information, use the correct {dialect} SQL to analyze and solve the task, subject to the following constraints.
Data structure information:
{data_structure}
constraint:
1. Please choose the best one from the display methods given below for data display, and put the type name into the name parameter value that returns the required format. If you can't find the most suitable display method, use Table as the display method. , the available data display methods are as follows: {disply_type}
2. Please check the sql you generated. It is forbidden to use column names that do not exist in the table, and it is forbidden to make up fields and tables that do not exist.
3. Pay attention to the data association between tables and tables, and you can use multiple tables at the same time to generate a SQL.
Please think step by step and return it in the following json format
{{
"display_type":"The chart rendering method currently selected by SQL",
"sql": "Analysis sql of the current step task",
"thought":"Summary of thoughts to the user"
}}
Make sure the response is correct json and can be parsed by Python json.loads.
"""
DEFAULT_DESCRIBE = """It is possible use the local database to generate analysis SQL to obtain data based on the table structure, and at the same time generate visual charts of the corresponding data. Note that only local databases can be queried."""
NAME = "DataScientist"
def __init__(
self,
memory: GptsMemory,
agent_context: AgentContext,
describe: Optional[str] = DEFAULT_DESCRIBE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
**kwargs,
):
super().__init__(
name=self.NAME,
memory=memory,
describe=describe,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self.register_reply(Agent, DataScientistAgent.generate_analysis_chart_reply)
self.agent_context = agent_context
self.db_connect = CFG.LOCAL_DB_MANAGE.get_connect(
self.agent_context.resource_db.get("name", None)
)
async def a_system_fill_param(self):
params = {
"data_structure": self.db_connect.get_table_info(),
"disply_type": ApiCall.default_chart_type_promot(),
"dialect": self.db_connect.db_type,
}
self.update_system_message(self.DEFAULT_SYSTEM_MESSAGE.format(**params))
async def generate_analysis_chart_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
json_objects = find_json_objects(message)
fail_reason = "The required json format answer was not generated."
json_count = len(json_objects)
rensponse_succ = True
view = None
content = None
if json_count != 1:
### Answer failed, turn on automatic repair
rensponse_succ = False
else:
try:
content = json.dumps(json_objects[0])
except Exception as e:
content = (
f"There is a format problem with the json of the answer{str(e)}"
)
rensponse_succ = False
try:
vis_client = ApiCall()
view = vis_client.display_only_sql_vis(
json_objects[0], self.db_connect.run_to_df
)
except Exception as e:
view = f"```vis-convert-error\n{content}\n```"
return True, {
"is_exe_success": rensponse_succ,
"content": content,
"view": view,
}
async def a_verify(self, message: Optional[Dict]):
action_reply = message.get("action_report", None)
if action_reply.get("is_exe_success", False) == False:
return (
False,
f"Please check your answer, {action_reply.get('content', '')}.",
)
action_reply_obj = json.loads(action_reply.get("content", ""))
sql = action_reply_obj.get("sql", None)
if not sql:
return (
False,
"Please check your answer, the sql information that needs to be generated is not found.",
)
try:
columns, values = self.db_connect.query_ex(sql)
if not values or len(values) <= 0:
return (
False,
"Please check your answer, the generated SQL cannot query any data.",
)
else:
logger.info(
f"reply check success! There are {len(values)} rows of data"
)
return True, None
except Exception as e:
return (
False,
f"SQL execution error, please re-read the historical information to fix this SQL. The error message is as follows:{str(e)}",
)

View File

@@ -0,0 +1,125 @@
import json
import logging
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt.core.awel import BaseOperator
from dbgpt.util.json_utils import find_json_objects
from ...memory.gpts_memory import GptsMemory
from ..agent import Agent, AgentContext
from ..base_agent import ConversableAgent
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
logger = logging.getLogger(__name__)
class PluginAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """
You are a useful artificial intelligence tool agent assistant.
You have been assigned the following list of tools, please select the most appropriate tool to complete the task based on the current user's goals:
{tool_list}
*** IMPORTANT REMINDER ***
Please read the parameter definition of the tool carefully and extract the specific parameters required to execute the tool from the user gogal.
Please output the selected tool name and specific parameter information in json in the following required format, refer to the following example:
user: Search for the latest hot financial news
assisant: {{
"tool_name":"The chart rendering method currently selected by SQL",
"args": "{{
"query": "latest hot financial news",
}}",
"thought":"I will use the google-search tool to search for the latest hot financial news."
}}
Please think step by step and return it in the following json format
{{
"tool_name":"The chart rendering method currently selected by SQL",
"args": "{{
"arg name1": "arg value1",
"arg name2": "arg value2",
}}",
"thought":"Summary of thoughts to the user"
}}
Make sure the response is correct json and can be parsed by Python json.loads.
"""
DEFAULT_DESCRIBE = """You can use the following tools to complete the task objectives, tool information: {tool-infos}"""
NAME = "ToolScientist"
def __init__(
self,
memory: GptsMemory,
agent_context: AgentContext,
describe: Optional[str] = DEFAULT_DESCRIBE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
**kwargs,
):
super().__init__(
name=self.NAME,
memory=memory,
describe=describe,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self.register_reply(Agent, PluginAgent.tool_call)
self.agent_context = agent_context
async def a_system_fill_param(self):
params = {
"tool_infos": self.db_connect.get_table_info(),
"dialect": self.db_connect.db_type,
}
self.update_system_message(self.DEFAULT_SYSTEM_MESSAGE.format(**params))
async def tool_call(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
json_objects = find_json_objects(message)
fail_reason = "The required json format answer was not generated."
json_count = len(json_objects)
rensponse_succ = True
view = None
content = None
if json_count != 1:
### Answer failed, turn on automatic repair
rensponse_succ = False
else:
try:
view = ""
except Exception as e:
view = f"```vis-convert-error\n{content}\n```"
return True, {
"is_exe_success": rensponse_succ,
"content": content,
"view": view,
}

View File

@@ -0,0 +1,111 @@
from typing import Callable, Dict, List, Literal, Optional, Union
from dbgpt.agent.agents.base_agent import ConversableAgent
from dbgpt.core.awel import BaseOperator
from dbgpt.agent.plugin.commands.command_mange import ApiCall
from ...memory.gpts_memory import GptsMemory
from ..agent import Agent, AgentContext
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
from dbgpt._private.config import Config
CFG = Config()
class SQLAssistantAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM.
AssistantAgent is a subclass of ConversableAgent configured with a default system message.
The default system message is designed to solve a task with LLM,
including suggesting python code blocks and debugging.
`human_input_mode` is default to "NEVER"
and `code_execution_config` is default to False.
This agent doesn't execute code by default, and expects the user to execute the code.
"""
DEFAULT_SYSTEM_MESSAGE = """You are a SQL expert and answer user questions by writing SQL using the following data structures.
Use the following data structure to write the best mysql SQL for the user's problem.
Data Structure information:
{data_structure}
- Please ensure that the SQL is correct and high-performance.
- Please be careful not to use tables or fields that are not mentioned.
- Make sure to only return SQL.
"""
DEFAULT_DESCRIBE = """You can analyze data with a known structure through SQL and generate a single analysis chart for a given target. Please note that you do not have the ability to obtain and process data and can only perform data analysis based on a given structure. If the task goal cannot or does not need to be solved by SQL analysis, please do not use"""
NAME = "SqlEngineer"
def __init__(
self,
memory: GptsMemory,
agent_context: AgentContext,
describe: Optional[str] = DEFAULT_DESCRIBE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
**kwargs,
):
super().__init__(
name=self.NAME,
memory=memory,
describe=describe,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self.register_reply(Agent, SQLAssistantAgent.generate_analysis_chart_reply)
self.agent_context = agent_context
self.db_connect = CFG.LOCAL_DB_MANAGE.get_connect(
self.agent_context.resource_db.get("name", None)
)
async def a_system_fill_param(self):
params = {
"data_structure": self.db_connect.get_table_info(),
"disply_type": ApiCall.default_chart_type_promot(),
"dialect": self.db_connect.db_type,
}
self.update_system_message(self.DEFAULT_SYSTEM_MESSAGE.format(**params))
async def generate_analysis_chart_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
# iterate through the last n messages reversly
# if code blocks are found, execute the code blocks and return the output
# if no code blocks are found, continue
self.api_call = ApiCall(display_registry=[])
if self.api_call.check_have_plugin_call(message):
exit_success = True
try:
chart_vis = self.api_call.display_sql_llmvis(
message, self.db_connect.run_to_df
)
except Exception as e:
err_info = f"{str(e)}"
exit_success = False
output = chart_vis if exit_success else err_info
else:
exit_success = False
output = message
return True, {"is_exe_success": exit_success, "content": f"{output}"}

View File

@@ -0,0 +1,35 @@
from typing import Dict, Optional, Union
from dbgpt.core.interface.llm import ModelRequest
def _build_model_request(
input_value: Union[Dict, str], model: Optional[str] = None
) -> ModelRequest:
"""Build model request from input value.
Args:
input_value(str or dict): input value
model(Optional[str]): model name
Returns:
ModelRequest: model request, pass to llm client
"""
if isinstance(input_value, str):
return ModelRequest._build(model, input_value)
elif isinstance(input_value, dict):
parm = {
"model": input_value.get("model"),
"messages": input_value.get("messages"),
"temperature": input_value.get("temperature", None),
"max_new_tokens": input_value.get("max_new_tokens", None),
"stop": input_value.get("stop", None),
"stop_token_ids": input_value.get("stop_token_ids", None),
"context_len": input_value.get("context_len", None),
"echo": input_value.get("echo", None),
"span_id": input_value.get("span_id", None),
}
return ModelRequest(**parm)
else:
raise ValueError("Build model request input Error!")

View File

@@ -0,0 +1,194 @@
import json
import logging
import traceback
from typing import Callable, Dict, Optional, Union
from dbgpt.core import LLMClient
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import root_tracer
from ..llm.llm import _build_model_request
logger = logging.getLogger(__name__)
class AIWrapper:
cache_path_root: str = ".cache"
extra_kwargs = {
"cache_seed",
"filter_func",
"allow_format_str_template",
"context",
"llm_model",
}
def __init__(
self, llm_client: LLMClient, output_parser: Optional[BaseOutputParser] = None
):
self.llm_echo = False
self.model_cache_enable = False
self._llm_client = llm_client
self._output_parser = output_parser or BaseOutputParser(is_stream_out=False)
@classmethod
def instantiate(
cls,
template: Optional[Union[str, Callable]] = None,
context: Optional[Dict] = None,
allow_format_str_template: Optional[bool] = False,
):
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
"""Prime the create_config with additional_kwargs."""
# Validate the config
prompt = create_config.get("prompt")
messages = create_config.get("messages")
if (prompt is None) == (messages is None):
raise ValueError(
"Either prompt or messages should be in create config but not both."
)
context = extra_kwargs.get("context")
if context is None:
# No need to instantiate if no context is provided.
return create_config
# Instantiate the prompt or messages
allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
# Make a copy of the config
params = create_config.copy()
if prompt is not None:
# Instantiate the prompt
params["prompt"] = self.instantiate(
prompt, context, allow_format_str_template
)
elif context:
# Instantiate the messages
params["messages"] = [
{
**m,
"content": self.instantiate(
m["content"], context, allow_format_str_template
),
}
if m.get("content")
else m
for m in messages
]
return params
def _separate_create_config(self, config):
"""Separate the config into create_config and extra_kwargs."""
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
return create_config, extra_kwargs
def _get_key(self, config):
"""Get a unique identifier of a configuration.
Args:
config (dict or list): A configuration.
Returns:
tuple: A unique identifier which can be used as a key for a dict.
"""
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
copied = False
for key in NON_CACHE_KEY:
if key in config:
config, copied = config.copy() if not copied else config, True
config.pop(key)
return json.dumps(config, sort_keys=True)
async def create(self, **config):
# merge the input config with the i-th config in the config list
full_config = {**config}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
cache_seed = extra_kwargs.get("cache_seed", 66)
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
llm_model = extra_kwargs.get("llm_model")
if context:
use_cache = context.get("use_cache", True)
if not use_cache:
cache_seed = None
# # Try to load the response from cache
# if cache_seed is not None:
# with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
# # Try to get the response from cache
# key = self._get_key(params)
# response = cache.get(key, None)
# if response is not None:
# # check the filter
# pass_filter = filter_func is None or filter_func(context=context, response=response)
# if pass_filter :
# # Return the response if it passes the filter
# # TODO: add response.cost
# return response
try:
response = await self._completions_create(llm_model, params)
except LLMChatError as e:
logger.debug(f"{llm_model} generate failed!{str(e)}")
raise e
else:
pass_filter = filter_func is None or filter_func(
context=context, response=response
)
if pass_filter:
# Return the response if it passes the filter
return response
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
def _llm_messages_convert(self, params):
gpts_messages = params["messages"]
### TODO
return gpts_messages
async def _completions_create(self, llm_model, params):
payload = {
"model": llm_model,
"prompt": params.get("prompt"),
"messages": self._llm_messages_convert(params),
"temperature": float(params.get("temperature")),
"max_new_tokens": int(params.get("max_new_tokens")),
"echo": self.llm_echo,
}
logger.info(f"Request: \n{payload}")
span = root_tracer.start_span(
"Agent.llm_client.no_streaming_call",
metadata=self._get_span_metadata(payload),
)
payload["span_id"] = span.span_id
payload["model_cache_enable"] = self.model_cache_enable
try:
model_request = _build_model_request(payload)
model_output = await self._llm_client.generate(model_request)
parsed_output = self._output_parser.parse_model_nostream_resp(
model_output, "###"
)
return parsed_output
except Exception as e:
logger.error(
f"Call LLMClient error, {str(e)}, detail: {traceback.format_exc()}"
)
raise LLMChatError(original_exception=e) from e
finally:
span.end()

View File

@@ -0,0 +1,452 @@
import json
import logging
import random
import re
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from dbgpt.core.awel import BaseOperator
from dbgpt.util.string_utils import str_to_bool
from ..common.schema import Status
from ..memory.gpts_memory import GptsMemory, GptsMessage, GptsPlan
from .agent import Agent, AgentContext
from .base_agent import ConversableAgent
from dbgpt.core.interface.message import ModelMessageRoleType
logger = logging.getLogger(__name__)
@dataclass
class PlanChat:
"""(In preview) A group chat class that contains the following data fields:
- agents: a list of participating agents.
- messages: a list of messages in the group chat.
- max_round: the maximum number of rounds.
- admin_name: the name of the admin agent if there is one. Default is "Admin".
KeyBoardInterrupt will make the admin agent take over.
- func_call_filter: whether to enforce function call filter. Default is True.
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
"""
agents: List[Agent]
messages: List[Dict]
max_round: int = 50
admin_name: str = "Admin"
func_call_filter: bool = True
speaker_selection_method: str = "auto"
allow_repeat_speaker: bool = True
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
return [agent.name for agent in self.agents]
def reset(self):
"""Reset the group chat."""
self.messages.clear()
def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]
# def select_speaker_msg(self, agents: List[Agent], task_context: str, models: Optional[List[dict]]):
# f"""Return the message for selecting the next speaker."""
# return f"""You are in a role play game. Read and understand the following tasks and assign the appropriate role to complete them.
# Task content: {task_context}
# You can fill the following roles: {[agent.name for agent in agents]},
# Please answer only the role name, such as: {agents[0].name}"""
def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles(agents)}.
Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. The role can be selected repeatedly.Only return the role."""
async def a_select_speaker(
self,
last_speaker: Agent,
selector: ConversableAgent,
now_plan_context: str,
pre_allocated: str = None,
):
"""Select the next speaker."""
if (
self.speaker_selection_method.lower()
not in self._VALID_SPEAKER_SELECTION_METHODS
):
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)
agents = self.agents
n_agents = len(agents)
# Warn if GroupChat is underpopulated
if (
n_agents <= 2
and self.speaker_selection_method.lower() != "round_robin"
and self.allow_repeat_speaker
):
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
"Or, use direct communication instead."
)
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = (
agents
if self.allow_repeat_speaker
else [agent for agent in agents if agent != last_speaker]
)
# if self.speaker_selection_method.lower() == "manual":
# selected_agent = self.manual_select_speaker(agents)
# if selected_agent:
# return selected_agent
# elif self.speaker_selection_method.lower() == "round_robin":
# return self.next_agent(last_speaker, agents)
# elif self.speaker_selection_method.lower() == "random":
# return random.choice(agents)
if pre_allocated:
# Preselect speakers
logger.info(f"Preselect speakers:{pre_allocated}")
name = pre_allocated
model = None
else:
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name, model = await selector.a_generate_oai_reply(
self.messages
+ [
{
"role": ModelMessageRoleType.HUMAN,
"content": f"""Read and understand the following task content and assign the appropriate role to complete the task.
Task content: {now_plan_context}
select the role from: {[agent.name for agent in agents]},
Please only return the role, such as: {agents[0].name}""",
}
]
)
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents), model
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)
# Return the result
try:
return self.agent_by_name(name), model
except Exception as e:
logger.warning(f"auto select speaker failed!{str(e)}")
return self.next_agent(last_speaker, agents), model
def _mentioned_agents(self, message_content: str, agents: List[Agent]) -> Dict:
"""
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
"""
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(
re.findall(regex, " " + message_content + " ")
) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = self.agents
roles = []
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
)
roles.append(f"{agent.name}: {agent.describe}")
return "\n".join(roles)
def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
"""Return the next agent in the list."""
if agents == self.agents:
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
else:
offset = self.agent_names.index(agent.name) + 1
for i in range(len(self.agents)):
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]
class PlanChatManager(ConversableAgent):
"""(In preview) A chat manager agent that can manage a group chat of multiple agents."""
NAME = "plan_manager"
def __init__(
self,
plan_chat: PlanChat,
planner: Agent,
memory: GptsMemory,
agent_context: "AgentContext",
# unlimited consecutive auto reply by default
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
human_input_mode: Optional[str] = "NEVER",
describe: Optional[str] = "Plan chat manager.",
**kwargs,
):
super().__init__(
name=self.NAME,
describe=describe,
memory=memory,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
# Order of register_reply is important.
# Allow async chat if initiated using a_initiate_chat
self.register_reply(
Agent,
PlanChatManager.a_run_chat,
config=plan_chat,
reset_config=PlanChat.reset,
)
self.plan_chat = plan_chat
self.planner = planner
async def a_reasoning_reply(
self, messages: Optional[List[Dict]] = None
) -> Union[str, Dict, None]:
if messages is None or len(messages) <= 0:
message = None
return None, None
else:
message = messages[-1]
self.plan_chat.messages.append(message)
return message["content"], None
async def a_process_rely_message(
self, conv_id: str, now_plan: GptsPlan, speaker: ConversableAgent
):
rely_prompt = ""
speaker.reset_rely_message()
if now_plan.rely and len(now_plan.rely) > 0:
rely_tasks_list = now_plan.rely.split(",")
rely_tasks = self.memory.plans_memory.get_by_conv_id_and_num(
conv_id, rely_tasks_list
)
if rely_tasks:
rely_prompt = "Read the result data of the dependent steps in the above historical message to complete the current goal:"
for rely_task in rely_tasks:
speaker.append_rely_message(
{"content": rely_task.sub_task_content},
ModelMessageRoleType.HUMAN,
)
speaker.append_rely_message(
{"content": rely_task.result}, ModelMessageRoleType.AI
)
return rely_prompt
async def a_verify_reply(
self, message: Optional[Dict], sender: "Agent", reviewer: "Agent", **kwargs
) -> Union[str, Dict, None]:
return True, message
async def a_run_chat(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Agent = None,
config: Optional[PlanChat] = None,
):
"""Run a group chat asynchronously."""
speaker = sender
groupchat = config
final_message = None
for i in range(groupchat.max_round):
plans = self.memory.plans_memory.get_by_conv_id(self.agent_context.conv_id)
if not plans or len(plans) <= 0:
###Have no plan, generate a new plan TODO init plan use planmanger
await self.a_send(
{"content": message, "current_gogal": message},
self.planner,
reviewer,
request_reply=False,
)
verify_pass, reply = await self.planner.a_generate_reply(
{"content": message, "current_gogal": message}, self, reviewer
)
await self.planner.a_send(
message=reply,
recipient=self,
reviewer=reviewer,
request_reply=False,
)
if not verify_pass:
final_message = reply
if i > 10:
break
else:
todo_plans = [
plan
for plan in plans
if plan.state in [Status.TODO.value, Status.RETRYING.value]
]
if not todo_plans or len(todo_plans) <= 0:
### The plan has been fully executed and a success message is sent to the user.
# complete
complete_message = {"content": f"TERMINATE", "is_exe_success": True}
return True, complete_message
else:
now_plan: GptsPlan = todo_plans[0]
# There is no need to broadcast the message to other agents, it will be automatically obtained from the collective memory according to the dependency relationship.
try:
if Status.RETRYING.value == now_plan.state:
if now_plan.retry_times <= now_plan.max_retry_times:
current_goal_message = {
"content": now_plan.result,
"current_gogal": now_plan.sub_task_content,
"context": {
"plan_task": now_plan.sub_task_content,
"plan_task_num": now_plan.sub_task_num,
},
}
else:
self.memory.plans_memory.update_task(
self.agent_context.conv_id,
now_plan.sub_task_num,
Status.FAILED.value,
now_plan.retry_times + 1,
speaker.name,
"",
plan_result,
)
faild_report = {
"content": f"ReTask [{now_plan.sub_task_content}] was retried more than the maximum number of times and still failed.{now_plan.result}",
"is_exe_success": False,
}
return True, faild_report
else:
current_goal_message = {
"content": now_plan.sub_task_content,
"current_gogal": now_plan.sub_task_content,
"context": {
"plan_task": now_plan.sub_task_content,
"plan_task_num": now_plan.sub_task_num,
},
}
# select the next speaker
speaker, model = await groupchat.a_select_speaker(
speaker,
self,
now_plan.sub_task_content,
now_plan.sub_task_agent,
)
# Tell the speaker the dependent history information
rely_prompt = await self.a_process_rely_message(
conv_id=self.agent_context.conv_id,
now_plan=now_plan,
speaker=speaker,
)
current_goal_message["content"] = (
rely_prompt + current_goal_message["content"]
)
is_recovery = False
if message == current_goal_message["content"]:
is_recovery = True
await self.a_send(
message=current_goal_message,
recipient=speaker,
reviewer=reviewer,
request_reply=False,
is_recovery=is_recovery,
)
verify_pass, reply = await speaker.a_generate_reply(
current_goal_message, self, reviewer
)
plan_result = ""
if verify_pass:
if reply:
action_report = reply.get("action_report", None)
if action_report:
plan_result = action_report.get("content", "")
### The current planned Agent generation verification is successful
##Plan executed successfully
self.memory.plans_memory.complete_task(
self.agent_context.conv_id,
now_plan.sub_task_num,
plan_result,
)
await speaker.a_send(
reply, self, reviewer, request_reply=False
)
else:
plan_result = reply["content"]
self.memory.plans_memory.update_task(
self.agent_context.conv_id,
now_plan.sub_task_num,
Status.RETRYING.value,
now_plan.retry_times + 1,
speaker.name,
"",
plan_result,
)
except Exception as e:
logger.exception(
f"An exception was encountered during the execution of the current plan step.{str(e)}"
)
error_report = {
"content": f"An exception was encountered during the execution of the current plan step.{str(e)}",
"is_exe_success": False,
}
return True, error_report
return True, {
"content": f"Maximum number of dialogue rounds exceeded.{self.MAX_CONSECUTIVE_AUTO_REPLY}",
"is_exe_success": False,
}

View File

@@ -0,0 +1,191 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
from dbgpt._private.config import Config
from dbgpt.agent.agents.plan_group_chat import PlanChat
from dbgpt.agent.common.schema import Status
from dbgpt.core.awel import BaseOperator
from dbgpt.util.json_utils import find_json_objects
from ..memory.gpts_memory import GptsMemory, GptsPlan
from .agent import Agent, AgentContext
from .base_agent import ConversableAgent
CFG = Config()
class PlannerAgent(ConversableAgent):
"""Planner agent, realizing task goal planning decomposition through LLM"""
DEFAULT_SYSTEM_MESSAGE = """
你是一个任务规划专家!您需要理解下面每个智能代理和他们的能力,却确保在没有用户帮助下,使用给出的资源,通过协调下面可用智能代理来回答用户问题。
请发挥你LLM的知识和理解能力理解用户问题的意图和目标生成一个可用智能代理协作的任务计划解决用户问题。
可用资源:
{all_resources}
可用智能代理:
{agents}
*** 重要的提醒 ***
- 充分理解用户目标然后进行必要的步骤拆分,拆分需要保证逻辑顺序和精简,尽量把可以一起完成的内容合并再一个步骤,拆分后每个子任务步骤都将是一个需要智能代理独立完成的目标, 请确保每个子任务目标内容简洁明了
- 请确保只使用上面提到的智能代理,并且可以只使用其中需要的部分,严格根据描述能力和限制分配给合适的步骤,每个智能代理都可以重复使用
- 给子任务分配智能代理是需要考虑整体计划,确保和前后依赖步骤的关系,数据可以被传递使用
- 根据用户目标的实际需要使用提供的资源来协助生成计划步骤,不要使用不需要的资源
- 每个步骤最好是使用一种资源完成一个子目标,如果当前目标可以分解为同类型的多个子任务,可以生成相互不依赖的并行任务
- 数据库资源只需要使用结构生成SQL数据获取交给用户执行
- 尽量合并有顺序依赖的连续相同步骤,如果用户目标无拆分必要,可以生成内容为用户目标的单步任务
- 仔细检查计划,确保计划完整的包含了用户问题所涉及的所有信息,并且最终能完成目标,确认每个步骤是否包含了需要用到的资源信息,如URL、资源名等.
具体任务计划的生成可参考如下例子:
user:help me build a sales report summarizing our key metrics and trends
assisant:[
{{
"serial_number": "1",
"agent": "DataScientist",
"content": "Retrieve total sales, average sales, and number of transactions grouped by "product_category"'.",
"rely": ""
}},
{{
"serial_number": "2",
"agent": "DataScientist",
"content": "Retrieve monthly sales and transaction number trends.",
"rely": ""
}},
{{
"serial_number": "3",
"agent": "DataScientist",
"content": "Count the number of transactions with "pay_status" as "paid" among all transactions to retrieve the sales conversion rate.",
"rely": ""
}},
{{
"serial_number": "4",
"agent": "Reporter",
"content": "Integrate analytical data into the format required to build sales reports.",
"rely": "1,2,3"
}}
]
请一步步思考并以如下json格式返回你的行动计划内容:
[{{
"serial_number":"0",
"agent": "用来完成当前步骤的智能代理",
"content": "当前步骤的任务内容,确保可以被智能代理执行",
"rely":"当前任务执行依赖的其他任务serial_number, 如:1,2,3, 无依赖为空"
}}]
确保回答的json可以被Python代码的json.loads函数加载解析.
"""
REPAIR_SYSTEM_MESSAGE = """
您是规划专家!现在你需要利用你的专业知识,仔细检查已生成的计划,进行重新评估和分析,确保计划的每个步骤都是清晰完整的,可以被智能代理理解的,解决当前计划中遇到的问题!并按要求返回新的计划内容。
"""
NAME = "Planner"
def __init__(
self,
memory: GptsMemory,
plan_chat: PlanChat,
agent_context: AgentContext,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
**kwargs,
):
super().__init__(
name=self.NAME,
memory=memory,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
**kwargs,
)
self.plan_chat = plan_chat
### register planning funtion
self.register_reply(Agent, PlannerAgent._a_planning)
def build_param(self, agent_context: AgentContext):
resources = []
if agent_context.resource_db is not None:
db_connect = CFG.LOCAL_DB_MANAGE.get_connect(
agent_context.resource_db.get("name")
)
resources.append(
f"{agent_context.resource_db.get('type')}:{agent_context.resource_db.get('name')}\n{db_connect.get_table_info()}"
)
if agent_context.resource_knowledge is not None:
resources.append(
f"{agent_context.resource_knowledge.get('type')}:{agent_context.resource_knowledge.get('name')}\n{agent_context.resource_knowledge.get('introduce')}"
)
if agent_context.resource_internet is not None:
resources.append(
f"{agent_context.resource_internet.get('type')}:{agent_context.resource_internet.get('name')}\n{agent_context.resource_internet.get('introduce')}"
)
return {
"all_resources": "\n".join([f"- {item}" for item in resources]),
"agents": "\n".join(
[f"- {item.name}:{item.describe}" for item in self.plan_chat.agents]
),
}
async def a_system_fill_param(self):
params = self.build_param(self.agent_context)
self.update_system_message(self.DEFAULT_SYSTEM_MESSAGE.format(**params))
async def _a_planning(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
json_objects = find_json_objects(message)
plan_objects = []
fail_reason = (
"Please recheck your answerno usable plans generated in correct format"
)
json_count = len(json_objects)
rensponse_succ = True
if json_count != 1:
### Answer failed, turn on automatic repair
fail_reason += f"There are currently {json_count} json contents"
rensponse_succ = False
else:
try:
for item in json_objects[0]:
plan = GptsPlan(
conv_id=self.agent_context.conv_id,
sub_task_num=item.get("serial_number"),
sub_task_content=item.get("content"),
)
plan.resource_name = item.get("resource")
plan.max_retry_times = self.agent_context.max_retry_round
plan.sub_task_agent = item.get("agent")
plan.sub_task_title = item.get("content")
plan.rely = item.get("rely")
plan.retry_times = 0
plan.status = Status.TODO.value
plan_objects.append(plan)
except Exception as e:
fail_reason += f"Return json structure error and cannot be converted to a usable plan{str(e)}"
rensponse_succ = False
if rensponse_succ:
if len(plan_objects) > 0:
### Delete the old plan every time before saving it
self.memory.plans_memory.remove_by_conv_id(self.agent_context.conv_id)
self.memory.plans_memory.batch_save(plan_objects)
content = "\n".join(
[
"{},{}".format(index + 1, item.get("content"))
for index, item in enumerate(json_objects[0])
]
)
else:
content = fail_reason
return True, {
"is_exe_success": rensponse_succ,
"content": content,
"view": content,
}

View File

@@ -0,0 +1,89 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from ..memory.gpts_memory import GptsMemory
from .agent import Agent, AgentContext
from .base_agent import ConversableAgent
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
class UserProxyAgent(ConversableAgent):
"""(In preview) A proxy agent for the user, that can execute code and provide feedback to the other agents."""
NAME = "User"
DEFAULT_DESCRIBE = (
"A human admin. Interact with the planner to discuss the plan. Plan execution needs to be approved by this admin.",
)
def __init__(
self,
memory: GptsMemory,
agent_context: AgentContext,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "ALWAYS",
default_auto_reply: Optional[Union[str, Dict, None]] = "",
):
super().__init__(
name=self.NAME,
memory=memory,
describe=self.DEFAULT_DESCRIBE,
system_message=self.DEFAULT_DESCRIBE,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
agent_context=agent_context,
)
self.register_reply(Agent, UserProxyAgent.check_termination_and_human_reply)
def get_human_input(self, prompt: str) -> str:
"""Get human input.
Override this method to customize the way to get human input.
Args:
prompt (str): prompt for the human input.
Returns:
str: human input.
"""
reply = input(prompt)
return reply
async def a_reasoning_reply(
self, messages: Optional[List[Dict]] = None
) -> Union[str, Dict, None]:
if messages is None or len(messages) <= 0:
message = None
return None, None
else:
message = messages[-1]
self.plan_chat.messages.append(message)
return message["content"], None
async def a_receive(
self,
message: Optional[Dict],
sender: Agent,
reviewer: Agent,
request_reply: Optional[bool] = True,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
self.consecutive_auto_reply_counter = sender.consecutive_auto_reply_counter + 1
self._process_received_message(message, sender, silent)
async def check_termination_and_human_reply(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
reviewer: Agent = None,
config: Optional[Union[Dict, Literal[False]]] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Check if the conversation should be terminated, and if human reply is provided."""
return True, None

View File

@@ -6,13 +6,15 @@ class PluginStorageType(Enum):
Oss = "oss"
class Status(Enum):
TODO = "todo"
RUNNING = "running"
FAILED = "failed"
COMPLETED = "completed"
class ApiTagType(Enum):
API_VIEW = "dbgpt_view"
API_CALL = "dbgpt_call"
class Status(Enum):
TODO = "todo"
RUNNING = "running"
WAITING = "waiting"
RETRYING = "retrying"
FAILED = "failed"
COMPLETE = "complete"

View File

@@ -1,159 +0,0 @@
import logging
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from abc import ABC
from typing import List
from dbgpt.app.openapi.api_view_model import (
Result,
)
from .model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
)
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .plugins_util import scan_plugins
from .commands.generator import PluginPromptGenerator
from dbgpt.configs.model_config import PLUGINS_DIR
from dbgpt.component import BaseComponent, ComponentType, SystemApp
router = APIRouter()
logger = logging.getLogger(__name__)
class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
branch = (
update_param.branch
if update_param.branch is not None and len(update_param.branch) > 0
else "main"
)
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
# TODO change it to async
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
return Result.succ(agent_dicts)
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -1,137 +0,0 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(
Integer, nullable=True, default=0, comment="plugin total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
UniqueConstraint("user_code", "name", name="uk_name")
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
user_name=engity.user_name,
name=engity.name,
type=engity.type,
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
session.commit()
id = my_plugin.id
session.close()
return id
def raw_update(self, entity: MyPluginEntity):
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
result = my_plugins.all()
session.close()
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
result = my_plugins.first()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
my_plugins.delete()
session.commit()
session.close()

View File

@@ -1,139 +0,0 @@
from datetime import datetime
import pytz
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
email = Column(String(255), nullable=True, comment="plugin author email")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
storage_channel = Column(String(255), comment="plugin storage channel")
storage_url = Column(String(255), comment="plugin download url")
download_param = Column(String(255), comment="plugin download param")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin upload time"
)
installed = Column(Integer, default=False, comment="plugin already installed count")
UniqueConstraint("name", name="uk_name")
Index("idx_q_type", "type")
class PluginHubDao(BaseDao):
def add(self, engity: PluginHubEntity):
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
email=engity.email,
type=engity.type,
version=engity.version,
storage_channel=engity.storage_channel,
storage_url=engity.storage_url,
gmt_created=timezone.localize(datetime.now()),
)
session.add(plugin_hub)
session.commit()
id = plugin_hub.id
session.close()
return id
def raw_update(self, entity: PluginHubEntity):
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result
def count(self, query: PluginHubEntity):
session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
count = plugin_hubs.scalar()
session.close()
return count
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
plugin_hubs.delete()
session.commit()
session.close()

View File

@@ -1,207 +0,0 @@
import json
import logging
import os
import glob
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from ..common.schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger(__name__)
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_plugin_dao = MyPluginDao()
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
branch_name = None
authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = user_name
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
else:
my_plugin_entity.user_code = Default_User
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
if plugin_entity is not None:
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = (
plugin_entity.storage_url.replace(".git", "")
.strip("/")
.split("/")[-1]
)
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
else:
files = glob.glob(
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
my_plugin_entity.type = hub_plugin.type
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path)
try:
for git_plugin in git_plugins:
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
if old_hub_info:
plugin_hub_info = old_hub_info
else:
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {}
if branch_name:
download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0:
download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.raw_update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <= 0:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
user, my_plugin._name
)
if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.raw_update(my_plugin_entiy)
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
if not user:
user = Default_User
return self.my_plugin_dao.get_by_user(user)

235
dbgpt/agent/memory/base.py Normal file
View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, fields
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt.agent.common.schema import Status
@dataclass
class GptsPlan:
"""Gpts plan"""
conv_id: str
sub_task_num: int
sub_task_content: Optional[str]
sub_task_title: Optional[str] = None
sub_task_agent: Optional[str] = None
resource_name: Optional[str] = None
rely: Optional[str] = None
agent_model: Optional[str] = None
retry_times: Optional[int] = 0
max_retry_times: Optional[int] = 5
state: Optional[str] = Status.TODO.value
result: Optional[str] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> GptsPlan:
return GptsPlan(
conv_id=d.get("conv_id"),
sub_task_num=d["sub_task_num"],
sub_task_content=d["sub_task_content"],
sub_task_agent=d["sub_task_agent"],
resource_name=d["resource_name"],
rely=d["rely"],
agent_model=d["agent_model"],
retry_times=d["retry_times"],
max_retry_times=d["max_retry_times"],
state=d["state"],
result=d["result"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclass
class GptsMessage:
"""Gpts plan"""
conv_id: str
sender: str
receiver: str
role: str
content: str
rounds: Optional[int]
current_gogal: str = None
context: Optional[str] = None
review_info: Optional[str] = None
action_report: Optional[str] = None
model_name: Optional[str] = None
created_at: datetime = datetime.utcnow
updated_at: datetime = datetime.utcnow
@staticmethod
def from_dict(d: Dict[str, Any]) -> GptsMessage:
return GptsMessage(
conv_id=d["conv_id"],
sender=d["sender"],
receiver=d["receiver"],
role=d["role"],
content=d["content"],
rounds=d["rounds"],
model_name=d["model_name"],
current_gogal=d["current_gogal"],
context=d["context"],
review_info=d["review_info"],
action_report=d["action_report"],
created_at=d["created_at"],
updated_at=d["updated_at"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
class GptsPlansMemory(ABC):
def batch_save(self, plans: list[GptsPlan]):
"""
batch save gpts plan
Args:
plans: panner generate plans info
Returns:
None
"""
pass
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
"""
get plans by conv_id
Args:
conv_id: conversation id
Returns:
List of planning steps
"""
def get_by_conv_id_and_num(
self, conv_id: str, task_nums: List[int]
) -> List[GptsPlan]:
"""
get
Args:
conv_id: conversation id
task_nums: List of sequence numbers of plans in the same conversation
Returns:
List of planning steps
"""
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
"""
Get unfinished planning steps
Args:
conv_id: conversation id
Returns:
List of planning steps
"""
def complete_task(self, conv_id: str, task_num: int, result: str):
"""
Complete designated planning step
Args:
conv_id: conversation id
task_num: Planning step num
result: Plan step results
Returns:
None
"""
def update_task(
self,
conv_id: str,
task_num: int,
state: str,
retry_times: int,
agent: str = None,
model: str = None,
result: str = None,
):
"""
Update planning step information
Args:
conv_id: conversation id
task_num: Planning step num
state: the status to update to
retry_times: Latest number of retries
agent: Agent's name
Returns:
"""
def remove_by_conv_id(self, conv_id: str):
"""
Delete planning
Args:
conv_id:
Returns:
"""
class GptsMessageMemory(ABC):
def append(self, message: GptsMessage):
"""
Add a message
Args:
message:
Returns:
"""
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
"""
Query information related to an agent
Args:
agent:agent's name
Returns:
messages
"""
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_gogal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
"""
Query information related to an agent
Args:
agent:agent's name
Returns:
messages
"""
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
"""
Query messages by conv id
Args:
conv_id:
Returns:
"""
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
"""
Query last message
Args:
conv_id:
Returns:
"""

View File

@@ -0,0 +1,119 @@
from dataclasses import fields
from typing import List, Optional
import pandas as pd
from dbgpt.agent.common.schema import Status
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
class DefaultGptsPlansMemory(GptsPlansMemory):
def __init__(self):
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsPlan)])
def batch_save(self, plans: list[GptsPlan]):
new_rows = pd.DataFrame([item.to_dict() for item in plans])
self.df = pd.concat([self.df, new_rows], ignore_index=True)
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
result = self.df.query(f"conv_id==@conv_id")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def get_by_conv_id_and_num(
self, conv_id: str, task_nums: List[int]
) -> List[GptsPlan]:
result = self.df.query(f"conv_id==@conv_id and sub_task_num in @task_nums")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
todo_states = [Status.TODO.value, Status.RETRYING.value]
result = self.df.query(f"conv_id==@conv_id and state in @todo_states")
plans = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
plans.append(GptsPlan.from_dict(row_dict))
return plans
def complete_task(self, conv_id: str, task_num: int, result: str):
condition = (self.df["conv_id"] == conv_id) & (
self.df["sub_task_num"] == task_num
)
self.df.loc[condition, "state"] = Status.COMPLETE.value
self.df.loc[condition, "result"] = result
def update_task(
self,
conv_id: str,
task_num: int,
state: str,
retry_times: int,
agent: str = None,
model=None,
result: str = None,
):
condition = (self.df["conv_id"] == conv_id) & (
self.df["sub_task_num"] == task_num
)
self.df.loc[condition, "state"] = state
self.df.loc[condition, "retry_times"] = retry_times
self.df.loc[condition, "result"] = result
if agent:
self.df.loc[condition, "sub_task_agent"] = agent
if model:
self.df.loc[condition, "agent_model"] = model
def remove_by_conv_id(self, conv_id: str):
self.df.drop(self.df[self.df["conv_id"] == conv_id].index, inplace=True)
class DefaultGptsMessageMemory(GptsMessageMemory):
def __init__(self):
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsMessage)])
def append(self, message: GptsMessage):
self.df.loc[len(self.df)] = message.to_dict()
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
result = self.df.query(
f"conv_id==@conv_id and (sender==@agent or receiver==@agent)"
)
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_gogal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
result = self.df.query(
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1)) and current_gogal==@current_gogal"
)
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
result = self.df.query(f"conv_id==@conv_id")
messages = []
for row in result.itertuples(index=False, name=None):
row_dict = dict(zip(self.df.columns, row))
messages.append(GptsMessage.from_dict(row_dict))
return messages

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
import json
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dbgpt.util.json_utils import EnhancedJSONEncoder
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
from .default_gpts_memory import DefaultGptsMessageMemory, DefaultGptsPlansMemory
class GptsMemory:
def __init__(
self,
plans_memory: Optional[GptsPlansMemory] = None,
message_memory: Optional[GptsMessageMemory] = None,
):
self._plans_memory: GptsPlansMemory = (
plans_memory if plans_memory is not None else DefaultGptsPlansMemory()
)
self._message_memory: GptsMessageMemory = (
message_memory if message_memory is not None else DefaultGptsMessageMemory()
)
@property
def plans_memory(self):
return self._plans_memory
@property
def message_memory(self):
return self._message_memory
async def one_plan_chat_competions(self, conv_id: str):
plans = self.plans_memory.get_by_conv_id(conv_id=conv_id)
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
messages_group = defaultdict(list)
for item in messages:
messages_group[item.current_gogal].append(item)
plans_info_map = defaultdict()
for plan in plans:
plans_info_map[plan.sub_task_content] = {
"name": plan.sub_task_content,
"num": plan.sub_task_num,
"status": plan.state,
"agent": plan.sub_task_agent,
"markdown": self._messages_to_agents_vis(
messages_group.get(plan.sub_task_content)
),
}
normal_messages = []
if messages_group:
for key, value in messages_group.items():
if key not in plans_info_map:
normal_messages.extend(value)
return f"{self._messages_to_agents_vis(normal_messages)}\n{self._messages_to_plan_vis(list(plans_info_map.values()))}"
@staticmethod
def _messages_to_agents_vis(messages: List[GptsMessage]):
if messages is None or len(messages) <= 0:
return ""
messages_view = []
for message in messages:
action_report_str = message.action_report
view_info = message.content
if action_report_str and len(action_report_str) > 0:
action_report = json.loads(action_report_str)
if action_report:
view = action_report.get("view", None)
view_info = view if view else action_report.get("content", "")
messages_view.append(
{
"sender": message.sender,
"receiver": message.receiver,
"model": message.model_name,
"markdown": view_info,
}
)
messages_content = json.dumps(
messages_view, ensure_ascii=False, cls=EnhancedJSONEncoder
)
return f"```agent-messages\n{messages_content}\n```"
@staticmethod
def _messages_to_plan_vis(messages: List[Dict]):
if messages is None or len(messages) <= 0:
return ""
messages_content = json.dumps(
messages, ensure_ascii=False, cls=EnhancedJSONEncoder
)
return f"```agent-plans\n{messages_content}\n```"

View File

@@ -0,0 +1,443 @@
import dataclasses
from dataclasses import asdict, dataclass, fields
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from dbgpt.agent.common.schema import Status
from datetime import datetime
from .base import GptsMessageMemory, GptsPlansMemory, GptsPlan, GptsMessage
@dataclass
class GptsPlanIdentifier(ResourceIdentifier):
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
conv_id: str
sub_task_num: Optional[str]
def __post_init__(self):
if self.conv_id is None or self.sub_task_num is None:
raise ValueError("conv_id and sub_task_num cannot be None")
if any(
self.identifier_split in key
for key in [
self.conv_id,
self.sub_task_num,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in conv_id, sub_task_num"
)
@property
def str_identifier(self) -> str:
return self.identifier_split.join(
key
for key in [
self.conv_id,
self.sub_task_num,
]
if key is not None
)
def to_dict(self) -> Dict:
return {
"conv_id": self.conv_id,
"sub_task_num": self.sub_task_num,
}
@dataclass
class GptsPlanStorage(StorageItem):
"""Gpts plan"""
conv_id: str
sub_task_num: int
sub_task_content: Optional[str]
sub_task_title: Optional[str] = None
sub_task_agent: Optional[str] = None
resource_name: Optional[str] = None
rely: Optional[str] = None
agent_model: Optional[str] = None
retry_times: Optional[int] = 0
max_retry_times: Optional[int] = 5
state: Optional[str] = Status.TODO.value
result: Optional[str] = None
_identifier: GptsPlanIdentifier = dataclasses.field(init=False)
@staticmethod
def from_dict(d: Dict[str, Any]):
return GptsPlanStorage(
conv_id=d.get("conv_id"),
sub_task_num=d["sub_task_num"],
sub_task_content=d["sub_task_content"],
sub_task_agent=d["sub_task_agent"],
resource_name=d["resource_name"],
rely=d["rely"],
agent_model=d["agent_model"],
retry_times=d["retry_times"],
max_retry_times=d["max_retry_times"],
state=d["state"],
result=d["result"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def _check(self):
if self.conv_id is None:
raise ValueError("conv_id cannot be None")
if self.sub_task_num is None:
raise ValueError("sub_task_num cannot be None")
if self.sub_task_content is None:
raise ValueError("sub_task_content cannot be None")
if self.state is None:
raise ValueError("state cannot be None")
@property
def identifier(self) -> GptsPlanIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, GptsPlanStorage):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
)
self.from_object(other)
@dataclass
class GptsMessageIdentifier(ResourceIdentifier):
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
conv_id: str
sender: Optional[str]
receiver: Optional[str]
rounds: Optional[int]
def __post_init__(self):
if (
self.conv_id is None
or self.sender is None
or self.receiver is None
or self.rounds is None
):
raise ValueError("conv_id and sub_task_num cannot be None")
if any(
self.identifier_split in key
for key in [
self.conv_id,
self.sender,
self.receiver,
self.rounds,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in conv_id, sender, receiver, rounds"
)
@property
def str_identifier(self) -> str:
return self.identifier_split.join(
key
for key in [
self.conv_id,
self.sender,
self.receiver,
self.rounds,
]
if key is not None
)
def to_dict(self) -> Dict:
return {
"conv_id": self.conv_id,
"sender": self.sender,
"receiver": self.receiver,
"rounds": self.rounds,
}
@dataclass
class GptsMessageStorage(StorageItem):
"""Gpts Message"""
conv_id: str
sender: str
receiver: str
role: str
content: str
rounds: Optional[int]
current_gogal: str = None
context: Optional[str] = None
review_info: Optional[str] = None
action_report: Optional[str] = None
model_name: Optional[str] = None
created_at: datetime = datetime.utcnow
updated_at: datetime = datetime.utcnow
_identifier: GptsMessageIdentifier = dataclasses.field(init=False)
@staticmethod
def from_dict(d: Dict[str, Any]):
return GptsMessageStorage(
conv_id=d["conv_id"],
sender=d["sender"],
receiver=d["receiver"],
role=d["role"],
content=d["content"],
rounds=d["rounds"],
model_name=d["model_name"],
current_gogal=d["current_gogal"],
context=d["context"],
review_info=d["review_info"],
action_report=d["action_report"],
created_at=d["created_at"],
updated_at=d["updated_at"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def _check(self):
if self.conv_id is None:
raise ValueError("conv_id cannot be None")
if self.sub_task_num is None:
raise ValueError("sub_task_num cannot be None")
if self.sub_task_content is None:
raise ValueError("sub_task_content cannot be None")
if self.state is None:
raise ValueError("state cannot be None")
def to_gpts_message(self) -> GptsMessage:
"""Convert the storage to a GptsMessage."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
)
return GptsMessage(
conv_id=self.conv_id,
sender=self.sender,
receiver=self.receiver,
role=self.role,
content=self.content,
rounds=self.rounds,
current_gogal=self.current_gogal,
context=self.context,
review_info=self.review_info,
action_report=self.action_report,
model_name=self.model_name,
created_at=self.created_at,
updated_at=self.updated_at,
)
@staticmethod
def from_gpts_message(gpts_message: GptsMessage) -> "StoragePromptTemplate":
"""Convert a GptsMessage to a storage e."""
return GptsMessageStorage(
conv_id=gpts_message.conv_id,
sender=gpts_message.sender,
receiver=gpts_message.receiver,
role=gpts_message.role,
content=gpts_message.content,
rounds=gpts_message.rounds,
current_gogal=gpts_message.current_gogal,
context=gpts_message.context,
review_info=gpts_message.review_info,
action_report=gpts_message.action_report,
model_name=gpts_message.model_name,
created_at=gpts_message.created_at,
updated_at=gpts_message.updated_at,
)
@property
def identifier(self) -> GptsMessageIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, GptsMessageStorage):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
)
self.from_object(other)
class GptsMessageManager(GptsMessageMemory):
"""The manager class for GptsMessage.
Simple wrapper for the storage interface.
"""
def __init__(self, storage: Optional[StorageInterface[GptsMessage, Any]] = None):
if storage is None:
storage = InMemoryStorage()
self._storage = storage
@property
def storage(self) -> StorageInterface[GptsMessage, Any]:
"""The storage interface for prompt templates."""
return self._storage
def append(self, message: GptsMessage):
self.storage.save(GptsMessageStorage.from_gpts_message(message))
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
query_spec = QuerySpec(
conditions={
"conv_id": conv_id,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[GptsMessageStorage] = self.storage.query(
query_spec, GptsMessageStorage
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_gogal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
return super().get_between_agents(conv_id, agent1, agent2, current_gogal)
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
return super().get_by_conv_id(conv_id)
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
return super().get_last_message(conv_id)
def prefer_query(
self,
prompt_name: str,
sys_code: Optional[str] = None,
prefer_prompt_language: Optional[str] = None,
prefer_model: Optional[str] = None,
**kwargs,
) -> List[GptsMessage]:
"""Query prompt templates from storage with prefer params.
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
Examples:
Query a prompt template.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query("hello")
Query with sys_code and username.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query(
"hello", sys_code="sys_code", user_name="user_name"
)
Query with prefer prompt language.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_prompt_language="zh-cn"
)
Query with prefer model.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_model="vicuna-13b-v1.5"
)
Args:
prompt_name (str): The name of the prompt template.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
"""
query_spec = QuerySpec(
conditions={
"prompt_name": prompt_name,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[StoragePromptTemplate] = self.storage.query(
query_spec, StoragePromptTemplate
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries

View File

@@ -1,69 +0,0 @@
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from typing import TypeVar, Generic, Any
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
class PagenationFilter(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
filter: T = None
class PagenationResult(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
data_dicts = []
for item in self.datas:
data_dicts.append(item.__dict__)
return {
"page_index": self.page_index,
"page_size": self.page_size,
"total_page": self.total_page,
"total_row_count": self.total_row_count,
"datas": data_dicts,
}
@dataclass
class PluginHubFilter(BaseModel):
name: str
description: str
author: str
email: str
type: str
version: str
storage_channel: str
storage_url: str
@dataclass
class MyPluginFilter(BaseModel):
tenant: str
user_code: str
user_name: str
name: str
file_name: str
type: str
version: str
class PluginHubParam(BaseModel):
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field(
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
description="Plugin storage url",
)
branch: Optional[str] = Field(
"main", description="github download branch", nullable=True
)
authorization: Optional[str] = Field(
None, description="github download authorization", nullable=True
)

View File

View File

@@ -3,8 +3,8 @@ import json
import requests
from dbgpt.agent.commands.command_mange import command
from dbgpt._private.config import Config
from ..command_mange import command
CFG = Config()

View File

@@ -0,0 +1 @@
from .show_chart_gen import static_message_img_path

View File

@@ -1,20 +1,22 @@
import os
import uuid
import matplotlib
import pandas as pd
import seaborn as sns
from pandas import DataFrame
from dbgpt.agent.commands.command_mange import command
import pandas as pd
import uuid
import os
import matplotlib
import seaborn as sns
from ...command_mange import command
matplotlib.use("Agg")
import logging
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from dbgpt.util.string_utils import is_scientific_notation
from dbgpt.configs.model_config import PILOT_PATH
import logging
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.util.string_utils import is_scientific_notation
logger = logging.getLogger(__name__)

View File

@@ -1,8 +1,8 @@
import logging
from pandas import DataFrame
from dbgpt.agent.commands.command_mange import command
import logging
from ...command_mange import command
logger = logging.getLogger(__name__)

View File

@@ -1,8 +1,8 @@
import logging
from pandas import DataFrame
from dbgpt.agent.commands.command_mange import command
import logging
from ...command_mange import command
logger = logging.getLogger(__name__)

View File

@@ -1,14 +1,14 @@
""" Image Generation Module for AutoGPT."""
import io
import logging
import uuid
from base64 import b64decode
import logging
import requests
from PIL import Image
from dbgpt.agent.commands.command_mange import command
from dbgpt._private.config import Config
from ..command_mange import command
logger = logging.getLogger(__name__)
CFG = Config()

View File

@@ -4,11 +4,11 @@
import json
from typing import Dict
from .exception_not_commands import NotCommands
from .generator import PluginPromptGenerator
from dbgpt._private.config import Config
from .exception_not_commands import NotCommands
from dbgpt.agent.plugin.generator import PluginPromptGenerator
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:

View File

@@ -4,17 +4,19 @@ import inspect
import json
import logging
import xml.etree.ElementTree as ET
from dbgpt.util.json_utils import serialize
from datetime import datetime
from typing import Any, Callable, Optional, List
from typing import Any, Callable, List, Optional
from dbgpt._private.pydantic import BaseModel
from .command import execute_command
from dbgpt.agent.common.schema import Status
from dbgpt.agent.commands.command import execute_command
from dbgpt.util.string_utils import extract_content_open_ending, extract_content
from dbgpt.util.json_utils import serialize
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
logger = logging.getLogger(__name__)
class Command:
@@ -404,7 +406,7 @@ class ApiCall:
value.api_result = execute_command(
value.name, value.args, self.plugin_generator
)
value.status = Status.COMPLETED.value
value.status = Status.COMPLETE.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
@@ -436,7 +438,7 @@ class ApiCall:
"response_table", **param
)
value.status = Status.COMPLETED.value
value.status = Status.COMPLETE.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
@@ -474,12 +476,13 @@ class ApiCall:
date_unit="s",
)
)
value.status = Status.COMPLETED.value
value.status = Status.COMPLETE.value
else:
value.status = Status.FAILED.value
value.err_msg = "No executable sql"
except Exception as e:
logging.error(f"data prepare exception{str(e)}")
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
@@ -488,3 +491,125 @@ class ApiCall:
raise ValueError("Api parsing exception," + str(e))
return self.api_view_context(llm_text, True)
def display_only_sql_vis(self, chart: dict, sql_2_df_func):
err_msg = None
try:
sql = chart.get("sql", None)
param = {}
df = sql_2_df_func(sql)
if not sql or len(sql) <= 0:
return None
param["sql"] = sql
param["type"] = chart.get("display_type", "response_table")
param["title"] = chart.get("title", "")
param["describe"] = chart.get("thought", "")
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {}
err_param["sql"] = f"{sql}"
err_param["type"] = "response_table"
# err_param["err_msg"] = str(e)
err_param["data"] = []
err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
# api_call_element.text = view_json_str
result = f"```vis-chart\n{view_json_str}\n```"
if err_msg:
return f"""<span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
else:
return result
def display_dashboard_vis(
self, charts: List[dict], sql_2_df_func, title: str = None
):
err_msg = None
view_json_str = None
chart_items = []
try:
if not charts or len(charts) <= 0:
return f"""Have no chart data!"""
for chart in charts:
param = {}
sql = chart.get("sql", "")
param["sql"] = sql
param["type"] = chart.get("display_type", "response_table")
param["title"] = chart.get("title", "")
param["describe"] = chart.get("thought", "")
try:
df = sql_2_df_func(sql)
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
except Exception as e:
param["data"] = []
param["err_msg"] = str(e)
chart_items.append(
f"```vis-chart-item\n{json.dumps(param, default=serialize, ensure_ascii=False)}\n```"
)
dashboard_param = {
"markdown": "\n".join(chart_items),
"chart_count": len(chart_items),
"title": title,
}
view_json_str = json.dumps(
dashboard_param, default=serialize, ensure_ascii=False
)
except Exception as e:
logger.error("parse_view_response error!" + str(e))
return f"```error\nReport rendering exception{str(e)}\n```"
result = f"```vis-dashboard\n{view_json_str}\n```"
if err_msg:
return (
f"""\\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
)
else:
return result
@staticmethod
def default_chart_type_promot() -> str:
"""this function is moved from excel_analyze/chat.py,and used by subclass.
Returns:
"""
antv_charts = [
{"response_line_chart": "used to display comparative trend analysis data"},
{
"response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
},
{
"response_table": "suitable for display with many display columns or non-numeric columns"
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
},
{
"response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
},
{
"response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
},
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
]
return "\n".join(
f"{key}:{value}"
for dict_item in antv_charts
for key, value in dict_item.items()
)

View File

@@ -0,0 +1,19 @@
import logging
from .generator import PluginPromptGenerator
from typing import List
logger = logging.getLogger(__name__)
class PluginLoader:
def load_plugins(
self, generator: PluginPromptGenerator, my_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{my_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in my_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator

View File

@@ -1,17 +1,17 @@
"""加载组件"""
import json
import os
import glob
import zipfile
import git
import threading
import datetime
import glob
import json
import logging
import os
import threading
import zipfile
from pathlib import Path
from typing import List
from zipimport import zipimporter
import git
import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate