mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-26 03:49:10 +00:00
style:format code
format code
This commit is contained in:
parent
0878f3c5d4
commit
682b1468d1
@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
|
|||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, set):
|
if isinstance(obj, set):
|
||||||
return list(obj)
|
return list(obj)
|
||||||
elif hasattr(obj, '__dict__'):
|
elif hasattr(obj, "__dict__"):
|
||||||
return obj.__dict__
|
return obj.__dict__
|
||||||
else:
|
else:
|
||||||
return json.JSONEncoder.default(self, obj)
|
return json.JSONEncoder.default(self, obj)
|
@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
|
|||||||
if not cfg.plugins_auto_load:
|
if not cfg.plugins_auto_load:
|
||||||
print("not auto load_native_plugins")
|
print("not auto load_native_plugins")
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_from_git(cfg: Config):
|
def load_from_git(cfg: Config):
|
||||||
print("async load_native_plugins")
|
print("async load_native_plugins")
|
||||||
branch_name = cfg.plugins_git_branch
|
branch_name = cfg.plugins_git_branch
|
||||||
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
|
|||||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||||
try:
|
try:
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
response = session.get(
|
||||||
headers={'Authorization': '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'})
|
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||||
|
headers={"Authorization": "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"},
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
plugins_path_path = Path(PLUGINS_DIR)
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
files = glob.glob(
|
||||||
|
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||||
|
)
|
||||||
for file in files:
|
for file in files:
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||||
print(file_name)
|
print(file_name)
|
||||||
with open(file_name, "wb") as f:
|
with open(file_name, "wb") as f:
|
||||||
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
|
|||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||||
"""Scan the plugins directory for plugins and loads them.
|
"""Scan the plugins directory for plugins and loads them.
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
|
|||||||
THREE = auto()
|
THREE = auto()
|
||||||
FOUR = auto()
|
FOUR = auto()
|
||||||
|
|
||||||
|
|
||||||
class ExampleType(Enum):
|
class ExampleType(Enum):
|
||||||
ONE_SHOT = "one_shot"
|
ONE_SHOT = "one_shot"
|
||||||
FEW_SHOT = "few_shot"
|
FEW_SHOT = "few_shot"
|
||||||
|
@ -5,8 +5,8 @@ from test_cls_base import TestBase
|
|||||||
|
|
||||||
class Test1(TestBase):
|
class Test1(TestBase):
|
||||||
mode: str = "456"
|
mode: str = "456"
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.test_values.append("x")
|
self.test_values.append("x")
|
||||||
self.test_values.append("y")
|
self.test_values.append("y")
|
||||||
self.test_values.append("g")
|
self.test_values.append("g")
|
||||||
|
|
||||||
|
@ -3,9 +3,11 @@ from pydantic import BaseModel
|
|||||||
from test_cls_base import TestBase
|
from test_cls_base import TestBase
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
|
||||||
class Test2(TestBase):
|
class Test2(TestBase):
|
||||||
test_2_values: List = []
|
test_2_values: List = []
|
||||||
mode: str = "789"
|
mode: str = "789"
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.test_values.append(1)
|
self.test_values.append(1)
|
||||||
self.test_values.append(2)
|
self.test_values.append(2)
|
||||||
|
@ -39,7 +39,9 @@ class KnowledgeEmbedding:
|
|||||||
return self.knowledge_embedding_client.read_batch()
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
return get_knowledge_embedding(
|
||||||
|
self.knowledge_type, self.knowledge_source, self.vector_store_config
|
||||||
|
)
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
|
|||||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
)
|
)
|
||||||
return vector_client.vector_name_exists()
|
return vector_client.vector_name_exists()
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids):
|
||||||
|
vector_client = VectorStoreConnector(
|
||||||
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
|
)
|
||||||
|
vector_client.delete_by_ids(ids=ids)
|
||||||
|
@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear session memory from the local file"""
|
"""Clear session memory from the local file"""
|
||||||
|
|
||||||
|
|
||||||
def conv_list(self, user_name: str = None) -> None:
|
def conv_list(self, user_name: str = None) -> None:
|
||||||
"""get user's conversation list"""
|
"""get user's conversation list"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder
|
|||||||
|
|
||||||
default_db_path = os.path.join(os.getcwd(), "message")
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||||
table_name = 'chat_history'
|
table_name = "chat_history"
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||||
|
|
||||||
def __init__(self, chat_session_id: str):
|
def __init__(self, chat_session_id: str):
|
||||||
self.chat_seesion_id = chat_session_id
|
self.chat_seesion_id = chat_session_id
|
||||||
os.makedirs(default_db_path, exist_ok=True)
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
self.__init_chat_history_tables()
|
self.__init_chat_history_tables()
|
||||||
|
|
||||||
def __init_chat_history_tables(self):
|
def __init_chat_history_tables(self):
|
||||||
|
|
||||||
# 检查表是否存在
|
# 检查表是否存在
|
||||||
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
result = self.connect.execute(
|
||||||
[table_name]).fetchall()
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
# 如果表不存在,则创建新表
|
# 如果表不存在,则创建新表
|
||||||
self.connect.execute(
|
self.connect.execute(
|
||||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
|
||||||
|
)
|
||||||
|
|
||||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
conversations.append(once_message)
|
conversations.append(once_message)
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
if context:
|
if context:
|
||||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
cursor.execute(
|
||||||
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
|
[
|
||||||
|
json.dumps(
|
||||||
|
conversations_to_dict(conversations),
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=4,
|
||||||
|
),
|
||||||
|
self.chat_seesion_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
cursor.execute(
|
||||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||||
|
[
|
||||||
|
self.chat_seesion_id,
|
||||||
|
"",
|
||||||
|
json.dumps(
|
||||||
|
conversations_to_dict(conversations),
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
def delete(self) -> bool:
|
def delete(self) -> bool:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile(duckdb_path):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
if user_name:
|
if user_name:
|
||||||
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
cursor.execute(
|
||||||
|
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||||
# 获取查询结果字段名
|
# 获取查询结果字段名
|
||||||
@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_messages(self) -> List[OnceConversation]:
|
def get_messages(self) -> List[OnceConversation]:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
cursor.execute(
|
||||||
|
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||||
|
)
|
||||||
context = cursor.fetchone()
|
context = cursor.fetchone()
|
||||||
if context:
|
if context:
|
||||||
return json.loads(context[0])
|
return json.loads(context[0])
|
||||||
|
@ -19,7 +19,6 @@ CFG = Config()
|
|||||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||||
histroies_map = FixedSizeDict(100)
|
histroies_map = FixedSizeDict(100)
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, chat_session_id: str):
|
def __init__(self, chat_session_id: str):
|
||||||
self.chat_seesion_id = chat_session_id
|
self.chat_seesion_id = chat_session_id
|
||||||
self.histroies_map.update({chat_session_id: []})
|
self.histroies_map.update({chat_session_id: []})
|
||||||
|
2
pilot/model/cache/base.py
vendored
2
pilot/model/cache/base.py
vendored
@ -3,8 +3,8 @@ import hashlib
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
class Cache(ABC):
|
|
||||||
|
|
||||||
|
class Cache(ABC):
|
||||||
def create(self, key: str) -> bool:
|
def create(self, key: str) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
6
pilot/model/cache/disk_cache.py
vendored
6
pilot/model/cache/disk_cache.py
vendored
@ -3,15 +3,15 @@ import diskcache
|
|||||||
import platformdirs
|
import platformdirs
|
||||||
from pilot.model.cache import Cache
|
from pilot.model.cache import Cache
|
||||||
|
|
||||||
|
|
||||||
class DiskCache(Cache):
|
class DiskCache(Cache):
|
||||||
"""DiskCache is a cache that uses diskcache lib.
|
"""DiskCache is a cache that uses diskcache lib.
|
||||||
https://github.com/grantjenks/python-diskcache
|
https://github.com/grantjenks/python-diskcache
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_name: str):
|
def __init__(self, llm_name: str):
|
||||||
self._diskcache = diskcache.Cache(
|
self._diskcache = diskcache.Cache(
|
||||||
os.path.join(
|
os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
|
||||||
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> str:
|
def __getitem__(self, key: str) -> str:
|
||||||
|
3
pilot/model/cache/gpt_cache.py
vendored
3
pilot/model/cache/gpt_cache.py
vendored
@ -9,6 +9,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GPTCache(Cache):
|
class GPTCache(Cache):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -24,7 +25,7 @@ class GPTCache(Cache):
|
|||||||
data_dir=os.path.join(
|
data_dir=os.path.join(
|
||||||
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
|
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
|
||||||
),
|
),
|
||||||
cache_obj=_cache
|
cache_obj=_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_cache = cache
|
_cache = cache
|
||||||
|
3
pilot/model/cache/memory_cache.py
vendored
3
pilot/model/cache/memory_cache.py
vendored
@ -1,8 +1,8 @@
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from pilot.model.cache import Cache
|
from pilot.model.cache import Cache
|
||||||
|
|
||||||
class InMemoryCache(Cache):
|
|
||||||
|
|
||||||
|
class InMemoryCache(Cache):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"Initialize that stores things in memory."
|
"Initialize that stores things in memory."
|
||||||
self._cache: Dict[str, Any] = {}
|
self._cache: Dict[str, Any] = {}
|
||||||
@ -21,4 +21,3 @@ class InMemoryCache(Cache):
|
|||||||
|
|
||||||
def __contains__(self, key: str) -> bool:
|
def __contains__(self, key: str) -> bool:
|
||||||
return self._cache.get(key, None) is not None
|
return self._cache.get(key, None) is not None
|
||||||
|
|
||||||
|
@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse
|
|||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
from pilot.server.api_v1.api_view_model import (
|
||||||
|
Result,
|
||||||
|
ConversationVo,
|
||||||
|
MessageVo,
|
||||||
|
ChatSceneVo,
|
||||||
|
)
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.configs.model_config import (LOGDIR)
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.scene.base_message import (BaseMessage)
|
from pilot.scene.base_message import BaseMessage
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
|
|
||||||
@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||||||
|
|
||||||
|
|
||||||
def __get_conv_user_message(conversations: dict):
|
def __get_conv_user_message(conversations: dict):
|
||||||
messages = conversations['messages']
|
messages = conversations["messages"]
|
||||||
for item in messages:
|
for item in messages:
|
||||||
if item['type'] == "human":
|
if item["type"] == "human":
|
||||||
return item['data']['content']
|
return item["data"]["content"]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
|
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||||
async def dialogue_list(response: Response, user_id: str = None):
|
async def dialogue_list(response: Response, user_id: str = None):
|
||||||
# 设置CORS头部信息
|
# 设置CORS头部信息
|
||||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
response.headers['Access-Control-Allow-Methods'] = 'GET'
|
response.headers["Access-Control-Allow-Methods"] = "GET"
|
||||||
response.headers['Access-Control-Request-Headers'] = 'content-type'
|
response.headers["Access-Control-Request-Headers"] = "content-type"
|
||||||
|
|
||||||
dialogues: List = []
|
dialogues: List = []
|
||||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||||
@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None):
|
|||||||
conversations = json.loads(messages)
|
conversations = json.loads(messages)
|
||||||
|
|
||||||
first_conv: OnceConversation = conversations[0]
|
first_conv: OnceConversation = conversations[0]
|
||||||
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
|
conv_vo: ConversationVo = ConversationVo(
|
||||||
chat_mode=first_conv['chat_mode'])
|
conv_uid=conv_uid,
|
||||||
|
user_input=__get_conv_user_message(first_conv),
|
||||||
|
chat_mode=first_conv["chat_mode"],
|
||||||
|
)
|
||||||
dialogues.append(conv_vo)
|
dialogues.append(conv_vo)
|
||||||
|
|
||||||
return Result[ConversationVo].succ(dialogues)
|
return Result[ConversationVo].succ(dialogues)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||||
async def dialogue_scenes():
|
async def dialogue_scenes():
|
||||||
scene_vos: List[ChatSceneVo] = []
|
scene_vos: List[ChatSceneVo] = []
|
||||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
new_modes: List[ChatScene] = [
|
||||||
|
ChatScene.ChatDb,
|
||||||
|
ChatScene.ChatData,
|
||||||
|
ChatScene.ChatDashboard,
|
||||||
|
ChatScene.ChatKnowledge,
|
||||||
|
ChatScene.ChatExecution,
|
||||||
|
]
|
||||||
for scene in new_modes:
|
for scene in new_modes:
|
||||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
if not scene.value in [
|
||||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
ChatScene.ChatNormal.value,
|
||||||
|
ChatScene.InnerChatDBSummary.value,
|
||||||
|
]:
|
||||||
|
scene_vo = ChatSceneVo(
|
||||||
|
chat_scene=scene.value,
|
||||||
|
scene_name=scene.name,
|
||||||
|
param_title="Selection Param",
|
||||||
|
)
|
||||||
scene_vos.append(scene_vo)
|
scene_vos.append(scene_vo)
|
||||||
return Result.succ(scene_vos)
|
return Result.succ(scene_vos)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
|
async def dialogue_new(
|
||||||
|
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||||
|
):
|
||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||||
|
|
||||||
@ -108,7 +131,7 @@ def knowledge_list():
|
|||||||
return knowledge_service.get_knowledge_space(request)
|
return knowledge_service.get_knowledge_space(request)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||||
if ChatScene.ChatDb.value == chat_mode:
|
if ChatScene.ChatDb.value == chat_mode:
|
||||||
return Result.succ(get_db_list())
|
return Result.succ(get_db_list())
|
||||||
@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
|||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/dialogue/delete')
|
@router.post("/v1/chat/dialogue/delete")
|
||||||
async def dialogue_delete(con_uid: str):
|
async def dialogue_delete(con_uid: str):
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
history_mem.delete()
|
history_mem.delete()
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
|
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
||||||
async def dialogue_history_messages(con_uid: str):
|
async def dialogue_history_messages(con_uid: str):
|
||||||
print(f"dialogue_history_messages:{con_uid}")
|
print(f"dialogue_history_messages:{con_uid}")
|
||||||
message_vos: List[MessageVo] = []
|
message_vos: List[MessageVo] = []
|
||||||
@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str):
|
|||||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
if history_messages:
|
if history_messages:
|
||||||
for once in history_messages:
|
for once in history_messages:
|
||||||
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
|
once_message_vos = [
|
||||||
|
message2Vo(element, once["chat_order"]) for element in once["messages"]
|
||||||
|
]
|
||||||
message_vos.extend(once_message_vos)
|
message_vos.extend(once_message_vos)
|
||||||
return Result.succ(message_vos)
|
return Result.succ(message_vos)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/completions')
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
|
||||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
raise StopAsyncIteration(
|
||||||
|
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
|
||||||
|
)
|
||||||
|
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": dialogue.conv_uid,
|
"chat_session_id": dialogue.conv_uid,
|
||||||
@ -188,7 +215,9 @@ def stream_generator(chat):
|
|||||||
model_response = chat.stream_call()
|
model_response = chat.stream_call()
|
||||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
|
chunk, chat.skip_echo_len
|
||||||
|
)
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
yield msg
|
yield msg
|
||||||
# chat.current_message.add_ai_message(msg)
|
# chat.current_message.add_ai_message(msg)
|
||||||
@ -206,7 +235,9 @@ def stream_generator(chat):
|
|||||||
|
|
||||||
def message2Vo(message: dict, order) -> MessageVo:
|
def message2Vo(message: dict, order) -> MessageVo:
|
||||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
return MessageVo(
|
||||||
|
role=message["type"], context=message["data"]["content"], order=order
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def non_stream_response(chat):
|
def non_stream_response(chat):
|
||||||
@ -214,38 +245,38 @@ def non_stream_response(chat):
|
|||||||
return chat.nostream_call()
|
return chat.nostream_call()
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/db/types', response_model=Result[str])
|
@router.get("/v1/db/types", response_model=Result[str])
|
||||||
async def db_types():
|
async def db_types():
|
||||||
return Result.succ(["mysql", "duckdb"])
|
return Result.succ(["mysql", "duckdb"])
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/db/list', response_model=Result[str])
|
@router.get("/v1/db/list", response_model=Result[str])
|
||||||
async def db_list():
|
async def db_list():
|
||||||
db = CFG.local_db
|
db = CFG.local_db
|
||||||
dbs = db.get_database_list()
|
dbs = db.get_database_list()
|
||||||
return Result.succ(dbs)
|
return Result.succ(dbs)
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/list')
|
@router.get("/v1/knowledge/list")
|
||||||
async def knowledge_list():
|
async def knowledge_list():
|
||||||
return ["test1", "test2"]
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/knowledge/add')
|
@router.post("/v1/knowledge/add")
|
||||||
async def knowledge_add():
|
async def knowledge_add():
|
||||||
return ["test1", "test2"]
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/knowledge/delete')
|
@router.post("/v1/knowledge/delete")
|
||||||
async def knowledge_delete():
|
async def knowledge_delete():
|
||||||
return ["test1", "test2"]
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/types')
|
@router.get("/v1/knowledge/types")
|
||||||
async def knowledge_types():
|
async def knowledge_types():
|
||||||
return ["test1", "test2"]
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
@router.get('/v1/knowledge/detail')
|
@router.get("/v1/knowledge/detail")
|
||||||
async def knowledge_detail():
|
async def knowledge_detail():
|
||||||
return ["test1", "test2"]
|
return ["test1", "test2"]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Union, List, Generic, Any
|
from typing import TypeVar, Union, List, Generic, Any
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Result(Generic[T], BaseModel):
|
class Result(Generic[T], BaseModel):
|
||||||
@ -28,10 +28,12 @@ class ChatSceneVo(BaseModel):
|
|||||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||||
|
|
||||||
|
|
||||||
class ConversationVo(BaseModel):
|
class ConversationVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
dialogue_uid
|
dialogue_uid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conv_uid: str = Field(..., description="dialogue uid")
|
conv_uid: str = Field(..., description="dialogue uid")
|
||||||
"""
|
"""
|
||||||
user input
|
user input
|
||||||
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
|
|||||||
select_param: str = None
|
select_param: str = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
role that sends out the current message
|
role that sends out the current message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: str
|
role: str
|
||||||
"""
|
"""
|
||||||
current message
|
current message
|
||||||
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
|
|||||||
time the current message was sent
|
time the current message was sent
|
||||||
"""
|
"""
|
||||||
time_stamp: Any = None
|
time_stamp: Any = None
|
||||||
|
|
||||||
|
@ -10,8 +10,10 @@ from pilot.configs.config import Config
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunkEntity(Base):
|
class DocumentChunkEntity(Base):
|
||||||
__tablename__ = 'document_chunk'
|
__tablename__ = "document_chunk"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
document_id = Column(Integer)
|
document_id = Column(Integer)
|
||||||
doc_name = Column(String(100))
|
doc_name = Column(String(100))
|
||||||
@ -29,8 +31,9 @@ class DocumentChunkDao:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database = "knowledge_management"
|
database = "knowledge_management"
|
||||||
self.db_engine = create_engine(
|
self.db_engine = create_engine(
|
||||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||||
echo=True)
|
echo=True,
|
||||||
|
)
|
||||||
self.Session = sessionmaker(bind=self.db_engine)
|
self.Session = sessionmaker(bind=self.db_engine)
|
||||||
|
|
||||||
def create_documents_chunks(self, documents: List):
|
def create_documents_chunks(self, documents: List):
|
||||||
@ -43,9 +46,10 @@ class DocumentChunkDao:
|
|||||||
content=document.content or "",
|
content=document.content or "",
|
||||||
meta_info=document.meta_info or "",
|
meta_info=document.meta_info or "",
|
||||||
gmt_created=datetime.now(),
|
gmt_created=datetime.now(),
|
||||||
gmt_modified=datetime.now()
|
gmt_modified=datetime.now(),
|
||||||
)
|
)
|
||||||
for document in documents]
|
for document in documents
|
||||||
|
]
|
||||||
session.add_all(docs)
|
session.add_all(docs)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
@ -56,16 +60,26 @@ class DocumentChunkDao:
|
|||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||||
if query.document_id is not None:
|
if query.document_id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
|
document_chunks = document_chunks.filter(
|
||||||
|
DocumentChunkEntity.document_id == query.document_id
|
||||||
|
)
|
||||||
if query.doc_type is not None:
|
if query.doc_type is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
|
document_chunks = document_chunks.filter(
|
||||||
|
DocumentChunkEntity.doc_type == query.doc_type
|
||||||
|
)
|
||||||
if query.doc_name is not None:
|
if query.doc_name is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
|
document_chunks = document_chunks.filter(
|
||||||
|
DocumentChunkEntity.doc_name == query.doc_name
|
||||||
|
)
|
||||||
if query.meta_info is not None:
|
if query.meta_info is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
|
document_chunks = document_chunks.filter(
|
||||||
|
DocumentChunkEntity.meta_info == query.meta_info
|
||||||
|
)
|
||||||
|
|
||||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
|
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
|
||||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
|
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
|
||||||
|
page_size
|
||||||
|
)
|
||||||
result = document_chunks.all()
|
result = document_chunks.all()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
|||||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||||
KnowledgeQueryRequest,
|
KnowledgeQueryRequest,
|
||||||
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
KnowledgeQueryResponse,
|
||||||
|
KnowledgeDocumentRequest,
|
||||||
|
DocumentSyncRequest,
|
||||||
|
ChunkQueryRequest,
|
||||||
|
DocumentQueryRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
|||||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||||
print(f"/document/list params: {space_name}, {query_request}")
|
print(f"/document/list params: {space_name}, {query_request}")
|
||||||
try:
|
try:
|
||||||
return Result.succ(knowledge_space_service.get_knowledge_documents(
|
return Result.succ(
|
||||||
space_name,
|
knowledge_space_service.get_knowledge_documents(space_name, query_request)
|
||||||
query_request
|
)
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.faild(code="E000X", msg=f"document list error {e}")
|
return Result.faild(code="E000X", msg=f"document list error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/upload")
|
@router.post("/knowledge/{space_name}/document/upload")
|
||||||
def document_sync(space_name: str, file: UploadFile = File(...)):
|
async def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||||
print(f"/document/upload params: {space_name}")
|
print(f"/document/upload params: {space_name}")
|
||||||
try:
|
try:
|
||||||
with NamedTemporaryFile(delete=False) as tmp:
|
with NamedTemporaryFile(delete=False) as tmp:
|
||||||
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
|||||||
knowledge_space_service.sync_knowledge_document(
|
knowledge_space_service.sync_knowledge_document(
|
||||||
space_name=space_name, doc_ids=request.doc_ids
|
space_name=space_name, doc_ids=request.doc_ids
|
||||||
)
|
)
|
||||||
Result.succ([])
|
return Result.succ([])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||||
|
|
||||||
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
|||||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||||
print(f"/document/list params: {space_name}, {query_request}")
|
print(f"/document/list params: {space_name}, {query_request}")
|
||||||
try:
|
try:
|
||||||
return Result.succ(knowledge_space_service.get_document_chunks(
|
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
|
||||||
query_request
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
|
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ Base = declarative_base()
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentEntity(Base):
|
class KnowledgeDocumentEntity(Base):
|
||||||
__tablename__ = 'knowledge_document'
|
__tablename__ = "knowledge_document"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
doc_name = Column(String(100))
|
doc_name = Column(String(100))
|
||||||
doc_type = Column(String(100))
|
doc_type = Column(String(100))
|
||||||
@ -21,20 +21,22 @@ class KnowledgeDocumentEntity(Base):
|
|||||||
status = Column(String(100))
|
status = Column(String(100))
|
||||||
last_sync = Column(String(100))
|
last_sync = Column(String(100))
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
|
result = Column(Text)
|
||||||
vector_ids = Column(Text)
|
vector_ids = Column(Text)
|
||||||
gmt_created = Column(DateTime)
|
gmt_created = Column(DateTime)
|
||||||
gmt_modified = Column(DateTime)
|
gmt_modified = Column(DateTime)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentDao:
|
class KnowledgeDocumentDao:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
database = "knowledge_management"
|
database = "knowledge_management"
|
||||||
self.db_engine = create_engine(
|
self.db_engine = create_engine(
|
||||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||||
echo=True)
|
echo=True,
|
||||||
|
)
|
||||||
self.Session = sessionmaker(bind=self.db_engine)
|
self.Session = sessionmaker(bind=self.db_engine)
|
||||||
|
|
||||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||||
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
|
|||||||
status=document.status,
|
status=document.status,
|
||||||
last_sync=document.last_sync,
|
last_sync=document.last_sync,
|
||||||
content=document.content or "",
|
content=document.content or "",
|
||||||
|
result=document.result or "",
|
||||||
vector_ids=document.vector_ids,
|
vector_ids=document.vector_ids,
|
||||||
gmt_created=datetime.now(),
|
gmt_created=datetime.now(),
|
||||||
gmt_modified=datetime.now()
|
gmt_modified=datetime.now(),
|
||||||
)
|
)
|
||||||
session.add(knowledge_document)
|
session.add(knowledge_document)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -60,18 +63,32 @@ class KnowledgeDocumentDao:
|
|||||||
session = self.Session()
|
session = self.Session()
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
KnowledgeDocumentEntity.id == query.id
|
||||||
|
)
|
||||||
if query.doc_name is not None:
|
if query.doc_name is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||||
|
)
|
||||||
if query.doc_type is not None:
|
if query.doc_type is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||||
|
)
|
||||||
if query.space is not None:
|
if query.space is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
KnowledgeDocumentEntity.space == query.space
|
||||||
|
)
|
||||||
if query.status is not None:
|
if query.status is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
KnowledgeDocumentEntity.status == query.status
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
|
knowledge_documents = knowledge_documents.order_by(
|
||||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
|
KnowledgeDocumentEntity.id.desc()
|
||||||
|
)
|
||||||
|
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
|
||||||
|
page_size
|
||||||
|
)
|
||||||
result = knowledge_documents.all()
|
result = knowledge_documents.all()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -4,17 +4,24 @@ from datetime import datetime
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
from pilot.openapi.knowledge.document_chunk_dao import (
|
||||||
|
DocumentChunkEntity,
|
||||||
|
DocumentChunkDao,
|
||||||
|
)
|
||||||
from pilot.openapi.knowledge.knowledge_document_dao import (
|
from pilot.openapi.knowledge.knowledge_document_dao import (
|
||||||
KnowledgeDocumentDao,
|
KnowledgeDocumentDao,
|
||||||
KnowledgeDocumentEntity,
|
KnowledgeDocumentEntity,
|
||||||
)
|
)
|
||||||
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
from pilot.openapi.knowledge.knowledge_space_dao import (
|
||||||
|
KnowledgeSpaceDao,
|
||||||
|
KnowledgeSpaceEntity,
|
||||||
|
)
|
||||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||||
KnowledgeSpaceRequest,
|
KnowledgeSpaceRequest,
|
||||||
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
KnowledgeDocumentRequest,
|
||||||
|
DocumentQueryRequest,
|
||||||
|
ChunkQueryRequest,
|
||||||
)
|
)
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -53,10 +60,7 @@ class KnowledgeService:
|
|||||||
"""create knowledge document"""
|
"""create knowledge document"""
|
||||||
|
|
||||||
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
||||||
query = KnowledgeDocumentEntity(
|
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
|
||||||
doc_name=request.doc_name,
|
|
||||||
space=space
|
|
||||||
)
|
|
||||||
documents = knowledge_document_dao.get_knowledge_documents(query)
|
documents = knowledge_document_dao.get_knowledge_documents(query)
|
||||||
if len(documents) > 0:
|
if len(documents) > 0:
|
||||||
raise Exception(f"document name:{request.doc_name} have already named")
|
raise Exception(f"document name:{request.doc_name} have already named")
|
||||||
@ -76,9 +80,7 @@ class KnowledgeService:
|
|||||||
|
|
||||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||||
query = KnowledgeSpaceEntity(
|
query = KnowledgeSpaceEntity(
|
||||||
name=request.name,
|
name=request.name, vector_type=request.vector_type, owner=request.owner
|
||||||
vector_type=request.vector_type,
|
|
||||||
owner=request.owner
|
|
||||||
)
|
)
|
||||||
return knowledge_space_dao.get_knowledge_space(query)
|
return knowledge_space_dao.get_knowledge_space(query)
|
||||||
|
|
||||||
@ -91,9 +93,12 @@ class KnowledgeService:
|
|||||||
space=space,
|
space=space,
|
||||||
status=request.status,
|
status=request.status,
|
||||||
)
|
)
|
||||||
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
|
return knowledge_document_dao.get_knowledge_documents(
|
||||||
|
query, page=request.page, page_size=request.page_size
|
||||||
|
)
|
||||||
|
|
||||||
"""sync knowledge document chunk into vector store"""
|
"""sync knowledge document chunk into vector store"""
|
||||||
|
|
||||||
def sync_knowledge_document(self, space_name, doc_ids):
|
def sync_knowledge_document(self, space_name, doc_ids):
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
query = KnowledgeDocumentEntity(
|
query = KnowledgeDocumentEntity(
|
||||||
@ -101,12 +106,14 @@ class KnowledgeService:
|
|||||||
space=space_name,
|
space=space_name,
|
||||||
)
|
)
|
||||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||||
client = KnowledgeEmbedding(knowledge_source=doc.content,
|
client = KnowledgeEmbedding(
|
||||||
|
knowledge_source=doc.content,
|
||||||
knowledge_type=doc.doc_type.upper(),
|
knowledge_type=doc.doc_type.upper(),
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": space_name,
|
"vector_store_name": space_name,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
chunk_docs = client.read()
|
chunk_docs = client.read()
|
||||||
# update document status
|
# update document status
|
||||||
doc.status = SyncStatus.RUNNING.name
|
doc.status = SyncStatus.RUNNING.name
|
||||||
@ -114,8 +121,11 @@ class KnowledgeService:
|
|||||||
doc.gmt_modified = datetime.now()
|
doc.gmt_modified = datetime.now()
|
||||||
knowledge_document_dao.update_knowledge_document(doc)
|
knowledge_document_dao.update_knowledge_document(doc)
|
||||||
# async doc embeddings
|
# async doc embeddings
|
||||||
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
|
thread = threading.Thread(
|
||||||
|
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||||
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||||
# save chunk details
|
# save chunk details
|
||||||
chunk_entities = [
|
chunk_entities = [
|
||||||
DocumentChunkEntity(
|
DocumentChunkEntity(
|
||||||
@ -125,9 +135,10 @@ class KnowledgeService:
|
|||||||
content=chunk_doc.page_content,
|
content=chunk_doc.page_content,
|
||||||
meta_info=str(chunk_doc.metadata),
|
meta_info=str(chunk_doc.metadata),
|
||||||
gmt_created=datetime.now(),
|
gmt_created=datetime.now(),
|
||||||
gmt_modified=datetime.now()
|
gmt_modified=datetime.now(),
|
||||||
)
|
)
|
||||||
for chunk_doc in chunk_docs]
|
for chunk_doc in chunk_docs
|
||||||
|
]
|
||||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -145,26 +156,30 @@ class KnowledgeService:
|
|||||||
return knowledge_space_dao.delete_knowledge_space(space_id)
|
return knowledge_space_dao.delete_knowledge_space(space_id)
|
||||||
|
|
||||||
"""get document chunks"""
|
"""get document chunks"""
|
||||||
|
|
||||||
def get_document_chunks(self, request: ChunkQueryRequest):
|
def get_document_chunks(self, request: ChunkQueryRequest):
|
||||||
query = DocumentChunkEntity(
|
query = DocumentChunkEntity(
|
||||||
id=request.id,
|
id=request.id,
|
||||||
document_id=request.document_id,
|
document_id=request.document_id,
|
||||||
doc_name=request.doc_name,
|
doc_name=request.doc_name,
|
||||||
doc_type=request.doc_type
|
doc_type=request.doc_type,
|
||||||
|
)
|
||||||
|
return document_chunk_dao.get_document_chunks(
|
||||||
|
query, page=request.page, page_size=request.page_size
|
||||||
)
|
)
|
||||||
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
|
|
||||||
|
|
||||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||||
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
|
logger.info(
|
||||||
|
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||||
doc.status = SyncStatus.FINISHED.name
|
doc.status = SyncStatus.FINISHED.name
|
||||||
doc.content = "embedding success"
|
doc.result = "document embedding success"
|
||||||
doc.vector_ids = ",".join(vector_ids)
|
doc.vector_ids = ",".join(vector_ids)
|
||||||
|
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
doc.status = SyncStatus.FAILED.name
|
doc.status = SyncStatus.FAILED.name
|
||||||
doc.content = str(e)
|
doc.result = "document embedding failed" + str(e)
|
||||||
|
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||||
return knowledge_document_dao.update_knowledge_document(doc)
|
return knowledge_document_dao.update_knowledge_document(doc)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceEntity(Base):
|
class KnowledgeSpaceEntity(Base):
|
||||||
__tablename__ = 'knowledge_space'
|
__tablename__ = "knowledge_space"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
name = Column(String(100))
|
name = Column(String(100))
|
||||||
vector_type = Column(String(100))
|
vector_type = Column(String(100))
|
||||||
@ -27,7 +29,10 @@ class KnowledgeSpaceEntity(Base):
|
|||||||
class KnowledgeSpaceDao:
|
class KnowledgeSpaceDao:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
database = "knowledge_management"
|
database = "knowledge_management"
|
||||||
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
|
self.db_engine = create_engine(
|
||||||
|
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||||
|
echo=True,
|
||||||
|
)
|
||||||
self.Session = sessionmaker(bind=self.db_engine)
|
self.Session = sessionmaker(bind=self.db_engine)
|
||||||
|
|
||||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||||
@ -38,7 +43,7 @@ class KnowledgeSpaceDao:
|
|||||||
desc=space.desc,
|
desc=space.desc,
|
||||||
owner=space.owner,
|
owner=space.owner,
|
||||||
gmt_created=datetime.now(),
|
gmt_created=datetime.now(),
|
||||||
gmt_modified=datetime.now()
|
gmt_modified=datetime.now(),
|
||||||
)
|
)
|
||||||
session.add(knowledge_space)
|
session.add(knowledge_space)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -49,28 +54,46 @@ class KnowledgeSpaceDao:
|
|||||||
session = self.Session()
|
session = self.Session()
|
||||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.id == query.id
|
||||||
|
)
|
||||||
if query.name is not None:
|
if query.name is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.name == query.name
|
||||||
|
)
|
||||||
if query.vector_type is not None:
|
if query.vector_type is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||||
|
)
|
||||||
if query.desc is not None:
|
if query.desc is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.desc == query.desc
|
||||||
|
)
|
||||||
if query.owner is not None:
|
if query.owner is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.owner == query.owner
|
||||||
|
)
|
||||||
if query.gmt_created is not None:
|
if query.gmt_created is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.gmt_created == query.gmt_created
|
||||||
|
)
|
||||||
if query.gmt_modified is not None:
|
if query.gmt_modified is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
|
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
|
knowledge_spaces = knowledge_spaces.order_by(
|
||||||
|
KnowledgeSpaceEntity.gmt_created.desc()
|
||||||
|
)
|
||||||
result = knowledge_spaces.all()
|
result = knowledge_spaces.all()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
||||||
cursor = self.conn.cursor()
|
cursor = self.conn.cursor()
|
||||||
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
||||||
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
|
cursor.execute(
|
||||||
|
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
|
||||||
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
@ -30,12 +30,14 @@ class KnowledgeDocumentRequest(BaseModel):
|
|||||||
"""doc_type: doc type"""
|
"""doc_type: doc type"""
|
||||||
doc_type: str
|
doc_type: str
|
||||||
"""content: content"""
|
"""content: content"""
|
||||||
content: str
|
content: str = None
|
||||||
"""text_chunk_size: text_chunk_size"""
|
"""text_chunk_size: text_chunk_size"""
|
||||||
# text_chunk_size: int
|
# text_chunk_size: int
|
||||||
|
|
||||||
|
|
||||||
class DocumentQueryRequest(BaseModel):
|
class DocumentQueryRequest(BaseModel):
|
||||||
"""doc_name: doc path"""
|
"""doc_name: doc path"""
|
||||||
|
|
||||||
doc_name: str = None
|
doc_name: str = None
|
||||||
"""doc_type: doc type"""
|
"""doc_type: doc type"""
|
||||||
doc_type: str = None
|
doc_type: str = None
|
||||||
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
|
|||||||
|
|
||||||
class DocumentSyncRequest(BaseModel):
|
class DocumentSyncRequest(BaseModel):
|
||||||
"""doc_ids: doc ids"""
|
"""doc_ids: doc ids"""
|
||||||
|
|
||||||
doc_ids: List
|
doc_ids: List
|
||||||
|
|
||||||
|
|
||||||
class ChunkQueryRequest(BaseModel):
|
class ChunkQueryRequest(BaseModel):
|
||||||
"""id: id"""
|
"""id: id"""
|
||||||
|
|
||||||
id: int = None
|
id: int = None
|
||||||
"""document_id: doc id"""
|
"""document_id: doc id"""
|
||||||
document_id: int = None
|
document_id: int = None
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|||||||
|
|
||||||
from pilot.common.schema import ExampleType
|
from pilot.common.schema import ExampleType
|
||||||
|
|
||||||
|
|
||||||
class ExampleSelector(BaseModel, ABC):
|
class ExampleSelector(BaseModel, ABC):
|
||||||
examples: List[List]
|
examples: List[List]
|
||||||
use_example: bool = False
|
use_example: bool = False
|
||||||
|
@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
|
|||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
from pilot.prompts.example_base import ExampleSelector
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
"""Format a template using jinja2."""
|
"""Format a template using jinja2."""
|
||||||
try:
|
try:
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class Scene:
|
class Scene:
|
||||||
def __init__(self, code, describe, is_inner):
|
def __init__(self, code, describe, is_inner):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.describe = describe
|
self.describe = describe
|
||||||
self.is_inner = is_inner
|
self.is_inner = is_inner
|
||||||
|
|
||||||
|
|
||||||
class ChatScene(Enum):
|
class ChatScene(Enum):
|
||||||
ChatWithDbExecute = "chat_with_db_execute"
|
ChatWithDbExecute = "chat_with_db_execute"
|
||||||
ChatWithDbQA = "chat_with_db_qa"
|
ChatWithDbQA = "chat_with_db_qa"
|
||||||
@ -24,4 +26,3 @@ class ChatScene(Enum):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def is_valid_mode(mode):
|
def is_valid_mode(mode):
|
||||||
return any(mode == item.value for item in ChatScene)
|
return any(mode == item.value for item in ChatScene)
|
||||||
|
|
||||||
|
@ -104,7 +104,9 @@ class BaseChat(ABC):
|
|||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
self.current_message.tokens = 0
|
self.current_message.tokens = 0
|
||||||
current_prompt = None
|
current_prompt = None
|
||||||
@ -200,8 +202,11 @@ class BaseChat(ABC):
|
|||||||
# }"""
|
# }"""
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
prompt_define_response = (
|
||||||
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
|
ai_response_text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
|
@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
|
|||||||
generate_htm_table,
|
generate_htm_table,
|
||||||
)
|
)
|
||||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||||
|
ChartData,
|
||||||
|
ReportData,
|
||||||
|
)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
|
|||||||
report_name: str
|
report_name: str
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_session_id, db_name, user_input, report_name):
|
||||||
self, chat_session_id, db_name, user_input, report_name
|
|
||||||
):
|
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatWithDbExecute,
|
chat_mode=ChatScene.ChatWithDbExecute,
|
||||||
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
|
|||||||
# TODO 修复流程
|
# TODO 修复流程
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
|
|
||||||
chart_datas.append(chart_data)
|
chart_datas.append(chart_data)
|
||||||
|
|
||||||
report_data.conv_uid = self.chat_session_id
|
report_data.conv_uid = self.chat_session_id
|
||||||
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
|
|||||||
report_data.charts = chart_datas
|
report_data.charts = chart_datas
|
||||||
|
|
||||||
return report_data
|
return report_data
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,3 @@ class ReportData(BaseModel):
|
|||||||
template_name: str
|
template_name: str
|
||||||
template_introduce: str
|
template_introduce: str
|
||||||
charts: List[ChartData]
|
charts: List[ChartData]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
|||||||
response = json.loads(clean_str)
|
response = json.loads(clean_str)
|
||||||
chart_items = List[ChartItem]
|
chart_items = List[ChartItem]
|
||||||
for item in response:
|
for item in response:
|
||||||
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
|
chart_items.append(
|
||||||
|
ChartItem(
|
||||||
|
item["sql"], item["title"], item["thoughts"], item["showcase"]
|
||||||
|
)
|
||||||
|
)
|
||||||
return chart_items
|
return chart_items
|
||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
|
|||||||
Ensure the response is correct json and can be parsed by Python json.loads
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RESPONSE_FORMAT = [{
|
RESPONSE_FORMAT = [
|
||||||
|
{
|
||||||
"sql": "data analysis SQL",
|
"sql": "data analysis SQL",
|
||||||
"title": "Data Analysis Title",
|
"title": "Data Analysis Title",
|
||||||
"showcase": "What type of charts to show",
|
"showcase": "What type of charts to show",
|
||||||
"thoughts": "Current thinking and value of data analysis"
|
"thoughts": "Current thinking and value of data analysis",
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_session_id, db_name, user_input):
|
||||||
self, chat_session_id, db_name, user_input
|
|
||||||
):
|
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatWithDbExecute,
|
chat_mode=ChatScene.ChatWithDbExecute,
|
||||||
|
@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_session_id, db_name, user_input):
|
||||||
self, chat_session_id, db_name, user_input
|
|
||||||
):
|
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatWithDbQA,
|
chat_mode=ChatScene.ChatWithDbQA,
|
||||||
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
|
|||||||
"table_info": table_info,
|
"table_info": table_info,
|
||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from pilot.prompts.example_base import ExampleSelector
|
|||||||
## Two examples are defined by default
|
## Two examples are defined by default
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
||||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
||||||
]
|
]
|
||||||
|
|
||||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
||||||
|
@ -36,7 +36,6 @@ RESPONSE_FORMAT = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
### Whether the model service is streaming output
|
### Whether the model service is streaming output
|
||||||
@ -49,8 +48,10 @@ prompt = PromptTemplate(
|
|||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||||
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
|
output_parser=PluginChatOutputParser(
|
||||||
example=example
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
example=example,
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, chat_session_id, user_input, knowledge_name):
|
||||||
self, chat_session_id, user_input, knowledge_name
|
|
||||||
):
|
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatNewKnowledge,
|
chat_mode=ChatScene.ChatNewKnowledge,
|
||||||
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
input_values = {"context": context, "question": self.current_user_input}
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatNewKnowledge.value
|
return ChatScene.ChatNewKnowledge.value
|
||||||
|
@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatDefaultKnowledge.value
|
return ChatScene.ChatDefaultKnowledge.value
|
||||||
|
@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.InnerChatDBSummary.value
|
return ChatScene.InnerChatDBSummary.value
|
||||||
|
@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
|
|||||||
input_values = {"context": context, "question": self.current_user_input}
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatUrlKnowledge.value
|
return ChatScene.ChatUrlKnowledge.value
|
||||||
|
@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatKnowledge.value
|
return ChatScene.ChatKnowledge.value
|
||||||
|
@ -85,7 +85,6 @@ class OnceConversation:
|
|||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
|
|
||||||
def get_user_message(self):
|
def get_user_message(self):
|
||||||
for once in self.messages:
|
for once in self.messages:
|
||||||
if isinstance(once, HumanMessage):
|
if isinstance(once, HumanMessage):
|
||||||
|
@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
|||||||
load_dotenv(verbose=True, override=True)
|
load_dotenv(verbose=True, override=True)
|
||||||
|
|
||||||
del load_dotenv
|
del load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelWorker:
|
class ModelWorker:
|
||||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||||
if model_path.endswith("/"):
|
if model_path.endswith("/"):
|
||||||
@ -106,6 +105,7 @@ worker = ModelWorker(
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
from pilot.openapi.knowledge.knowledge_controller import router
|
from pilot.openapi.knowledge.knowledge_controller import router
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
@ -119,9 +119,10 @@ app.add_middleware(
|
|||||||
allow_origins=origins,
|
allow_origins=origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"]
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
temperature: float
|
temperature: float
|
||||||
|
@ -107,12 +107,16 @@ knowledge_qa_type_list = [
|
|||||||
add_knowledge_base_dialogue,
|
add_knowledge_base_dialogue,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def swagger_monkey_patch(*args, **kwargs):
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
return get_swagger_ui_html(
|
return get_swagger_ui_html(
|
||||||
*args, **kwargs,
|
*args,
|
||||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
**kwargs,
|
||||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||||
|
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -360,14 +364,18 @@ def http_bot(
|
|||||||
response = chat.stream_call()
|
response = chat.stream_call()
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
|
chunk, chat.skip_echo_len
|
||||||
|
)
|
||||||
state.messages[-1][-1] = msg
|
state.messages[-1][-1] = msg
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
state.messages[-1][
|
||||||
|
-1
|
||||||
|
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
parser.add_argument(
|
||||||
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-new", "--new", action="store_true", help="enable new http mode"
|
||||||
|
)
|
||||||
|
|
||||||
# old version server config
|
# old version server config
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
@ -702,27 +714,24 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
|
||||||
|
|
||||||
# init server config
|
# init server config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_init(args)
|
server_init(args)
|
||||||
|
|
||||||
if args.new:
|
if args.new:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||||
else:
|
else:
|
||||||
### Compatibility mode starts the old version server by default
|
### Compatibility mode starts the old version server by default
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
demo.queue(
|
demo.queue(
|
||||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
concurrency_count=args.concurrency_count,
|
||||||
|
status_update_rate=10,
|
||||||
|
api_open=False,
|
||||||
).launch(
|
).launch(
|
||||||
server_name=args.host,
|
server_name=args.host,
|
||||||
server_port=args.port,
|
server_port=args.port,
|
||||||
share=args.share,
|
share=args.share,
|
||||||
max_threads=200,
|
max_threads=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,9 @@ class KnowledgeEmbedding:
|
|||||||
return self.knowledge_embedding_client.read_batch()
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
|
return get_knowledge_embedding(
|
||||||
|
self.file_type.upper(), self.file_path, self.vector_store_config
|
||||||
|
)
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
|
Loading…
Reference in New Issue
Block a user