feat(Agent): ChatAgent And AgentHub

1.AgentHub Module complete
2.New ChatAgent Mode complete
3.MetaData develop,auto update ddl
This commit is contained in:
yhjun1026 2023-09-26 17:07:42 +08:00
parent bf84663b83
commit ed702db9bc
29 changed files with 549 additions and 187 deletions

View File

@ -6,7 +6,7 @@ from fastapi import (
UploadFile,
File,
)
from abc import ABC, abstractmethod
from typing import List
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
@ -18,14 +18,44 @@ from pilot.openapi.api_view_model import (
from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .db.my_plugin_db import MyPluginEntity
from .plugins_util import scan_plugins
from .commands.generator import PluginPromptGenerator
from pilot.configs.model_config import PLUGINS_DIR
from pilot.componet import BaseComponet, ComponetType, SystemApp
router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
@router.post("/api/v1/agent/hub/update", response_model=Result[str])
class ModuleAgent(BaseComponet, ABC):
name = ComponetType.AGENT_HUB
def __init__(self):
#load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
@ -38,14 +68,15 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
@router.post("/api/v1/agent/query", response_model=Result[str])
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{json.dumps(filter)}")
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity:PluginHubEntity = PluginHubEntity()
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
@ -54,21 +85,30 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
return Result.succ(result)
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/api/v1/agent/my", response_model=Result[str])
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user:str= None):
logger.info(f"my_agents:{json.dumps(my_agents)}")
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
return Result.succ(agent_hub.get_my_plugin(user))
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
return Result.succ(agent_dicts)
@router.post("/api/v1/agent/install", response_model=Result[str])
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name:str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
@ -76,19 +116,21 @@ async def agent_install(plugin_name:str, user: str = None):
@router.post("/api/v1/agent/uninstall", response_model=Result[str])
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name:str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/api/v1/personal/agent/upload", response_model=Result[str])
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:

View File

@ -14,13 +14,13 @@ class MyPluginEntity(Base):
__tablename__ = 'my_plugin'
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String, nullable=True, comment="user's tenant")
user_code = Column(String, nullable=True, comment="user code")
user_name = Column(String, nullable=True, comment="user name")
name = Column(String, unique=True, nullable=False, comment="plugin name")
file_name = Column(String, nullable=False, comment="plugin package file name")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=True, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")

View File

@ -12,15 +12,15 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
name = Column(String, unique=True, nullable=False, comment="plugin name")
description = Column(String, nullable=False, comment="plugin description")
author = Column(String, nullable=True, comment="plugin author")
email = Column(String, nullable=True, comment="plugin author email")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
storage_channel = Column(String, comment="plugin storage channel")
storage_url = Column(String, comment="plugin download url")
download_param = Column(String, comment="plugin download param")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
email = Column(String(255), nullable=True, comment="plugin author email")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
storage_channel = Column(String(255), comment="plugin storage channel")
storage_url = Column(String(255), comment="plugin download url")
download_param = Column(String(255), comment="plugin download param")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Integer, default=False, comment="plugin already installed count")
@ -146,11 +146,10 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
session = self.get_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = PluginHubEntity(id=plugin_id)
plugin_hubs = session.query(PluginHubEntity)
if query.id is not None:
if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.id == query.id
PluginHubEntity.id == plugin_id
)
plugin_hubs.delete()
session.commit()

View File

@ -17,6 +17,18 @@ class PagenationResult(BaseModel, Generic[T]):
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
data_dicts =[]
for item in self.datas:
data_dicts.append(item.__dict__)
return {
'page_index': self.page_index,
'page_size': self.page_size,
'total_page': self.total_page,
'total_row_count': self.total_row_count,
'datas': data_dicts
}
@dataclass
class PluginHubFilter(BaseModel):
name: str
@ -41,10 +53,9 @@ class MyPluginFilter(BaseModel):
class PluginHubParam(BaseModel):
channel: str = Field(..., description="Plugin storage channel")
url: str = Field(..., description="Plugin storage url")
branch: Optional[str] = Field(None, description="github download branch", nullable=True)
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
branch: Optional[str] = Field("main", description="github download branch", nullable=True)
authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)

View File

