mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat(Agent): ChatAgent And AgentHub
1.AgentHub Module complete 2.New ChatAgent Mode complete 3.MetaData develop,auto update ddl
This commit is contained in:
parent
bf84663b83
commit
ed702db9bc
@ -6,7 +6,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
@ -18,14 +18,44 @@ from pilot.openapi.api_view_model import (
|
||||
from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
|
||||
from .hub.agent_hub import AgentHub
|
||||
from .db.plugin_hub_db import PluginHubEntity
|
||||
from .db.my_plugin_db import MyPluginEntity
|
||||
from .plugins_util import scan_plugins
|
||||
from .commands.generator import PluginPromptGenerator
|
||||
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
||||
|
||||
router = APIRouter()
|
||||
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
||||
|
||||
|
||||
@router.post("/api/v1/agent/hub/update", response_model=Result[str])
|
||||
class ModuleAgent(BaseComponet, ABC):
|
||||
name = ComponetType.AGENT_HUB
|
||||
|
||||
def __init__(self):
|
||||
#load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
if plugin._name in select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
|
||||
module_agent = ModuleAgent()
|
||||
|
||||
|
||||
@router.post("/v1/agent/hub/update", response_model=Result[str])
|
||||
async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||
try:
|
||||
@ -38,14 +68,15 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
|
||||
|
||||
|
||||
@router.post("/api/v1/agent/query", response_model=Result[str])
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
logger.info(f"get_agent_list:{json.dumps(filter)}")
|
||||
logger.info(f"get_agent_list:{filter.__dict__}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
filter_enetity:PluginHubEntity = PluginHubEntity()
|
||||
attrs = vars(filter.filter) # 获取原始对象的属性字典
|
||||
for attr, value in attrs.items():
|
||||
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
|
||||
if filter.filter:
|
||||
attrs = vars(filter.filter) # 获取原始对象的属性字典
|
||||
for attr, value in attrs.items():
|
||||
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
|
||||
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
@ -54,21 +85,30 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
result.total_page = total_pages
|
||||
result.total_row_count = total_count
|
||||
result.datas = datas
|
||||
return Result.succ(result)
|
||||
# print(json.dumps(result.to_dic()))
|
||||
return Result.succ(result.to_dic())
|
||||
|
||||
@router.post("/api/v1/agent/my", response_model=Result[str])
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
async def my_agents(user:str= None):
|
||||
logger.info(f"my_agents:{json.dumps(my_agents)}")
|
||||
logger.info(f"my_agents:{user}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
return Result.succ(agent_hub.get_my_plugin(user))
|
||||
agents = agent_hub.get_my_plugin(user)
|
||||
agent_dicts = []
|
||||
for agent in agents:
|
||||
agent_dicts.append(agent.__dict__)
|
||||
|
||||
return Result.succ(agent_dicts)
|
||||
|
||||
|
||||
@router.post("/api/v1/agent/install", response_model=Result[str])
|
||||
@router.post("/v1/agent/install", response_model=Result[str])
|
||||
async def agent_install(plugin_name:str, user: str = None):
|
||||
logger.info(f"agent_install:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.install_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
@ -76,19 +116,21 @@ async def agent_install(plugin_name:str, user: str = None):
|
||||
|
||||
|
||||
|
||||
@router.post("/api/v1/agent/uninstall", response_model=Result[str])
|
||||
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(plugin_name:str, user: str = None):
|
||||
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.uninstall_plugin(plugin_name, user)
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
@router.post("/api/v1/personal/agent/upload", response_model=Result[str])
|
||||
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
|
||||
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||
try:
|
||||
|
@ -14,13 +14,13 @@ class MyPluginEntity(Base):
|
||||
__tablename__ = 'my_plugin'
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String, nullable=True, comment="user's tenant")
|
||||
user_code = Column(String, nullable=True, comment="user code")
|
||||
user_name = Column(String, nullable=True, comment="user name")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String, nullable=False, comment="plugin package file name")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
tenant = Column(String(255), nullable=True, comment="user's tenant")
|
||||
user_code = Column(String(255), nullable=True, comment="user code")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
|
||||
|
@ -12,15 +12,15 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
class PluginHubEntity(Base):
|
||||
__tablename__ = 'plugin_hub'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String, nullable=False, comment="plugin description")
|
||||
author = Column(String, nullable=True, comment="plugin author")
|
||||
email = Column(String, nullable=True, comment="plugin author email")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
storage_channel = Column(String, comment="plugin storage channel")
|
||||
storage_url = Column(String, comment="plugin download url")
|
||||
download_param = Column(String, comment="plugin download param")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String(255), nullable=False, comment="plugin description")
|
||||
author = Column(String(255), nullable=True, comment="plugin author")
|
||||
email = Column(String(255), nullable=True, comment="plugin author email")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
storage_channel = Column(String(255), comment="plugin storage channel")
|
||||
storage_url = Column(String(255), comment="plugin download url")
|
||||
download_param = Column(String(255), comment="plugin download param")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
@ -146,11 +146,10 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
session = self.get_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = PluginHubEntity(id=plugin_id)
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if query.id is not None:
|
||||
if plugin_id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.id == query.id
|
||||
PluginHubEntity.id == plugin_id
|
||||
)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
|
@ -17,6 +17,18 @@ class PagenationResult(BaseModel, Generic[T]):
|
||||
total_row_count: int = 0
|
||||
datas: List[T] = []
|
||||
|
||||
def to_dic(self):
|
||||
data_dicts =[]
|
||||
for item in self.datas:
|
||||
data_dicts.append(item.__dict__)
|
||||
return {
|
||||
'page_index': self.page_index,
|
||||
'page_size': self.page_size,
|
||||
'total_page': self.total_page,
|
||||
'total_row_count': self.total_row_count,
|
||||
'datas': data_dicts
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class PluginHubFilter(BaseModel):
|
||||
name: str
|
||||
@ -41,10 +53,9 @@ class MyPluginFilter(BaseModel):
|
||||
|
||||
|
||||
class PluginHubParam(BaseModel):
|
||||
channel: str = Field(..., description="Plugin storage channel")
|
||||
url: str = Field(..., description="Plugin storage url")
|
||||
|
||||
branch: Optional[str] = Field(None, description="github download branch", nullable=True)
|
||||
channel: Optional[str] = Field("git", description="Plugin storage channel")
|
||||
url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
|
||||
branch: Optional[str] = Field("main", description="github download branch", nullable=True)
|
||||
authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)
|
||||
|
||||
|
||||
|
@ -2,33 +2,72 @@ import uuid
|
||||
import os
|
||||
import duckdb
|
||||
import sqlite3
|
||||
import logging
|
||||
import fnmatch
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate,upgrade
|
||||
from flask.cli import with_appcontext
|
||||
import subprocess
|
||||
|
||||
from sqlalchemy import create_engine,DateTime, String, func, MetaData
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from alembic import context, command
|
||||
from alembic.config import Config
|
||||
from alembic.config import Config as AlembicConfig
|
||||
from urllib.parse import quote
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
logger = logging.getLogger("meta_data")
|
||||
|
||||
CFG = Config()
|
||||
default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
|
||||
db_path = default_db_path + "/dbgpt.db"
|
||||
# Meta Info
|
||||
db_name = "dbgpt"
|
||||
db_path = default_db_path + f"/{db_name}.db"
|
||||
connection = sqlite3.connect(db_path)
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
|
||||
if CFG.LOCAL_DB_TYPE == 'mysql':
|
||||
engine_temp = create_engine(f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
)
|
||||
# check and auto create mysqldatabase
|
||||
try:
|
||||
# try to connect
|
||||
with engine_temp.connect() as conn:
|
||||
conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
|
||||
print(f"Already connect '{db_name}'")
|
||||
|
||||
except OperationalError as e:
|
||||
# if connect failed, create dbgpt database
|
||||
logger.error(f"{db_name} not connect success!")
|
||||
|
||||
engine = create_engine(f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ f"/{db_name}"
|
||||
)
|
||||
else:
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
|
||||
|
||||
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
@ -40,7 +79,7 @@ Base = declarative_base(bind=engine)
|
||||
# 创建Alembic配置对象
|
||||
|
||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg = AlembicConfig(alembic_ini_path)
|
||||
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
|
||||
|
||||
@ -59,36 +98,6 @@ alembic_cfg.attributes['session'] = session
|
||||
# Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
# app = Flask(__name__)
|
||||
# default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||
# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
|
||||
# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
|
||||
# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
# db = SQLAlchemy(app)
|
||||
# migrate = Migrate(app, db)
|
||||
#
|
||||
# # 设置FLASK_APP环境变量
|
||||
# import os
|
||||
# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
|
||||
#
|
||||
# @app.cli.command("db_init")
|
||||
# @with_appcontext
|
||||
# def db_init():
|
||||
# subprocess.run(["flask", "db", "init"])
|
||||
#
|
||||
# @app.cli.command("db_migrate")
|
||||
# @with_appcontext
|
||||
# def db_migrate():
|
||||
# subprocess.run(["flask", "db", "migrate"])
|
||||
#
|
||||
# @app.cli.command("db_upgrade")
|
||||
# @with_appcontext
|
||||
# def db_upgrade():
|
||||
# subprocess.run(["flask", "db", "upgrade"])
|
||||
#
|
||||
|
||||
|
||||
|
||||
def ddl_init_and_upgrade():
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
# 生成并应用迁移脚本
|
||||
@ -96,14 +105,8 @@ def ddl_init_and_upgrade():
|
||||
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes['connection'] = connection
|
||||
command.revision(alembic_cfg, "test", True)
|
||||
heads = command.heads(alembic_cfg)
|
||||
print("heads:" + str(heads))
|
||||
|
||||
command.revision(alembic_cfg, "dbgpt ddl upate", True)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
# alembic_cfg.attributes['connection'] = engine.connect()
|
||||
# command.upgrade(alembic_cfg, 'head')
|
||||
|
||||
# with app.app_context():
|
||||
# db_init()
|
||||
# db_migrate()
|
||||
# db_upgrade()
|
||||
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
|
||||
|
@ -44,6 +44,7 @@ class LifeCycle:
|
||||
class ComponetType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
AGENT_HUB = "dbgpt_agent_hub"
|
||||
|
||||
|
||||
class BaseComponet(LifeCycle, ABC):
|
||||
|
@ -98,26 +98,6 @@ class Config(metaclass=Singleton):
|
||||
### message stor file
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = plugins_denylist.split(",")
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
### Native SQL Execution Capability Control Configuration
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
|
||||
@ -126,7 +106,10 @@ class Config(metaclass=Singleton):
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||
)
|
||||
|
||||
### default Local database connection configuration
|
||||
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
###dbgpt meta info database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
|
||||
@ -138,7 +121,8 @@ class Config(metaclass=Singleton):
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
|
||||
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
@ -197,10 +181,6 @@ class Config(metaclass=Singleton):
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_plugins(self, value: list) -> None:
|
||||
"""Set the plugins value."""
|
||||
self.plugins = value
|
||||
|
||||
def set_templature(self, value: int) -> None:
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
@ -0,0 +1 @@
|
||||
from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
|
62
pilot/connections/manages/connect_config_db.py
Normal file
62
pilot/connections/manages/connect_config_db.py
Normal file
@ -0,0 +1,62 @@
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
class ConnectConfigEntity(Base):
|
||||
__tablename__ = 'connect_config'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
db_type = Column(String(255), nullable=False, comment="db type")
|
||||
db_name = Column(String(255), nullable=False, comment="db name")
|
||||
db_path = Column(String(255), nullable=True, comment="file db path")
|
||||
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
|
||||
db_port = Column(String(255), nullable=True, comment="db cnnect port(not file db)")
|
||||
db_user = Column(String(255), nullable=True, comment="db user")
|
||||
db_pwd = Column(String(255), nullable=True, comment="db password")
|
||||
comment = Column(Text, nullable=True, comment="db comment")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('db_name', name="uk_db"),
|
||||
Index('idx_q_db_type', 'db_type'),
|
||||
)
|
||||
|
||||
|
||||
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def update(self, entity: ConnectConfigEntity):
|
||||
session = self.get_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def delete(self, db_name: int):
|
||||
session = self.get_session()
|
||||
if db_name is None:
|
||||
raise Exception("db_name is None")
|
||||
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(
|
||||
ConnectConfigEntity.db_name == db_name
|
||||
)
|
||||
db_connect.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_name(self, db_name: str) -> ConnectConfigEntity:
|
||||
session = self.get_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(
|
||||
ConnectConfigEntity.db_name == db_name
|
||||
)
|
||||
result = db_connect.first()
|
||||
session.close()
|
||||
return result
|
||||
|
@ -50,7 +50,7 @@ class ConnectManager:
|
||||
def __init__(self, system_app: SystemApp):
|
||||
self.storage = DuckdbConnectConfig()
|
||||
self.db_summary_client = DBSummaryClient(system_app)
|
||||
self.__load_config_db()
|
||||
# self.__load_config_db()
|
||||
|
||||
def __load_config_db(self):
|
||||
if CFG.LOCAL_DB_HOST:
|
||||
|
@ -0,0 +1 @@
|
||||
from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao
|
@ -2,11 +2,18 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from enum import Enum
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
class MemoryStoreType(Enum):
|
||||
File= 'file'
|
||||
Memory = 'memory'
|
||||
DB = 'db'
|
||||
DuckDb = 'duckdb'
|
||||
|
||||
|
||||
class BaseChatHistoryMemory(ABC):
|
||||
store_type: MemoryStoreType
|
||||
def __init__(self):
|
||||
self.conversations: List[OnceConversation] = []
|
||||
|
||||
@ -22,10 +29,31 @@ class BaseChatHistoryMemory(ABC):
|
||||
def append(self, message: OnceConversation) -> None:
|
||||
"""Append the message to the record in the local file"""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
# @abstractmethod
|
||||
# def clear(self) -> None:
|
||||
# """Clear session memory from the local file"""
|
||||
|
||||
@abstractmethod
|
||||
def conv_list(self, user_name: str = None) -> None:
|
||||
"""get user's conversation list"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
pass
|
28
pilot/memory/chat_history/chat_hisotry_factory.py
Normal file
28
pilot/memory/chat_history/chat_hisotry_factory.py
Normal file
@ -0,0 +1,28 @@
|
||||
from .base import MemoryStoreType
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
|
||||
def __init__(self):
|
||||
self.memory_type = MemoryStoreType.DB.value
|
||||
self.mem_store_class_map = {}
|
||||
from .store_type.duckdb_history import DuckdbHistoryMemory
|
||||
from .store_type.file_history import FileHistoryMemory
|
||||
from .store_type.meta_db_history import DbHistoryMemory
|
||||
from .store_type.mem_history import MemHistoryMemory
|
||||
self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
|
||||
self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
|
||||
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
|
||||
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
|
||||
|
||||
|
||||
def get_store_instance(self, chat_session_id):
|
||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
|
||||
|
||||
|
||||
def get_store_cls(self):
|
||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
|
88
pilot/memory/chat_history/chat_history_db.py
Normal file
88
pilot/memory/chat_history/chat_history_db.py
Normal file
@ -0,0 +1,88 @@
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
class ChatHistoryEntity(Base):
|
||||
__tablename__ = 'chat_history'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
|
||||
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
|
||||
summary = Column(String(255), nullable=False, comment="Conversation record summary")
|
||||
user_name = Column(String(255), nullable=True, comment="interlocutor")
|
||||
messages = Column(Text, nullable=True, comment="Conversation details")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('conv_uid', name="uk_conversation"),
|
||||
Index('idx_q_user', 'user_name'),
|
||||
Index('idx_q_mode', 'chat_mode'),
|
||||
Index('idx_q_conv', 'summary'),
|
||||
)
|
||||
|
||||
|
||||
class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def list_last_20(self, user_name: str = None):
|
||||
session = self.get_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
if user_name:
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.user_name == user_name
|
||||
)
|
||||
|
||||
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
|
||||
|
||||
result = chat_history.limit(20).all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update(self, entity: ChatHistoryEntity):
|
||||
session = self.get_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def update_message_by_uid(self, message: str, conv_uid:str):
|
||||
session = self.get_session()
|
||||
try:
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
updated = chat_history.update({ChatHistoryEntity.messages: message})
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def delete(self, conv_uid: int):
|
||||
session = self.get_session()
|
||||
if conv_uid is None:
|
||||
raise Exception("conv_uid is None")
|
||||
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
chat_history.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
||||
session = self.get_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
result = chat_history.first()
|
||||
session.close()
|
||||
return result
|
||||
|
@ -10,6 +10,7 @@ from pilot.scene.message import (
|
||||
_conversation_to_dic,
|
||||
)
|
||||
from pilot.common.formatting import MyEncoder
|
||||
from ..base import MemoryStoreType
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
@ -19,6 +20,8 @@ CFG = Config()
|
||||
|
||||
|
||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.DuckDb.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
@ -117,30 +120,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if user_name:
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
|
||||
[user_name],
|
||||
)
|
||||
else:
|
||||
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
|
||||
# 获取查询结果字段名
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
|
||||
return data
|
||||
|
||||
return []
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
@ -168,3 +147,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
if context[0]:
|
||||
return json.loads(context[0])
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if user_name:
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
|
||||
[user_name],
|
||||
)
|
||||
else:
|
||||
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
|
||||
# 获取查询结果字段名
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
|
||||
return data
|
||||
|
||||
return []
|
@ -11,12 +11,14 @@ from pilot.scene.message import (
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
|
||||
from pilot.memory.chat_history.base import MemoryStoreType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.File.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
now = datetime.datetime.now()
|
||||
date_string = now.strftime("%Y%m%d")
|
||||
@ -47,3 +49,5 @@ class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def clear(self) -> None:
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
|
@ -4,11 +4,14 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.common.custom_data_structure import FixedSizeDict
|
||||
from pilot.memory.chat_history.base import MemoryStoreType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.Memory.value
|
||||
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
100
pilot/memory/chat_history/store_type/meta_db_history.py
Normal file
100
pilot/memory/chat_history/store_type/meta_db_history.py
Normal file
@ -0,0 +1,100 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from pilot.configs.config import Config
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
_conversation_to_dic,
|
||||
)
|
||||
from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
|
||||
|
||||
from pilot.memory.chat_history.base import MemoryStoreType
|
||||
CFG = Config()
|
||||
logger = logging.getLogger("db_chat_history")
|
||||
|
||||
class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.DB.value
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.chat_history_dao = ChatHistoryDao()
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
if context:
|
||||
conversations: List[OnceConversation] = json.loads(context)
|
||||
return conversations
|
||||
return []
|
||||
|
||||
|
||||
def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||
try:
|
||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
chat_history.chat_mode = chat_mode
|
||||
chat_history.summary = summary
|
||||
chat_history.user_name = user_name
|
||||
|
||||
self.chat_history_dao.update(chat_history)
|
||||
except Exception as e:
|
||||
logger.error("init create conversation log error!" + str(e))
|
||||
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
logger.info("db history append:{}", once_message)
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
conversations: List[OnceConversation] = []
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
if context:
|
||||
conversations = json.loads(context)
|
||||
else:
|
||||
chat_history.summary = once_message.get_user_conv().content
|
||||
else:
|
||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
chat_history.conv_uid = self.chat_seesion_id
|
||||
chat_history.chat_mode = once_message.chat_mode
|
||||
chat_history.user_name = "default"
|
||||
chat_history.summary = once_message.get_user_conv().content
|
||||
|
||||
conversations.append(_conversation_to_dic(once_message))
|
||||
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
|
||||
|
||||
self.chat_history_dao.update(chat_history)
|
||||
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
|
||||
|
||||
|
||||
def delete(self) -> bool:
|
||||
self.chat_history_dao.delete(self.chat_seesion_id)
|
||||
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
logger.info("conv_info:{}", conv_uid)
|
||||
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
|
||||
return chat_history.__dict__
|
||||
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
logger.info("get_messages:{}", self.chat_seesion_id)
|
||||
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
return json.loads(context)
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
|
||||
chat_history_dao = ChatHistoryDao()
|
||||
history_list = chat_history_dao.list_last_20()
|
||||
result = []
|
||||
for history in history_list:
|
||||
result.append(history.__dict__)
|
||||
return result
|
@ -142,12 +142,12 @@ def initialize_controller(
|
||||
controller.backend = LocalModelController()
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api")
|
||||
app.include_router(router, prefix="/api", tags=['Model'])
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api")
|
||||
app.include_router(router, prefix="/api", tags=['Model'])
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
|
@ -8,13 +8,10 @@ from fastapi import (
|
||||
Request,
|
||||
File,
|
||||
UploadFile,
|
||||
Form,
|
||||
Body,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from typing import List
|
||||
import tempfile
|
||||
|
||||
@ -35,11 +32,10 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -61,7 +57,6 @@ def __get_conv_user_message(conversations: dict):
|
||||
|
||||
def __new_conversation(chat_mode, user_id) -> ConversationVo:
|
||||
unique_id = uuid.uuid1()
|
||||
# history_mem = DuckdbHistoryMemory(str(unique_id))
|
||||
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
|
||||
|
||||
|
||||
@ -145,7 +140,8 @@ async def db_support_types():
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(user_id: str = None):
|
||||
dialogues: List = []
|
||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||
chat_history_service = ChatHistory()
|
||||
datas = chat_history_service.get_store_cls().conv_list(user_id)
|
||||
for item in datas:
|
||||
conv_uid = item.get("conv_uid")
|
||||
summary = item.get("summary")
|
||||
@ -257,14 +253,18 @@ async def params_load(
|
||||
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
|
||||
history_fac = ChatHistory()
|
||||
history_mem = history_fac.get_store_instance(con_uid)
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
def get_hist_messages(conv_uid: str):
|
||||
message_vos: List[MessageVo] = []
|
||||
history_mem = DuckdbHistoryMemory(conv_uid)
|
||||
history_fac = ChatHistory()
|
||||
history_mem = history_fac.get_store_instance(conv_uid)
|
||||
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
|
@ -26,10 +26,10 @@ from pilot.openapi.editor_view_model import (
|
||||
)
|
||||
|
||||
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -67,7 +67,8 @@ async def get_editor_tables(
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
result: List = []
|
||||
@ -89,7 +90,8 @@ async def get_editor_sql_rounds(con_uid: str):
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
async def get_editor_sql(con_uid: str, round: int):
|
||||
logger.info(f"get_editor_sql:{con_uid},{round}")
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
@ -138,7 +140,9 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
@router.post("/v1/sql/editor/submit")
|
||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(sql_edit_context.con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
@ -171,7 +175,8 @@ async def get_editor_chart_list(con_uid: str):
|
||||
logger.info(
|
||||
f"get_editor_sql_rounds:{con_uid}",
|
||||
)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
@ -193,7 +198,8 @@ async def get_editor_chart_info(param: dict = Body()):
|
||||
conv_uid = param["con_uid"]
|
||||
chart_title = param["chart_title"]
|
||||
|
||||
history_mem = DuckdbHistoryMemory(conv_uid)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
@ -269,7 +275,9 @@ async def editor_chart_run(run_param: dict = Body()):
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
|
||||
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
|
@ -7,15 +7,12 @@ from typing import Any, List, Dict
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.utils import build_logger, get_or_create_event_loop
|
||||
from pydantic import Extra
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
@ -54,9 +51,9 @@ class BaseChat(ABC):
|
||||
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||
)
|
||||
)
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
### can configurable storage methods
|
||||
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
|
||||
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(
|
||||
@ -162,6 +159,7 @@ class BaseChat(ABC):
|
||||
output, self.skip_echo_len
|
||||
)
|
||||
view_msg = self.stream_plugin_call(msg)
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
yield view_msg
|
||||
self.current_message.add_ai_message(msg)
|
||||
self.current_message.add_view_message(view_msg)
|
||||
|
@ -9,6 +9,8 @@ from pilot.base_modules.agent.commands.command_mange import ApiCall
|
||||
from pilot.base_modules.agent import PluginPromptGenerator
|
||||
from pilot.common.string_utils import extract_content
|
||||
from .prompt import prompt
|
||||
from pilot.componet import ComponetType
|
||||
from pilot.base_modules.agent.controller import ModuleAgent
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -27,14 +29,10 @@ class ChatAgent(BaseChat):
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# load select plugin
|
||||
for plugin in CFG.plugins:
|
||||
if plugin._name in self.select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
# load select plugin
|
||||
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
|
||||
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
|
||||
|
||||
self.api_call = ApiCall(self.plugins_prompt_generator)
|
||||
|
||||
|
@ -32,7 +32,7 @@ _DEFAULT_TEMPLATE_ZH = """
|
||||
根据用户目标,请一步步思考,如何在满足下面约束条件的前提下,优先使用给出工具回答或者完成用户目标。
|
||||
|
||||
约束条件:
|
||||
1.从下面给定工具列表找到可用的工具后,请确保输出结果包含以下内容用来使用工具:
|
||||
1.从下面给定工具列表找到可用的工具后,请输出以下内容用来使用工具, 注意要确保下面内容在输出结果中只出现一次:
|
||||
<api-call><name>Selected Tool name</name><args><arg1>value</arg1><arg2>value</arg2></args></api-call>
|
||||
2.请根据工具列表对应工具的定义来生成上述调用文本, 参考案例如下:
|
||||
工具作用介绍: "工具名称", args: "参数1": "<参数1取值描述>","参数2": "<参数2取值描述>" 对应调用文本:<api-call><name>工具名称</name><args><参数1>value</参数1><参数2>value</参数2></args></api-call>
|
||||
|
@ -32,7 +32,6 @@ def async_db_summery(system_app: SystemApp):
|
||||
def server_init(args, system_app: SystemApp):
|
||||
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
||||
|
||||
from pilot.base_modules.agent.plugins_util import scan_plugins
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
@ -45,7 +44,7 @@ def server_init(args, system_app: SystemApp):
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
|
||||
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
|
@ -24,9 +24,11 @@ def initialize_componets(
|
||||
embedding_model_path: str,
|
||||
):
|
||||
from pilot.model.cluster.controller.controller import controller
|
||||
|
||||
system_app.register_instance(controller)
|
||||
|
||||
from pilot.base_modules.agent.controller import module_agent
|
||||
system_app.register_instance(module_agent)
|
||||
|
||||
_initialize_embedding_model(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
@ -71,11 +71,11 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
app.include_router(api_v1, prefix="/api")
|
||||
app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(agent_route, prefix="/api")
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
|
||||
app.include_router(knowledge_router)
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
|
||||
|
||||
def mount_static_files(app):
|
||||
|
Loading…
Reference in New Issue
Block a user