mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat(agent): Multi agent sdk (#976)
Co-authored-by: xtyuns <xtyuns@163.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: qidanrui <qidanrui@gmail.com>
This commit is contained in:
0
dbgpt/serve/agent/agents/__init__.py
Normal file
0
dbgpt/serve/agent/agents/__init__.py
Normal file
338
dbgpt/serve/agent/agents/controller.py
Normal file
338
dbgpt/serve/agent/agents/controller.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from dbgpt.core.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
|
||||
from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
|
||||
from dbgpt.app.openapi.api_view_model import Result, ConversationVo
|
||||
from dbgpt.util.json_utils import EnhancedJSONEncoder
|
||||
from dbgpt.serve.agent.model import (
|
||||
PluginHubParam,
|
||||
PagenationFilter,
|
||||
PagenationResult,
|
||||
PluginHubFilter,
|
||||
)
|
||||
|
||||
from dbgpt.agent.common.schema import Status
|
||||
from dbgpt.agent.agents.agents_mange import AgentsMange
|
||||
from dbgpt.agent.agents.planner_agent import PlannerAgent
|
||||
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
|
||||
from dbgpt.agent.agents.plan_group_chat import PlanChat, PlanChatManager
|
||||
from dbgpt.agent.agents.agent import AgentContext
|
||||
from dbgpt.agent.memory.gpts_memory import GptsMemory
|
||||
|
||||
from .db_gpts_memory import MetaDbGptsPlansMemory, MetaDbGptsMessageMemory
|
||||
|
||||
from ..db.gpts_mange_db import GptsInstanceDao, GptsInstanceEntity
|
||||
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
|
||||
|
||||
from .dbgpts import DbGptsCompletion, DbGptsTaskStep, DbGptsMessage, DbGptsInstance
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.agent.agents.agents_mange import agent_mange
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.model.cluster.controller.controller import BaseModelController
|
||||
from dbgpt.agent.memory.gpts_memory import GptsMessage
|
||||
from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory
|
||||
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
|
||||
CFG = Config()
|
||||
|
||||
import asyncio
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiAgents(BaseComponent, ABC):
|
||||
name = ComponentType.MULTI_AGENTS
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Multi-Agents"])
|
||||
|
||||
def __init__(self):
|
||||
self.gpts_intance = GptsInstanceDao()
|
||||
self.gpts_conversations = GptsConversationsDao()
|
||||
self.memory = GptsMemory(
|
||||
plans_memory=MetaDbGptsPlansMemory(),
|
||||
message_memory=MetaDbGptsMessageMemory(),
|
||||
)
|
||||
|
||||
def gpts_create(self, entity: GptsInstanceEntity):
|
||||
self.gpts_intance.add(entity)
|
||||
|
||||
async def plan_chat(
|
||||
self,
|
||||
name: str,
|
||||
user_query: str,
|
||||
conv_id: str,
|
||||
user_code: str = None,
|
||||
sys_code: str = None,
|
||||
):
|
||||
gpts_instance: GptsInstanceEntity = self.gpts_intance.get_by_name(name)
|
||||
if gpts_instance is None:
|
||||
raise ValueError(f"can't find dbgpts!{name}")
|
||||
agents_names = json.loads(gpts_instance.gpts_agents)
|
||||
llm_models_priority = json.loads(gpts_instance.gpts_models)
|
||||
resource_db = (
|
||||
json.loads(gpts_instance.resource_db) if gpts_instance.resource_db else None
|
||||
)
|
||||
resource_knowledge = (
|
||||
json.loads(gpts_instance.resource_knowledge)
|
||||
if gpts_instance.resource_knowledge
|
||||
else None
|
||||
)
|
||||
resource_internet = (
|
||||
json.loads(gpts_instance.resource_internet)
|
||||
if gpts_instance.resource_internet
|
||||
else None
|
||||
)
|
||||
|
||||
### init chat param
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_task = DefaultLLMClient(worker_manager)
|
||||
context: AgentContext = AgentContext(conv_id=conv_id, llm_provider=llm_task)
|
||||
context.gpts_name = gpts_instance.gpts_name
|
||||
context.resource_db = resource_db
|
||||
context.resource_internet = resource_internet
|
||||
context.resource_knowledge = resource_knowledge
|
||||
context.agents = agents_names
|
||||
|
||||
context.llm_models = await llm_task.models()
|
||||
context.model_priority = llm_models_priority
|
||||
|
||||
agent_map = defaultdict()
|
||||
|
||||
### default plan excute mode
|
||||
agents = []
|
||||
for name in agents_names:
|
||||
cls = agent_mange.get_by_name(name)
|
||||
agent = cls(
|
||||
agent_context=context,
|
||||
memory=self.memory,
|
||||
)
|
||||
agents.append(agent)
|
||||
agent_map[name] = agent
|
||||
|
||||
groupchat = PlanChat(agents=agents, messages=[], max_round=50)
|
||||
planner = PlannerAgent(
|
||||
agent_context=context,
|
||||
memory=self.memory,
|
||||
plan_chat=groupchat,
|
||||
)
|
||||
agent_map[planner.name] = planner
|
||||
|
||||
manager = PlanChatManager(
|
||||
agent_context=context,
|
||||
memory=self.memory,
|
||||
plan_chat=groupchat,
|
||||
planner=planner,
|
||||
)
|
||||
agent_map[manager.name] = manager
|
||||
|
||||
user_proxy = UserProxyAgent(memory=self.memory, agent_context=context)
|
||||
agent_map[user_proxy.name] = user_proxy
|
||||
|
||||
gpts_conversation = self.gpts_conversations.get_by_conv_id(conv_id)
|
||||
if gpts_conversation is None:
|
||||
self.gpts_conversations.add(
|
||||
GptsConversationsEntity(
|
||||
conv_id=conv_id,
|
||||
user_goal=user_query,
|
||||
gpts_name=gpts_instance.gpts_name,
|
||||
state=Status.RUNNING.value,
|
||||
max_auto_reply_round=context.max_chat_round,
|
||||
auto_reply_count=0,
|
||||
user_code=user_code,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
)
|
||||
|
||||
## dbgpts conversation save
|
||||
try:
|
||||
await user_proxy.a_initiate_chat(
|
||||
recipient=manager,
|
||||
message=user_query,
|
||||
memory=self.memory,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"chat abnormal termination!{str(e)}", e)
|
||||
self.gpts_conversations.update(conv_id, Status.FAILED.value)
|
||||
|
||||
else:
|
||||
# retry chat
|
||||
self.gpts_conversations.update(conv_id, Status.RUNNING.value)
|
||||
try:
|
||||
await user_proxy.a_retry_chat(
|
||||
recipient=manager,
|
||||
agent_map=agent_map,
|
||||
memory=self.memory,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"chat abnormal termination!{str(e)}", e)
|
||||
self.gpts_conversations.update(conv_id, Status.FAILED.value)
|
||||
|
||||
self.gpts_conversations.update(conv_id, Status.COMPLETE.value)
|
||||
return conv_id
|
||||
|
||||
async def chat_completions(
|
||||
self, conv_id: str, user_code: str = None, system_app: str = None
|
||||
):
|
||||
is_complete = False
|
||||
while True:
|
||||
gpts_conv = self.gpts_conversations.get_by_conv_id(conv_id)
|
||||
if gpts_conv:
|
||||
is_complete = (
|
||||
True
|
||||
if gpts_conv.state
|
||||
in [
|
||||
Status.COMPLETE.value,
|
||||
Status.WAITING.value,
|
||||
Status.FAILED.value,
|
||||
]
|
||||
else False
|
||||
)
|
||||
yield await self.memory.one_plan_chat_competions(conv_id)
|
||||
if is_complete:
|
||||
return
|
||||
else:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stable_message(
|
||||
self, conv_id: str, user_code: str = None, system_app: str = None
|
||||
):
|
||||
gpts_conv = self.gpts_conversations.get_by_conv_id(conv_id)
|
||||
if gpts_conv:
|
||||
is_complete = (
|
||||
True
|
||||
if gpts_conv.state
|
||||
in [Status.COMPLETE.value, Status.WAITING.value, Status.FAILED.value]
|
||||
else False
|
||||
)
|
||||
if is_complete:
|
||||
return await self.self.memory.one_plan_chat_competions(conv_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The conversation has not been completed yet, so we cannot directly obtain information."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No conversation record found!")
|
||||
|
||||
def gpts_conv_list(self, user_code: str = None, system_app: str = None):
|
||||
return self.gpts_conversations.get_convs(user_code, system_app)
|
||||
|
||||
|
||||
multi_agents = MultiAgents()
|
||||
|
||||
|
||||
@router.post("/v1/dbbgpts/agents/list", response_model=Result[str])
|
||||
async def agents_list():
|
||||
logger.info("agents_list!")
|
||||
try:
|
||||
agents = agent_mange.all_agents()
|
||||
return Result.succ(agents)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E30001", msg=str(e))
|
||||
|
||||
|
||||
@router.post("/v1/dbbgpts/create", response_model=Result[str])
|
||||
async def create_dbgpts(gpts_instance: DbGptsInstance = Body()):
|
||||
logger.info(f"create_dbgpts:{gpts_instance}")
|
||||
try:
|
||||
multi_agents.gpts_create(
|
||||
GptsInstanceEntity(
|
||||
gpts_name=gpts_instance.gpts_name,
|
||||
gpts_describe=gpts_instance.gpts_describe,
|
||||
resource_db=json.dumps(gpts_instance.resource_db.to_dict()),
|
||||
resource_internet=json.dumps(gpts_instance.resource_internet.to_dict()),
|
||||
resource_knowledge=json.dumps(
|
||||
gpts_instance.resource_knowledge.to_dict()
|
||||
),
|
||||
gpts_agents=json.dumps(gpts_instance.gpts_agents),
|
||||
gpts_models=json.dumps(gpts_instance.gpts_models),
|
||||
language=gpts_instance.language,
|
||||
user_code=gpts_instance.user_code,
|
||||
sys_code=gpts_instance.sys_code,
|
||||
)
|
||||
)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error(f"create_dbgpts failed:{str(e)}")
|
||||
return Result.failed(msg=str(e), code="E300002")
|
||||
|
||||
|
||||
async def stream_generator(conv_id: str):
|
||||
async for chunk in multi_agents.chat_completions(conv_id):
|
||||
if chunk:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
|
||||
@router.post("/v1/dbbgpts/chat/plan/completions", response_model=Result[str])
|
||||
async def dgpts_completions(
|
||||
gpts_name: str,
|
||||
user_query: str,
|
||||
conv_id: str = None,
|
||||
user_code: str = None,
|
||||
sys_code: str = None,
|
||||
):
|
||||
logger.info(f"dgpts_completions:{gpts_name},{user_query},{conv_id}")
|
||||
if conv_id is None:
|
||||
conv_id = str(uuid.uuid1())
|
||||
asyncio.create_task(
|
||||
multi_agents.plan_chat(gpts_name, user_query, conv_id, user_code, sys_code)
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
stream_generator(conv_id),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/dbbgpts/plan/chat/cancel", response_model=Result[str])
|
||||
async def dgpts_plan_chat_cancel(
|
||||
conv_id: str = None, user_code: str = None, sys_code: str = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/v1/dbbgpts/chat/plan/messages", response_model=Result[str])
|
||||
async def plan_chat_messages(conv_id: str, user_code: str = None, sys_code: str = None):
|
||||
logger.info(f"plan_chat_messages:{conv_id},{user_code},{sys_code}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
stream_generator(conv_id),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/dbbgpts/chat/feedback", response_model=Result[str])
|
||||
async def dgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
pass
|
121
dbgpt/serve/agent/agents/db_gpts_memory.py
Normal file
121
dbgpt/serve/agent/agents/db_gpts_memory.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import List, Optional
|
||||
from dbgpt.agent.memory.gpts_memory import (
|
||||
GptsPlansMemory,
|
||||
GptsPlan,
|
||||
GptsMessageMemory,
|
||||
GptsMessage,
|
||||
)
|
||||
|
||||
from ..db.gpts_plans_db import GptsPlansEntity, GptsPlansDao
|
||||
from ..db.gpts_messages_db import GptsMessagesDao, GptsMessagesEntity
|
||||
|
||||
|
||||
class MetaDbGptsPlansMemory(GptsPlansMemory):
|
||||
def __init__(self):
|
||||
self.gpts_plan = GptsPlansDao()
|
||||
|
||||
def batch_save(self, plans: list[GptsPlan]):
|
||||
self.gpts_plan.batch_save([item.to_dict() for item in plans])
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
||||
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id(
|
||||
conv_id=conv_id
|
||||
)
|
||||
results = []
|
||||
for item in db_results:
|
||||
results.append(GptsPlan.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: List[int]
|
||||
) -> List[GptsPlan]:
|
||||
db_results: List[GptsPlansEntity] = self.gpts_plan.get_by_conv_id_and_num(
|
||||
conv_id=conv_id, task_nums=task_nums
|
||||
)
|
||||
results = []
|
||||
for item in db_results:
|
||||
results.append(GptsPlan.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
||||
db_results: List[GptsPlansEntity] = self.gpts_plan.get_todo_plans(
|
||||
conv_id=conv_id
|
||||
)
|
||||
results = []
|
||||
for item in db_results:
|
||||
results.append(GptsPlan.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str):
|
||||
self.gpts_plan.complete_task(conv_id=conv_id, task_num=task_num, result=result)
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: str = None,
|
||||
model: str = None,
|
||||
result: str = None,
|
||||
):
|
||||
self.gpts_plan.update_task(
|
||||
conv_id=conv_id,
|
||||
task_num=task_num,
|
||||
state=state,
|
||||
retry_times=retry_times,
|
||||
agent=agent,
|
||||
model=model,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def remove_by_conv_id(self, conv_id: str):
|
||||
self.gpts_plan.remove_by_conv_id(conv_id=conv_id)
|
||||
|
||||
|
||||
class MetaDbGptsMessageMemory(GptsMessageMemory):
|
||||
def __init__(self):
|
||||
self.gpts_message = GptsMessagesDao()
|
||||
|
||||
def append(self, message: GptsMessage):
|
||||
self.gpts_message.append(message.to_dict())
|
||||
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
db_results = self.gpts_message.get_by_agent(conv_id, agent)
|
||||
results = []
|
||||
db_results = sorted(db_results, key=lambda x: x.rounds)
|
||||
for item in db_results:
|
||||
results.append(GptsMessage.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_gogal: Optional[str] = None,
|
||||
) -> Optional[List[GptsMessage]]:
|
||||
db_results = self.gpts_message.get_between_agents(
|
||||
conv_id, agent1, agent2, current_gogal
|
||||
)
|
||||
results = []
|
||||
db_results = sorted(db_results, key=lambda x: x.rounds)
|
||||
for item in db_results:
|
||||
results.append(GptsMessage.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
|
||||
db_results = self.gpts_message.get_by_conv_id(conv_id)
|
||||
|
||||
results = []
|
||||
db_results = sorted(db_results, key=lambda x: x.rounds)
|
||||
for item in db_results:
|
||||
results.append(GptsMessage.from_dict(item.__dict__))
|
||||
return results
|
||||
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
||||
db_result = self.gpts_message.get_last_message(conv_id)
|
||||
if db_result:
|
||||
return GptsMessage.from_dict(db_result.__dict__)
|
||||
else:
|
||||
return None
|
93
dbgpt/serve/agent/agents/dbgpts.py
Normal file
93
dbgpt/serve/agent/agents/dbgpts.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from dataclasses import dataclass, asdict, fields
|
||||
import dataclasses
|
||||
|
||||
from dbgpt.agent.agents.agent import AgentResource
|
||||
|
||||
|
||||
class AgentMode(Enum):
|
||||
PLAN_EXCUTE = "plan_excute"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbGptsInstance:
|
||||
gpts_name: str
|
||||
gpts_describe: str
|
||||
gpts_agents: list[str]
|
||||
resource_db: Optional[AgentResource] = None
|
||||
resource_internet: Optional[AgentResource] = None
|
||||
resource_knowledge: Optional[AgentResource] = None
|
||||
gpts_models: Optional[Dict[str, List[str]]] = None
|
||||
language: str = "en"
|
||||
user_code: str = None
|
||||
sys_code: str = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbGptsMessage:
|
||||
sender: str
|
||||
receiver: str
|
||||
content: str
|
||||
action_report: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> DbGptsMessage:
|
||||
return DbGptsMessage(
|
||||
sender=d["sender"],
|
||||
receiver=d["receiver"],
|
||||
content=d["content"],
|
||||
model_name=d["model_name"],
|
||||
agent_name=d["agent_name"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbGptsTaskStep:
|
||||
task_num: str
|
||||
task_content: str
|
||||
state: str
|
||||
result: str
|
||||
agent_name: str
|
||||
model_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> DbGptsTaskStep:
|
||||
return DbGptsTaskStep(
|
||||
task_num=d["task_num"],
|
||||
task_content=d["task_content"],
|
||||
state=d["state"],
|
||||
result=d["result"],
|
||||
agent_name=d["agent_name"],
|
||||
model_name=d["model_name"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbGptsCompletion:
|
||||
conv_id: str
|
||||
task_steps: Optional[List[DbGptsTaskStep]]
|
||||
messages: Optional[List[DbGptsMessage]]
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> DbGptsCompletion:
|
||||
return DbGptsCompletion(
|
||||
conv_id=d.get("conv_id"),
|
||||
task_steps=DbGptsTaskStep.from_dict(d["task_steps"]),
|
||||
messages=DbGptsMessage.from_dict(d["messages"]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
6
dbgpt/serve/agent/db/__init__.py
Normal file
6
dbgpt/serve/agent/db/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
|
||||
from .gpts_mange_db import GptsInstanceDao, GptsInstanceEntity
|
||||
from .gpts_messages_db import GptsMessagesDao, GptsMessagesEntity
|
||||
from .gpts_plans_db import GptsPlansDao, GptsPlansEntity
|
||||
from .my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
from .plugin_hub_db import PluginHubDao, PluginHubEntity
|
95
dbgpt/serve/agent/db/gpts_conversations_db.py
Normal file
95
dbgpt/serve/agent/db/gpts_conversations_db.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, desc
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsConversationsEntity(Model):
|
||||
__tablename__ = "gpts_conversations"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
user_goal = Column(Text, nullable=False, comment="User's goals content")
|
||||
|
||||
gpts_name = Column(String(255), nullable=False, comment="The gpts name")
|
||||
state = Column(String(255), nullable=True, comment="The gpts state")
|
||||
|
||||
max_auto_reply_round = Column(
|
||||
Integer, nullable=False, comment="max auto reply round"
|
||||
)
|
||||
auto_reply_count = Column(Integer, nullable=False, comment="auto reply count")
|
||||
|
||||
user_code = Column(String(255), nullable=True, comment="user code")
|
||||
sys_code = Column(String(255), nullable=True, comment="system app ")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("conv_id", name="uk_gpts_conversations"),
|
||||
Index("idx_gpts_name", "gpts_name"),
|
||||
)
|
||||
|
||||
|
||||
class GptsConversationsDao(BaseDao):
|
||||
def add(self, engity: GptsConversationsEntity):
|
||||
session = self.get_raw_session()
|
||||
session.add(engity)
|
||||
session.commit()
|
||||
id = engity.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_conv_id(self, conv_id: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_conv = session.query(GptsConversationsEntity)
|
||||
if conv_id:
|
||||
gpts_conv = gpts_conv.filter(GptsConversationsEntity.conv_id == conv_id)
|
||||
result = gpts_conv.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_convs(self, user_code: str = None, system_app: str = None):
|
||||
session = self.get_raw_session()
|
||||
gpts_conversations = session.query(GptsConversationsEntity)
|
||||
if user_code:
|
||||
gpts_conversations = gpts_conversations.filter(
|
||||
GptsConversationsEntity.user_code == user_code
|
||||
)
|
||||
if system_app:
|
||||
gpts_conversations = gpts_conversations.filter(
|
||||
GptsConversationsEntity.system_app == system_app
|
||||
)
|
||||
|
||||
result = (
|
||||
gpts_conversations.limit(20)
|
||||
.order_by(desc(GptsConversationsEntity.id))
|
||||
.all()
|
||||
)
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update(self, conv_id: str, state: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_convs = session.query(GptsConversationsEntity)
|
||||
gpts_convs = gpts_convs.filter(GptsConversationsEntity.conv_id == conv_id)
|
||||
gpts_convs.update(
|
||||
{GptsConversationsEntity.state: state}, synchronize_session="fetch"
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
78
dbgpt/serve/agent/db/gpts_mange_db.py
Normal file
78
dbgpt/serve/agent/db/gpts_mange_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, Boolean
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsInstanceEntity(Model):
|
||||
__tablename__ = "gpts_instance"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
gpts_name = Column(String(255), nullable=False, comment="Current AI assistant name")
|
||||
gpts_describe = Column(
|
||||
String(2255), nullable=False, comment="Current AI assistant describe"
|
||||
)
|
||||
resource_db = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="List of structured database names contained in the current gpts",
|
||||
)
|
||||
resource_internet = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="Is it possible to retrieve information from the internet",
|
||||
)
|
||||
resource_knowledge = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="List of unstructured database names contained in the current gpts",
|
||||
)
|
||||
gpts_agents = Column(
|
||||
String(1000),
|
||||
nullable=True,
|
||||
comment="List of agents names contained in the current gpts",
|
||||
)
|
||||
gpts_models = Column(
|
||||
String(1000),
|
||||
nullable=True,
|
||||
comment="List of llm model names contained in the current gpts",
|
||||
)
|
||||
language = Column(String(100), nullable=True, comment="gpts language")
|
||||
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
sys_code = Column(String(255), nullable=True, comment="system app code")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
|
||||
__table_args__ = (UniqueConstraint("gpts_name", name="uk_gpts"),)
|
||||
|
||||
|
||||
class GptsInstanceDao(BaseDao):
|
||||
def add(self, engity: GptsInstanceEntity):
|
||||
session = self.get_raw_session()
|
||||
session.add(engity)
|
||||
session.commit()
|
||||
id = engity.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_name(self, name: str) -> GptsInstanceEntity:
|
||||
session = self.get_raw_session()
|
||||
gpts_instance = session.query(GptsInstanceEntity)
|
||||
if name:
|
||||
gpts_instance = gpts_instance.filter(GptsInstanceEntity.gpts_name == name)
|
||||
result = gpts_instance.first()
|
||||
session.close()
|
||||
return result
|
160
dbgpt/serve/agent/db/gpts_messages_db.py
Normal file
160
dbgpt/serve/agent/db/gpts_messages_db.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Index,
|
||||
DateTime,
|
||||
func,
|
||||
Text,
|
||||
or_,
|
||||
and_,
|
||||
desc,
|
||||
)
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsMessagesEntity(Model):
|
||||
__tablename__ = "gpts_messages"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
sender = Column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
comment="Who speaking in the current conversation turn",
|
||||
)
|
||||
receiver = Column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
comment="Who receive message in the current conversation turn",
|
||||
)
|
||||
model_name = Column(String(255), nullable=True, comment="message generate model")
|
||||
rounds = Column(Integer, nullable=False, comment="dialogue turns")
|
||||
content = Column(Text, nullable=True, comment="Content of the speech")
|
||||
current_gogal = Column(
|
||||
Text, nullable=True, comment="The target corresponding to the current message"
|
||||
)
|
||||
context = Column(Text, nullable=True, comment="Current conversation context")
|
||||
review_info = Column(
|
||||
Text, nullable=True, comment="Current conversation review info"
|
||||
)
|
||||
action_report = Column(
|
||||
Text, nullable=True, comment="Current conversation action report"
|
||||
)
|
||||
|
||||
role = Column(
|
||||
String(255), nullable=True, comment="The role of the current message content"
|
||||
)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
__table_args__ = (Index("idx_q_messages", "conv_id", "rounds", "sender"),)
|
||||
|
||||
|
||||
class GptsMessagesDao(BaseDao):
|
||||
def append(self, entity: dict):
|
||||
session = self.get_raw_session()
|
||||
message = GptsMessagesEntity(
|
||||
conv_id=entity.get("conv_id"),
|
||||
sender=entity.get("sender"),
|
||||
receiver=entity.get("receiver"),
|
||||
content=entity.get("content"),
|
||||
role=entity.get("role", None),
|
||||
model_name=entity.get("model_name", None),
|
||||
context=entity.get("context", None),
|
||||
rounds=entity.get("rounds", None),
|
||||
current_gogal=entity.get("current_gogal", None),
|
||||
review_info=entity.get("review_info", None),
|
||||
action_report=entity.get("action_report", None),
|
||||
)
|
||||
session.add(message)
|
||||
session.commit()
|
||||
id = message.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_agent(
|
||||
self, conv_id: str, agent: str
|
||||
) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if agent:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).filter(
|
||||
or_(
|
||||
GptsMessagesEntity.sender == agent,
|
||||
GptsMessagesEntity.receiver == agent,
|
||||
)
|
||||
)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if conv_id:
|
||||
gpts_messages = gpts_messages.filter(GptsMessagesEntity.conv_id == conv_id)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_gogal: Optional[str] = None,
|
||||
) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if agent1 and agent2:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).filter(
|
||||
or_(
|
||||
and_(
|
||||
GptsMessagesEntity.sender == agent1,
|
||||
GptsMessagesEntity.receiver == agent2,
|
||||
),
|
||||
and_(
|
||||
GptsMessagesEntity.sender == agent2,
|
||||
GptsMessagesEntity.receiver == agent1,
|
||||
),
|
||||
)
|
||||
)
|
||||
if current_gogal:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.current_gogal == current_gogal
|
||||
)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessagesEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if conv_id:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).order_by(desc(GptsMessagesEntity.rounds))
|
||||
|
||||
result = gpts_messages.first()
|
||||
session.close()
|
||||
return result
|
156
dbgpt/serve/agent/db/gpts_plans_db.py
Normal file
156
dbgpt/serve/agent/db/gpts_plans_db.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt.agent.common.schema import Status
|
||||
|
||||
|
||||
class GptsPlansEntity(Model):
|
||||
__tablename__ = "gpts_plans"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
sub_task_num = Column(Integer, nullable=False, comment="Subtask number")
|
||||
sub_task_title = Column(String(255), nullable=False, comment="subtask title")
|
||||
sub_task_content = Column(Text, nullable=False, comment="subtask content")
|
||||
sub_task_agent = Column(
|
||||
String(255), nullable=True, comment="Available agents corresponding to subtasks"
|
||||
)
|
||||
resource_name = Column(String(255), nullable=True, comment="resource name")
|
||||
rely = Column(
|
||||
String(255), nullable=True, comment="Subtask dependencies,like: 1,2,3"
|
||||
)
|
||||
|
||||
agent_model = Column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="LLM model used by subtask processing agents",
|
||||
)
|
||||
retry_times = Column(Integer, default=False, comment="number of retries")
|
||||
max_retry_times = Column(
|
||||
Integer, default=False, comment="Maximum number of retries"
|
||||
)
|
||||
state = Column(String(255), nullable=True, comment="subtask status")
|
||||
result = Column(Text(length=2**31 - 1), nullable=True, comment="subtask result")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("conv_id", "sub_task_num", name="uk_sub_task"),)
|
||||
|
||||
|
||||
class GptsPlansDao(BaseDao):
|
||||
def batch_save(self, plans: list[dict]):
|
||||
session = self.get_raw_session()
|
||||
session.bulk_insert_mappings(GptsPlansEntity, plans)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if conv_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id)
|
||||
result = gpts_plans.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_task_id(self, task_id: int) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if task_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.id == task_id)
|
||||
result = gpts_plans.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: list
|
||||
) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if conv_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num.in_(task_nums)
|
||||
)
|
||||
result = gpts_plans.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_todo_plans(self, conv_id: str) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if not conv_id:
|
||||
return []
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.state.in_([Status.TODO.value, Status.RETRYING.value])
|
||||
)
|
||||
result = gpts_plans.order_by(GptsPlansEntity.sub_task_num).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num == task_num
|
||||
)
|
||||
gpts_plans.update(
|
||||
{
|
||||
GptsPlansEntity.state: Status.COMPLETE.value,
|
||||
GptsPlansEntity.result: result,
|
||||
},
|
||||
synchronize_session="fetch",
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: str = None,
|
||||
model: str = None,
|
||||
result: str = None,
|
||||
):
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num == task_num
|
||||
)
|
||||
update_param = {}
|
||||
update_param[GptsPlansEntity.state] = state
|
||||
update_param[GptsPlansEntity.retry_times] = retry_times
|
||||
update_param[GptsPlansEntity.result] = result
|
||||
if agent:
|
||||
update_param[GptsPlansEntity.sub_task_agent] = agent
|
||||
if model:
|
||||
update_param[GptsPlansEntity.agent_model] = model
|
||||
|
||||
gpts_plans.update(update_param, synchronize_session="fetch")
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def remove_by_conv_id(self, conv_id: str):
|
||||
session = self.get_raw_session()
|
||||
if conv_id is None:
|
||||
raise Exception("conv_id is None")
|
||||
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).delete()
|
||||
session.commit()
|
||||
session.close()
|
137
dbgpt/serve/agent/db/my_plugin_db.py
Normal file
137
dbgpt/serve/agent/db/my_plugin_db.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class MyPluginEntity(Model):
|
||||
__tablename__ = "my_plugin"
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String(255), nullable=True, comment="user's tenant")
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total use count"
|
||||
)
|
||||
succ_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total success count"
|
||||
)
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin install time"
|
||||
)
|
||||
UniqueConstraint("user_code", "name", name="uk_name")
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao):
|
||||
def add(self, engity: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
my_plugin = MyPluginEntity(
|
||||
tenant=engity.tenant,
|
||||
user_code=engity.user_code,
|
||||
user_name=engity.user_name,
|
||||
name=engity.name,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
use_count=engity.use_count or 0,
|
||||
succ_count=engity.succ_count or 0,
|
||||
sys_code=engity.sys_code,
|
||||
gmt_created=datetime.now(),
|
||||
)
|
||||
session.add(my_plugin)
|
||||
session.commit()
|
||||
id = my_plugin.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def raw_update(self, entity: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
|
||||
result = my_plugins.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
|
||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
count = my_plugins.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def raw_delete(self, plugin_id: int):
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
my_plugins.delete()
|
||||
session.commit()
|
||||
session.close()
|
139
dbgpt/serve/agent/db/plugin_hub_db.py
Normal file
139
dbgpt/serve/agent/db/plugin_hub_db.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||
|
||||
|
||||
class PluginHubEntity(Model):
|
||||
__tablename__ = "plugin_hub"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String(255), nullable=False, comment="plugin description")
|
||||
author = Column(String(255), nullable=True, comment="plugin author")
|
||||
email = Column(String(255), nullable=True, comment="plugin author email")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
storage_channel = Column(String(255), comment="plugin storage channel")
|
||||
storage_url = Column(String(255), comment="plugin download url")
|
||||
download_param = Column(String(255), comment="plugin download param")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin upload time"
|
||||
)
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
UniqueConstraint("name", name="uk_name")
|
||||
Index("idx_q_type", "type")
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao):
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
timezone = pytz.timezone("Asia/Shanghai")
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
author=engity.author,
|
||||
email=engity.email,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
storage_channel=engity.storage_channel,
|
||||
storage_url=engity.storage_url,
|
||||
gmt_created=timezone.localize(datetime.now()),
|
||||
)
|
||||
session.add(plugin_hub)
|
||||
session.commit()
|
||||
id = plugin_hub.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def raw_update(self, entity: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def list(
|
||||
self, query: PluginHubEntity, page=1, page_size=20
|
||||
) -> list[PluginHubEntity]:
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
|
||||
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
|
||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def get_by_storage_url(self, storage_url):
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||
result = plugin_hubs.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def count(self, query: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
count = plugin_hubs.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def raw_delete(self, plugin_id: int):
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if plugin_id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
session.close()
|
0
dbgpt/serve/agent/dbgpts/__init__.py
Normal file
0
dbgpt/serve/agent/dbgpts/__init__.py
Normal file
0
dbgpt/serve/agent/hub/__init__.py
Normal file
0
dbgpt/serve/agent/hub/__init__.py
Normal file
208
dbgpt/serve/agent/hub/agent_hub.py
Normal file
208
dbgpt/serve/agent/hub/agent_hub.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
from typing import Any
|
||||
import tempfile
|
||||
|
||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
|
||||
from dbgpt.agent.common.schema import PluginStorageType
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Default_User = "default"
|
||||
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
||||
TEMP_PLUGIN_PATH = ""
|
||||
|
||||
|
||||
class AgentHub:
|
||||
def __init__(self, plugin_dir) -> None:
|
||||
self.hub_dao = PluginHubDao()
|
||||
self.my_plugin_dao = MyPluginDao()
|
||||
os.makedirs(plugin_dir, exist_ok=True)
|
||||
self.plugin_dir = plugin_dir
|
||||
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
||||
|
||||
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||
logger.info(f"install_plugin {plugin_name}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
if plugin_entity:
|
||||
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||
try:
|
||||
branch_name = None
|
||||
authorization = None
|
||||
if plugin_entity.download_param:
|
||||
download_param = json.loads(plugin_entity.download_param)
|
||||
branch_name = download_param.get("branch_name")
|
||||
authorization = download_param.get("authorization")
|
||||
file_name = self.__download_from_git(
|
||||
plugin_entity.storage_url, branch_name, authorization
|
||||
)
|
||||
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = plugin_entity.installed + 1
|
||||
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
|
||||
user_name, plugin_name
|
||||
)
|
||||
if my_plugin_entity is None:
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
my_plugin_entity.file_name = file_name
|
||||
if user_name:
|
||||
# TODO use user
|
||||
my_plugin_entity.user_code = user_name
|
||||
my_plugin_entity.user_name = user_name
|
||||
my_plugin_entity.tenant = ""
|
||||
else:
|
||||
my_plugin_entity.user_code = Default_User
|
||||
|
||||
with self.hub_dao.session() as session:
|
||||
if my_plugin_entity.id is None:
|
||||
session.add(my_plugin_entity)
|
||||
else:
|
||||
session.merge(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
except Exception as e:
|
||||
logger.error("install pluguin exception!", e)
|
||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||
|
||||
def uninstall_plugin(self, plugin_name, user):
|
||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
||||
if plugin_entity is not None:
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
with self.hub_dao.session() as session:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||
MyPluginEntity.name == plugin_name
|
||||
)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
if plugin_entity is not None:
|
||||
session.merge(plugin_entity)
|
||||
|
||||
if plugin_entity is not None:
|
||||
# delete package file if not use
|
||||
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||
have_installed = False
|
||||
for plugin_info in plugin_infos:
|
||||
if plugin_info.installed > 0:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = (
|
||||
plugin_entity.storage_url.replace(".git", "")
|
||||
.strip("/")
|
||||
.split("/")[-1]
|
||||
)
|
||||
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
else:
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
|
||||
def __download_from_git(self, github_repo, branch_name, authorization):
|
||||
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
|
||||
|
||||
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||
my_plugin_entity = MyPluginEntity()
|
||||
my_plugin_entity.name = hub_plugin.name
|
||||
my_plugin_entity.type = hub_plugin.type
|
||||
my_plugin_entity.version = hub_plugin.version
|
||||
return my_plugin_entity
|
||||
|
||||
def refresh_hub_from_git(
|
||||
self,
|
||||
github_repo: str = None,
|
||||
branch_name: str = "main",
|
||||
authorization: str = None,
|
||||
):
|
||||
logger.info("refresh_hub_by_git start!")
|
||||
update_from_git(
|
||||
self.temp_hub_file_path, github_repo, branch_name, authorization
|
||||
)
|
||||
git_plugins = scan_plugins(self.temp_hub_file_path)
|
||||
try:
|
||||
for git_plugin in git_plugins:
|
||||
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
|
||||
if old_hub_info:
|
||||
plugin_hub_info = old_hub_info
|
||||
else:
|
||||
plugin_hub_info = PluginHubEntity()
|
||||
plugin_hub_info.type = ""
|
||||
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
||||
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
||||
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
|
||||
plugin_hub_info.email = getattr(git_plugin, "_email", "")
|
||||
download_param = {}
|
||||
if branch_name:
|
||||
download_param["branch_name"] = branch_name
|
||||
if authorization and len(authorization) > 0:
|
||||
download_param["authorization"] = authorization
|
||||
plugin_hub_info.download_param = json.dumps(download_param)
|
||||
plugin_hub_info.installed = 0
|
||||
|
||||
plugin_hub_info.name = git_plugin._name
|
||||
plugin_hub_info.version = git_plugin._version
|
||||
plugin_hub_info.description = git_plugin._description
|
||||
self.hub_dao.raw_update(plugin_hub_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
|
||||
|
||||
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
file_path = os.path.join(self.plugin_dir, doc_file.filename)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(self.plugin_dir, doc_file.filename),
|
||||
)
|
||||
|
||||
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
|
||||
|
||||
if user is None or len(user) <= 0:
|
||||
user = Default_User
|
||||
|
||||
for my_plugin in my_plugins:
|
||||
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
|
||||
user, my_plugin._name
|
||||
)
|
||||
if my_plugin_entiy is None:
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
my_plugin_entiy.name = my_plugin._name
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.type = "Personal"
|
||||
my_plugin_entiy.user_code = user
|
||||
my_plugin_entiy.user_name = user
|
||||
my_plugin_entiy.tenant = ""
|
||||
my_plugin_entiy.file_name = doc_file.filename
|
||||
self.my_plugin_dao.raw_update(my_plugin_entiy)
|
||||
|
||||
def reload_my_plugins(self):
|
||||
logger.info(f"load_plugins start!")
|
||||
return scan_plugins(self.plugin_dir)
|
||||
|
||||
def get_my_plugin(self, user: str):
|
||||
logger.info(f"get_my_plugin:{user}")
|
||||
if not user:
|
||||
user = Default_User
|
||||
return self.my_plugin_dao.get_by_user(user)
|
160
dbgpt/serve/agent/hub/controller.py
Normal file
160
dbgpt/serve/agent/hub/controller.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import logging
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
|
||||
from dbgpt.serve.agent.model import (
|
||||
PluginHubParam,
|
||||
PagenationFilter,
|
||||
PagenationResult,
|
||||
PluginHubFilter,
|
||||
)
|
||||
from dbgpt.serve.agent.hub.agent_hub import AgentHub
|
||||
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
|
||||
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
|
||||
from dbgpt.configs.model_config import PLUGINS_DIR
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleAgent(BaseComponent, ABC):
|
||||
name = ComponentType.AGENT_HUB
|
||||
|
||||
def __init__(self):
|
||||
# load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(
|
||||
self, generator: PluginPromptGenerator, select_plugins: List[str]
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
if plugin._name in select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
|
||||
|
||||
module_agent = ModuleAgent()
|
||||
|
||||
|
||||
@router.post("/v1/agent/hub/update", response_model=Result[str])
|
||||
async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
branch = (
|
||||
update_param.branch
|
||||
if update_param.branch is not None and len(update_param.branch) > 0
|
||||
else "main"
|
||||
)
|
||||
authorization = (
|
||||
update_param.authorization
|
||||
if update_param.branch is not None and len(update_param.branch) > 0
|
||||
else None
|
||||
)
|
||||
# TODO change it to async
|
||||
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Agent Hub Update Error!", e)
|
||||
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
|
||||
|
||||
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
logger.info(f"get_agent_list:{filter.__dict__}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
filter_enetity: PluginHubEntity = PluginHubEntity()
|
||||
if filter.filter:
|
||||
attrs = vars(filter.filter) # 获取原始对象的属性字典
|
||||
for attr, value in attrs.items():
|
||||
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
|
||||
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(
|
||||
filter_enetity, filter.page_index, filter.page_size
|
||||
)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
result.page_index = filter.page_index
|
||||
result.page_size = filter.page_size
|
||||
result.total_page = total_pages
|
||||
result.total_row_count = total_count
|
||||
result.datas = datas
|
||||
# print(json.dumps(result.to_dic()))
|
||||
return Result.succ(result.to_dic())
|
||||
|
||||
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
async def my_agents(user: str = None):
|
||||
logger.info(f"my_agents:{user}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agents = agent_hub.get_my_plugin(user)
|
||||
agent_dicts = []
|
||||
for agent in agents:
|
||||
agent_dicts.append(agent.__dict__)
|
||||
|
||||
return Result.succ(agent_dicts)
|
||||
|
||||
|
||||
@router.post("/v1/agent/install", response_model=Result[str])
|
||||
async def agent_install(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_install:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.install_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.uninstall_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
|
||||
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
await agent_hub.upload_my_plugin(doc_file, user)
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Upload Personal Plugin Error!", e)
|
||||
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")
|
69
dbgpt/serve/agent/model.py
Normal file
69
dbgpt/serve/agent/model.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Generic, Any
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PagenationFilter(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
filter: T = None
|
||||
|
||||
|
||||
class PagenationResult(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
total_page: int = 0
|
||||
total_row_count: int = 0
|
||||
datas: List[T] = []
|
||||
|
||||
def to_dic(self):
|
||||
data_dicts = []
|
||||
for item in self.datas:
|
||||
data_dicts.append(item.__dict__)
|
||||
return {
|
||||
"page_index": self.page_index,
|
||||
"page_size": self.page_size,
|
||||
"total_page": self.total_page,
|
||||
"total_row_count": self.total_row_count,
|
||||
"datas": data_dicts,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginHubFilter(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
author: str
|
||||
email: str
|
||||
type: str
|
||||
version: str
|
||||
storage_channel: str
|
||||
storage_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyPluginFilter(BaseModel):
|
||||
tenant: str
|
||||
user_code: str
|
||||
user_name: str
|
||||
name: str
|
||||
file_name: str
|
||||
type: str
|
||||
version: str
|
||||
|
||||
|
||||
class PluginHubParam(BaseModel):
|
||||
channel: Optional[str] = Field("git", description="Plugin storage channel")
|
||||
url: Optional[str] = Field(
|
||||
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
|
||||
description="Plugin storage url",
|
||||
)
|
||||
branch: Optional[str] = Field(
|
||||
"main", description="github download branch", nullable=True
|
||||
)
|
||||
authorization: Optional[str] = Field(
|
||||
None, description="github download authorization", nullable=True
|
||||
)
|
Reference in New Issue
Block a user