@ -2,33 +2,72 @@ import uuid
import os
import duckdb
import sqlite3
import logging
import fnmatch
from datetime import datetime
from typing import Optional, Type, TypeVar
import sqlalchemy as sa
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate,upgrade
from flask.cli import with_appcontext
import subprocess
from sqlalchemy import create_engine,DateTime, String, func, MetaData
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import context, command
from alembic.config import Config
from alembic.config import Config as AlembicConfig
from urllib.parse import quote
from pilot.configs.config import Config
logger = logging.getLogger("meta_data")
CFG = Config()
default_db_path = os.path.join(os.getcwd(), "meta_data")
os.makedirs(default_db_path, exist_ok=True)
db_path = default_db_path + "/dbgpt.db"
# Meta Info
db_name = "dbgpt"
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
engine = create_engine(f'sqlite:///{db_path}')
if CFG.LOCAL_DB_TYPE == 'mysql':
engine_temp = create_engine(f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
# check and auto create mysqldatabase
try:
# try to connect
with engine_temp.connect() as conn:
conn.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
print(f"Already connect '{db_name}'")
except OperationalError as e:
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
engine = create_engine(f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ f"/{db_name}"
)
else:
engine = create_engine(f'sqlite:///{db_path}')
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
@ -40,7 +79,7 @@ Base = declarative_base(bind=engine)
# 创建Alembic配置对象
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = Config(alembic_ini_path)
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
@ -59,36 +98,6 @@ alembic_cfg.attributes['session'] = session
# Base.metadata.drop_all(engine)
# app = Flask(__name__)
# default_db_path = os.path.join(os.getcwd(), "meta_data")
# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# db = SQLAlchemy(app)
# migrate = Migrate(app, db)
#
# # 设置FLASK_APP环境变量
# import os
# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
#
# @app.cli.command("db_init")
# @with_appcontext
# def db_init():
# subprocess.run(["flask", "db", "init"])
#
# @app.cli.command("db_migrate")
# @with_appcontext
# def db_migrate():
# subprocess.run(["flask", "db", "migrate"])
#
# @app.cli.command("db_upgrade")
# @with_appcontext
# def db_upgrade():
# subprocess.run(["flask", "db", "upgrade"])
#
def ddl_init_and_upgrade():
# Base.metadata.create_all(bind=engine)
# 生成并应用迁移脚本
@ -96,14 +105,8 @@ def ddl_init_and_upgrade():
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
with engine.connect() as connection:
alembic_cfg.attributes['connection'] = connection
command.revision(alembic_cfg, "test", True)
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))
command.revision(alembic_cfg, "dbgpt ddl upate", True)
command.upgrade(alembic_cfg, "head")
# alembic_cfg.attributes['connection'] = engine.connect()
# command.upgrade(alembic_cfg, 'head')
# with app.app_context():
# db_init()
# db_migrate()
# db_upgrade()

View File

@ -0,0 +1,2 @@

View File

@ -44,6 +44,7 @@ class LifeCycle:
class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
MODEL_CONTROLLER = "dbgpt_model_controller"
AGENT_HUB = "dbgpt_agent_hub"
class BaseComponet(LifeCycle, ABC):

View File

@ -98,26 +98,6 @@ class Config(metaclass=Singleton):
### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
from auto_gpt_plugin_template import AutoGPTPluginTemplate
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
if plugins_allowlist:
self.plugins_allowlist = plugins_allowlist.split(",")
else:
self.plugins_allowlist = []
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
if plugins_denylist:
self.plugins_denylist = plugins_denylist.split(",")
else:
self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
@ -126,7 +106,10 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
)
### default Local database connection configuration
self.LOCAL_DB_MANAGE = None
###dbgpt meta info database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
@ -138,7 +121,8 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_MANAGE = None
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
@ -197,10 +181,6 @@ class Config(metaclass=Singleton):
"""Set the debug mode value"""
self.debug_mode = value
def set_plugins(self, value: list) -> None:
"""Set the plugins value."""
self.plugins = value
def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value

View File

@ -0,0 +1 @@
from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao

View File

