feat(agent): Multi agent sdk (#976)

Co-authored-by: xtyuns <xtyuns@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: qidanrui <qidanrui@gmail.com>
This commit is contained in:
明天
2023-12-27 16:25:55 +08:00
committed by GitHub
parent 69fb97e508
commit 9aec636b02
79 changed files with 6359 additions and 121 deletions

View File

View File

@@ -0,0 +1,338 @@
import logging
import json
import asyncio
import uuid
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from fastapi.responses import StreamingResponse
from abc import ABC
from typing import List
from dbgpt.core.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
from dbgpt.app.openapi.api_view_model import Result, ConversationVo
from dbgpt.util.json_utils import EnhancedJSONEncoder
from dbgpt.serve.agent.model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
)
from dbgpt.agent.common.schema import Status
from dbgpt.agent.agents.agents_mange import AgentsMange
from dbgpt.agent.agents.planner_agent import PlannerAgent
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
from dbgpt.agent.agents.plan_group_chat import PlanChat, PlanChatManager
from dbgpt.agent.agents.agent import AgentContext
from dbgpt.agent.memory.gpts_memory import GptsMemory
from .db_gpts_memory import MetaDbGptsPlansMemory, MetaDbGptsMessageMemory
from ..db.gpts_mange_db import GptsInstanceDao, GptsInstanceEntity
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
from .dbgpts import DbGptsCompletion, DbGptsTaskStep, DbGptsMessage, DbGptsInstance
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.agent.agents.agents_mange import agent_mange
from dbgpt._private.config import Config
from dbgpt.model.cluster.controller.controller import BaseModelController
from dbgpt.agent.memory.gpts_memory import GptsMessage
from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
CFG = Config()
import asyncio
router = APIRouter()
logger = logging.getLogger(__name__)
class MultiAgents(BaseComponent, ABC):
name = ComponentType.MULTI_AGENTS
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Multi-Agents"])
def __init__(self):
self.gpts_intance = GptsInstanceDao()
self.gpts_conversations = GptsConversationsDao()
self.memory = GptsMemory(
plans_memory=MetaDbGptsPlansMemory(),
message_memory=MetaDbGptsMessageMemory(),
)
def gpts_create(self, entity: GptsInstanceEntity):
self.gpts_intance.add(entity)
async def plan_chat(
self,
name: str,
user_query: str,
conv_id: str,
user_code: str = None,
sys_code: str = None,
):
gpts_instance: GptsInstanceEntity = self.gpts_intance.get_by_name(name)
if gpts_instance is None:
raise ValueError(f"can't find dbgpts!{name}")
agents_names = json.loads(gpts_instance.gpts_agents)
llm_models_priority = json.loads(gpts_instance.gpts_models)
resource_db = (
json.loads(gpts_instance.resource_db) if gpts_instance.resource_db else None
)
resource_knowledge = (
json.loads(gpts_instance.resource_knowledge)
if gpts_instance.resource_knowledge
else None
)
resource_internet = (
json.loads(gpts_instance.resource_internet)
if gpts_instance.resource_internet
else None
)
### init chat param
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_task = DefaultLLMClient(worker_manager)
context: AgentContext = AgentContext(conv_id=conv_id, llm_provider=llm_task)
context.gpts_name = gpts_instance.gpts_name
context.resource_db = resource_db
context.resource_internet = resource_internet
context.resource_knowledge = resource_knowledge
context.agents = agents_names
context.llm_models = await llm_task.models()
context.model_priority = llm_models_priority
agent_map = defaultdict()
### default plan excute mode
agents = []
for name in agents_names:
cls = agent_mange.get_by_name(name)
agent = cls(
agent_context=context,
memory=self.memory,
)
agents.append(agent)
agent_map[name] = agent
groupchat = PlanChat(agents=agents, messages=[], max_round=50)
planner = PlannerAgent(
agent_context=context,
memory=self.memory,
plan_chat=groupchat,
)
agent_map[planner.name] = planner
manager = PlanChatManager(
agent_context=context,
memory=self.memory,
plan_chat=groupchat,
planner=planner,
)
agent_map[manager.name] = manager
user_proxy = UserProxyAgent(memory=self.memory, agent_context=context)
agent_map[user_proxy.name] = user_proxy
gpts_conversation = self.gpts_conversations.get_by_conv_id(conv_id)
if gpts_conversation is None:
self.gpts_conversations.add(
GptsConversationsEntity(
conv_id=conv_id,
user_goal=user_query,
gpts_name=gpts_instance.gpts_name,
state=Status.RUNNING.value,
max_auto_reply_round=context.max_chat_round,
auto_reply_count=0,
user_code=user_code,
sys_code=sys_code,
)
)
## dbgpts conversation save
try:
await user_proxy.a_initiate_chat(
recipient=manager,
message=user_query,
memory=self.memory,
)
except Exception as e:
logger.error(f"chat abnormal termination{str(e)}", e)
self.gpts_conversations.update(conv_id, Status.FAILED.value)
else:
# retry chat
self.gpts_conversations.update(conv_id, Status.RUNNING.value)
try:
await user_proxy.a_retry_chat(
recipient=manager,
agent_map=agent_map,
memory=self.memory,
)
except Exception as e:
logger.error(f"chat abnormal termination{str(e)}", e)
self.gpts_conversations.update(conv_id, Status.FAILED.value)
self.gpts_conversations.update(conv_id, Status.COMPLETE.value)
return conv_id
async def chat_completions(
self, conv_id: str, user_code: str = None, system_app: str = None
):
is_complete = False
while True:
gpts_conv = self.gpts_conversations.get_by_conv_id(conv_id)
if gpts_conv:
is_complete = (
True
if gpts_conv.state
in [
Status.COMPLETE.value,
Status.WAITING.value,
Status.FAILED.value,
]
else False
)
yield await self.memory.one_plan_chat_competions(conv_id)
if is_complete:
return
else:
await asyncio.sleep(5)
async def stable_message(
self, conv_id: str, user_code: str = None, system_app: str = None
):
gpts_conv = self.gpts_conversations.get_by_conv_id(conv_id)
if gpts_conv:
is_complete = (
True
if gpts_conv.state
in [Status.COMPLETE.value, Status.WAITING.value, Status.FAILED.value]
else False
)
if is_complete:
return await self.self.memory.one_plan_chat_competions(conv_id)
else:
raise ValueError(
"The conversation has not been completed yet, so we cannot directly obtain information."
)
else:
raise ValueError("No conversation record found!")
def gpts_conv_list(self, user_code: str = None, system_app: str = None):
return self.gpts_conversations.get_convs(user_code, system_app)
multi_agents = MultiAgents()
@router.post("/v1/dbbgpts/agents/list", response_model=Result[str])
async def agents_list():
logger.info("agents_list!")
try:
agents = agent_mange.all_agents()
return Result.succ(agents)
except Exception as e:
return Result.failed(code="E30001", msg=str(e))
@router.post("/v1/dbbgpts/create", response_model=Result[str])
async def create_dbgpts(gpts_instance: DbGptsInstance = Body()):
logger.info(f"create_dbgpts:{gpts_instance}")
try:
multi_agents.gpts_create(
GptsInstanceEntity(
gpts_name=gpts_instance.gpts_name,
gpts_describe=gpts_instance.gpts_describe,
resource_db=json.dumps(gpts_instance.resource_db.to_dict()),
resource_internet=json.dumps(gpts_instance.resource_internet.to_dict()),
resource_knowledge=json.dumps(
gpts_instance.resource_knowledge.to_dict()
),
gpts_agents=json.dumps(gpts_instance.gpts_agents),
gpts_models=json.dumps(gpts_instance.gpts_models),
language=gpts_instance.language,
user_code=gpts_instance.user_code,
sys_code=gpts_instance.sys_code,
)
)
return Result.succ(None)
except Exception as e:
logger.error(f"create_dbgpts failed:{str(e)}")
return Result.failed(msg=str(e), code="E300002")
async def stream_generator(conv_id: str):
async for chunk in multi_agents.chat_completions(conv_id):
if chunk:
yield f"data: {chunk}\n\n"
@router.post("/v1/dbbgpts/chat/plan/completions", response_model=Result[str])
async def dgpts_completions(
gpts_name: str,
user_query: str,
conv_id: str = None,
user_code: str = None,
sys_code: str = None,
):
logger.info(f"dgpts_completions:{gpts_name},{user_query},{conv_id}")
if conv_id is None:
conv_id = str(uuid.uuid1())
asyncio.create_task(
multi_agents.plan_chat(gpts_name, user_query, conv_id, user_code, sys_code)
)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
stream_generator(conv_id),
headers=headers,
media_type="text/plain",
)
@router.post("/v1/dbbgpts/plan/chat/cancel", response_model=Result[str])
async def dgpts_plan_chat_cancel(
conv_id: str = None, user_code: str = None, sys_code: str = None
):
pass
@router.get("/v1/dbbgpts/chat/plan/messages", response_model=Result[str])
async def plan_chat_messages(conv_id: str, user_code: str = None, sys_code: str = None):
logger.info(f"plan_chat_messages:{conv_id},{user_code},{sys_code}")
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
stream_generator(conv_id),
headers=headers,
media_type="text/plain",
)
@router.post("/v1/dbbgpts/chat/feedback", response_model=Result[str])
async def dgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
pass

