From ed702db9bc3d191c0d0cbd18e1630b04d9dfe5da Mon Sep 17 00:00:00 2001
From: yhjun1026 <460342015@qq.com>
Date: Tue, 26 Sep 2023 17:07:42 +0800
Subject: [PATCH] feat(Agent): ChatAgent And AgentHub
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1.AgentHub Module complete
2.New ChatAgent Mode complete
3.MetaData develop,auto update ddl
---
pilot/base_modules/agent/controller.py | 72 +++++++++---
pilot/base_modules/agent/db/my_plugin_db.py | 14 +--
pilot/base_modules/agent/db/plugin_hub_db.py | 23 ++--
pilot/base_modules/agent/model.py | 19 +++-
pilot/base_modules/meta_data/meta_data.py | 107 +++++++++---------
pilot/base_modules/module_factory.py | 2 +
pilot/componet.py | 1 +
pilot/configs/config.py | 32 +-----
pilot/connections/__init__.py | 1 +
.../connections/manages/connect_config_db.py | 62 ++++++++++
.../connections/manages/connection_manager.py | 2 +-
pilot/memory/__init__.py | 1 +
pilot/memory/chat_history/base.py | 36 +++++-
.../chat_history/chat_hisotry_factory.py | 28 +++++
pilot/memory/chat_history/chat_history_db.py | 88 ++++++++++++++
.../chat_history/store_type/__init__.py} | 0
.../{ => store_type}/duckdb_history.py | 52 +++++----
.../{ => store_type}/file_history.py | 6 +-
.../{ => store_type}/mem_history.py | 3 +
.../store_type/meta_db_history.py | 100 ++++++++++++++++
pilot/model/cluster/controller/controller.py | 4 +-
pilot/openapi/api_v1/api_v1.py | 20 ++--
pilot/openapi/api_v1/editor/api_editor_v1.py | 22 ++--
pilot/scene/base_chat.py | 10 +-
pilot/scene/chat_agent/chat.py | 14 +--
pilot/scene/chat_agent/prompt.py | 2 +-
pilot/server/base.py | 3 +-
pilot/server/componet_configs.py | 4 +-
pilot/server/dbgpt_server.py | 8 +-
29 files changed, 549 insertions(+), 187 deletions(-)
create mode 100644 pilot/connections/manages/connect_config_db.py
create mode 100644 pilot/memory/chat_history/chat_hisotry_factory.py
create mode 100644 pilot/memory/chat_history/chat_history_db.py
rename pilot/{base_modules/agent/plugins_loader.py => memory/chat_history/store_type/__init__.py} (100%)
rename pilot/memory/chat_history/{ => store_type}/duckdb_history.py (98%)
rename pilot/memory/chat_history/{ => store_type}/file_history.py (93%)
rename pilot/memory/chat_history/{ => store_type}/mem_history.py (88%)
create mode 100644 pilot/memory/chat_history/store_type/meta_db_history.py
diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py
index 151f6e40b..86bef3569 100644
--- a/pilot/base_modules/agent/controller.py
+++ b/pilot/base_modules/agent/controller.py
@@ -6,7 +6,7 @@ from fastapi import (
UploadFile,
File,
)
-
+from abc import ABC, abstractmethod
from typing import List
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
@@ -18,14 +18,44 @@ from pilot.openapi.api_view_model import (
from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
-from .db.my_plugin_db import MyPluginEntity
+from .plugins_util import scan_plugins
+from .commands.generator import PluginPromptGenerator
+
from pilot.configs.model_config import PLUGINS_DIR
+from pilot.componet import BaseComponet, ComponetType, SystemApp
router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
-@router.post("/api/v1/agent/hub/update", response_model=Result[str])
+class ModuleAgent(BaseComponet, ABC):
+ name = ComponetType.AGENT_HUB
+
+ def __init__(self):
+ #load plugins
+ self.plugins = scan_plugins(PLUGINS_DIR)
+
+ def init_app(self, system_app: SystemApp):
+ system_app.app.include_router(router, prefix="/api", tags=["Agent"])
+
+
+ def refresh_plugins(self):
+ self.plugins = scan_plugins(PLUGINS_DIR)
+
+ def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
+ logger.info(f"load_select_plugin:{select_plugins}")
+ # load select plugin
+ for plugin in self.plugins:
+ if plugin._name in select_plugins:
+ if not plugin.can_handle_post_prompt():
+ continue
+ generator = plugin.post_prompt(generator)
+ return generator
+
+module_agent = ModuleAgent()
+
+
+@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
@@ -38,14 +68,15 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
-@router.post("/api/v1/agent/query", response_model=Result[str])
+@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
- logger.info(f"get_agent_list:{json.dumps(filter)}")
+ logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity:PluginHubEntity = PluginHubEntity()
- attrs = vars(filter.filter) # 获取原始对象的属性字典
- for attr, value in attrs.items():
- setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
+ if filter.filter:
+ attrs = vars(filter.filter) # 获取原始对象的属性字典
+ for attr, value in attrs.items():
+ setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
@@ -54,21 +85,30 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
- return Result.succ(result)
+ # print(json.dumps(result.to_dic()))
+ return Result.succ(result.to_dic())
-@router.post("/api/v1/agent/my", response_model=Result[str])
+@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user:str= None):
- logger.info(f"my_agents:{json.dumps(my_agents)}")
+ logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
- return Result.succ(agent_hub.get_my_plugin(user))
+ agents = agent_hub.get_my_plugin(user)
+ agent_dicts = []
+ for agent in agents:
+ agent_dicts.append(agent.__dict__)
+
+ return Result.succ(agent_dicts)
-@router.post("/api/v1/agent/install", response_model=Result[str])
+@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name:str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
+
+ module_agent.refresh_plugins()
+
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
@@ -76,19 +116,21 @@ async def agent_install(plugin_name:str, user: str = None):
-@router.post("/api/v1/agent/uninstall", response_model=Result[str])
+@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name:str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
+
+ module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
-@router.post("/api/v1/personal/agent/upload", response_model=Result[str])
+@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py
index 553c9ce68..3958b757d 100644
--- a/pilot/base_modules/agent/db/my_plugin_db.py
+++ b/pilot/base_modules/agent/db/my_plugin_db.py
@@ -14,13 +14,13 @@ class MyPluginEntity(Base):
__tablename__ = 'my_plugin'
id = Column(Integer, primary_key=True, comment="autoincrement id")
- tenant = Column(String, nullable=True, comment="user's tenant")
- user_code = Column(String, nullable=True, comment="user code")
- user_name = Column(String, nullable=True, comment="user name")
- name = Column(String, unique=True, nullable=False, comment="plugin name")
- file_name = Column(String, nullable=False, comment="plugin package file name")
- type = Column(String, comment="plugin type")
- version = Column(String, comment="plugin version")
+ tenant = Column(String(255), nullable=True, comment="user's tenant")
+ user_code = Column(String(255), nullable=True, comment="user code")
+ user_name = Column(String(255), nullable=True, comment="user name")
+ name = Column(String(255), unique=True, nullable=False, comment="plugin name")
+ file_name = Column(String(255), nullable=False, comment="plugin package file name")
+ type = Column(String(255), comment="plugin type")
+ version = Column(String(255), comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py
index 3c0360342..f5620fc79 100644
--- a/pilot/base_modules/agent/db/plugin_hub_db.py
+++ b/pilot/base_modules/agent/db/plugin_hub_db.py
@@ -12,15 +12,15 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
- name = Column(String, unique=True, nullable=False, comment="plugin name")
- description = Column(String, nullable=False, comment="plugin description")
- author = Column(String, nullable=True, comment="plugin author")
- email = Column(String, nullable=True, comment="plugin author email")
- type = Column(String, comment="plugin type")
- version = Column(String, comment="plugin version")
- storage_channel = Column(String, comment="plugin storage channel")
- storage_url = Column(String, comment="plugin download url")
- download_param = Column(String, comment="plugin download param")
+ name = Column(String(255), unique=True, nullable=False, comment="plugin name")
+ description = Column(String(255), nullable=False, comment="plugin description")
+ author = Column(String(255), nullable=True, comment="plugin author")
+ email = Column(String(255), nullable=True, comment="plugin author email")
+ type = Column(String(255), comment="plugin type")
+ version = Column(String(255), comment="plugin version")
+ storage_channel = Column(String(255), comment="plugin storage channel")
+ storage_url = Column(String(255), comment="plugin download url")
+ download_param = Column(String(255), comment="plugin download param")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Integer, default=False, comment="plugin already installed count")
@@ -146,11 +146,10 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
session = self.get_session()
if plugin_id is None:
raise Exception("plugin_id is None")
- query = PluginHubEntity(id=plugin_id)
plugin_hubs = session.query(PluginHubEntity)
- if query.id is not None:
+ if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(
- PluginHubEntity.id == query.id
+ PluginHubEntity.id == plugin_id
)
plugin_hubs.delete()
session.commit()
diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py
index c5b3872ba..8c02e4af4 100644
--- a/pilot/base_modules/agent/model.py
+++ b/pilot/base_modules/agent/model.py
@@ -17,6 +17,18 @@ class PagenationResult(BaseModel, Generic[T]):
total_row_count: int = 0
datas: List[T] = []
+ def to_dic(self):
+ data_dicts =[]
+ for item in self.datas:
+ data_dicts.append(item.__dict__)
+ return {
+ 'page_index': self.page_index,
+ 'page_size': self.page_size,
+ 'total_page': self.total_page,
+ 'total_row_count': self.total_row_count,
+ 'datas': data_dicts
+ }
+
@dataclass
class PluginHubFilter(BaseModel):
name: str
@@ -41,10 +53,9 @@ class MyPluginFilter(BaseModel):
class PluginHubParam(BaseModel):
- channel: str = Field(..., description="Plugin storage channel")
- url: str = Field(..., description="Plugin storage url")
-
- branch: Optional[str] = Field(None, description="github download branch", nullable=True)
+ channel: Optional[str] = Field("git", description="Plugin storage channel")
+ url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
+ branch: Optional[str] = Field("main", description="github download branch", nullable=True)
authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)
diff --git a/pilot/base_modules/meta_data/meta_data.py b/pilot/base_modules/meta_data/meta_data.py
index 74edf13a5..5db0ac464 100644
--- a/pilot/base_modules/meta_data/meta_data.py
+++ b/pilot/base_modules/meta_data/meta_data.py
@@ -2,33 +2,72 @@ import uuid
import os
import duckdb
import sqlite3
+import logging
+import fnmatch
from datetime import datetime
from typing import Optional, Type, TypeVar
-import sqlalchemy as sa
-
-from flask import Flask
-from flask_sqlalchemy import SQLAlchemy
-from flask_migrate import Migrate,upgrade
-from flask.cli import with_appcontext
-import subprocess
-
from sqlalchemy import create_engine,DateTime, String, func, MetaData
+from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import context, command
-from alembic.config import Config
+from alembic.config import Config as AlembicConfig
+from urllib.parse import quote
+from pilot.configs.config import Config
+
+logger = logging.getLogger("meta_data")
+
+CFG = Config()
default_db_path = os.path.join(os.getcwd(), "meta_data")
os.makedirs(default_db_path, exist_ok=True)
-db_path = default_db_path + "/dbgpt.db"
+# Meta Info
+db_name = "dbgpt"
+db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
-engine = create_engine(f'sqlite:///{db_path}')
+
+if CFG.LOCAL_DB_TYPE == 'mysql':
+ engine_temp = create_engine(f"mysql+pymysql://"
+ + quote(CFG.LOCAL_DB_USER)
+ + ":"
+ + quote(CFG.LOCAL_DB_PASSWORD)
+ + "@"
+ + CFG.LOCAL_DB_HOST
+ + ":"
+ + str(CFG.LOCAL_DB_PORT)
+ )
+ # check and auto create mysqldatabase
+ try:
+ # try to connect
+ with engine_temp.connect() as conn:
+ conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
+ print(f"Already connect '{db_name}'")
+
+ except OperationalError as e:
+ # if connect failed, create dbgpt database
+ logger.error(f"{db_name} not connect success!")
+
+ engine = create_engine(f"mysql+pymysql://"
+ + quote(CFG.LOCAL_DB_USER)
+ + ":"
+ + quote(CFG.LOCAL_DB_PASSWORD)
+ + "@"
+ + CFG.LOCAL_DB_HOST
+ + ":"
+ + str(CFG.LOCAL_DB_PORT)
+ + f"/{db_name}"
+ )
+else:
+ engine = create_engine(f'sqlite:///{db_path}')
+
+
+
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
@@ -40,7 +79,7 @@ Base = declarative_base(bind=engine)
# 创建Alembic配置对象
alembic_ini_path = default_db_path + "/alembic.ini"
-alembic_cfg = Config(alembic_ini_path)
+alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
@@ -59,36 +98,6 @@ alembic_cfg.attributes['session'] = session
# Base.metadata.drop_all(engine)
-# app = Flask(__name__)
-# default_db_path = os.path.join(os.getcwd(), "meta_data")
-# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
-# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
-# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
-# db = SQLAlchemy(app)
-# migrate = Migrate(app, db)
-#
-# # 设置FLASK_APP环境变量
-# import os
-# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
-#
-# @app.cli.command("db_init")
-# @with_appcontext
-# def db_init():
-# subprocess.run(["flask", "db", "init"])
-#
-# @app.cli.command("db_migrate")
-# @with_appcontext
-# def db_migrate():
-# subprocess.run(["flask", "db", "migrate"])
-#
-# @app.cli.command("db_upgrade")
-# @with_appcontext
-# def db_upgrade():
-# subprocess.run(["flask", "db", "upgrade"])
-#
-
-
-
def ddl_init_and_upgrade():
# Base.metadata.create_all(bind=engine)
# 生成并应用迁移脚本
@@ -96,14 +105,8 @@ def ddl_init_and_upgrade():
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
with engine.connect() as connection:
alembic_cfg.attributes['connection'] = connection
- command.revision(alembic_cfg, "test", True)
+ heads = command.heads(alembic_cfg)
+ print("heads:" + str(heads))
+
+ command.revision(alembic_cfg, "dbgpt ddl upate", True)
command.upgrade(alembic_cfg, "head")
- # alembic_cfg.attributes['connection'] = engine.connect()
- # command.upgrade(alembic_cfg, 'head')
-
- # with app.app_context():
- # db_init()
- # db_migrate()
- # db_upgrade()
-
-
diff --git a/pilot/base_modules/module_factory.py b/pilot/base_modules/module_factory.py
index e69de29bb..139597f9c 100644
--- a/pilot/base_modules/module_factory.py
+++ b/pilot/base_modules/module_factory.py
@@ -0,0 +1,2 @@
+
+
diff --git a/pilot/componet.py b/pilot/componet.py
index 0897b3365..ceab72f62 100644
--- a/pilot/componet.py
+++ b/pilot/componet.py
@@ -44,6 +44,7 @@ class LifeCycle:
class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
MODEL_CONTROLLER = "dbgpt_model_controller"
+ AGENT_HUB = "dbgpt_agent_hub"
class BaseComponet(LifeCycle, ABC):
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 0276c2a17..ce3add851 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -98,26 +98,6 @@ class Config(metaclass=Singleton):
### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
- ### The associated configuration parameters of the plug-in control the loading and use of the plug-in
- from auto_gpt_plugin_template import AutoGPTPluginTemplate
-
- self.plugins: List[AutoGPTPluginTemplate] = []
- self.plugins_openai = []
- self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
-
- self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
-
- plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
- if plugins_allowlist:
- self.plugins_allowlist = plugins_allowlist.split(",")
- else:
- self.plugins_allowlist = []
-
- plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
- if plugins_denylist:
- self.plugins_denylist = plugins_denylist.split(",")
- else:
- self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
@@ -126,7 +106,10 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
)
- ### default Local database connection configuration
+
+ self.LOCAL_DB_MANAGE = None
+
+ ###dbgpt meta info database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
@@ -138,7 +121,8 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
- self.LOCAL_DB_MANAGE = None
+ self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
+
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
@@ -197,10 +181,6 @@ class Config(metaclass=Singleton):
"""Set the debug mode value"""
self.debug_mode = value
- def set_plugins(self, value: list) -> None:
- """Set the plugins value."""
- self.plugins = value
-
def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value
diff --git a/pilot/connections/__init__.py b/pilot/connections/__init__.py
index e69de29bb..ce13a69f3 100644
--- a/pilot/connections/__init__.py
+++ b/pilot/connections/__init__.py
@@ -0,0 +1 @@
+from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
\ No newline at end of file
diff --git a/pilot/connections/manages/connect_config_db.py b/pilot/connections/manages/connect_config_db.py
new file mode 100644
index 000000000..42307b243
--- /dev/null
+++ b/pilot/connections/manages/connect_config_db.py
@@ -0,0 +1,62 @@
+from pilot.base_modules.meta_data.base_dao import BaseDao
+from pilot.base_modules.meta_data.meta_data import Base, engine, session
+from typing import List
+from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
+from sqlalchemy import UniqueConstraint
+
+class ConnectConfigEntity(Base):
+ __tablename__ = 'connect_config'
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
+ db_type = Column(String(255), nullable=False, comment="db type")
+ db_name = Column(String(255), nullable=False, comment="db name")
+ db_path = Column(String(255), nullable=True, comment="file db path")
+ db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
+ db_port = Column(String(255), nullable=True, comment="db cnnect port(not file db)")
+ db_user = Column(String(255), nullable=True, comment="db user")
+ db_pwd = Column(String(255), nullable=True, comment="db password")
+ comment = Column(Text, nullable=True, comment="db comment")
+
+ __table_args__ = (
+ UniqueConstraint('db_name', name="uk_db"),
+ Index('idx_q_db_type', 'db_type'),
+ )
+
+
+class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
+ def __init__(self):
+ super().__init__(
+ database="dbgpt", orm_base=Base, db_engine=engine, session=session
+ )
+
+ def update(self, entity: ConnectConfigEntity):
+ session = self.get_session()
+ try:
+ updated = session.merge(entity)
+ session.commit()
+ return updated.id
+ finally:
+ session.close()
+
+ def delete(self, db_name: int):
+ session = self.get_session()
+ if db_name is None:
+ raise Exception("db_name is None")
+
+ db_connect = session.query(ConnectConfigEntity)
+ db_connect = db_connect.filter(
+ ConnectConfigEntity.db_name == db_name
+ )
+ db_connect.delete()
+ session.commit()
+ session.close()
+
+ def get_by_name(self, db_name: str) -> ConnectConfigEntity:
+ session = self.get_session()
+ db_connect = session.query(ConnectConfigEntity)
+ db_connect = db_connect.filter(
+ ConnectConfigEntity.db_name == db_name
+ )
+ result = db_connect.first()
+ session.close()
+ return result
+
diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py
index b5cfbbfdd..3d3a69114 100644
--- a/pilot/connections/manages/connection_manager.py
+++ b/pilot/connections/manages/connection_manager.py
@@ -50,7 +50,7 @@ class ConnectManager:
def __init__(self, system_app: SystemApp):
self.storage = DuckdbConnectConfig()
self.db_summary_client = DBSummaryClient(system_app)
- self.__load_config_db()
+ # self.__load_config_db()
def __load_config_db(self):
if CFG.LOCAL_DB_HOST:
diff --git a/pilot/memory/__init__.py b/pilot/memory/__init__.py
index e69de29bb..2e8c7af1b 100644
--- a/pilot/memory/__init__.py
+++ b/pilot/memory/__init__.py
@@ -0,0 +1 @@
+from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao
\ No newline at end of file
diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py
index d9d1cba29..4d9291e0b 100644
--- a/pilot/memory/chat_history/base.py
+++ b/pilot/memory/chat_history/base.py
@@ -2,11 +2,18 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
-
+from enum import Enum
from pilot.scene.message import OnceConversation
+class MemoryStoreType(Enum):
+ File= 'file'
+ Memory = 'memory'
+ DB = 'db'
+ DuckDb = 'duckdb'
+
class BaseChatHistoryMemory(ABC):
+ store_type: MemoryStoreType
def __init__(self):
self.conversations: List[OnceConversation] = []
@@ -22,10 +29,31 @@ class BaseChatHistoryMemory(ABC):
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""
- @abstractmethod
- def clear(self) -> None:
- """Clear session memory from the local file"""
+ # @abstractmethod
+ # def clear(self) -> None:
+ # """Clear session memory from the local file"""
+ @abstractmethod
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass
+
+ @abstractmethod
+ def update(self, messages: List[OnceConversation]) -> None:
+ pass
+
+ @abstractmethod
+ def delete(self) -> bool:
+ pass
+
+ @abstractmethod
+ def conv_info(self, conv_uid: str = None) -> None:
+ pass
+
+ @abstractmethod
+ def get_messages(self) -> List[OnceConversation]:
+ pass
+
+ @staticmethod
+ def conv_list(cls, user_name: str = None) -> None:
+ pass
\ No newline at end of file
diff --git a/pilot/memory/chat_history/chat_hisotry_factory.py b/pilot/memory/chat_history/chat_hisotry_factory.py
new file mode 100644
index 000000000..6c36053dd
--- /dev/null
+++ b/pilot/memory/chat_history/chat_hisotry_factory.py
@@ -0,0 +1,28 @@
+from .base import MemoryStoreType
+from pilot.configs.config import Config
+
+CFG = Config()
+
+
+
+class ChatHistory:
+
+ def __init__(self):
+ self.memory_type = MemoryStoreType.DB.value
+ self.mem_store_class_map = {}
+ from .store_type.duckdb_history import DuckdbHistoryMemory
+ from .store_type.file_history import FileHistoryMemory
+ from .store_type.meta_db_history import DbHistoryMemory
+ from .store_type.mem_history import MemHistoryMemory
+ self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
+ self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
+ self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
+ self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
+
+
+ def get_store_instance(self, chat_session_id):
+ return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
+
+
+ def get_store_cls(self):
+ return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
diff --git a/pilot/memory/chat_history/chat_history_db.py b/pilot/memory/chat_history/chat_history_db.py
new file mode 100644
index 000000000..f49e1983c
--- /dev/null
+++ b/pilot/memory/chat_history/chat_history_db.py
@@ -0,0 +1,88 @@
+from pilot.base_modules.meta_data.base_dao import BaseDao
+from pilot.base_modules.meta_data.meta_data import Base, engine, session
+from typing import List
+from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
+from sqlalchemy import UniqueConstraint
+
+class ChatHistoryEntity(Base):
+ __tablename__ = 'chat_history'
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
+ conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
+ chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
+ summary = Column(String(255), nullable=False, comment="Conversation record summary")
+ user_name = Column(String(255), nullable=True, comment="interlocutor")
+ messages = Column(Text, nullable=True, comment="Conversation details")
+
+ __table_args__ = (
+ UniqueConstraint('conv_uid', name="uk_conversation"),
+ Index('idx_q_user', 'user_name'),
+ Index('idx_q_mode', 'chat_mode'),
+ Index('idx_q_conv', 'summary'),
+ )
+
+
+class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
+ def __init__(self):
+ super().__init__(
+ database="dbgpt", orm_base=Base, db_engine=engine, session=session
+ )
+
+ def list_last_20(self, user_name: str = None):
+ session = self.get_session()
+ chat_history = session.query(ChatHistoryEntity)
+ if user_name:
+ chat_history = chat_history.filter(
+ ChatHistoryEntity.user_name == user_name
+ )
+
+ chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
+
+ result = chat_history.limit(20).all()
+ session.close()
+ return result
+
+ def update(self, entity: ChatHistoryEntity):
+ session = self.get_session()
+ try:
+ updated = session.merge(entity)
+ session.commit()
+ return updated.id
+ finally:
+ session.close()
+
+ def update_message_by_uid(self, message: str, conv_uid:str):
+ session = self.get_session()
+ try:
+ chat_history = session.query(ChatHistoryEntity)
+ chat_history = chat_history.filter(
+ ChatHistoryEntity.conv_uid == conv_uid
+ )
+ updated = chat_history.update({ChatHistoryEntity.messages: message})
+ session.commit()
+ return updated.id
+ finally:
+ session.close()
+
+ def delete(self, conv_uid: int):
+ session = self.get_session()
+ if conv_uid is None:
+ raise Exception("conv_uid is None")
+
+ chat_history = session.query(ChatHistoryEntity)
+ chat_history = chat_history.filter(
+ ChatHistoryEntity.conv_uid == conv_uid
+ )
+ chat_history.delete()
+ session.commit()
+ session.close()
+
+ def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
+ session = self.get_session()
+ chat_history = session.query(ChatHistoryEntity)
+ chat_history = chat_history.filter(
+ ChatHistoryEntity.conv_uid == conv_uid
+ )
+ result = chat_history.first()
+ session.close()
+ return result
+
diff --git a/pilot/base_modules/agent/plugins_loader.py b/pilot/memory/chat_history/store_type/__init__.py
similarity index 100%
rename from pilot/base_modules/agent/plugins_loader.py
rename to pilot/memory/chat_history/store_type/__init__.py
diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/store_type/duckdb_history.py
similarity index 98%
rename from pilot/memory/chat_history/duckdb_history.py
rename to pilot/memory/chat_history/store_type/duckdb_history.py
index cbc21f03a..97aae159b 100644
--- a/pilot/memory/chat_history/duckdb_history.py
+++ b/pilot/memory/chat_history/store_type/duckdb_history.py
@@ -10,6 +10,7 @@ from pilot.scene.message import (
_conversation_to_dic,
)
from pilot.common.formatting import MyEncoder
+from ..base import MemoryStoreType
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
@@ -19,6 +20,8 @@ CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
+ store_type: str = MemoryStoreType.DuckDb.value
+
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@@ -117,30 +120,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit()
return True
- @staticmethod
- def conv_list(cls, user_name: str = None) -> None:
- if os.path.isfile(duckdb_path):
- cursor = duckdb.connect(duckdb_path).cursor()
- if user_name:
- cursor.execute(
- "SELECT * FROM chat_history where user_name=? order by id desc limit 20",
- [user_name],
- )
- else:
- cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
- # 获取查询结果字段名
- fields = [field[0] for field in cursor.description]
- data = []
- for row in cursor.fetchall():
- row_dict = {}
- for i, field in enumerate(fields):
- row_dict[field] = row[i]
- data.append(row_dict)
-
- return data
-
- return []
-
def conv_info(self, conv_uid: str = None) -> None:
cursor = self.connect.cursor()
cursor.execute(
@@ -168,3 +147,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if context[0]:
return json.loads(context[0])
return None
+
+
+ @staticmethod
+ def conv_list(cls, user_name: str = None) -> None:
+ if os.path.isfile(duckdb_path):
+ cursor = duckdb.connect(duckdb_path).cursor()
+ if user_name:
+ cursor.execute(
+ "SELECT * FROM chat_history where user_name=? order by id desc limit 20",
+ [user_name],
+ )
+ else:
+ cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
+ # 获取查询结果字段名
+ fields = [field[0] for field in cursor.description]
+ data = []
+ for row in cursor.fetchall():
+ row_dict = {}
+ for i, field in enumerate(fields):
+ row_dict[field] = row[i]
+ data.append(row_dict)
+
+ return data
+
+ return []
diff --git a/pilot/memory/chat_history/file_history.py b/pilot/memory/chat_history/store_type/file_history.py
similarity index 93%
rename from pilot/memory/chat_history/file_history.py
rename to pilot/memory/chat_history/store_type/file_history.py
index ffdd4169b..fa1143309 100644
--- a/pilot/memory/chat_history/file_history.py
+++ b/pilot/memory/chat_history/store_type/file_history.py
@@ -11,12 +11,14 @@ from pilot.scene.message import (
conversation_from_dict,
conversations_to_dict,
)
-
+from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
+ store_type: str = MemoryStoreType.File.value
+
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
date_string = now.strftime("%Y%m%d")
@@ -47,3 +49,5 @@ class FileHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))
+
+
diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/store_type/mem_history.py
similarity index 88%
rename from pilot/memory/chat_history/mem_history.py
rename to pilot/memory/chat_history/store_type/mem_history.py
index 2e832041f..5c3ddc217 100644
--- a/pilot/memory/chat_history/mem_history.py
+++ b/pilot/memory/chat_history/store_type/mem_history.py
@@ -4,11 +4,14 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.configs.config import Config
from pilot.scene.message import OnceConversation
from pilot.common.custom_data_structure import FixedSizeDict
+from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
+ store_type: str = MemoryStoreType.Memory.value
+
histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str):
diff --git a/pilot/memory/chat_history/store_type/meta_db_history.py b/pilot/memory/chat_history/store_type/meta_db_history.py
new file mode 100644
index 000000000..c1fc0ec5d
--- /dev/null
+++ b/pilot/memory/chat_history/store_type/meta_db_history.py
@@ -0,0 +1,100 @@
+import json
+import logging
+from typing import List
+from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
+from sqlalchemy import UniqueConstraint
+from pilot.configs.config import Config
+from pilot.memory.chat_history.base import BaseChatHistoryMemory
+from pilot.scene.message import (
+ OnceConversation,
+ _conversation_to_dic,
+)
+from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
+
+from pilot.memory.chat_history.base import MemoryStoreType
+CFG = Config()
+logger = logging.getLogger("db_chat_history")
+
+class DbHistoryMemory(BaseChatHistoryMemory):
+ store_type: str = MemoryStoreType.DB.value
+ def __init__(self, chat_session_id: str):
+ self.chat_seesion_id = chat_session_id
+ self.chat_history_dao = ChatHistoryDao()
+
+ def messages(self) -> List[OnceConversation]:
+
+ chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ if chat_history:
+ context = chat_history.messages
+ if context:
+ conversations: List[OnceConversation] = json.loads(context)
+ return conversations
+ return []
+
+
+ def create(self, chat_mode, summary: str, user_name: str) -> None:
+ try:
+ chat_history: ChatHistoryEntity = ChatHistoryEntity()
+ chat_history.chat_mode = chat_mode
+ chat_history.summary = summary
+ chat_history.user_name = user_name
+
+ self.chat_history_dao.update(chat_history)
+ except Exception as e:
+ logger.error("init create conversation log error!" + str(e))
+
+
+ def append(self, once_message: OnceConversation) -> None:
+ logger.info("db history append:{}", once_message)
+ chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ conversations: List[OnceConversation] = []
+ if chat_history:
+ context = chat_history.messages
+ if context:
+ conversations = json.loads(context)
+ else:
+ chat_history.summary = once_message.get_user_conv().content
+ else:
+ chat_history: ChatHistoryEntity = ChatHistoryEntity()
+ chat_history.conv_uid = self.chat_seesion_id
+ chat_history.chat_mode = once_message.chat_mode
+ chat_history.user_name = "default"
+ chat_history.summary = once_message.get_user_conv().content
+
+ conversations.append(_conversation_to_dic(once_message))
+ chat_history.messages = json.dumps(conversations, ensure_ascii=False)
+
+ self.chat_history_dao.update(chat_history)
+
+ def update(self, messages: List[OnceConversation]) -> None:
+ self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
+
+
+ def delete(self) -> bool:
+ self.chat_history_dao.delete(self.chat_seesion_id)
+
+
+ def conv_info(self, conv_uid: str = None) -> None:
+ logger.info("conv_info:{}", conv_uid)
+ chat_history = self.chat_history_dao.get_by_uid(conv_uid)
+ return chat_history.__dict__
+
+
+ def get_messages(self) -> List[OnceConversation]:
+ logger.info("get_messages:{}", self.chat_seesion_id)
+ chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
+ if chat_history:
+ context = chat_history.messages
+ return json.loads(context)
+ return []
+
+
+ @staticmethod
+ def conv_list(cls, user_name: str = None) -> None:
+
+ chat_history_dao = ChatHistoryDao()
+ history_list = chat_history_dao.list_last_20()
+ result = []
+ for history in history_list:
+ result.append(history.__dict__)
+ return result
\ No newline at end of file
diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py
index 54360e477..d9aeade44 100644
--- a/pilot/model/cluster/controller/controller.py
+++ b/pilot/model/cluster/controller/controller.py
@@ -142,12 +142,12 @@ def initialize_controller(
controller.backend = LocalModelController()
if app:
- app.include_router(router, prefix="/api")
+ app.include_router(router, prefix="/api", tags=['Model'])
else:
import uvicorn
app = FastAPI()
- app.include_router(router, prefix="/api")
+ app.include_router(router, prefix="/api", tags=['Model'])
uvicorn.run(app, host=host, port=port, log_level="info")
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index c51b9d254..dd4091ac0 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -8,13 +8,10 @@ from fastapi import (
Request,
File,
UploadFile,
- Form,
Body,
- BackgroundTasks,
)
from fastapi.responses import StreamingResponse
-from fastapi.exceptions import RequestValidationError
from typing import List
import tempfile
@@ -35,11 +32,10 @@ from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
-from pilot.common.schema import DBType
-from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
-from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
+from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient
+from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
CFG = Config()
@@ -61,7 +57,6 @@ def __get_conv_user_message(conversations: dict):
def __new_conversation(chat_mode, user_id) -> ConversationVo:
unique_id = uuid.uuid1()
- # history_mem = DuckdbHistoryMemory(str(unique_id))
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
@@ -145,7 +140,8 @@ async def db_support_types():
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None):
dialogues: List = []
- datas = DuckdbHistoryMemory.conv_list(user_id)
+ chat_history_service = ChatHistory()
+ datas = chat_history_service.get_store_cls().conv_list(user_id)
for item in datas:
conv_uid = item.get("conv_uid")
summary = item.get("summary")
@@ -257,14 +253,18 @@ async def params_load(
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
- history_mem = DuckdbHistoryMemory(con_uid)
+
+ history_fac = ChatHistory()
+ history_mem = history_fac.get_store_instance(con_uid)
history_mem.delete()
return Result.succ(None)
def get_hist_messages(conv_uid: str):
message_vos: List[MessageVo] = []
- history_mem = DuckdbHistoryMemory(conv_uid)
+ history_fac = ChatHistory()
+ history_mem = history_fac.get_store_instance(conv_uid)
+
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py
index e1b313664..0e43393be 100644
--- a/pilot/openapi/api_v1/editor/api_editor_v1.py
+++ b/pilot/openapi/api_v1/editor/api_editor_v1.py
@@ -26,10 +26,10 @@ from pilot.openapi.editor_view_model import (
)
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
-from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.data_loader import DbDataLoader
+from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
CFG = Config()
@@ -67,7 +67,8 @@ async def get_editor_tables(
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
async def get_editor_sql_rounds(con_uid: str):
logger.info("get_editor_sql_rounds:{con_uid}")
- history_mem = DuckdbHistoryMemory(con_uid)
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
result: List = []
@@ -89,7 +90,8 @@ async def get_editor_sql_rounds(con_uid: str):
@router.get("/v1/editor/sql", response_model=Result[dict])
async def get_editor_sql(con_uid: str, round: int):
logger.info(f"get_editor_sql:{con_uid},{round}")
- history_mem = DuckdbHistoryMemory(con_uid)
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
@@ -138,7 +140,9 @@ async def editor_sql_run(run_param: dict = Body()):
@router.post("/v1/sql/editor/submit")
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
- history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
+
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(sql_edit_context.con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
@@ -171,7 +175,8 @@ async def get_editor_chart_list(con_uid: str):
logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
- history_mem = DuckdbHistoryMemory(con_uid)
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x["chat_order"])
@@ -193,7 +198,8 @@ async def get_editor_chart_info(param: dict = Body()):
conv_uid = param["con_uid"]
chart_title = param["chart_title"]
- history_mem = DuckdbHistoryMemory(conv_uid)
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x["chat_order"])
@@ -269,7 +275,9 @@ async def editor_chart_run(run_param: dict = Body()):
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
- history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
+
+ chat_history_fac = ChatHistory()
+ history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 57033cede..8594b7252 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -7,15 +7,12 @@ from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
-from pilot.memory.chat_history.base import BaseChatHistoryMemory
-from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
-from pilot.memory.chat_history.file_history import FileHistoryMemory
-from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, get_or_create_event_loop
from pydantic import Extra
+from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -54,9 +51,9 @@ class BaseChat(ABC):
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
-
+ chat_history_fac = ChatHistory()
### can configurable storage methods
- self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
+ self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(
@@ -162,6 +159,7 @@ class BaseChat(ABC):
output, self.skip_echo_len
)
view_msg = self.stream_plugin_call(msg)
+ view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
self.current_message.add_view_message(view_msg)
diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py
index 26a58f73f..8175d3d7c 100644
--- a/pilot/scene/chat_agent/chat.py
+++ b/pilot/scene/chat_agent/chat.py
@@ -9,6 +9,8 @@ from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.base_modules.agent import PluginPromptGenerator
from pilot.common.string_utils import extract_content
from .prompt import prompt
+from pilot.componet import ComponetType
+from pilot.base_modules.agent.controller import ModuleAgent
CFG = Config()
@@ -27,14 +29,10 @@ class ChatAgent(BaseChat):
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
- # load select plugin
- for plugin in CFG.plugins:
- if plugin._name in self.select_plugins:
- if not plugin.can_handle_post_prompt():
- continue
- self.plugins_prompt_generator = plugin.post_prompt(
- self.plugins_prompt_generator
- )
+
+ # load select plugin
+ agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
+ self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
self.api_call = ApiCall(self.plugins_prompt_generator)
diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py
index 493fef4cd..1c5229f89 100644
--- a/pilot/scene/chat_agent/prompt.py
+++ b/pilot/scene/chat_agent/prompt.py
@@ -32,7 +32,7 @@ _DEFAULT_TEMPLATE_ZH = """
根据用户目标,请一步步思考,如何在满足下面约束条件的前提下,优先使用给出工具回答或者完成用户目标。
约束条件:
- 1.从下面给定工具列表找到可用的工具后,请确保输出结果包含以下内容用来使用工具:
+ 1.从下面给定工具列表找到可用的工具后,请输出以下内容用来使用工具, 注意要确保下面内容在输出结果中只出现一次:
Selected Tool namevaluevalue
2.请根据工具列表对应工具的定义来生成上述调用文本, 参考案例如下:
工具作用介绍: "工具名称", args: "参数1": "<参数1取值描述>","参数2": "<参数2取值描述>" 对应调用文本:工具名称<参数1>value参数1><参数2>value参数2>
diff --git a/pilot/server/base.py b/pilot/server/base.py
index 064f2d247..fb1d518c7 100644
--- a/pilot/server/base.py
+++ b/pilot/server/base.py
@@ -32,7 +32,6 @@ def async_db_summery(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
- from pilot.base_modules.agent.plugins_util import scan_plugins
# logger.info(f"args: {args}")
@@ -45,7 +44,7 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
- cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
+
# Loader plugins and commands
command_categories = [
diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py
index 755f13b21..c398181fa 100644
--- a/pilot/server/componet_configs.py
+++ b/pilot/server/componet_configs.py
@@ -24,9 +24,11 @@ def initialize_componets(
embedding_model_path: str,
):
from pilot.model.cluster.controller.controller import controller
-
system_app.register_instance(controller)
+ from pilot.base_modules.agent.controller import module_agent
+ system_app.register_instance(module_agent)
+
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py
index 078a2daa3..b27cba92e 100644
--- a/pilot/server/dbgpt_server.py
+++ b/pilot/server/dbgpt_server.py
@@ -71,11 +71,11 @@ app.add_middleware(
)
-app.include_router(api_v1, prefix="/api")
-app.include_router(api_editor_route_v1, prefix="/api")
-app.include_router(agent_route, prefix="/api")
+app.include_router(api_v1, prefix="/api", tags=["Chat"])
+app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
-app.include_router(knowledge_router)
+
+app.include_router(knowledge_router, tags=["Knowledge"])
def mount_static_files(app):