@ -0,0 +1,62 @@
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
class ConnectConfigEntity(Base):
__tablename__ = 'connect_config'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
db_type = Column(String(255), nullable=False, comment="db type")
db_name = Column(String(255), nullable=False, comment="db name")
db_path = Column(String(255), nullable=True, comment="file db path")
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
db_port = Column(String(255), nullable=True, comment="db cnnect port(not file db)")
db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment")
__table_args__ = (
UniqueConstraint('db_name', name="uk_db"),
Index('idx_q_db_type', 'db_type'),
)
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def __init__(self):
super().__init__(
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def update(self, entity: ConnectConfigEntity):
session = self.get_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def delete(self, db_name: int):
session = self.get_session()
if db_name is None:
raise Exception("db_name is None")
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(
ConnectConfigEntity.db_name == db_name
)
db_connect.delete()
session.commit()
session.close()
def get_by_name(self, db_name: str) -> ConnectConfigEntity:
session = self.get_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(
ConnectConfigEntity.db_name == db_name
)
result = db_connect.first()
session.close()
return result

View File

@ -50,7 +50,7 @@ class ConnectManager:
def __init__(self, system_app: SystemApp):
self.storage = DuckdbConnectConfig()
self.db_summary_client = DBSummaryClient(system_app)
self.__load_config_db()
# self.__load_config_db()
def __load_config_db(self):
if CFG.LOCAL_DB_HOST:

View File

@ -0,0 +1 @@
from .chat_history.chat_history_db import ChatHistoryEntity, ChatHistoryDao

View File

@ -2,11 +2,18 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from enum import Enum
from pilot.scene.message import OnceConversation
class MemoryStoreType(Enum):
File= 'file'
Memory = 'memory'
DB = 'db'
DuckDb = 'duckdb'
class BaseChatHistoryMemory(ABC):
store_type: MemoryStoreType
def __init__(self):
self.conversations: List[OnceConversation] = []
@ -22,10 +29,31 @@ class BaseChatHistoryMemory(ABC):
def append(self, message: OnceConversation) -> None:
"""Append the message to the record in the local file"""
@abstractmethod
def clear(self) -> None:
"""Clear session memory from the local file"""
# @abstractmethod
# def clear(self) -> None:
# """Clear session memory from the local file"""
@abstractmethod
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass
@abstractmethod
def update(self, messages: List[OnceConversation]) -> None:
pass
@abstractmethod
def delete(self) -> bool:
pass
@abstractmethod
def conv_info(self, conv_uid: str = None) -> None:
pass
@abstractmethod
def get_messages(self) -> List[OnceConversation]:
pass
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
pass

View File

@ -0,0 +1,28 @@
from .base import MemoryStoreType
from pilot.configs.config import Config
CFG = Config()
class ChatHistory:
def __init__(self):
self.memory_type = MemoryStoreType.DB.value
self.mem_store_class_map = {}
from .store_type.duckdb_history import DuckdbHistoryMemory
from .store_type.file_history import FileHistoryMemory
from .store_type.meta_db_history import DbHistoryMemory
from .store_type.mem_history import MemHistoryMemory
self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
def get_store_instance(self, chat_session_id):
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
def get_store_cls(self):
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)

View File

@ -0,0 +1,88 @@
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
class ChatHistoryEntity(Base):
__tablename__ = 'chat_history'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
summary = Column(String(255), nullable=False, comment="Conversation record summary")
user_name = Column(String(255), nullable=True, comment="interlocutor")
messages = Column(Text, nullable=True, comment="Conversation details")
__table_args__ = (
UniqueConstraint('conv_uid', name="uk_conversation"),
Index('idx_q_user', 'user_name'),
Index('idx_q_mode', 'chat_mode'),
Index('idx_q_conv', 'summary'),
)
class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
def __init__(self):
super().__init__(
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def list_last_20(self, user_name: str = None):
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
chat_history = chat_history.filter(
ChatHistoryEntity.user_name == user_name
)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
result = chat_history.limit(20).all()
session.close()
return result
def update(self, entity: ChatHistoryEntity):
session = self.get_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def update_message_by_uid(self, message: str, conv_uid:str):
session = self.get_session()
try:
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
updated = chat_history.update({ChatHistoryEntity.messages: message})
session.commit()
return updated.id
finally:
session.close()
def delete(self, conv_uid: int):
session = self.get_session()
if conv_uid is None:
raise Exception("conv_uid is None")
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
chat_history.delete()
session.commit()
session.close()
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
result = chat_history.first()
session.close()
return result

View File

@ -10,6 +10,7 @@ from pilot.scene.message import (
_conversation_to_dic,
)
from pilot.common.formatting import MyEncoder
from ..base import MemoryStoreType
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
@ -19,6 +20,8 @@ CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.DuckDb.value
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@ -117,30 +120,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit()
return True
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute(
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
[user_name],
)
else:
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
# 获取查询结果字段名
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []
def conv_info(self, conv_uid: str = None) -> None:
cursor = self.connect.cursor()
cursor.execute(
@ -168,3 +147,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if context[0]:
return json.loads(context[0])
return None
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute(
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
[user_name],
)
else:
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
# 获取查询结果字段名
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []

View File

@ -11,12 +11,14 @@ from pilot.scene.message import (
conversation_from_dict,
conversations_to_dict,
)
from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.File.value
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
date_string = now.strftime("%Y%m%d")
@ -47,3 +49,5 @@ class FileHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))

