From ed702db9bc3d191c0d0cbd18e1630b04d9dfe5da Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 26 Sep 2023 17:07:42 +0800 Subject: [PATCH] feat(Agent): ChatAgent And AgentHub MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1.AgentHub Module complete 2.New ChatAgent Mode complete 3.MetaData develop,auto update ddl --- pilot/base_modules/agent/controller.py | 72 +++++++++--- pilot/base_modules/agent/db/my_plugin_db.py | 14 +-- pilot/base_modules/agent/db/plugin_hub_db.py | 23 ++-- pilot/base_modules/agent/model.py | 19 +++- pilot/base_modules/meta_data/meta_data.py | 107 +++++++++--------- pilot/base_modules/module_factory.py | 2 + pilot/componet.py | 1 + pilot/configs/config.py | 32 +----- pilot/connections/__init__.py | 1 + .../connections/manages/connect_config_db.py | 62 ++++++++++ .../connections/manages/connection_manager.py | 2 +- pilot/memory/__init__.py | 1 + pilot/memory/chat_history/base.py | 36 +++++- .../chat_history/chat_hisotry_factory.py | 28 +++++ pilot/memory/chat_history/chat_history_db.py | 88 ++++++++++++++ .../chat_history/store_type/__init__.py} | 0 .../{ => store_type}/duckdb_history.py | 52 +++++---- .../{ => store_type}/file_history.py | 6 +- .../{ => store_type}/mem_history.py | 3 + .../store_type/meta_db_history.py | 100 ++++++++++++++++ pilot/model/cluster/controller/controller.py | 4 +- pilot/openapi/api_v1/api_v1.py | 20 ++-- pilot/openapi/api_v1/editor/api_editor_v1.py | 22 ++-- pilot/scene/base_chat.py | 10 +- pilot/scene/chat_agent/chat.py | 14 +-- pilot/scene/chat_agent/prompt.py | 2 +- pilot/server/base.py | 3 +- pilot/server/componet_configs.py | 4 +- pilot/server/dbgpt_server.py | 8 +- 29 files changed, 549 insertions(+), 187 deletions(-) create mode 100644 pilot/connections/manages/connect_config_db.py create mode 100644 pilot/memory/chat_history/chat_hisotry_factory.py create mode 100644 pilot/memory/chat_history/chat_history_db.py rename pilot/{base_modules/agent/plugins_loader.py => memory/chat_history/store_type/__init__.py} (100%) rename pilot/memory/chat_history/{ => store_type}/duckdb_history.py (98%) rename pilot/memory/chat_history/{ => store_type}/file_history.py (93%) rename pilot/memory/chat_history/{ => store_type}/mem_history.py (88%) create mode 100644 pilot/memory/chat_history/store_type/meta_db_history.py diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py index 151f6e40b..86bef3569 100644 --- a/pilot/base_modules/agent/controller.py +++ b/pilot/base_modules/agent/controller.py @@ -6,7 +6,7 @@ from fastapi import ( UploadFile, File, ) - +from abc import ABC, abstractmethod from typing import List from pilot.configs.model_config import LOGDIR from pilot.utils import build_logger @@ -18,14 +18,44 @@ from pilot.openapi.api_view_model import ( from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter from .hub.agent_hub import AgentHub from .db.plugin_hub_db import PluginHubEntity -from .db.my_plugin_db import MyPluginEntity +from .plugins_util import scan_plugins +from .commands.generator import PluginPromptGenerator + from pilot.configs.model_config import PLUGINS_DIR +from pilot.componet import BaseComponet, ComponetType, SystemApp router = APIRouter() logger = build_logger("agent_mange", LOGDIR + "agent_mange.log") -@router.post("/api/v1/agent/hub/update", response_model=Result[str]) +class ModuleAgent(BaseComponet, ABC): + name = ComponetType.AGENT_HUB + + def __init__(self): + #load plugins + self.plugins = scan_plugins(PLUGINS_DIR) + + def init_app(self, system_app: SystemApp): + system_app.app.include_router(router, prefix="/api", tags=["Agent"]) + + + def refresh_plugins(self): + self.plugins = scan_plugins(PLUGINS_DIR) + + def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator: + logger.info(f"load_select_plugin:{select_plugins}") + # load select plugin + for plugin in self.plugins: + if plugin._name in select_plugins: + if not plugin.can_handle_post_prompt(): + continue + generator = plugin.post_prompt(generator) + return generator + +module_agent = ModuleAgent() + + +@router.post("/v1/agent/hub/update", response_model=Result[str]) async def agent_hub_update(update_param: PluginHubParam = Body()): logger.info(f"agent_hub_update:{update_param.__dict__}") try: @@ -38,14 +68,15 @@ async def agent_hub_update(update_param: PluginHubParam = Body()): -@router.post("/api/v1/agent/query", response_model=Result[str]) +@router.post("/v1/agent/query", response_model=Result[str]) async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): - logger.info(f"get_agent_list:{json.dumps(filter)}") + logger.info(f"get_agent_list:{filter.__dict__}") agent_hub = AgentHub(PLUGINS_DIR) filter_enetity:PluginHubEntity = PluginHubEntity() - attrs = vars(filter.filter) # 获取原始对象的属性字典 - for attr, value in attrs.items(): - setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值 + if filter.filter: + attrs = vars(filter.filter) # 获取原始对象的属性字典 + for attr, value in attrs.items(): + setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值 datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size) result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]() @@ -54,21 +85,30 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): result.total_page = total_pages result.total_row_count = total_count result.datas = datas - return Result.succ(result) + # print(json.dumps(result.to_dic())) + return Result.succ(result.to_dic()) -@router.post("/api/v1/agent/my", response_model=Result[str]) +@router.post("/v1/agent/my", response_model=Result[str]) async def my_agents(user:str= None): - logger.info(f"my_agents:{json.dumps(my_agents)}") + logger.info(f"my_agents:{user}") agent_hub = AgentHub(PLUGINS_DIR) - return Result.succ(agent_hub.get_my_plugin(user)) + agents = agent_hub.get_my_plugin(user) + agent_dicts = [] + for agent in agents: + agent_dicts.append(agent.__dict__) + + return Result.succ(agent_dicts) -@router.post("/api/v1/agent/install", response_model=Result[str]) +@router.post("/v1/agent/install", response_model=Result[str]) async def agent_install(plugin_name:str, user: str = None): logger.info(f"agent_install:{plugin_name},{user}") try: agent_hub = AgentHub(PLUGINS_DIR) agent_hub.install_plugin(plugin_name, user) + + module_agent.refresh_plugins() + return Result.succ(None) except Exception as e: logger.error("Plugin Install Error!", e) @@ -76,19 +116,21 @@ async def agent_install(plugin_name:str, user: str = None): -@router.post("/api/v1/agent/uninstall", response_model=Result[str]) +@router.post("/v1/agent/uninstall", response_model=Result[str]) async def agent_uninstall(plugin_name:str, user: str = None): logger.info(f"agent_uninstall:{plugin_name},{user}") try: agent_hub = AgentHub(PLUGINS_DIR) agent_hub.uninstall_plugin(plugin_name, user) + + module_agent.refresh_plugins() return Result.succ(None) except Exception as e: logger.error("Plugin Uninstall Error!", e) return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}") -@router.post("/api/v1/personal/agent/upload", response_model=Result[str]) +@router.post("/v1/personal/agent/upload", response_model=Result[str]) async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None): logger.info(f"personal_agent_upload:{doc_file.filename},{user}") try: diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py index 553c9ce68..3958b757d 100644 --- a/pilot/base_modules/agent/db/my_plugin_db.py +++ b/pilot/base_modules/agent/db/my_plugin_db.py @@ -14,13 +14,13 @@ class MyPluginEntity(Base): __tablename__ = 'my_plugin' id = Column(Integer, primary_key=True, comment="autoincrement id") - tenant = Column(String, nullable=True, comment="user's tenant") - user_code = Column(String, nullable=True, comment="user code") - user_name = Column(String, nullable=True, comment="user name") - name = Column(String, unique=True, nullable=False, comment="plugin name") - file_name = Column(String, nullable=False, comment="plugin package file name") - type = Column(String, comment="plugin type") - version = Column(String, comment="plugin version") + tenant = Column(String(255), nullable=True, comment="user's tenant") + user_code = Column(String(255), nullable=True, comment="user code") + user_name = Column(String(255), nullable=True, comment="user name") + name = Column(String(255), unique=True, nullable=False, comment="plugin name") + file_name = Column(String(255), nullable=False, comment="plugin package file name") + type = Column(String(255), comment="plugin type") + version = Column(String(255), comment="plugin version") use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count") created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time") diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py index 3c0360342..f5620fc79 100644 --- a/pilot/base_modules/agent/db/plugin_hub_db.py +++ b/pilot/base_modules/agent/db/plugin_hub_db.py @@ -12,15 +12,15 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session class PluginHubEntity(Base): __tablename__ = 'plugin_hub' id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") - name = Column(String, unique=True, nullable=False, comment="plugin name") - description = Column(String, nullable=False, comment="plugin description") - author = Column(String, nullable=True, comment="plugin author") - email = Column(String, nullable=True, comment="plugin author email") - type = Column(String, comment="plugin type") - version = Column(String, comment="plugin version") - storage_channel = Column(String, comment="plugin storage channel") - storage_url = Column(String, comment="plugin download url") - download_param = Column(String, comment="plugin download param") + name = Column(String(255), unique=True, nullable=False, comment="plugin name") + description = Column(String(255), nullable=False, comment="plugin description") + author = Column(String(255), nullable=True, comment="plugin author") + email = Column(String(255), nullable=True, comment="plugin author email") + type = Column(String(255), comment="plugin type") + version = Column(String(255), comment="plugin version") + storage_channel = Column(String(255), comment="plugin storage channel") + storage_url = Column(String(255), comment="plugin download url") + download_param = Column(String(255), comment="plugin download param") created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time") installed = Column(Integer, default=False, comment="plugin already installed count") @@ -146,11 +146,10 @@ class PluginHubDao(BaseDao[PluginHubEntity]): session = self.get_session() if plugin_id is None: raise Exception("plugin_id is None") - query = PluginHubEntity(id=plugin_id) plugin_hubs = session.query(PluginHubEntity) - if query.id is not None: + if plugin_id is not None: plugin_hubs = plugin_hubs.filter( - PluginHubEntity.id == query.id + PluginHubEntity.id == plugin_id ) plugin_hubs.delete() session.commit() diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py index c5b3872ba..8c02e4af4 100644 --- a/pilot/base_modules/agent/model.py +++ b/pilot/base_modules/agent/model.py @@ -17,6 +17,18 @@ class PagenationResult(BaseModel, Generic[T]): total_row_count: int = 0 datas: List[T] = [] + def to_dic(self): + data_dicts =[] + for item in self.datas: + data_dicts.append(item.__dict__) + return { + 'page_index': self.page_index, + 'page_size': self.page_size, + 'total_page': self.total_page, + 'total_row_count': self.total_row_count, + 'datas': data_dicts + } + @dataclass class PluginHubFilter(BaseModel): name: str @@ -41,10 +53,9 @@ class MyPluginFilter(BaseModel): class PluginHubParam(BaseModel): - channel: str = Field(..., description="Plugin storage channel") - url: str = Field(..., description="Plugin storage url") - - branch: Optional[str] = Field(None, description="github download branch", nullable=True) + channel: Optional[str] = Field("git", description="Plugin storage channel") + url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url") + branch: Optional[str] = Field("main", description="github download branch", nullable=True) authorization: Optional[str] = Field(None, description="github download authorization", nullable=True) diff --git a/pilot/base_modules/meta_data/meta_data.py b/pilot/base_modules/meta_data/meta_data.py index 74edf13a5..5db0ac464 100644 --- a/pilot/base_modules/meta_data/meta_data.py +++ b/pilot/base_modules/meta_data/meta_data.py @@ -2,33 +2,72 @@ import uuid import os import duckdb import sqlite3 +import logging +import fnmatch from datetime import datetime from typing import Optional, Type, TypeVar -import sqlalchemy as sa - -from flask import Flask -from flask_sqlalchemy import SQLAlchemy -from flask_migrate import Migrate,upgrade -from flask.cli import with_appcontext -import subprocess - from sqlalchemy import create_engine,DateTime, String, func, MetaData +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from alembic import context, command -from alembic.config import Config +from alembic.config import Config as AlembicConfig +from urllib.parse import quote +from pilot.configs.config import Config + +logger = logging.getLogger("meta_data") + +CFG = Config() default_db_path = os.path.join(os.getcwd(), "meta_data") os.makedirs(default_db_path, exist_ok=True) -db_path = default_db_path + "/dbgpt.db" +# Meta Info +db_name = "dbgpt" +db_path = default_db_path + f"/{db_name}.db" connection = sqlite3.connect(db_path) -engine = create_engine(f'sqlite:///{db_path}') + +if CFG.LOCAL_DB_TYPE == 'mysql': + engine_temp = create_engine(f"mysql+pymysql://" + + quote(CFG.LOCAL_DB_USER) + + ":" + + quote(CFG.LOCAL_DB_PASSWORD) + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + ) + # check and auto create mysqldatabase + try: + # try to connect + with engine_temp.connect() as conn: + conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}") + print(f"Already connect '{db_name}'") + + except OperationalError as e: + # if connect failed, create dbgpt database + logger.error(f"{db_name} not connect success!") + + engine = create_engine(f"mysql+pymysql://" + + quote(CFG.LOCAL_DB_USER) + + ":" + + quote(CFG.LOCAL_DB_PASSWORD) + + "@" + + CFG.LOCAL_DB_HOST + + ":" + + str(CFG.LOCAL_DB_PORT) + + f"/{db_name}" + ) +else: + engine = create_engine(f'sqlite:///{db_path}') + + + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) session = Session() @@ -40,7 +79,7 @@ Base = declarative_base(bind=engine) # 创建Alembic配置对象 alembic_ini_path = default_db_path + "/alembic.ini" -alembic_cfg = Config(alembic_ini_path) +alembic_cfg = AlembicConfig(alembic_ini_path) alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url)) @@ -59,36 +98,6 @@ alembic_cfg.attributes['session'] = session # Base.metadata.drop_all(engine) -# app = Flask(__name__) -# default_db_path = os.path.join(os.getcwd(), "meta_data") -# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db") -# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}' -# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False -# db = SQLAlchemy(app) -# migrate = Migrate(app, db) -# -# # 设置FLASK_APP环境变量 -# import os -# os.environ['FLASK_APP'] = 'server.dbgpt_server.py' -# -# @app.cli.command("db_init") -# @with_appcontext -# def db_init(): -# subprocess.run(["flask", "db", "init"]) -# -# @app.cli.command("db_migrate") -# @with_appcontext -# def db_migrate(): -# subprocess.run(["flask", "db", "migrate"]) -# -# @app.cli.command("db_upgrade") -# @with_appcontext -# def db_upgrade(): -# subprocess.run(["flask", "db", "upgrade"]) -# - - - def ddl_init_and_upgrade(): # Base.metadata.create_all(bind=engine) # 生成并应用迁移脚本 @@ -96,14 +105,8 @@ def ddl_init_and_upgrade(): # subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"]) with engine.connect() as connection: alembic_cfg.attributes['connection'] = connection - command.revision(alembic_cfg, "test", True) + heads = command.heads(alembic_cfg) + print("heads:" + str(heads)) + + command.revision(alembic_cfg, "dbgpt ddl upate", True) command.upgrade(alembic_cfg, "head") - # alembic_cfg.attributes['connection'] = engine.connect() - # command.upgrade(alembic_cfg, 'head') - - # with app.app_context(): - # db_init() - # db_migrate() - # db_upgrade() - - diff --git a/pilot/base_modules/module_factory.py b/pilot/base_modules/module_factory.py index e69de29bb..139597f9c 100644 --- a/pilot/base_modules/module_factory.py +++ b/pilot/base_modules/module_factory.py @@ -0,0 +1,2 @@ + + diff --git a/pilot/componet.py b/pilot/componet.py index 0897b3365..ceab72f62 100644 --- a/pilot/componet.py +++ b/pilot/componet.py @@ -44,6 +44,7 @@ class LifeCycle: class ComponetType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" MODEL_CONTROLLER = "dbgpt_model_controller" + AGENT_HUB = "dbgpt_agent_hub" class BaseComponet(LifeCycle, ABC): diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 0276c2a17..ce3add851 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -98,26 +98,6 @@ class Config(metaclass=Singleton): ### message stor file self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") - ### The associated configuration parameters of the plug-in control the loading and use of the plug-in - from auto_gpt_plugin_template import AutoGPTPluginTemplate - - self.plugins: List[AutoGPTPluginTemplate] = [] - self.plugins_openai = [] - self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" - - self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") - - plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") - if plugins_allowlist: - self.plugins_allowlist = plugins_allowlist.split(",") - else: - self.plugins_allowlist = [] - - plugins_denylist = os.getenv("DENYLISTED_PLUGINS") - if plugins_denylist: - self.plugins_denylist = plugins_denylist.split(",") - else: - self.plugins_denylist = [] ### Native SQL Execution Capability Control Configuration self.NATIVE_SQL_CAN_RUN_DDL = ( os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True" @@ -126,7 +106,10 @@ class Config(metaclass=Singleton): os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True" ) - ### default Local database connection configuration + + self.LOCAL_DB_MANAGE = None + + ###dbgpt meta info database connection configuration self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql") @@ -138,7 +121,8 @@ class Config(metaclass=Singleton): self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") - self.LOCAL_DB_MANAGE = None + self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb") + ### LLM Model Service Configuration self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") @@ -197,10 +181,6 @@ class Config(metaclass=Singleton): """Set the debug mode value""" self.debug_mode = value - def set_plugins(self, value: list) -> None: - """Set the plugins value.""" - self.plugins = value - def set_templature(self, value: int) -> None: """Set the temperature value.""" self.temperature = value diff --git a/pilot/connections/__init__.py b/pilot/connections/__init__.py index e69de29bb..ce13a69f3 100644 --- a/pilot/connections/__init__.py +++ b/pilot/connections/__init__.py @@ -0,0 +1 @@ +from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao \ No newline at end of file diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py new file mode 100644 index 000000000..42307b243 --- /dev/null +++ b/pilot/connections/manages/connect_config_db.py @@ -0,0 +1,62 @@ +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session +from typing import List +from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text +from sqlalchemy import UniqueConstraint + +class ConnectConfigEntity(Base): + __tablename__ = 'connect_config' + id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + db_type = Column(String(255), nullable=False, comment="db type") + db_name = Column(String(255), nullable=False, comment="db name") + db_path = Column(String(255), nullable=True, comment="file db path") + db_host = Column(String(255), nullable=True, comment="db connect host(not file db)") + db_port = Column(String(255), nullable=True, comment="db cnnect port(not file db)") + db_user = Column(String(255), nullable=True, comment="db user") + db_pwd = Column(String(255), nullable=True, comment="db password") + comment = Column(Text, nullable=True, comment="db comment") + + __table_args__ = ( + UniqueConstraint('db_name', name="uk_db"), + Index('idx_q_db_type', 'db_type'), + ) + + +class ConnectConfigDao(BaseDao[ConnectConfigEntity]): + def __init__(self): + super().__init__( + database="dbgpt", orm_base=Base, db_engine=engine, session=session + ) + + def update(self, entity: ConnectConfigEntity): + session = self.get_session() + try: + updated = session.merge(entity) + session.commit() + return updated.id + finally: + session.close() + + def delete(self, db_name: int): + session = self.get_session() + if db_name is None: + raise Exception("db_name is None") + + db_connect = session.query(ConnectConfigEntity) + db_connect = db_connect.filter( + ConnectConfigEntity.db_name == db_name + ) + db_connect.delete() + session.commit() + session.close() + + def get_by_name(self, db_name: str) -> ConnectConfigEntity: + session = self.get_session() + db_connect = session.query(ConnectConfigEntity) + db_connect = db_connect.filter( + ConnectConfigEntity.db_name == db_name + ) + result = db_connect.first() + session.close() + return result + diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index b5cfbbfdd..3d3a69114 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -50,7 +50,7 @@ class ConnectManager: def __init__(self, system_app: SystemApp): self.storage = DuckdbConnectConfig() self.db_summary_client = DBSummaryClient(system_app) - self.__load_config_db() + # self.__load_config_db() def __load_config_db(self): if CFG.LOCAL_DB_HOST: diff --git a/pilot/memory/__init__.py b/pilot/memory/__init__.py index e69de29bb..2e8c7af1b 100644 --- a/pilot/memory/__init__.py +++ b/pilot/memory/__init__.py @@ -0,0 +1 @@ +from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao \ No newline at end of file diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index d9d1cba29..4d9291e0b 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -2,11 +2,18 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import List - +from enum import Enum from pilot.scene.message import OnceConversation +class MemoryStoreType(Enum): + File= 'file' + Memory = 'memory' + DB = 'db' + DuckDb = 'duckdb' + class BaseChatHistoryMemory(ABC): + store_type: MemoryStoreType def __init__(self): self.conversations: List[OnceConversation] = [] @@ -22,10 +29,31 @@ class BaseChatHistoryMemory(ABC): def append(self, message: OnceConversation) -> None: """Append the message to the record in the local file""" - @abstractmethod - def clear(self) -> None: - """Clear session memory from the local file""" + # @abstractmethod + # def clear(self) -> None: + # """Clear session memory from the local file""" + @abstractmethod def conv_list(self, user_name: str = None) -> None: """get user's conversation list""" pass + + @abstractmethod + def update(self, messages: List[OnceConversation]) -> None: + pass + + @abstractmethod + def delete(self) -> bool: + pass + + @abstractmethod + def conv_info(self, conv_uid: str = None) -> None: + pass + + @abstractmethod + def get_messages(self) -> List[OnceConversation]: + pass + + @staticmethod + def conv_list(cls, user_name: str = None) -> None: + pass \ No newline at end of file diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py new file mode 100644 index 000000000..6c36053dd --- /dev/null +++ b/pilot/memory/chat_history/chat_hisotry_factory.py @@ -0,0 +1,28 @@ +from .base import MemoryStoreType +from pilot.configs.config import Config + +CFG = Config() + + + +class ChatHistory: + + def __init__(self): + self.memory_type = MemoryStoreType.DB.value + self.mem_store_class_map = {} + from .store_type.duckdb_history import DuckdbHistoryMemory + from .store_type.file_history import FileHistoryMemory + from .store_type.meta_db_history import DbHistoryMemory + from .store_type.mem_history import MemHistoryMemory + self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory + self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory + self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory + self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory + + + def get_store_instance(self, chat_session_id): + return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id) + + + def get_store_cls(self): + return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE) diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py new file mode 100644 index 000000000..f49e1983c --- /dev/null +++ b/pilot/memory/chat_history/chat_history_db.py @@ -0,0 +1,88 @@ +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session +from typing import List +from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text +from sqlalchemy import UniqueConstraint + +class ChatHistoryEntity(Base): + __tablename__ = 'chat_history' + id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id") + chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode") + summary = Column(String(255), nullable=False, comment="Conversation record summary") + user_name = Column(String(255), nullable=True, comment="interlocutor") + messages = Column(Text, nullable=True, comment="Conversation details") + + __table_args__ = ( + UniqueConstraint('conv_uid', name="uk_conversation"), + Index('idx_q_user', 'user_name'), + Index('idx_q_mode', 'chat_mode'), + Index('idx_q_conv', 'summary'), + ) + + +class ChatHistoryDao(BaseDao[ChatHistoryEntity]): + def __init__(self): + super().__init__( + database="dbgpt", orm_base=Base, db_engine=engine, session=session + ) + + def list_last_20(self, user_name: str = None): + session = self.get_session() + chat_history = session.query(ChatHistoryEntity) + if user_name: + chat_history = chat_history.filter( + ChatHistoryEntity.user_name == user_name + ) + + chat_history = chat_history.order_by(ChatHistoryEntity.id.desc()) + + result = chat_history.limit(20).all() + session.close() + return result + + def update(self, entity: ChatHistoryEntity): + session = self.get_session() + try: + updated = session.merge(entity) + session.commit() + return updated.id + finally: + session.close() + + def update_message_by_uid(self, message: str, conv_uid:str): + session = self.get_session() + try: + chat_history = session.query(ChatHistoryEntity) + chat_history = chat_history.filter( + ChatHistoryEntity.conv_uid == conv_uid + ) + updated = chat_history.update({ChatHistoryEntity.messages: message}) + session.commit() + return updated.id + finally: + session.close() + + def delete(self, conv_uid: int): + session = self.get_session() + if conv_uid is None: + raise Exception("conv_uid is None") + + chat_history = session.query(ChatHistoryEntity) + chat_history = chat_history.filter( + ChatHistoryEntity.conv_uid == conv_uid + ) + chat_history.delete() + session.commit() + session.close() + + def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: + session = self.get_session() + chat_history = session.query(ChatHistoryEntity) + chat_history = chat_history.filter( + ChatHistoryEntity.conv_uid == conv_uid + ) + result = chat_history.first() + session.close() + return result + diff --git a/pilot/base_modules/agent/plugins_loader.py b/pilot/memory/chat_history/store_type/__init__.py similarity index 100% rename from pilot/base_modules/agent/plugins_loader.py rename to pilot/memory/chat_history/store_type/__init__.py diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py similarity index 98% rename from pilot/memory/chat_history/duckdb_history.py rename to pilot/memory/chat_history/store_type/duckdb_history.py index cbc21f03a..97aae159b 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/store_type/duckdb_history.py @@ -10,6 +10,7 @@ from pilot.scene.message import ( _conversation_to_dic, ) from pilot.common.formatting import MyEncoder +from ..base import MemoryStoreType default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") @@ -19,6 +20,8 @@ CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): + store_type: str = MemoryStoreType.DuckDb.value + def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -117,30 +120,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor.commit() return True - @staticmethod - def conv_list(cls, user_name: str = None) -> None: - if os.path.isfile(duckdb_path): - cursor = duckdb.connect(duckdb_path).cursor() - if user_name: - cursor.execute( - "SELECT * FROM chat_history where user_name=? order by id desc limit 20", - [user_name], - ) - else: - cursor.execute("SELECT * FROM chat_history order by id desc limit 20") - # 获取查询结果字段名 - fields = [field[0] for field in cursor.description] - data = [] - for row in cursor.fetchall(): - row_dict = {} - for i, field in enumerate(fields): - row_dict[field] = row[i] - data.append(row_dict) - - return data - - return [] - def conv_info(self, conv_uid: str = None) -> None: cursor = self.connect.cursor() cursor.execute( @@ -168,3 +147,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if context[0]: return json.loads(context[0]) return None + + + @staticmethod + def conv_list(cls, user_name: str = None) -> None: + if os.path.isfile(duckdb_path): + cursor = duckdb.connect(duckdb_path).cursor() + if user_name: + cursor.execute( + "SELECT * FROM chat_history where user_name=? order by id desc limit 20", + [user_name], + ) + else: + cursor.execute("SELECT * FROM chat_history order by id desc limit 20") + # 获取查询结果字段名 + fields = [field[0] for field in cursor.description] + data = [] + for row in cursor.fetchall(): + row_dict = {} + for i, field in enumerate(fields): + row_dict[field] = row[i] + data.append(row_dict) + + return data + + return [] diff --git a/pilot/memory/chat_history/file_history.py b/pilot/memory/chat_history/store_type/file_history.py similarity index 93% rename from pilot/memory/chat_history/file_history.py rename to pilot/memory/chat_history/store_type/file_history.py index ffdd4169b..fa1143309 100644 --- a/pilot/memory/chat_history/file_history.py +++ b/pilot/memory/chat_history/store_type/file_history.py @@ -11,12 +11,14 @@ from pilot.scene.message import ( conversation_from_dict, conversations_to_dict, ) - +from pilot.memory.chat_history.base import MemoryStoreType CFG = Config() class FileHistoryMemory(BaseChatHistoryMemory): + store_type: str = MemoryStoreType.File.value + def __init__(self, chat_session_id: str): now = datetime.datetime.now() date_string = now.strftime("%Y%m%d") @@ -47,3 +49,5 @@ class FileHistoryMemory(BaseChatHistoryMemory): def clear(self) -> None: self.file_path.write_text(json.dumps([])) + + diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/store_type/mem_history.py similarity index 88% rename from pilot/memory/chat_history/mem_history.py rename to pilot/memory/chat_history/store_type/mem_history.py index 2e832041f..5c3ddc217 100644 --- a/pilot/memory/chat_history/mem_history.py +++ b/pilot/memory/chat_history/store_type/mem_history.py @@ -4,11 +4,14 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.configs.config import Config from pilot.scene.message import OnceConversation from pilot.common.custom_data_structure import FixedSizeDict +from pilot.memory.chat_history.base import MemoryStoreType CFG = Config() class MemHistoryMemory(BaseChatHistoryMemory): + store_type: str = MemoryStoreType.Memory.value + histroies_map = FixedSizeDict(100) def __init__(self, chat_session_id: str): diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py new file mode 100644 index 000000000..c1fc0ec5d --- /dev/null +++ b/pilot/memory/chat_history/store_type/meta_db_history.py @@ -0,0 +1,100 @@ +import json +import logging +from typing import List +from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text +from sqlalchemy import UniqueConstraint +from pilot.configs.config import Config +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pilot.scene.message import ( + OnceConversation, + _conversation_to_dic, +) +from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao + +from pilot.memory.chat_history.base import MemoryStoreType +CFG = Config() +logger = logging.getLogger("db_chat_history") + +class DbHistoryMemory(BaseChatHistoryMemory): + store_type: str = MemoryStoreType.DB.value + def __init__(self, chat_session_id: str): + self.chat_seesion_id = chat_session_id + self.chat_history_dao = ChatHistoryDao() + + def messages(self) -> List[OnceConversation]: + + chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + if chat_history: + context = chat_history.messages + if context: + conversations: List[OnceConversation] = json.loads(context) + return conversations + return [] + + + def create(self, chat_mode, summary: str, user_name: str) -> None: + try: + chat_history: ChatHistoryEntity = ChatHistoryEntity() + chat_history.chat_mode = chat_mode + chat_history.summary = summary + chat_history.user_name = user_name + + self.chat_history_dao.update(chat_history) + except Exception as e: + logger.error("init create conversation log error!" + str(e)) + + + def append(self, once_message: OnceConversation) -> None: + logger.info("db history append:{}", once_message) + chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + conversations: List[OnceConversation] = [] + if chat_history: + context = chat_history.messages + if context: + conversations = json.loads(context) + else: + chat_history.summary = once_message.get_user_conv().content + else: + chat_history: ChatHistoryEntity = ChatHistoryEntity() + chat_history.conv_uid = self.chat_seesion_id + chat_history.chat_mode = once_message.chat_mode + chat_history.user_name = "default" + chat_history.summary = once_message.get_user_conv().content + + conversations.append(_conversation_to_dic(once_message)) + chat_history.messages = json.dumps(conversations, ensure_ascii=False) + + self.chat_history_dao.update(chat_history) + + def update(self, messages: List[OnceConversation]) -> None: + self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id) + + + def delete(self) -> bool: + self.chat_history_dao.delete(self.chat_seesion_id) + + + def conv_info(self, conv_uid: str = None) -> None: + logger.info("conv_info:{}", conv_uid) + chat_history = self.chat_history_dao.get_by_uid(conv_uid) + return chat_history.__dict__ + + + def get_messages(self) -> List[OnceConversation]: + logger.info("get_messages:{}", self.chat_seesion_id) + chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id) + if chat_history: + context = chat_history.messages + return json.loads(context) + return [] + + + @staticmethod + def conv_list(cls, user_name: str = None) -> None: + + chat_history_dao = ChatHistoryDao() + history_list = chat_history_dao.list_last_20() + result = [] + for history in history_list: + result.append(history.__dict__) + return result \ No newline at end of file diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 54360e477..d9aeade44 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -142,12 +142,12 @@ def initialize_controller( controller.backend = LocalModelController() if app: - app.include_router(router, prefix="/api") + app.include_router(router, prefix="/api", tags=['Model']) else: import uvicorn app = FastAPI() - app.include_router(router, prefix="/api") + app.include_router(router, prefix="/api", tags=['Model']) uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index c51b9d254..dd4091ac0 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -8,13 +8,10 @@ from fastapi import ( Request, File, UploadFile, - Form, Body, - BackgroundTasks, ) from fastapi.responses import StreamingResponse -from fastapi.exceptions import RequestValidationError from typing import List import tempfile @@ -35,11 +32,10 @@ from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.configs.model_config import LOGDIR from pilot.utils import build_logger -from pilot.common.schema import DBType -from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation -from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH +from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.summary.db_summary_client import DBSummaryClient +from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory router = APIRouter() CFG = Config() @@ -61,7 +57,6 @@ def __get_conv_user_message(conversations: dict): def __new_conversation(chat_mode, user_id) -> ConversationVo: unique_id = uuid.uuid1() - # history_mem = DuckdbHistoryMemory(str(unique_id)) return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode) @@ -145,7 +140,8 @@ async def db_support_types(): @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) async def dialogue_list(user_id: str = None): dialogues: List = [] - datas = DuckdbHistoryMemory.conv_list(user_id) + chat_history_service = ChatHistory() + datas = chat_history_service.get_store_cls().conv_list(user_id) for item in datas: conv_uid = item.get("conv_uid") summary = item.get("summary") @@ -257,14 +253,18 @@ async def params_load( @router.post("/v1/chat/dialogue/delete") async def dialogue_delete(con_uid: str): - history_mem = DuckdbHistoryMemory(con_uid) + + history_fac = ChatHistory() + history_mem = history_fac.get_store_instance(con_uid) history_mem.delete() return Result.succ(None) def get_hist_messages(conv_uid: str): message_vos: List[MessageVo] = [] - history_mem = DuckdbHistoryMemory(conv_uid) + history_fac = ChatHistory() + history_mem = history_fac.get_store_instance(conv_uid) + history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py index e1b313664..0e43393be 100644 --- a/pilot/openapi/api_v1/editor/api_editor_v1.py +++ b/pilot/openapi/api_v1/editor/api_editor_v1.py @@ -26,10 +26,10 @@ from pilot.openapi.editor_view_model import ( ) from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData -from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_db.data_loader import DbDataLoader +from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory router = APIRouter() CFG = Config() @@ -67,7 +67,8 @@ async def get_editor_tables( @router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds]) async def get_editor_sql_rounds(con_uid: str): logger.info("get_editor_sql_rounds:{con_uid}") - history_mem = DuckdbHistoryMemory(con_uid) + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: result: List = [] @@ -89,7 +90,8 @@ async def get_editor_sql_rounds(con_uid: str): @router.get("/v1/editor/sql", response_model=Result[dict]) async def get_editor_sql(con_uid: str, round: int): logger.info(f"get_editor_sql:{con_uid},{round}") - history_mem = DuckdbHistoryMemory(con_uid) + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: @@ -138,7 +140,9 @@ async def editor_sql_run(run_param: dict = Body()): @router.post("/v1/sql/editor/submit") async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") - history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid) + + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(sql_edit_context.con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) @@ -171,7 +175,8 @@ async def get_editor_chart_list(con_uid: str): logger.info( f"get_editor_sql_rounds:{con_uid}", ) - history_mem = DuckdbHistoryMemory(con_uid) + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: last_round = max(history_messages, key=lambda x: x["chat_order"]) @@ -193,7 +198,8 @@ async def get_editor_chart_info(param: dict = Body()): conv_uid = param["con_uid"] chart_title = param["chart_title"] - history_mem = DuckdbHistoryMemory(conv_uid) + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(conv_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: last_round = max(history_messages, key=lambda x: x["chat_order"]) @@ -269,7 +275,9 @@ async def editor_chart_run(run_param: dict = Body()): @router.post("/v1/chart/editor/submit", response_model=Result[bool]) async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}") - history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid) + + chat_history_fac = ChatHistory() + history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 57033cede..8594b7252 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -7,15 +7,12 @@ from typing import Any, List, Dict from pilot.configs.config import Config from pilot.configs.model_config import LOGDIR -from pilot.memory.chat_history.base import BaseChatHistoryMemory -from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory -from pilot.memory.chat_history.file_history import FileHistoryMemory -from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.prompts.prompt_new import PromptTemplate from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.message import OnceConversation from pilot.utils import build_logger, get_or_create_event_loop from pydantic import Extra +from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} @@ -54,9 +51,9 @@ class BaseChat(ABC): proxyllm_backend=CFG.PROXYLLM_BACKEND, ) ) - + chat_history_fac = ChatHistory() ### can configurable storage methods - self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"]) + self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"]) self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation( @@ -162,6 +159,7 @@ class BaseChat(ABC): output, self.skip_echo_len ) view_msg = self.stream_plugin_call(msg) + view_msg = view_msg.replace("\n", "\\n") yield view_msg self.current_message.add_ai_message(msg) self.current_message.add_view_message(view_msg) diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index 26a58f73f..8175d3d7c 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -9,6 +9,8 @@ from pilot.base_modules.agent.commands.command_mange import ApiCall from pilot.base_modules.agent import PluginPromptGenerator from pilot.common.string_utils import extract_content from .prompt import prompt +from pilot.componet import ComponetType +from pilot.base_modules.agent.controller import ModuleAgent CFG = Config() @@ -27,14 +29,10 @@ class ChatAgent(BaseChat): super().__init__(chat_param=chat_param) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry - # load select plugin - for plugin in CFG.plugins: - if plugin._name in self.select_plugins: - if not plugin.can_handle_post_prompt(): - continue - self.plugins_prompt_generator = plugin.post_prompt( - self.plugins_prompt_generator - ) + + # load select plugin + agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent) + self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins) self.api_call = ApiCall(self.plugins_prompt_generator) diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py index 493fef4cd..1c5229f89 100644 --- a/pilot/scene/chat_agent/prompt.py +++ b/pilot/scene/chat_agent/prompt.py @@ -32,7 +32,7 @@ _DEFAULT_TEMPLATE_ZH = """ 根据用户目标,请一步步思考,如何在满足下面约束条件的前提下,优先使用给出工具回答或者完成用户目标。 约束条件: - 1.从下面给定工具列表找到可用的工具后,请确保输出结果包含以下内容用来使用工具: + 1.从下面给定工具列表找到可用的工具后,请输出以下内容用来使用工具, 注意要确保下面内容在输出结果中只出现一次: Selected Tool namevaluevalue 2.请根据工具列表对应工具的定义来生成上述调用文本, 参考案例如下: 工具作用介绍: "工具名称", args: "参数1": "<参数1取值描述>","参数2": "<参数2取值描述>" 对应调用文本:工具名称<参数1>value<参数2>value diff --git a/pilot/server/base.py b/pilot/server/base.py index 064f2d247..fb1d518c7 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -32,7 +32,6 @@ def async_db_summery(system_app: SystemApp): def server_init(args, system_app: SystemApp): from pilot.base_modules.agent.commands.command_mange import CommandRegistry - from pilot.base_modules.agent.plugins_util import scan_plugins # logger.info(f"args: {args}") @@ -45,7 +44,7 @@ def server_init(args, system_app: SystemApp): # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) - cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode)) + # Loader plugins and commands command_categories = [ diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py index 755f13b21..c398181fa 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/componet_configs.py @@ -24,9 +24,11 @@ def initialize_componets( embedding_model_path: str, ): from pilot.model.cluster.controller.controller import controller - system_app.register_instance(controller) + from pilot.base_modules.agent.controller import module_agent + system_app.register_instance(module_agent) + _initialize_embedding_model( param, system_app, embedding_model_name, embedding_model_path ) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 078a2daa3..b27cba92e 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -71,11 +71,11 @@ app.add_middleware( ) -app.include_router(api_v1, prefix="/api") -app.include_router(api_editor_route_v1, prefix="/api") -app.include_router(agent_route, prefix="/api") +app.include_router(api_v1, prefix="/api", tags=["Chat"]) +app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) -app.include_router(knowledge_router) + +app.include_router(knowledge_router, tags=["Knowledge"]) def mount_static_files(app):