View File

@@ -0,0 +1,121 @@
from typing import List, Optional
from dbgpt.agent.memory.gpts_memory import (
GptsPlansMemory,
GptsPlan,
GptsMessageMemory,
GptsMessage,
)
from ..db.gpts_plans_db import GptsPlansEntity, GptsPlansDao
from ..db.gpts_messages_db import GptsMessagesDao, GptsMessagesEntity
class MetaDbGptsPlansMemory(GptsPlansMemory):
def __init__(self):
self.gpts_plan = GptsPlansDao()
def batch_save(self, plans: list[GptsPlan]):
self.gpts_plan.batch_save([item.to_dict() for item in plans])
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id(
conv_id=conv_id
)
results = []
for item in db_results:
results.append(GptsPlan.from_dict(item.__dict__))
return results
def get_by_conv_id_and_num(
self, conv_id: str, task_nums: List[int]
) -> List[GptsPlan]:
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id_and_num(
conv_id=conv_id, task_nums=task_nums
)
results = []
for item in db_results:
results.append(GptsPlan.from_dict(item.__dict__))
return results
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
db_results: List[GptsPlansEntity] = self.gpts_plan.get_todo_plans(
conv_id=conv_id
)
results = []
for item in db_results:
results.append(GptsPlan.from_dict(item.__dict__))
return results
def complete_task(self, conv_id: str, task_num: int, result: str):
self.gpts_plan.complete_task(conv_id=conv_id, task_num=task_num, result=result)
def update_task(
self,
conv_id: str,
task_num: int,
state: str,
retry_times: int,
agent: str = None,
model: str = None,
result: str = None,
):
self.gpts_plan.update_task(
conv_id=conv_id,
task_num=task_num,
state=state,
retry_times=retry_times,
agent=agent,
model=model,
result=result,
)
def remove_by_conv_id(self, conv_id: str):
self.gpts_plan.remove_by_conv_id(conv_id=conv_id)
class MetaDbGptsMessageMemory(GptsMessageMemory):
def __init__(self):
self.gpts_message = GptsMessagesDao()
def append(self, message: GptsMessage):
self.gpts_message.append(message.to_dict())
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
db_results = self.gpts_message.get_by_agent(conv_id, agent)
results = []
db_results = sorted(db_results, key=lambda x: x.rounds)
for item in db_results:
results.append(GptsMessage.from_dict(item.__dict__))
return results
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_gogal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
db_results = self.gpts_message.get_between_agents(
conv_id, agent1, agent2, current_gogal
)
results = []
db_results = sorted(db_results, key=lambda x: x.rounds)
for item in db_results:
results.append(GptsMessage.from_dict(item.__dict__))
return results
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
db_results = self.gpts_message.get_by_conv_id(conv_id)
results = []
db_results = sorted(db_results, key=lambda x: x.rounds)
for item in db_results:
results.append(GptsMessage.from_dict(item.__dict__))
return results
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
db_result = self.gpts_message.get_last_message(conv_id)
if db_result:
return GptsMessage.from_dict(db_result.__dict__)
else:
return None

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from dataclasses import dataclass, asdict, fields
import dataclasses
from dbgpt.agent.agents.agent import AgentResource
class AgentMode(Enum):
PLAN_EXCUTE = "plan_excute"
@dataclass
class DbGptsInstance:
gpts_name: str
gpts_describe: str
gpts_agents: list[str]
resource_db: Optional[AgentResource] = None
resource_internet: Optional[AgentResource] = None
resource_knowledge: Optional[AgentResource] = None
gpts_models: Optional[Dict[str, List[str]]] = None
language: str = "en"
user_code: str = None
sys_code: str = None
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclass
class DbGptsMessage:
sender: str
receiver: str
content: str
action_report: str
@staticmethod
def from_dict(d: Dict[str, Any]) -> DbGptsMessage:
return DbGptsMessage(
sender=d["sender"],
receiver=d["receiver"],
content=d["content"],
model_name=d["model_name"],
agent_name=d["agent_name"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclass
class DbGptsTaskStep:
task_num: str
task_content: str
state: str
result: str
agent_name: str
model_name: str
@staticmethod
def from_dict(d: Dict[str, Any]) -> DbGptsTaskStep:
return DbGptsTaskStep(
task_num=d["task_num"],
task_content=d["task_content"],
state=d["state"],
result=d["result"],
agent_name=d["agent_name"],
model_name=d["model_name"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclass
class DbGptsCompletion:
conv_id: str
task_steps: Optional[List[DbGptsTaskStep]]
messages: Optional[List[DbGptsMessage]]
@staticmethod
def from_dict(d: Dict[str, Any]) -> DbGptsCompletion:
return DbGptsCompletion(
conv_id=d.get("conv_id"),
task_steps=DbGptsTaskStep.from_dict(d["task_steps"]),
messages=DbGptsMessage.from_dict(d["messages"]),
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)

View File

@@ -0,0 +1,6 @@
from .gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
from .gpts_mange_db import GptsInstanceDao, GptsInstanceEntity
from .gpts_messages_db import GptsMessagesDao, GptsMessagesEntity
from .gpts_plans_db import GptsPlansDao, GptsPlansEntity
from .my_plugin_db import MyPluginDao, MyPluginEntity
from .plugin_hub_db import PluginHubDao, PluginHubEntity

View File

@@ -0,0 +1,95 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, desc
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
class GptsConversationsEntity(Model):
__tablename__ = "gpts_conversations"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
conv_id = Column(
String(255), nullable=False, comment="The unique id of the conversation record"
)
user_goal = Column(Text, nullable=False, comment="User's goals content")
gpts_name = Column(String(255), nullable=False, comment="The gpts name")
state = Column(String(255), nullable=True, comment="The gpts state")
max_auto_reply_round = Column(
Integer, nullable=False, comment="max auto reply round"
)
auto_reply_count = Column(Integer, nullable=False, comment="auto reply count")
user_code = Column(String(255), nullable=True, comment="user code")
sys_code = Column(String(255), nullable=True, comment="system app ")
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
updated_at = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
comment="last update time",
)
__table_args__ = (
UniqueConstraint("conv_id", name="uk_gpts_conversations"),
Index("idx_gpts_name", "gpts_name"),
)
class GptsConversationsDao(BaseDao):
def add(self, engity: GptsConversationsEntity):
session = self.get_raw_session()
session.add(engity)
session.commit()
id = engity.id
session.close()
return id
def get_by_conv_id(self, conv_id: str):
session = self.get_raw_session()
gpts_conv = session.query(GptsConversationsEntity)
if conv_id:
gpts_conv = gpts_conv.filter(GptsConversationsEntity.conv_id == conv_id)
result = gpts_conv.first()
session.close()
return result
def get_convs(self, user_code: str = None, system_app: str = None):
session = self.get_raw_session()
gpts_conversations = session.query(GptsConversationsEntity)
if user_code:
gpts_conversations = gpts_conversations.filter(
GptsConversationsEntity.user_code == user_code
)
if system_app:
gpts_conversations = gpts_conversations.filter(
GptsConversationsEntity.system_app == system_app
)
result = (
gpts_conversations.limit(20)
.order_by(desc(GptsConversationsEntity.id))
.all()
)
session.close()
return result
def update(self, conv_id: str, state: str):
session = self.get_raw_session()
gpts_convs = session.query(GptsConversationsEntity)
gpts_convs = gpts_convs.filter(GptsConversationsEntity.conv_id == conv_id)
gpts_convs.update(
{GptsConversationsEntity.state: state}, synchronize_session="fetch"
)
session.commit()
session.close()

View File

@@ -0,0 +1,78 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, Boolean
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
class GptsInstanceEntity(Model):
__tablename__ = "gpts_instance"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
gpts_name = Column(String(255), nullable=False, comment="Current AI assistant name")
gpts_describe = Column(
String(2255), nullable=False, comment="Current AI assistant describe"
)
resource_db = Column(
Text,
nullable=True,
comment="List of structured database names contained in the current gpts",
)
resource_internet = Column(
Text,
nullable=True,
comment="Is it possible to retrieve information from the internet",
)
resource_knowledge = Column(
Text,
nullable=True,
comment="List of unstructured database names contained in the current gpts",
)
gpts_agents = Column(
String(1000),
nullable=True,
comment="List of agents names contained in the current gpts",
)
gpts_models = Column(
String(1000),
nullable=True,
comment="List of llm model names contained in the current gpts",
)
language = Column(String(100), nullable=True, comment="gpts language")
user_code = Column(String(255), nullable=False, comment="user code")
sys_code = Column(String(255), nullable=True, comment="system app code")
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
updated_at = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
comment="last update time",
)
__table_args__ = (UniqueConstraint("gpts_name", name="uk_gpts"),)
class GptsInstanceDao(BaseDao):
def add(self, engity: GptsInstanceEntity):
session = self.get_raw_session()
session.add(engity)
session.commit()
id = engity.id
session.close()
return id
def get_by_name(self, name: str) -> GptsInstanceEntity:
session = self.get_raw_session()
gpts_instance = session.query(GptsInstanceEntity)
if name:
gpts_instance = gpts_instance.filter(GptsInstanceEntity.gpts_name == name)
result = gpts_instance.first()
session.close()
return result

View File

@@ -0,0 +1,160 @@
from datetime import datetime
from typing import List, Optional
from sqlalchemy import (
Column,
Integer,
String,
Index,
DateTime,
func,
Text,
or_,
and_,
desc,
)
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
class GptsMessagesEntity(Model):
__tablename__ = "gpts_messages"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
conv_id = Column(
String(255), nullable=False, comment="The unique id of the conversation record"
)
sender = Column(
String(255),
nullable=False,
comment="Who speaking in the current conversation turn",
)
receiver = Column(
String(255),
nullable=False,
comment="Who receive message in the current conversation turn",
)
model_name = Column(String(255), nullable=True, comment="message generate model")
rounds = Column(Integer, nullable=False, comment="dialogue turns")
content = Column(Text, nullable=True, comment="Content of the speech")
current_gogal = Column(
Text, nullable=True, comment="The target corresponding to the current message"
)
context = Column(Text, nullable=True, comment="Current conversation context")
review_info = Column(
Text, nullable=True, comment="Current conversation review info"
)
action_report = Column(
Text, nullable=True, comment="Current conversation action report"
)
role = Column(
String(255), nullable=True, comment="The role of the current message content"
)
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
updated_at = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
comment="last update time",
)
__table_args__ = (Index("idx_q_messages", "conv_id", "rounds", "sender"),)
class GptsMessagesDao(BaseDao):
def append(self, entity: dict):
session = self.get_raw_session()
message = GptsMessagesEntity(
conv_id=entity.get("conv_id"),
sender=entity.get("sender"),
receiver=entity.get("receiver"),
content=entity.get("content"),
role=entity.get("role", None),
model_name=entity.get("model_name", None),
context=entity.get("context", None),
rounds=entity.get("rounds", None),
current_gogal=entity.get("current_gogal", None),
review_info=entity.get("review_info", None),
action_report=entity.get("action_report", None),
)
session.add(message)
session.commit()
id = message.id
session.close()
return id
def get_by_agent(
self, conv_id: str, agent: str
) -> Optional[List[GptsMessagesEntity]]:
session = self.get_raw_session()
gpts_messages = session.query(GptsMessagesEntity)
if agent:
gpts_messages = gpts_messages.filter(
GptsMessagesEntity.conv_id == conv_id
).filter(
or_(
GptsMessagesEntity.sender == agent,
GptsMessagesEntity.receiver == agent,
)
)
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
session.close()
return result
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]:
session = self.get_raw_session()
gpts_messages = session.query(GptsMessagesEntity)
if conv_id:
gpts_messages = gpts_messages.filter(GptsMessagesEntity.conv_id == conv_id)
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
session.close()
return result
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_gogal: Optional[str] = None,
) -> Optional[List[GptsMessagesEntity]]:
session = self.get_raw_session()
gpts_messages = session.query(GptsMessagesEntity)
if agent1 and agent2:
gpts_messages = gpts_messages.filter(
GptsMessagesEntity.conv_id == conv_id
).filter(
or_(
and_(
GptsMessagesEntity.sender == agent1,
GptsMessagesEntity.receiver == agent2,
),
and_(
GptsMessagesEntity.sender == agent2,
GptsMessagesEntity.receiver == agent1,
),
)
)
if current_gogal:
gpts_messages = gpts_messages.filter(
GptsMessagesEntity.current_gogal == current_gogal
)
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
session.close()
return result
def get_last_message(self, conv_id: str) -> Optional[GptsMessagesEntity]:
session = self.get_raw_session()
gpts_messages = session.query(GptsMessagesEntity)
if conv_id:
gpts_messages = gpts_messages.filter(
GptsMessagesEntity.conv_id == conv_id
).order_by(desc(GptsMessagesEntity.rounds))
result = gpts_messages.first()
session.close()
return result

