feat(Agent): ChatAgent And AgentHub

1.Upgrade sqlalchemy to version 2.0
This commit is contained in:
yhjun1026
2023-10-17 15:39:15 +08:00
parent 87243ae504
commit c4f8e0ecad
20 changed files with 94 additions and 68 deletions

View File

@@ -71,6 +71,20 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session.close()
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == user
)
my_plugins = my_plugins.filter(
MyPluginEntity.name == plugin
)
result = my_plugins.first()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.get_session()

View File

@@ -21,7 +21,7 @@ TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_lugin_dao = MyPluginDao()
self.my_plugin_dao = MyPluginDao()
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
@@ -43,11 +43,13 @@ class AgentHub:
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user_name, plugin_name)
if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = ""
my_plugin_entity.user_code = user_name
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
else:
@@ -55,11 +57,15 @@ class AgentHub:
with self.hub_dao.get_session() as session:
try:
session.add(my_plugin_entity)
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except:
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
except Exception as e:
logger.error("install pluguin exception!", e)
@@ -72,29 +78,39 @@ class AgentHub:
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
plugin_entity.installed = plugin_entity.installed - 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
session.merge(plugin_entity)
if plugin_entity is not None:
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
if plugin_entity is not None:
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
files = glob.glob(
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
)
for file in files:
os.remove(file)
else:
files = glob.glob(
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
)
for file in files:
os.remove(file)
@@ -162,8 +178,9 @@ class AgentHub:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(user, my_plugin._name)
if my_plugin_entiy is None :
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
@@ -171,8 +188,7 @@ class AgentHub:
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_lugin_dao.update(my_plugin_entiy)
self.my_plugin_dao.update(my_plugin_entiy)
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
@@ -182,5 +198,5 @@ class AgentHub:
logger.info(f"get_my_plugin:{user}")
if not user:
user = Default_User
return self.my_lugin_dao.get_by_user(user)
return self.my_plugin_dao.get_by_user(user)