View File

@ -4,11 +4,14 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.configs.config import Config
from pilot.scene.message import OnceConversation
from pilot.common.custom_data_structure import FixedSizeDict
from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.Memory.value
histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str):

View File

@ -0,0 +1,100 @@
import json
import logging
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import (
OnceConversation,
_conversation_to_dic,
)
from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
logger = logging.getLogger("db_chat_history")
class DbHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.DB.value
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.chat_history_dao = ChatHistoryDao()
def messages(self) -> List[OnceConversation]:
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
if chat_history:
context = chat_history.messages
if context:
conversations: List[OnceConversation] = json.loads(context)
return conversations
return []
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.chat_mode = chat_mode
chat_history.summary = summary
chat_history.user_name = user_name
self.chat_history_dao.update(chat_history)
except Exception as e:
logger.error("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None:
logger.info("db history append:{}", once_message)
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
conversations: List[OnceConversation] = []
if chat_history:
context = chat_history.messages
if context:
conversations = json.loads(context)
else:
chat_history.summary = once_message.get_user_conv().content
else:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = "default"
chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message))
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
self.chat_history_dao.update(chat_history)
def update(self, messages: List[OnceConversation]) -> None:
self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
def delete(self) -> bool:
self.chat_history_dao.delete(self.chat_seesion_id)
def conv_info(self, conv_uid: str = None) -> None:
logger.info("conv_info:{}", conv_uid)
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
return chat_history.__dict__
def get_messages(self) -> List[OnceConversation]:
logger.info("get_messages:{}", self.chat_seesion_id)
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
if chat_history:
context = chat_history.messages
return json.loads(context)
return []
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
chat_history_dao = ChatHistoryDao()
history_list = chat_history_dao.list_last_20()
result = []
for history in history_list:
result.append(history.__dict__)
return result

View File

@ -142,12 +142,12 @@ def initialize_controller(
controller.backend = LocalModelController()
if app:
app.include_router(router, prefix="/api")
app.include_router(router, prefix="/api", tags=['Model'])
else:
import uvicorn
app = FastAPI()
app.include_router(router, prefix="/api")
app.include_router(router, prefix="/api", tags=['Model'])
uvicorn.run(app, host=host, port=port, log_level="info")

View File

@ -8,13 +8,10 @@ from fastapi import (
Request,
File,
UploadFile,
Form,
Body,
BackgroundTasks,
)
from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError
from typing import List
import tempfile
@ -35,11 +32,10 @@ from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
CFG = Config()
@ -61,7 +57,6 @@ def __get_conv_user_message(conversations: dict):
def __new_conversation(chat_mode, user_id) -> ConversationVo:
unique_id = uuid.uuid1()
# history_mem = DuckdbHistoryMemory(str(unique_id))
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
@ -145,7 +140,8 @@ async def db_support_types():
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None):
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
chat_history_service = ChatHistory()
datas = chat_history_service.get_store_cls().conv_list(user_id)
for item in datas:
conv_uid = item.get("conv_uid")
summary = item.get("summary")
@ -257,14 +253,18 @@ async def params_load(
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid)
history_mem.delete()
return Result.succ(None)
def get_hist_messages(conv_uid: str):
message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(conv_uid)
history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:

