mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-29 21:49:35 +00:00
feat(Agent): ChatAgent And AgentHub
1.Upgrade sqlalchemy to version 2.0
This commit is contained in:
@@ -71,6 +71,20 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == user
|
||||
)
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.name == plugin
|
||||
)
|
||||
result = my_plugins.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
|
||||
@@ -21,7 +21,7 @@ TEMP_PLUGIN_PATH = ""
|
||||
class AgentHub:
|
||||
def __init__(self, plugin_dir) -> None:
|
||||
self.hub_dao = PluginHubDao()
|
||||
self.my_lugin_dao = MyPluginDao()
|
||||
self.my_plugin_dao = MyPluginDao()
|
||||
os.makedirs(plugin_dir, exist_ok=True)
|
||||
self.plugin_dir = plugin_dir
|
||||
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
||||
@@ -43,11 +43,13 @@ class AgentHub:
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = plugin_entity.installed + 1
|
||||
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user_name, plugin_name)
|
||||
if my_plugin_entity is None:
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
my_plugin_entity.file_name = file_name
|
||||
if user_name:
|
||||
# TODO use user
|
||||
my_plugin_entity.user_code = ""
|
||||
my_plugin_entity.user_code = user_name
|
||||
my_plugin_entity.user_name = user_name
|
||||
my_plugin_entity.tenant = ""
|
||||
else:
|
||||
@@ -55,11 +57,15 @@ class AgentHub:
|
||||
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
session.add(my_plugin_entity)
|
||||
if my_plugin_entity.id is None:
|
||||
session.add(my_plugin_entity)
|
||||
else:
|
||||
session.merge(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
session.close()
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error("install merge roll back!" + str(e))
|
||||
session.rollback()
|
||||
except Exception as e:
|
||||
logger.error("install pluguin exception!", e)
|
||||
@@ -72,29 +78,39 @@ class AgentHub:
|
||||
def uninstall_plugin(self, plugin_name, user):
|
||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
||||
if plugin_entity is not None:
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
session.merge(plugin_entity)
|
||||
if plugin_entity is not None:
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
|
||||
# delete package file if not use
|
||||
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||
have_installed = False
|
||||
for plugin_info in plugin_infos:
|
||||
if plugin_info.installed > 0:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
|
||||
if plugin_entity is not None:
|
||||
# delete package file if not use
|
||||
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||
have_installed = False
|
||||
for plugin_info in plugin_infos:
|
||||
if plugin_info.installed > 0:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
else:
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
|
||||
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
@@ -162,8 +178,9 @@ class AgentHub:
|
||||
user = Default_User
|
||||
|
||||
for my_plugin in my_plugins:
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
|
||||
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(user, my_plugin._name)
|
||||
if my_plugin_entiy is None :
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
my_plugin_entiy.name = my_plugin._name
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.type = "Personal"
|
||||
@@ -171,8 +188,7 @@ class AgentHub:
|
||||
my_plugin_entiy.user_name = user
|
||||
my_plugin_entiy.tenant = ""
|
||||
my_plugin_entiy.file_name = doc_file.filename
|
||||
|
||||
self.my_lugin_dao.update(my_plugin_entiy)
|
||||
self.my_plugin_dao.update(my_plugin_entiy)
|
||||
|
||||
def reload_my_plugins(self):
|
||||
logger.info(f"load_plugins start!")
|
||||
@@ -182,5 +198,5 @@ class AgentHub:
|
||||
logger.info(f"get_my_plugin:{user}")
|
||||
if not user:
|
||||
user = Default_User
|
||||
return self.my_lugin_dao.get_by_user(user)
|
||||
return self.my_plugin_dao.get_by_user(user)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user