diff --git a/pilot/common/formatting.py b/pilot/common/formatting.py
index c2db4126a..6bf10c1b2 100644
--- a/pilot/common/formatting.py
+++ b/pilot/common/formatting.py
@@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
- elif hasattr(obj, '__dict__'):
+ elif hasattr(obj, "__dict__"):
return obj.__dict__
else:
- return json.JSONEncoder.default(self, obj)
\ No newline at end of file
+ return json.JSONEncoder.default(self, obj)
diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py
index 832144d22..e22224399 100644
--- a/pilot/common/plugins.py
+++ b/pilot/common/plugins.py
@@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
+
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
@@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
- response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
- headers={'Authorization': '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'})
+ response = session.get(
+ url.format(repo=native_plugin_repo, branch=branch_name),
+ headers={"Authorization": "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"},
+ )
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
- files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
+ files = glob.glob(
+ os.path.join(plugins_path_path, f"{native_plugin_repo}*")
+ )
for file in files:
os.remove(file)
now = datetime.datetime.now()
- time_str = now.strftime('%Y%m%d%H%M%S')
+ time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
@@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
t.start()
-
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
diff --git a/pilot/common/schema.py b/pilot/common/schema.py
index dc7a6c7cc..ab882e30b 100644
--- a/pilot/common/schema.py
+++ b/pilot/common/schema.py
@@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
THREE = auto()
FOUR = auto()
+
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"
diff --git a/pilot/configs/config.py b/pilot/configs/config.py
index 94ff19e21..7e259f9fc 100644
--- a/pilot/configs/config.py
+++ b/pilot/configs/config.py
@@ -90,7 +90,7 @@ class Config(metaclass=Singleton):
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
- self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
+ self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py
index 5ad5be08f..31b060ef1 100644
--- a/pilot/connections/rdbms/py_study/pd_study.py
+++ b/pilot/connections/rdbms/py_study/pd_study.py
@@ -6,7 +6,7 @@ import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
-from test_cls_1 import TestBase,Test1
+from test_cls_1 import TestBase, Test1
from test_cls_2 import Test2
CFG = Config()
@@ -60,21 +60,21 @@ CFG = Config()
# if __name__ == "__main__":
- # def __extract_json(s):
- # i = s.index("{")
- # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
- # for j, c in enumerate(s[i + 1 :], start=i + 1):
- # if c == "}":
- # count -= 1
- # elif c == "{":
- # count += 1
- # if count == 0:
- # break
- # assert count == 0 # 检查是否找到最后一个'}'
- # return s[i : j + 1]
- #
- # ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
- # print(__extract_json(ss))
+# def __extract_json(s):
+# i = s.index("{")
+# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
+# for j, c in enumerate(s[i + 1 :], start=i + 1):
+# if c == "}":
+# count -= 1
+# elif c == "{":
+# count += 1
+# if count == 0:
+# break
+# assert count == 0 # 检查是否找到最后一个'}'
+# return s[i : j + 1]
+#
+# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
+# print(__extract_json(ss))
if __name__ == "__main__":
test1 = Test1()
@@ -83,4 +83,4 @@ if __name__ == "__main__":
test1.test()
test2.write()
test1.test()
- test2.test()
\ No newline at end of file
+ test2.test()
diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py
index c7d26a674..1b91b5601 100644
--- a/pilot/connections/rdbms/py_study/test_cls_1.py
+++ b/pilot/connections/rdbms/py_study/test_cls_1.py
@@ -4,9 +4,9 @@ from test_cls_base import TestBase
class Test1(TestBase):
- mode:str = "456"
+ mode: str = "456"
+
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")
-
diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py
index e911f0542..1fb4d5e88 100644
--- a/pilot/connections/rdbms/py_study/test_cls_2.py
+++ b/pilot/connections/rdbms/py_study/test_cls_2.py
@@ -3,13 +3,15 @@ from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
+
class Test2(TestBase):
- test_2_values:List = []
- mode:str = "789"
+ test_2_values: List = []
+ mode: str = "789"
+
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
- self.test_2_values.append("z")
\ No newline at end of file
+ self.test_2_values.append("z")
diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py
index b8377c73d..676c1f2a5 100644
--- a/pilot/connections/rdbms/py_study/test_cls_base.py
+++ b/pilot/connections/rdbms/py_study/test_cls_base.py
@@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
- mode:str = "123"
+ mode: str = "123"
def test(self):
- print(self.__class__.__name__ + ":" )
+ print(self.__class__.__name__ + ":")
print(self.test_values)
- print(self.mode)
\ No newline at end of file
+ print(self.mode)
diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/knowledge_embedding.py
index 5a8a2f944..fd28b938b 100644
--- a/pilot/embedding_engine/knowledge_embedding.py
+++ b/pilot/embedding_engine/knowledge_embedding.py
@@ -39,7 +39,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
- return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
+ return get_knowledge_embedding(
+ self.knowledge_type, self.knowledge_source, self.vector_store_config
+ )
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(
@@ -56,3 +58,9 @@ class KnowledgeEmbedding:
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
return vector_client.vector_name_exists()
+
+ def delete_by_ids(self, ids):
+ vector_client = VectorStoreConnector(
+ CFG.VECTOR_STORE_TYPE, self.vector_store_config
+ )
+ vector_client.delete_by_ids(ids=ids)
diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py
index 649afa4ab..830968504 100644
--- a/pilot/memory/chat_history/base.py
+++ b/pilot/memory/chat_history/base.py
@@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
def clear(self) -> None:
"""Clear session memory from the local file"""
-
- def conv_list(self, user_name:str=None) -> None:
+ def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass
-
diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py
index b24546d19..d97232217 100644
--- a/pilot/memory/chat_history/duckdb_history.py
+++ b/pilot/memory/chat_history/duckdb_history.py
@@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
-table_name = 'chat_history'
+table_name = "chat_history"
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
-
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
-
# 检查表是否存在
- result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
- [table_name]).fetchall()
+ result = self.connect.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
+ ).fetchall()
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
- "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
+ "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
+ )
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
@@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(once_message)
cursor = self.connect.cursor()
if context:
- cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
- [json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
+ cursor.execute(
+ "UPDATE chat_history set messages=? where conv_uid=?",
+ [
+ json.dumps(
+ conversations_to_dict(conversations),
+ ensure_ascii=False,
+ indent=4,
+ ),
+ self.chat_seesion_id,
+ ],
+ )
else:
- cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
- [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
+ cursor.execute(
+ "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
+ [
+ self.chat_seesion_id,
+ "",
+ json.dumps(
+ conversations_to_dict(conversations),
+ ensure_ascii=False,
+ indent=4,
+ ),
+ ],
+ )
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
- cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
+ cursor.execute(
+ "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
+ )
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
- cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
+ cursor.execute(
+ "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
+ )
cursor.commit()
return True
@@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
- cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
+ cursor.execute(
+ "SELECT * FROM chat_history where user_name=? limit 20", [user_name]
+ )
else:
cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名
@@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return []
-
- def get_messages(self)-> List[OnceConversation]:
+ def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
- cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
+ cursor.execute(
+ "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
+ )
context = cursor.fetchone()
if context:
return json.loads(context[0])
diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/mem_history.py
index a46428e75..c53595d3d 100644
--- a/pilot/memory/chat_history/mem_history.py
+++ b/pilot/memory/chat_history/mem_history.py
@@ -11,7 +11,7 @@ from pilot.scene.message import (
conversation_from_dict,
conversations_to_dict,
)
-from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
+from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
CFG = Config()
@@ -19,7 +19,6 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = FixedSizeDict(100)
-
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []})
diff --git a/pilot/model/cache/__init__.py b/pilot/model/cache/__init__.py
index 6a2d6fc7e..37dd0025b 100644
--- a/pilot/model/cache/__init__.py
+++ b/pilot/model/cache/__init__.py
@@ -1,4 +1,4 @@
from .base import Cache
from .disk_cache import DiskCache
from .memory_cache import InMemoryCache
-from .gpt_cache import GPTCache
\ No newline at end of file
+from .gpt_cache import GPTCache
diff --git a/pilot/model/cache/base.py b/pilot/model/cache/base.py
index d8c3d5851..d0b814f84 100644
--- a/pilot/model/cache/base.py
+++ b/pilot/model/cache/base.py
@@ -3,8 +3,8 @@ import hashlib
from typing import Any, Dict
from abc import ABC, abstractmethod
+
class Cache(ABC):
-
def create(self, key: str) -> bool:
pass
@@ -24,4 +24,4 @@ class Cache(ABC):
@abstractmethod
def __contains__(self, key: str) -> bool:
"""see if we can return a cached value for the passed key"""
- pass
\ No newline at end of file
+ pass
diff --git a/pilot/model/cache/disk_cache.py b/pilot/model/cache/disk_cache.py
index c461a37ae..25895dc9f 100644
--- a/pilot/model/cache/disk_cache.py
+++ b/pilot/model/cache/disk_cache.py
@@ -3,15 +3,15 @@ import diskcache
import platformdirs
from pilot.model.cache import Cache
+
class DiskCache(Cache):
"""DiskCache is a cache that uses diskcache lib.
- https://github.com/grantjenks/python-diskcache
+ https://github.com/grantjenks/python-diskcache
"""
+
def __init__(self, llm_name: str):
self._diskcache = diskcache.Cache(
- os.path.join(
- platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
- )
+ os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
)
def __getitem__(self, key: str) -> str:
@@ -22,6 +22,6 @@ class DiskCache(Cache):
def __contains__(self, key: str) -> bool:
return key in self._diskcache
-
+
def clear(self):
- self._diskcache.clear()
\ No newline at end of file
+ self._diskcache.clear()
diff --git a/pilot/model/cache/gpt_cache.py b/pilot/model/cache/gpt_cache.py
index 0fc680510..ee4529a5c 100644
--- a/pilot/model/cache/gpt_cache.py
+++ b/pilot/model/cache/gpt_cache.py
@@ -9,22 +9,23 @@ try:
except ImportError:
pass
+
class GPTCache(Cache):
- """
- GPTCache is a semantic cache that uses
+ """
+ GPTCache is a semantic cache that uses
"""
def __init__(self, cache) -> None:
"""GPT Cache is a semantic cache that uses GPTCache lib."""
-
+
if isinstance(cache, str):
_cache = Cache()
init_similar_cache(
data_dir=os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
),
- cache_obj=_cache
+ cache_obj=_cache,
)
else:
_cache = cache
@@ -41,4 +42,4 @@ class GPTCache(Cache):
return get(key) is not None
def create(self, llm: str, **kwargs: Dict[str, Any]) -> str:
- pass
\ No newline at end of file
+ pass
diff --git a/pilot/model/cache/memory_cache.py b/pilot/model/cache/memory_cache.py
index b5311a341..7960b3d2e 100644
--- a/pilot/model/cache/memory_cache.py
+++ b/pilot/model/cache/memory_cache.py
@@ -1,24 +1,23 @@
from typing import Dict, Any
from pilot.model.cache import Cache
+
class InMemoryCache(Cache):
-
def __init__(self) -> None:
"Initialize that stores things in memory."
self._cache: Dict[str, Any] = {}
def create(self, key: str) -> bool:
- pass
+ pass
def clear(self):
return self._cache.clear()
def __setitem__(self, key: str, value: str) -> None:
self._cache[key] = value
-
+
def __getitem__(self, key: str) -> str:
return self._cache[key]
-
- def __contains__(self, key: str) -> bool:
- return self._cache.get(key, None) is not None
+ def __contains__(self, key: str) -> bool:
+ return self._cache.get(key, None) is not None
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index 376e17ae4..9a3d3d88e 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
-from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
+from pilot.server.api_v1.api_view_model import (
+ Result,
+ ConversationVo,
+ MessageVo,
+ ChatSceneVo,
+)
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
-from pilot.configs.model_config import (LOGDIR)
+from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
-from pilot.scene.base_message import (BaseMessage)
+from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
@@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
def __get_conv_user_message(conversations: dict):
- messages = conversations['messages']
+ messages = conversations["messages"]
for item in messages:
- if item['type'] == "human":
- return item['data']['content']
+ if item["type"] == "human":
+ return item["data"]["content"]
return ""
-@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
+@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
- response.headers['Access-Control-Allow-Origin'] = '*'
- response.headers['Access-Control-Allow-Methods'] = 'GET'
- response.headers['Access-Control-Request-Headers'] = 'content-type'
+ response.headers["Access-Control-Allow-Origin"] = "*"
+ response.headers["Access-Control-Allow-Methods"] = "GET"
+ response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
@@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None):
conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0]
- conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
- chat_mode=first_conv['chat_mode'])
+ conv_vo: ConversationVo = ConversationVo(
+ conv_uid=conv_uid,
+ user_input=__get_conv_user_message(first_conv),
+ chat_mode=first_conv["chat_mode"],
+ )
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues)
-@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
+@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
- new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
+ new_modes: List[ChatScene] = [
+ ChatScene.ChatDb,
+ ChatScene.ChatData,
+ ChatScene.ChatDashboard,
+ ChatScene.ChatKnowledge,
+ ChatScene.ChatExecution,
+ ]
for scene in new_modes:
- if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
- scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
+ if not scene.value in [
+ ChatScene.ChatNormal.value,
+ ChatScene.InnerChatDBSummary.value,
+ ]:
+ scene_vo = ChatSceneVo(
+ chat_scene=scene.value,
+ scene_name=scene.name,
+ param_title="Selection Param",
+ )
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
-@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
-async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
+@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
+async def dialogue_new(
+ chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
+):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@@ -90,7 +113,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
- params:dict = {}
+ params: dict = {}
for name in dbs:
params.update({name: name})
return params
@@ -108,7 +131,7 @@ def knowledge_list():
return knowledge_service.get_knowledge_space(request)
-@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
+@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode:
return Result.succ(get_db_list())
@@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
return Result.succ(None)
-@router.post('/v1/chat/dialogue/delete')
+@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete()
return Result.succ(None)
-@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
+@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = []
@@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
- once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
+ once_message_vos = [
+ message2Vo(element, once["chat_order"]) for element in once["messages"]
+ ]
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
-@router.post('/v1/chat/completions')
+@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not ChatScene.is_valid_mode(dialogue.chat_mode):
- raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
+ raise StopAsyncIteration(
+ Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
+ )
chat_param = {
"chat_session_id": dialogue.conv_uid,
@@ -188,7 +215,9 @@ def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
- msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
+ msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
+ chunk, chat.skip_echo_len
+ )
chat.current_message.add_ai_message(msg)
yield msg
# chat.current_message.add_ai_message(msg)
@@ -206,7 +235,9 @@ def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
- return MessageVo(role=message['type'], context=message['data']['content'], order=order)
+ return MessageVo(
+ role=message["type"], context=message["data"]["content"], order=order
+ )
def non_stream_response(chat):
@@ -214,38 +245,38 @@ def non_stream_response(chat):
return chat.nostream_call()
-@router.get('/v1/db/types', response_model=Result[str])
+@router.get("/v1/db/types", response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
-@router.get('/v1/db/list', response_model=Result[str])
+@router.get("/v1/db/list", response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
-@router.get('/v1/knowledge/list')
+@router.get("/v1/knowledge/list")
async def knowledge_list():
return ["test1", "test2"]
-@router.post('/v1/knowledge/add')
+@router.post("/v1/knowledge/add")
async def knowledge_add():
return ["test1", "test2"]
-@router.post('/v1/knowledge/delete')
+@router.post("/v1/knowledge/delete")
async def knowledge_delete():
return ["test1", "test2"]
-@router.get('/v1/knowledge/types')
+@router.get("/v1/knowledge/types")
async def knowledge_types():
return ["test1", "test2"]
-@router.get('/v1/knowledge/detail')
+@router.get("/v1/knowledge/detail")
async def knowledge_detail():
return ["test1", "test2"]
diff --git a/pilot/openapi/api_v1/api_view_model.py b/pilot/openapi/api_v1/api_view_model.py
index f2d58599b..02eb70ddd 100644
--- a/pilot/openapi/api_v1/api_view_model.py
+++ b/pilot/openapi/api_v1/api_view_model.py
@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
-T = TypeVar('T')
+T = TypeVar("T")
class Result(Generic[T], BaseModel):
@@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel):
class ChatSceneVo(BaseModel):
- chat_scene: str = Field(..., description="chat_scene")
- scene_name: str = Field(..., description="chat_scene name show for user")
- param_title: str = Field(..., description="chat_scene required parameter title")
+ chat_scene: str = Field(..., description="chat_scene")
+ scene_name: str = Field(..., description="chat_scene name show for user")
+ param_title: str = Field(..., description="chat_scene required parameter title")
+
class ConversationVo(BaseModel):
"""
dialogue_uid
"""
- conv_uid: str = Field(..., description="dialogue uid")
+
+ conv_uid: str = Field(..., description="dialogue uid")
"""
user input
"""
@@ -44,7 +46,7 @@ class ConversationVo(BaseModel):
"""
the scene of chat
"""
- chat_mode: str = Field(..., description="the scene of chat ")
+ chat_mode: str = Field(..., description="the scene of chat ")
"""
chat scene select param
@@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
select_param: str = None
-
class MessageVo(BaseModel):
"""
- role that sends out the current message
+ role that sends out the current message
"""
+
role: str
"""
current message
@@ -70,4 +72,3 @@ class MessageVo(BaseModel):
time the current message was sent
"""
time_stamp: Any = None
-
diff --git a/pilot/openapi/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py
index e9a994e66..cb728e85c 100644
--- a/pilot/openapi/knowledge/document_chunk_dao.py
+++ b/pilot/openapi/knowledge/document_chunk_dao.py
@@ -10,8 +10,10 @@ from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
+
+
class DocumentChunkEntity(Base):
- __tablename__ = 'document_chunk'
+ __tablename__ = "document_chunk"
id = Column(Integer, primary_key=True)
document_id = Column(Integer)
doc_name = Column(String(100))
@@ -29,43 +31,55 @@ class DocumentChunkDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
- f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
- echo=True)
+ f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
+ echo=True,
+ )
self.Session = sessionmaker(bind=self.db_engine)
- def create_documents_chunks(self, documents:List):
+ def create_documents_chunks(self, documents: List):
session = self.Session()
docs = [
DocumentChunkEntity(
- doc_name=document.doc_name,
- doc_type=document.doc_type,
- document_id=document.document_id,
- content=document.content or "",
- meta_info=document.meta_info or "",
- gmt_created=datetime.now(),
- gmt_modified=datetime.now()
+ doc_name=document.doc_name,
+ doc_type=document.doc_type,
+ document_id=document.document_id,
+ content=document.content or "",
+ meta_info=document.meta_info or "",
+ gmt_created=datetime.now(),
+ gmt_modified=datetime.now(),
)
- for document in documents]
+ for document in documents
+ ]
session.add_all(docs)
session.commit()
session.close()
- def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
+ def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
- document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
+ document_chunks = document_chunks.filter(
+ DocumentChunkEntity.document_id == query.document_id
+ )
if query.doc_type is not None:
- document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
+ document_chunks = document_chunks.filter(
+ DocumentChunkEntity.doc_type == query.doc_type
+ )
if query.doc_name is not None:
- document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
+ document_chunks = document_chunks.filter(
+ DocumentChunkEntity.doc_name == query.doc_name
+ )
if query.meta_info is not None:
- document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
+ document_chunks = document_chunks.filter(
+ DocumentChunkEntity.meta_info == query.meta_info
+ )
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
- document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
+ document_chunks = document_chunks.offset((page - 1) * page_size).limit(
+ page_size
+ )
result = document_chunks.all()
return result
diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py
index cf368136f..bebbc8a3f 100644
--- a/pilot/openapi/knowledge/knowledge_controller.py
+++ b/pilot/openapi/knowledge/knowledge_controller.py
@@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest,
- KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
+ KnowledgeQueryResponse,
+ KnowledgeDocumentRequest,
+ DocumentSyncRequest,
+ ChunkQueryRequest,
+ DocumentQueryRequest,
)
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
- return Result.succ(knowledge_space_service.get_knowledge_documents(
- space_name,
- query_request
- ))
+ return Result.succ(
+ knowledge_space_service.get_knowledge_documents(space_name, query_request)
+ )
except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload")
-def document_sync(space_name: str, file: UploadFile = File(...)):
+async def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
@@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids
)
- Result.succ([])
+ return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
- return Result.succ(knowledge_space_service.get_document_chunks(
- query_request
- ))
+ return Result.succ(knowledge_space_service.get_document_chunks(query_request))
except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py
index 276907b69..1f7afc401 100644
--- a/pilot/openapi/knowledge/knowledge_document_dao.py
+++ b/pilot/openapi/knowledge/knowledge_document_dao.py
@@ -12,7 +12,7 @@ Base = declarative_base()
class KnowledgeDocumentEntity(Base):
- __tablename__ = 'knowledge_document'
+ __tablename__ = "knowledge_document"
id = Column(Integer, primary_key=True)
doc_name = Column(String(100))
doc_type = Column(String(100))
@@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
status = Column(String(100))
last_sync = Column(String(100))
content = Column(Text)
+ result = Column(Text)
vector_ids = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
- return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
+ return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
- f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
- echo=True)
+ f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
+ echo=True,
+ )
self.Session = sessionmaker(bind=self.db_engine)
- def create_knowledge_document(self, document:KnowledgeDocumentEntity):
+ def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
@@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
status=document.status,
last_sync=document.last_sync,
content=document.content or "",
+ result=document.result or "",
vector_ids=document.vector_ids,
gmt_created=datetime.now(),
- gmt_modified=datetime.now()
+ gmt_modified=datetime.now(),
)
session.add(knowledge_document)
session.commit()
@@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
- knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
+ knowledge_documents = knowledge_documents.filter(
+ KnowledgeDocumentEntity.id == query.id
+ )
if query.doc_name is not None:
- knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
+ knowledge_documents = knowledge_documents.filter(
+ KnowledgeDocumentEntity.doc_name == query.doc_name
+ )
if query.doc_type is not None:
- knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
+ knowledge_documents = knowledge_documents.filter(
+ KnowledgeDocumentEntity.doc_type == query.doc_type
+ )
if query.space is not None:
- knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
+ knowledge_documents = knowledge_documents.filter(
+ KnowledgeDocumentEntity.space == query.space
+ )
if query.status is not None:
- knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
+ knowledge_documents = knowledge_documents.filter(
+ KnowledgeDocumentEntity.status == query.status
+ )
- knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
- knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
+ knowledge_documents = knowledge_documents.order_by(
+ KnowledgeDocumentEntity.id.desc()
+ )
+ knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
+ page_size
+ )
result = knowledge_documents.all()
return result
- def update_knowledge_document(self, document:KnowledgeDocumentEntity):
+ def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
- def delete_knowledge_document(self, document_id:int):
+ def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,))
diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py
index b6713dff4..c41630ede 100644
--- a/pilot/openapi/knowledge/knowledge_service.py
+++ b/pilot/openapi/knowledge/knowledge_service.py
@@ -4,17 +4,24 @@ from datetime import datetime
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
-from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger
-from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
+from pilot.openapi.knowledge.document_chunk_dao import (
+ DocumentChunkEntity,
+ DocumentChunkDao,
+)
from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
-from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
+from pilot.openapi.knowledge.knowledge_space_dao import (
+ KnowledgeSpaceDao,
+ KnowledgeSpaceEntity,
+)
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest,
- KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
+ KnowledgeDocumentRequest,
+ DocumentQueryRequest,
+ ChunkQueryRequest,
)
from enum import Enum
@@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao()
-CFG=Config()
+CFG = Config()
class SyncStatus(Enum):
@@ -53,10 +60,7 @@ class KnowledgeService:
"""create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
- query = KnowledgeDocumentEntity(
- doc_name=request.doc_name,
- space=space
- )
+ query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named")
@@ -74,26 +78,27 @@ class KnowledgeService:
"""get knowledge space"""
- def get_knowledge_space(self, request:KnowledgeSpaceRequest):
+ def get_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity(
- name=request.name,
- vector_type=request.vector_type,
- owner=request.owner
+ name=request.name, vector_type=request.vector_type, owner=request.owner
)
return knowledge_space_dao.get_knowledge_space(query)
"""get knowledge get_knowledge_documents"""
- def get_knowledge_documents(self, space, request:DocumentQueryRequest):
+ def get_knowledge_documents(self, space, request: DocumentQueryRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
- return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
+ return knowledge_document_dao.get_knowledge_documents(
+ query, page=request.page, page_size=request.page_size
+ )
"""sync knowledge document chunk into vector store"""
+
def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
@@ -101,12 +106,14 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
- client = KnowledgeEmbedding(knowledge_source=doc.content,
- knowledge_type=doc.doc_type.upper(),
- model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
- vector_store_config={
- "vector_store_name": space_name,
- })
+ client = KnowledgeEmbedding(
+ knowledge_source=doc.content,
+ knowledge_type=doc.doc_type.upper(),
+ model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
+ vector_store_config={
+ "vector_store_name": space_name,
+ },
+ )
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
@@ -114,9 +121,12 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
- thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
+ thread = threading.Thread(
+ target=self.async_doc_embedding, args=(client, chunk_docs, doc)
+ )
thread.start()
- #save chunk details
+ logger.info(f"begin save document chunks, doc:{doc.doc_name}")
+ # save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
@@ -125,9 +135,10 @@ class KnowledgeService:
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
- gmt_modified=datetime.now()
+ gmt_modified=datetime.now(),
)
- for chunk_doc in chunk_docs]
+ for chunk_doc in chunk_docs
+ ]
document_chunk_dao.create_documents_chunks(chunk_entities)
return True
@@ -145,26 +156,30 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks"""
- def get_document_chunks(self, request:ChunkQueryRequest):
+
+ def get_document_chunks(self, request: ChunkQueryRequest):
query = DocumentChunkEntity(
id=request.id,
document_id=request.document_id,
doc_name=request.doc_name,
- doc_type=request.doc_type
+ doc_type=request.doc_type,
+ )
+ return document_chunk_dao.get_document_chunks(
+ query, page=request.page, page_size=request.page_size
)
- return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc):
- logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
+ logger.info(
+ f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
+ )
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
- doc.content = "embedding success"
+ doc.result = "document embedding success"
doc.vector_ids = ",".join(vector_ids)
+ logger.info(f"async document embedding, success:{doc.doc_name}")
except Exception as e:
doc.status = SyncStatus.FAILED.name
- doc.content = str(e)
-
+ doc.result = "document embedding failed" + str(e)
+ logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
-
-
diff --git a/pilot/openapi/knowledge/knowledge_space_dao.py b/pilot/openapi/knowledge/knowledge_space_dao.py
index 31894d6ac..16def1e99 100644
--- a/pilot/openapi/knowledge/knowledge_space_dao.py
+++ b/pilot/openapi/knowledge/knowledge_space_dao.py
@@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
CFG = Config()
Base = declarative_base()
+
+
class KnowledgeSpaceEntity(Base):
- __tablename__ = 'knowledge_space'
+ __tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))
@@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao:
def __init__(self):
database = "knowledge_management"
- self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
+ self.db_engine = create_engine(
+ f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
+ echo=True,
+ )
self.Session = sessionmaker(bind=self.db_engine)
- def create_knowledge_space(self, space:KnowledgeSpaceRequest):
+ def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
@@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
- gmt_modified=datetime.now()
+ gmt_modified=datetime.now(),
)
session.add(knowledge_space)
session.commit()
session.close()
- def get_knowledge_space(self, query:KnowledgeSpaceEntity):
+ def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.id == query.id
+ )
if query.name is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.name == query.name
+ )
if query.vector_type is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.vector_type == query.vector_type
+ )
if query.desc is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.desc == query.desc
+ )
if query.owner is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.owner == query.owner
+ )
if query.gmt_created is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.gmt_created == query.gmt_created
+ )
if query.gmt_modified is not None:
- knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
+ knowledge_spaces = knowledge_spaces.filter(
+ KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
+ )
- knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
+ knowledge_spaces = knowledge_spaces.order_by(
+ KnowledgeSpaceEntity.gmt_created.desc()
+ )
result = knowledge_spaces.all()
return result
- def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
+ def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
- cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
+ cursor.execute(
+ query, (space.name, space.vector_type, space.desc, space.owner, space_id)
+ )
self.conn.commit()
cursor.close()
- def delete_knowledge_space(self, space_id:int):
+ def delete_knowledge_space(self, space_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_space WHERE id = %s"
cursor.execute(query, (space_id,))
diff --git a/pilot/openapi/knowledge/request/knowledge_request.py b/pilot/openapi/knowledge/request/knowledge_request.py
index bf68b06f1..1c5916f7c 100644
--- a/pilot/openapi/knowledge/request/knowledge_request.py
+++ b/pilot/openapi/knowledge/request/knowledge_request.py
@@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
"""doc_type: doc type"""
doc_type: str
"""content: content"""
- content: str
+ content: str = None
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int
+
class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
+
doc_name: str = None
"""doc_type: doc type"""
- doc_type: str= None
+ doc_type: str = None
"""status: status"""
- status: str= None
+ status: str = None
"""page: page"""
page: int = 1
"""page_size: page size"""
@@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
+
doc_ids: List
+
class ChunkQueryRequest(BaseModel):
"""id: id"""
+
id: int = None
"""document_id: doc id"""
document_id: int = None
diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py
index 4d876aa51..e930927c9 100644
--- a/pilot/prompts/example_base.py
+++ b/pilot/prompts/example_base.py
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType
+
class ExampleSelector(BaseModel, ABC):
examples: List[List]
use_example: bool = False
diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py
index 65e107d11..475f82ea4 100644
--- a/pilot/prompts/prompt_new.py
+++ b/pilot/prompts/prompt_new.py
@@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
+
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
diff --git a/pilot/scene/base.py b/pilot/scene/base.py
index cec443beb..d0bb99255 100644
--- a/pilot/scene/base.py
+++ b/pilot/scene/base.py
@@ -1,11 +1,13 @@
from enum import Enum
+
class Scene:
def __init__(self, code, describe, is_inner):
self.code = code
self.describe = describe
self.is_inner = is_inner
+
class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa"
@@ -19,9 +21,8 @@ class ChatScene(Enum):
ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db"
- ChatData= "chat_data"
+ ChatData = "chat_data"
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)
-
diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index d6628a19a..c9ada6b29 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -104,7 +104,9 @@ class BaseChat(ABC):
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
- self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ self.current_message.start_date = datetime.datetime.now().strftime(
+ "%Y-%m-%d %H:%M:%S"
+ )
# TODO
self.current_message.tokens = 0
current_prompt = None
@@ -200,8 +202,11 @@ class BaseChat(ABC):
# }"""
self.current_message.add_ai_message(ai_response_text)
- prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
-
+ prompt_define_response = (
+ self.prompt_template.output_parser.parse_prompt_response(
+ ai_response_text
+ )
+ )
result = self.do_action(prompt_define_response)
diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py
index 19a59c39a..a9f84e431 100644
--- a/pilot/scene/chat_dashboard/chat.py
+++ b/pilot/scene/chat_dashboard/chat.py
@@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
-from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
+from pilot.scene.chat_dashboard.data_preparation.report_schma import (
+ ChartData,
+ ReportData,
+)
CFG = Config()
@@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""
- def __init__(
- self, chat_session_id, db_name, user_input, report_name
- ):
+ def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
@@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
- "supported_chat_type": "" #TODO
+ "supported_chat_type": "" # TODO
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
@@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
# TODO 修复流程
print(str(e))
-
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
@@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
report_data.charts = chart_datas
return report_data
-
-
diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py
index 9323aacc3..84468f02e 100644
--- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py
+++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py
@@ -12,11 +12,7 @@ class ChartData(BaseModel):
class ReportData(BaseModel):
- conv_uid:str
- template_name:str
- template_introduce:str
+ conv_uid: str
+ template_name: str
+ template_introduce: str
charts: List[ChartData]
-
-
-
-
diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py
index 975196978..079b5f59a 100644
--- a/pilot/scene/chat_dashboard/out_parser.py
+++ b/pilot/scene/chat_dashboard/out_parser.py
@@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
class ChartItem(NamedTuple):
sql: str
- title:str
+ title: str
thoughts: str
- showcase:str
+ showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
@@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
response = json.loads(clean_str)
chart_items = List[ChartItem]
for item in response:
- chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
+ chart_items.append(
+ ChartItem(
+ item["sql"], item["title"], item["thoughts"], item["showcase"]
+ )
+ )
return chart_items
def parse_view_response(self, speak, data) -> str:
diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py
index e44144d4d..481ecac22 100644
--- a/pilot/scene/chat_dashboard/prompt.py
+++ b/pilot/scene/chat_dashboard/prompt.py
@@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
Ensure the response is correct json and can be parsed by Python json.loads
"""
-RESPONSE_FORMAT = [{
- "sql": "data analysis SQL",
- "title": "Data Analysis Title",
- "showcase": "What type of charts to show",
- "thoughts": "Current thinking and value of data analysis"
-}]
+RESPONSE_FORMAT = [
+ {
+ "sql": "data analysis SQL",
+ "title": "Data Analysis Title",
+ "showcase": "What type of charts to show",
+ "thoughts": "Current thinking and value of data analysis",
+ }
+]
PROMPT_SEP = SeparatorStyle.SINGLE.value
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index 0a87f51f6..613c660c7 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
- def __init__(
- self, chat_session_id, db_name, user_input
- ):
+ def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py
index 3d6cd0db4..3c8d0eb37 100644
--- a/pilot/scene/chat_db/professional_qa/chat.py
+++ b/pilot/scene/chat_db/professional_qa/chat.py
@@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
- def __init__(
- self, chat_session_id, db_name, user_input
- ):
+ def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
@@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
"table_info": table_info,
}
return input_values
-
-
diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py
index 6cd71b39c..e41de3abd 100644
--- a/pilot/scene/chat_execution/example.py
+++ b/pilot/scene/chat_execution/example.py
@@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
- [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
- [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
+ [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
+ [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)
diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py
index b5bc38bb2..ceb5db09d 100644
--- a/pilot/scene/chat_execution/prompt.py
+++ b/pilot/scene/chat_execution/prompt.py
@@ -36,7 +36,6 @@ RESPONSE_FORMAT = {
}
-
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
@@ -49,8 +48,10 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
- output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
- example=example
+ output_parser=PluginChatOutputParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
+ ),
+ example=example,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})
diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py
index 370b5518d..9a67b4d75 100644
--- a/pilot/scene/chat_knowledge/custom/chat.py
+++ b/pilot/scene/chat_knowledge/custom/chat.py
@@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query"""
- def __init__(
- self, chat_session_id, user_input, knowledge_name
- ):
+ def __init__(self, chat_session_id, user_input, knowledge_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatNewKnowledge,
@@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
-
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value
diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py
index 8052de910..51587a048 100644
--- a/pilot/scene/chat_knowledge/default/chat.py
+++ b/pilot/scene/chat_knowledge/default/chat.py
@@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat):
"""Number of results to return from the query"""
- def __init__(self, chat_session_id, user_input):
+ def __init__(self, chat_session_id, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatDefaultKnowledge,
@@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
)
return input_values
-
-
@property
def chat_type(self) -> str:
return ChatScene.ChatDefaultKnowledge.value
diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py
index b4dcc536f..c1e8d279a 100644
--- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py
+++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py
@@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
}
return input_values
-
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value
diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py
index bdb964214..aab4bdb9f 100644
--- a/pilot/scene/chat_knowledge/url/chat.py
+++ b/pilot/scene/chat_knowledge/url/chat.py
@@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
-
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value
diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py
index 321c4d8eb..404a9347b 100644
--- a/pilot/scene/chat_knowledge/v1/chat.py
+++ b/pilot/scene/chat_knowledge/v1/chat.py
@@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
- def __init__(self, chat_session_id, user_input, knowledge_space):
+ def __init__(self, chat_session_id, user_input, knowledge_space):
""" """
super().__init__(
chat_mode=ChatScene.ChatKnowledge,
@@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
)
return input_values
-
-
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value
diff --git a/pilot/scene/message.py b/pilot/scene/message.py
index a2a894fe8..ae32bbfe7 100644
--- a/pilot/scene/message.py
+++ b/pilot/scene/message.py
@@ -85,7 +85,6 @@ class OnceConversation:
self.messages.clear()
self.session_id = None
-
def get_user_message(self):
for once in self.messages:
if isinstance(once, HumanMessage):
diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py
index 0435c3679..ac72fc637 100644
--- a/pilot/server/__init__.py
+++ b/pilot/server/__init__.py
@@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True)
del load_dotenv
-
-
diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py
index 3030d1fdc..367671d68 100644
--- a/pilot/server/llmserver.py
+++ b/pilot/server/llmserver.py
@@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
-
class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
@@ -106,6 +105,7 @@ worker = ModelWorker(
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
+
app.include_router(router)
origins = [
@@ -119,9 +119,10 @@ app.add_middleware(
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
- allow_headers=["*"]
+ allow_headers=["*"],
)
+
class PromptRequest(BaseModel):
prompt: str
temperature: float
diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py
index 9513c0c5b..fd7401833 100644
--- a/pilot/server/webserver.py
+++ b/pilot/server/webserver.py
@@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
@@ -107,12 +107,16 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue,
]
+
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
- *args, **kwargs,
- swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
- swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
+ *args,
+ **kwargs,
+ swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
+ swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
)
+
+
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
@@ -360,14 +364,18 @@ def http_bot(
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
- msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
- state.messages[-1][-1] =msg
+ msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
+ chunk, chat.skip_echo_len
+ )
+ state.messages[-1][-1] = msg
chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e:
print(traceback.format_exc())
- state.messages[-1][-1] = f"""ERROR!{str(e)} """
+ state.messages[-1][
+ -1
+ ] = f"""ERROR!{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@@ -693,8 +701,12 @@ def signal_handler(sig, frame):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
- parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
+ parser.add_argument(
+ "--model_list_mode", type=str, default="once", choices=["once", "reload"]
+ )
+ parser.add_argument(
+ "-new", "--new", action="store_true", help="enable new http mode"
+ )
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -702,27 +714,24 @@ if __name__ == "__main__":
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
-
# init server config
args = parser.parse_args()
server_init(args)
if args.new:
import uvicorn
+
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
- concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ concurrency_count=args.concurrency_count,
+ status_update_rate=10,
+ api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)
-
-
-
-
-
diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py
index 6caccc474..cf13b3017 100644
--- a/pilot/source_embedding/knowledge_embedding.py
+++ b/pilot/source_embedding/knowledge_embedding.py
@@ -38,7 +38,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
- return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
+ return get_knowledge_embedding(
+ self.file_type.upper(), self.file_path, self.vector_store_config
+ )
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(