mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
feat(agent): Multi agent sdk (#976)
Co-authored-by: xtyuns <xtyuns@163.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: qidanrui <qidanrui@gmail.com>
This commit is contained in:
6
dbgpt/serve/agent/db/__init__.py
Normal file
6
dbgpt/serve/agent/db/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
|
||||
from .gpts_mange_db import GptsInstanceDao, GptsInstanceEntity
|
||||
from .gpts_messages_db import GptsMessagesDao, GptsMessagesEntity
|
||||
from .gpts_plans_db import GptsPlansDao, GptsPlansEntity
|
||||
from .my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
from .plugin_hub_db import PluginHubDao, PluginHubEntity
|
95
dbgpt/serve/agent/db/gpts_conversations_db.py
Normal file
95
dbgpt/serve/agent/db/gpts_conversations_db.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, desc
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsConversationsEntity(Model):
|
||||
__tablename__ = "gpts_conversations"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
user_goal = Column(Text, nullable=False, comment="User's goals content")
|
||||
|
||||
gpts_name = Column(String(255), nullable=False, comment="The gpts name")
|
||||
state = Column(String(255), nullable=True, comment="The gpts state")
|
||||
|
||||
max_auto_reply_round = Column(
|
||||
Integer, nullable=False, comment="max auto reply round"
|
||||
)
|
||||
auto_reply_count = Column(Integer, nullable=False, comment="auto reply count")
|
||||
|
||||
user_code = Column(String(255), nullable=True, comment="user code")
|
||||
sys_code = Column(String(255), nullable=True, comment="system app ")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("conv_id", name="uk_gpts_conversations"),
|
||||
Index("idx_gpts_name", "gpts_name"),
|
||||
)
|
||||
|
||||
|
||||
class GptsConversationsDao(BaseDao):
|
||||
def add(self, engity: GptsConversationsEntity):
|
||||
session = self.get_raw_session()
|
||||
session.add(engity)
|
||||
session.commit()
|
||||
id = engity.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_conv_id(self, conv_id: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_conv = session.query(GptsConversationsEntity)
|
||||
if conv_id:
|
||||
gpts_conv = gpts_conv.filter(GptsConversationsEntity.conv_id == conv_id)
|
||||
result = gpts_conv.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_convs(self, user_code: str = None, system_app: str = None):
|
||||
session = self.get_raw_session()
|
||||
gpts_conversations = session.query(GptsConversationsEntity)
|
||||
if user_code:
|
||||
gpts_conversations = gpts_conversations.filter(
|
||||
GptsConversationsEntity.user_code == user_code
|
||||
)
|
||||
if system_app:
|
||||
gpts_conversations = gpts_conversations.filter(
|
||||
GptsConversationsEntity.system_app == system_app
|
||||
)
|
||||
|
||||
result = (
|
||||
gpts_conversations.limit(20)
|
||||
.order_by(desc(GptsConversationsEntity.id))
|
||||
.all()
|
||||
)
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update(self, conv_id: str, state: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_convs = session.query(GptsConversationsEntity)
|
||||
gpts_convs = gpts_convs.filter(GptsConversationsEntity.conv_id == conv_id)
|
||||
gpts_convs.update(
|
||||
{GptsConversationsEntity.state: state}, synchronize_session="fetch"
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
78
dbgpt/serve/agent/db/gpts_mange_db.py
Normal file
78
dbgpt/serve/agent/db/gpts_mange_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text, Boolean
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsInstanceEntity(Model):
|
||||
__tablename__ = "gpts_instance"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
gpts_name = Column(String(255), nullable=False, comment="Current AI assistant name")
|
||||
gpts_describe = Column(
|
||||
String(2255), nullable=False, comment="Current AI assistant describe"
|
||||
)
|
||||
resource_db = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="List of structured database names contained in the current gpts",
|
||||
)
|
||||
resource_internet = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="Is it possible to retrieve information from the internet",
|
||||
)
|
||||
resource_knowledge = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="List of unstructured database names contained in the current gpts",
|
||||
)
|
||||
gpts_agents = Column(
|
||||
String(1000),
|
||||
nullable=True,
|
||||
comment="List of agents names contained in the current gpts",
|
||||
)
|
||||
gpts_models = Column(
|
||||
String(1000),
|
||||
nullable=True,
|
||||
comment="List of llm model names contained in the current gpts",
|
||||
)
|
||||
language = Column(String(100), nullable=True, comment="gpts language")
|
||||
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
sys_code = Column(String(255), nullable=True, comment="system app code")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
|
||||
__table_args__ = (UniqueConstraint("gpts_name", name="uk_gpts"),)
|
||||
|
||||
|
||||
class GptsInstanceDao(BaseDao):
|
||||
def add(self, engity: GptsInstanceEntity):
|
||||
session = self.get_raw_session()
|
||||
session.add(engity)
|
||||
session.commit()
|
||||
id = engity.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_name(self, name: str) -> GptsInstanceEntity:
|
||||
session = self.get_raw_session()
|
||||
gpts_instance = session.query(GptsInstanceEntity)
|
||||
if name:
|
||||
gpts_instance = gpts_instance.filter(GptsInstanceEntity.gpts_name == name)
|
||||
result = gpts_instance.first()
|
||||
session.close()
|
||||
return result
|
160
dbgpt/serve/agent/db/gpts_messages_db.py
Normal file
160
dbgpt/serve/agent/db/gpts_messages_db.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Index,
|
||||
DateTime,
|
||||
func,
|
||||
Text,
|
||||
or_,
|
||||
and_,
|
||||
desc,
|
||||
)
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class GptsMessagesEntity(Model):
|
||||
__tablename__ = "gpts_messages"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
sender = Column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
comment="Who speaking in the current conversation turn",
|
||||
)
|
||||
receiver = Column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
comment="Who receive message in the current conversation turn",
|
||||
)
|
||||
model_name = Column(String(255), nullable=True, comment="message generate model")
|
||||
rounds = Column(Integer, nullable=False, comment="dialogue turns")
|
||||
content = Column(Text, nullable=True, comment="Content of the speech")
|
||||
current_gogal = Column(
|
||||
Text, nullable=True, comment="The target corresponding to the current message"
|
||||
)
|
||||
context = Column(Text, nullable=True, comment="Current conversation context")
|
||||
review_info = Column(
|
||||
Text, nullable=True, comment="Current conversation review info"
|
||||
)
|
||||
action_report = Column(
|
||||
Text, nullable=True, comment="Current conversation action report"
|
||||
)
|
||||
|
||||
role = Column(
|
||||
String(255), nullable=True, comment="The role of the current message content"
|
||||
)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
__table_args__ = (Index("idx_q_messages", "conv_id", "rounds", "sender"),)
|
||||
|
||||
|
||||
class GptsMessagesDao(BaseDao):
|
||||
def append(self, entity: dict):
|
||||
session = self.get_raw_session()
|
||||
message = GptsMessagesEntity(
|
||||
conv_id=entity.get("conv_id"),
|
||||
sender=entity.get("sender"),
|
||||
receiver=entity.get("receiver"),
|
||||
content=entity.get("content"),
|
||||
role=entity.get("role", None),
|
||||
model_name=entity.get("model_name", None),
|
||||
context=entity.get("context", None),
|
||||
rounds=entity.get("rounds", None),
|
||||
current_gogal=entity.get("current_gogal", None),
|
||||
review_info=entity.get("review_info", None),
|
||||
action_report=entity.get("action_report", None),
|
||||
)
|
||||
session.add(message)
|
||||
session.commit()
|
||||
id = message.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def get_by_agent(
|
||||
self, conv_id: str, agent: str
|
||||
) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if agent:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).filter(
|
||||
or_(
|
||||
GptsMessagesEntity.sender == agent,
|
||||
GptsMessagesEntity.receiver == agent,
|
||||
)
|
||||
)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if conv_id:
|
||||
gpts_messages = gpts_messages.filter(GptsMessagesEntity.conv_id == conv_id)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_gogal: Optional[str] = None,
|
||||
) -> Optional[List[GptsMessagesEntity]]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if agent1 and agent2:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).filter(
|
||||
or_(
|
||||
and_(
|
||||
GptsMessagesEntity.sender == agent1,
|
||||
GptsMessagesEntity.receiver == agent2,
|
||||
),
|
||||
and_(
|
||||
GptsMessagesEntity.sender == agent2,
|
||||
GptsMessagesEntity.receiver == agent1,
|
||||
),
|
||||
)
|
||||
)
|
||||
if current_gogal:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.current_gogal == current_gogal
|
||||
)
|
||||
result = gpts_messages.order_by(GptsMessagesEntity.rounds).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessagesEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_messages = session.query(GptsMessagesEntity)
|
||||
if conv_id:
|
||||
gpts_messages = gpts_messages.filter(
|
||||
GptsMessagesEntity.conv_id == conv_id
|
||||
).order_by(desc(GptsMessagesEntity.rounds))
|
||||
|
||||
result = gpts_messages.first()
|
||||
session.close()
|
||||
return result
|
156
dbgpt/serve/agent/db/gpts_plans_db.py
Normal file
156
dbgpt/serve/agent/db/gpts_plans_db.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt.agent.common.schema import Status
|
||||
|
||||
|
||||
class GptsPlansEntity(Model):
|
||||
__tablename__ = "gpts_plans"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
|
||||
conv_id = Column(
|
||||
String(255), nullable=False, comment="The unique id of the conversation record"
|
||||
)
|
||||
sub_task_num = Column(Integer, nullable=False, comment="Subtask number")
|
||||
sub_task_title = Column(String(255), nullable=False, comment="subtask title")
|
||||
sub_task_content = Column(Text, nullable=False, comment="subtask content")
|
||||
sub_task_agent = Column(
|
||||
String(255), nullable=True, comment="Available agents corresponding to subtasks"
|
||||
)
|
||||
resource_name = Column(String(255), nullable=True, comment="resource name")
|
||||
rely = Column(
|
||||
String(255), nullable=True, comment="Subtask dependencies,like: 1,2,3"
|
||||
)
|
||||
|
||||
agent_model = Column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="LLM model used by subtask processing agents",
|
||||
)
|
||||
retry_times = Column(Integer, default=False, comment="number of retries")
|
||||
max_retry_times = Column(
|
||||
Integer, default=False, comment="Maximum number of retries"
|
||||
)
|
||||
state = Column(String(255), nullable=True, comment="subtask status")
|
||||
result = Column(Text(length=2**31 - 1), nullable=True, comment="subtask result")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="create time")
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="last update time",
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("conv_id", "sub_task_num", name="uk_sub_task"),)
|
||||
|
||||
|
||||
class GptsPlansDao(BaseDao):
|
||||
def batch_save(self, plans: list[dict]):
|
||||
session = self.get_raw_session()
|
||||
session.bulk_insert_mappings(GptsPlansEntity, plans)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if conv_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id)
|
||||
result = gpts_plans.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_task_id(self, task_id: int) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if task_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.id == task_id)
|
||||
result = gpts_plans.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: list
|
||||
) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if conv_id:
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num.in_(task_nums)
|
||||
)
|
||||
result = gpts_plans.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_todo_plans(self, conv_id: str) -> list[GptsPlansEntity]:
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
if not conv_id:
|
||||
return []
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.state.in_([Status.TODO.value, Status.RETRYING.value])
|
||||
)
|
||||
result = gpts_plans.order_by(GptsPlansEntity.sub_task_num).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str):
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num == task_num
|
||||
)
|
||||
gpts_plans.update(
|
||||
{
|
||||
GptsPlansEntity.state: Status.COMPLETE.value,
|
||||
GptsPlansEntity.result: result,
|
||||
},
|
||||
synchronize_session="fetch",
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: str = None,
|
||||
model: str = None,
|
||||
result: str = None,
|
||||
):
|
||||
session = self.get_raw_session()
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans = gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).filter(
|
||||
GptsPlansEntity.sub_task_num == task_num
|
||||
)
|
||||
update_param = {}
|
||||
update_param[GptsPlansEntity.state] = state
|
||||
update_param[GptsPlansEntity.retry_times] = retry_times
|
||||
update_param[GptsPlansEntity.result] = result
|
||||
if agent:
|
||||
update_param[GptsPlansEntity.sub_task_agent] = agent
|
||||
if model:
|
||||
update_param[GptsPlansEntity.agent_model] = model
|
||||
|
||||
gpts_plans.update(update_param, synchronize_session="fetch")
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def remove_by_conv_id(self, conv_id: str):
|
||||
session = self.get_raw_session()
|
||||
if conv_id is None:
|
||||
raise Exception("conv_id is None")
|
||||
|
||||
gpts_plans = session.query(GptsPlansEntity)
|
||||
gpts_plans.filter(GptsPlansEntity.conv_id == conv_id).delete()
|
||||
session.commit()
|
||||
session.close()
|
137
dbgpt/serve/agent/db/my_plugin_db.py
Normal file
137
dbgpt/serve/agent/db/my_plugin_db.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class MyPluginEntity(Model):
|
||||
__tablename__ = "my_plugin"
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String(255), nullable=True, comment="user's tenant")
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total use count"
|
||||
)
|
||||
succ_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total success count"
|
||||
)
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin install time"
|
||||
)
|
||||
UniqueConstraint("user_code", "name", name="uk_name")
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao):
|
||||
def add(self, engity: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
my_plugin = MyPluginEntity(
|
||||
tenant=engity.tenant,
|
||||
user_code=engity.user_code,
|
||||
user_name=engity.user_name,
|
||||
name=engity.name,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
use_count=engity.use_count or 0,
|
||||
succ_count=engity.succ_count or 0,
|
||||
sys_code=engity.sys_code,
|
||||
gmt_created=datetime.now(),
|
||||
)
|
||||
session.add(my_plugin)
|
||||
session.commit()
|
||||
id = my_plugin.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def raw_update(self, entity: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
|
||||
result = my_plugins.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
|
||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
if query.sys_code is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||
count = my_plugins.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def raw_delete(self, plugin_id: int):
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
my_plugins.delete()
|
||||
session.commit()
|
||||
session.close()
|
139
dbgpt/serve/agent/db/plugin_hub_db.py
Normal file
139
dbgpt/serve/agent/db/plugin_hub_db.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||
|
||||
|
||||
class PluginHubEntity(Model):
|
||||
__tablename__ = "plugin_hub"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String(255), nullable=False, comment="plugin description")
|
||||
author = Column(String(255), nullable=True, comment="plugin author")
|
||||
email = Column(String(255), nullable=True, comment="plugin author email")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
storage_channel = Column(String(255), comment="plugin storage channel")
|
||||
storage_url = Column(String(255), comment="plugin download url")
|
||||
download_param = Column(String(255), comment="plugin download param")
|
||||
gmt_created = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin upload time"
|
||||
)
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
UniqueConstraint("name", name="uk_name")
|
||||
Index("idx_q_type", "type")
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao):
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
timezone = pytz.timezone("Asia/Shanghai")
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
author=engity.author,
|
||||
email=engity.email,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
storage_channel=engity.storage_channel,
|
||||
storage_url=engity.storage_url,
|
||||
gmt_created=timezone.localize(datetime.now()),
|
||||
)
|
||||
session.add(plugin_hub)
|
||||
session.commit()
|
||||
id = plugin_hub.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def raw_update(self, entity: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def list(
|
||||
self, query: PluginHubEntity, page=1, page_size=20
|
||||
) -> list[PluginHubEntity]:
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
|
||||
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
|
||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def get_by_storage_url(self, storage_url):
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||
result = plugin_hubs.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def count(self, query: PluginHubEntity):
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
count = plugin_hubs.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def raw_delete(self, plugin_id: int):
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if plugin_id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
session.close()
|
Reference in New Issue
Block a user