View File

@@ -0,0 +1,156 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.agent.common.schema import Status
class GptsPlansEntity(Model):
__tablename__ = "gpts_plans"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True, comment="autoincrement id")
conv_id = Column(
String(255), nullable=False, comment="The unique id of the conversation record"
)
sub_task_num = Column(Integer, nullable=False, comment="Subtask number")
sub_task_title = Column(String(255), nullable=False, comment="subtask title")
sub_task_content = Column(Text, nullable=False, comment="subtask content")
sub_task_agent = Column(
String(255), nullable=True, comment="Available agents corresponding to subtasks"
)
resource_name = Column(String(255), nullable=True, comment="resource name")
rely = Column(
String(255), nullable=True, comment="Subtask dependencieslike: 1,2,3"
)
agent_model = Column(
String(255),
nullable=True,
comment="LLM model used by subtask processing agents",
)
retry_times = Column(Integer, default=False, comment="number of retries")
max_retry_times = Column(
Integer, default=False, comment="Maximum number of retries"
)
state = Column(String(255), nullable=True, comment="subtask status")
result = Column(Text(length=2**31 - 1), nullable=True, comment="subtask result")
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
updated_at = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
comment="last update time",
)
__table_args__ = (UniqueConstraint("conv_id", "sub_task_num", name="uk_sub_task"),)
class GptsPlansDao(BaseDao):
def batch_save(self, plans: list[dict]):
session = self.get_raw_session()
session.bulk_insert_mappings(GptsPlansEntity, plans)
session.commit()
session.close()
def get_by_conv_id(self, conv_id: str) -> list[GptsPlansEntity]:
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
if conv_id:
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id)
result = gpts_plans.all()
session.close()
return result
def get_by_task_id(self, task_id: int) -> list[GptsPlansEntity]:
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
if task_id:
gpts_plans = gpts_plans.filter(GptsPlansEntity.id == task_id)
result = gpts_plans.first()
session.close()
return result
def get_by_conv_id_and_num(
self, conv_id: str, task_nums: list
) -> list[GptsPlansEntity]:
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
if conv_id:
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
GptsPlansEntity.sub_task_num.in_(task_nums)
)
result = gpts_plans.all()
session.close()
return result
def get_todo_plans(self, conv_id: str) -> list[GptsPlansEntity]:
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
if not conv_id:
return []
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
GptsPlansEntity.state.in_([Status.TODO.value, Status.RETRYING.value])
)
result = gpts_plans.order_by(GptsPlansEntity.sub_task_num).all()
session.close()
return result
def complete_task(self, conv_id: str, task_num: int, result: str):
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
GptsPlansEntity.sub_task_num == task_num
)
gpts_plans.update(
{
GptsPlansEntity.state: Status.COMPLETE.value,
GptsPlansEntity.result: result,
},
synchronize_session="fetch",
)
session.commit()
session.close()
def update_task(
self,
conv_id: str,
task_num: int,
state: str,
retry_times: int,
agent: str = None,
model: str = None,
result: str = None,
):
session = self.get_raw_session()
gpts_plans = session.query(GptsPlansEntity)
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
GptsPlansEntity.sub_task_num == task_num
)
update_param = {}
update_param[GptsPlansEntity.state] = state
update_param[GptsPlansEntity.retry_times] = retry_times
update_param[GptsPlansEntity.result] = result
if agent:
update_param[GptsPlansEntity.sub_task_agent] = agent
if model:
update_param[GptsPlansEntity.agent_model] = model
gpts_plans.update(update_param, synchronize_session="fetch")
session.commit()
session.close()
def remove_by_conv_id(self, conv_id: str):
session = self.get_raw_session()
if conv_id is None:
raise Exception("conv_id is None")
gpts_plans = session.query(GptsPlansEntity)
gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).delete()
session.commit()
session.close()