View File

@ -26,10 +26,10 @@ from pilot.openapi.editor_view_model import (
)
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.data_loader import DbDataLoader
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter()
CFG = Config()
@ -67,7 +67,8 @@ async def get_editor_tables(
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
async def get_editor_sql_rounds(con_uid: str):
logger.info("get_editor_sql_rounds:{con_uid}")
history_mem = DuckdbHistoryMemory(con_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
result: List = []
@ -89,7 +90,8 @@ async def get_editor_sql_rounds(con_uid: str):
@router.get("/v1/editor/sql", response_model=Result[dict])
async def get_editor_sql(con_uid: str, round: int):
logger.info(f"get_editor_sql:{con_uid},{round}")
history_mem = DuckdbHistoryMemory(con_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
@ -138,7 +140,9 @@ async def editor_sql_run(run_param: dict = Body()):
@router.post("/v1/sql/editor/submit")
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(sql_edit_context.con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
@ -171,7 +175,8 @@ async def get_editor_chart_list(con_uid: str):
logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
history_mem = DuckdbHistoryMemory(con_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x["chat_order"])
@ -193,7 +198,8 @@ async def get_editor_chart_info(param: dict = Body()):
conv_uid = param["con_uid"]
chart_title = param["chart_title"]
history_mem = DuckdbHistoryMemory(conv_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x["chat_order"])
@ -269,7 +275,9 @@ async def editor_chart_run(run_param: dict = Body()):
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(chart_edit_context.con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()

View File

@ -7,15 +7,12 @@ from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, get_or_create_event_loop
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
@ -54,9 +51,9 @@ class BaseChat(ABC):
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
chat_history_fac = ChatHistory()
### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(
@ -162,6 +159,7 @@ class BaseChat(ABC):
output, self.skip_echo_len
)
view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
self.current_message.add_view_message(view_msg)

View File

@ -9,6 +9,8 @@ from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.base_modules.agent import PluginPromptGenerator
from pilot.common.string_utils import extract_content
from .prompt import prompt
from pilot.componet import ComponetType
from pilot.base_modules.agent.controller import ModuleAgent
CFG = Config()
@ -27,14 +29,10 @@ class ChatAgent(BaseChat):
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# load select plugin
for plugin in CFG.plugins:
if plugin._name in self.select_plugins:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
# load select plugin
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
self.api_call = ApiCall(self.plugins_prompt_generator)

View File

@ -32,7 +32,7 @@ _DEFAULT_TEMPLATE_ZH = """
根据用户目标请一步步思考如何在满足下面约束条件的前提下优先使用给出工具回答或者完成用户目标
约束条件:
1.从下面给定工具列表找到可用的工具后确保输出结果包含以下内容用来使用工具:
1.从下面给定工具列表找到可用的工具后输出以下内容用来使用工具, 注意要确保下面内容在输出结果中只出现一次:
<api-call><name>Selected Tool name</name><args><arg1>value</arg1><arg2>value</arg2></args></api-call>
2.请根据工具列表对应工具的定义来生成上述调用文本, 参考案例如下:
工具作用介绍: "工具名称", args: "参数1": "<参数1取值描述>","参数2": "<参数2取值描述>" 对应调用文本:<api-call><name>工具名称</name><args><参数1>value</参数1><参数2>value</参数2></args></api-call>

View File

@ -32,7 +32,6 @@ def async_db_summery(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
from pilot.base_modules.agent.plugins_util import scan_plugins
# logger.info(f"args: {args}")
@ -45,7 +44,7 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
# Loader plugins and commands
command_categories = [

View File

@ -24,9 +24,11 @@ def initialize_componets(
embedding_model_path: str,
):
from pilot.model.cluster.controller.controller import controller
system_app.register_instance(controller)
from pilot.base_modules.agent.controller import module_agent
system_app.register_instance(module_agent)
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)

View File

@ -71,11 +71,11 @@ app.add_middleware(
)
app.include_router(api_v1, prefix="/api")
app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(agent_route, prefix="/api")
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(knowledge_router)
app.include_router(knowledge_router, tags=["Knowledge"])
def mount_static_files(app):