View File

@@ -0,0 +1,137 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(
Integer, nullable=True, default=0, comment="plugin total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
UniqueConstraint("user_code", "name", name="uk_name")
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
user_name=engity.user_name,
name=engity.name,
type=engity.type,
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
session.commit()
id = my_plugin.id
session.close()
return id
def raw_update(self, entity: MyPluginEntity):
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
result = my_plugins.all()
session.close()
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
result = my_plugins.first()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
my_plugins.delete()
session.commit()
session.close()

View File

@@ -0,0 +1,139 @@
from datetime import datetime
import pytz
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model
# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
email = Column(String(255), nullable=True, comment="plugin author email")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
storage_channel = Column(String(255), comment="plugin storage channel")
storage_url = Column(String(255), comment="plugin download url")
download_param = Column(String(255), comment="plugin download param")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin upload time"
)
installed = Column(Integer, default=False, comment="plugin already installed count")
UniqueConstraint("name", name="uk_name")
Index("idx_q_type", "type")
class PluginHubDao(BaseDao):
def add(self, engity: PluginHubEntity):
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
email=engity.email,
type=engity.type,
version=engity.version,
storage_channel=engity.storage_channel,
storage_url=engity.storage_url,
gmt_created=timezone.localize(datetime.now()),
)
session.add(plugin_hub)
session.commit()
id = plugin_hub.id
session.close()
return id
def raw_update(self, entity: PluginHubEntity):
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result
def count(self, query: PluginHubEntity):
session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
count = plugin_hubs.scalar()
session.close()
return count
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
plugin_hubs.delete()
session.commit()
session.close()

View File

View File

View File

@@ -0,0 +1,208 @@
import json
import logging
import os
import glob
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from dbgpt.agent.common.schema import PluginStorageType
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
logger = logging.getLogger(__name__)
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_plugin_dao = MyPluginDao()
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
branch_name = None
authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = user_name
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
else:
my_plugin_entity.user_code = Default_User
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
if plugin_entity is not None:
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = (
plugin_entity.storage_url.replace(".git", "")
.strip("/")
.split("/")[-1]
)
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
else:
files = glob.glob(
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
my_plugin_entity.type = hub_plugin.type
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path)
try:
for git_plugin in git_plugins:
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
if old_hub_info:
plugin_hub_info = old_hub_info
else:
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {}
if branch_name:
download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0:
download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.raw_update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <= 0:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
user, my_plugin._name
)
if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.raw_update(my_plugin_entiy)
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
if not user:
user = Default_User
return self.my_plugin_dao.get_by_user(user)

View File

@@ -0,0 +1,160 @@
import logging
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from abc import ABC
from typing import List
from dbgpt.app.openapi.api_view_model import (
Result,
)
from dbgpt.serve.agent.model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
)
from dbgpt.serve.agent.hub.agent_hub import AgentHub
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.agent.plugin.plugins_util import scan_plugins
from dbgpt.agent.plugin.generator import PluginPromptGenerator
from dbgpt.configs.model_config import PLUGINS_DIR
from dbgpt.component import BaseComponent, ComponentType, SystemApp
router = APIRouter()
logger = logging.getLogger(__name__)
class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
branch = (
update_param.branch
if update_param.branch is not None and len(update_param.branch) > 0
else "main"
)
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
# TODO change it to async
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
return Result.succ(agent_dicts)
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -0,0 +1,69 @@
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from typing import TypeVar, Generic, Any
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
class PagenationFilter(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
filter: T = None
class PagenationResult(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
data_dicts = []
for item in self.datas:
data_dicts.append(item.__dict__)
return {
"page_index": self.page_index,
"page_size": self.page_size,
"total_page": self.total_page,
"total_row_count": self.total_row_count,
"datas": data_dicts,
}
@dataclass
class PluginHubFilter(BaseModel):
name: str
description: str
author: str
email: str
type: str
version: str
storage_channel: str
storage_url: str
@dataclass
class MyPluginFilter(BaseModel):
tenant: str
user_code: str
user_name: str
name: str
file_name: str
type: str
version: str
class PluginHubParam(BaseModel):
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field(
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
description="Plugin storage url",
)
branch: Optional[str] = Field(
"main", description="github download branch", nullable=True
)
authorization: Optional[str] = Field(
None, description="github download authorization", nullable=True
)