mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-22 10:08:34 +00:00
style:format code
format code
This commit is contained in:
parent
2475ffe282
commit
7f979c0880
@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif hasattr(obj, '__dict__'):
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
|
||||
if not cfg.plugins_auto_load:
|
||||
print("not auto load_native_plugins")
|
||||
return
|
||||
|
||||
def load_from_git(cfg: Config):
|
||||
print("async load_native_plugins")
|
||||
branch_name = cfg.plugins_git_branch
|
||||
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={'Authorization': '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'})
|
||||
response = session.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
|
@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
|
||||
class ExampleType(Enum):
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
@ -90,7 +90,7 @@ class Config(metaclass=Singleton):
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
from test_cls_1 import TestBase,Test1
|
||||
from test_cls_1 import TestBase, Test1
|
||||
from test_cls_2 import Test2
|
||||
|
||||
CFG = Config()
|
||||
@ -60,21 +60,21 @@ CFG = Config()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test1 = Test1()
|
||||
@ -83,4 +83,4 @@ if __name__ == "__main__":
|
||||
test1.test()
|
||||
test2.write()
|
||||
test1.test()
|
||||
test2.test()
|
||||
test2.test()
|
||||
|
@ -4,9 +4,9 @@ from test_cls_base import TestBase
|
||||
|
||||
|
||||
class Test1(TestBase):
|
||||
mode:str = "456"
|
||||
mode: str = "456"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append("x")
|
||||
self.test_values.append("y")
|
||||
self.test_values.append("g")
|
||||
|
||||
|
@ -3,13 +3,15 @@ from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
|
||||
class Test2(TestBase):
|
||||
test_2_values:List = []
|
||||
mode:str = "789"
|
||||
test_2_values: List = []
|
||||
mode: str = "789"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append(1)
|
||||
self.test_values.append(2)
|
||||
self.test_values.append(3)
|
||||
self.test_2_values.append("x")
|
||||
self.test_2_values.append("y")
|
||||
self.test_2_values.append("z")
|
||||
self.test_2_values.append("z")
|
||||
|
@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
class TestBase(BaseModel, ABC):
|
||||
test_values: List = []
|
||||
mode:str = "123"
|
||||
mode: str = "123"
|
||||
|
||||
def test(self):
|
||||
print(self.__class__.__name__ + ":" )
|
||||
print(self.__class__.__name__ + ":")
|
||||
print(self.test_values)
|
||||
print(self.mode)
|
||||
print(self.mode)
|
||||
|
@ -39,7 +39,9 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
||||
return get_knowledge_embedding(
|
||||
self.knowledge_type, self.knowledge_source, self.vector_store_config
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
vector_client.delete_by_ids(ids=ids)
|
||||
|
@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
|
||||
|
||||
def conv_list(self, user_name:str=None) -> None:
|
||||
def conv_list(self, user_name: str = None) -> None:
|
||||
"""get user's conversation list"""
|
||||
pass
|
||||
|
||||
|
@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
table_name = 'chat_history'
|
||||
table_name = "chat_history"
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
self.__init_chat_history_tables()
|
||||
|
||||
def __init_chat_history_tables(self):
|
||||
|
||||
# 检查表是否存在
|
||||
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
[table_name]).fetchall()
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# 如果表不存在,则创建新表
|
||||
self.connect.execute(
|
||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
|
||||
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
|
||||
)
|
||||
|
||||
def __get_messages_by_conv_uid(self, conv_uid: str):
|
||||
cursor = self.connect.cursor()
|
||||
@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
conversations.append(once_message)
|
||||
cursor = self.connect.cursor()
|
||||
if context:
|
||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[
|
||||
json.dumps(
|
||||
conversations_to_dict(conversations),
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
),
|
||||
self.chat_seesion_id,
|
||||
],
|
||||
)
|
||||
else:
|
||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
|
||||
cursor.execute(
|
||||
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||
[
|
||||
self.chat_seesion_id,
|
||||
"",
|
||||
json.dumps(
|
||||
conversations_to_dict(conversations),
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
),
|
||||
],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def delete(self) -> bool:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if user_name:
|
||||
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
|
||||
)
|
||||
else:
|
||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||
# 获取查询结果字段名
|
||||
@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_messages(self)-> List[OnceConversation]:
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
context = cursor.fetchone()
|
||||
if context:
|
||||
return json.loads(context[0])
|
||||
|
@ -11,7 +11,7 @@ from pilot.scene.message import (
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -19,7 +19,6 @@ CFG = Config()
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.histroies_map.update({chat_session_id: []})
|
||||
|
2
pilot/model/cache/__init__.py
vendored
2
pilot/model/cache/__init__.py
vendored
@ -1,4 +1,4 @@
|
||||
from .base import Cache
|
||||
from .disk_cache import DiskCache
|
||||
from .memory_cache import InMemoryCache
|
||||
from .gpt_cache import GPTCache
|
||||
from .gpt_cache import GPTCache
|
||||
|
4
pilot/model/cache/base.py
vendored
4
pilot/model/cache/base.py
vendored
@ -3,8 +3,8 @@ import hashlib
|
||||
from typing import Any, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Cache(ABC):
|
||||
|
||||
def create(self, key: str) -> bool:
|
||||
pass
|
||||
|
||||
@ -24,4 +24,4 @@ class Cache(ABC):
|
||||
@abstractmethod
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""see if we can return a cached value for the passed key"""
|
||||
pass
|
||||
pass
|
||||
|
12
pilot/model/cache/disk_cache.py
vendored
12
pilot/model/cache/disk_cache.py
vendored
@ -3,15 +3,15 @@ import diskcache
|
||||
import platformdirs
|
||||
from pilot.model.cache import Cache
|
||||
|
||||
|
||||
class DiskCache(Cache):
|
||||
"""DiskCache is a cache that uses diskcache lib.
|
||||
https://github.com/grantjenks/python-diskcache
|
||||
https://github.com/grantjenks/python-diskcache
|
||||
"""
|
||||
|
||||
def __init__(self, llm_name: str):
|
||||
self._diskcache = diskcache.Cache(
|
||||
os.path.join(
|
||||
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
|
||||
)
|
||||
os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
@ -22,6 +22,6 @@ class DiskCache(Cache):
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._diskcache
|
||||
|
||||
|
||||
def clear(self):
|
||||
self._diskcache.clear()
|
||||
self._diskcache.clear()
|
||||
|
11
pilot/model/cache/gpt_cache.py
vendored
11
pilot/model/cache/gpt_cache.py
vendored
@ -9,22 +9,23 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class GPTCache(Cache):
|
||||
|
||||
"""
|
||||
GPTCache is a semantic cache that uses
|
||||
"""
|
||||
GPTCache is a semantic cache that uses
|
||||
"""
|
||||
|
||||
def __init__(self, cache) -> None:
|
||||
"""GPT Cache is a semantic cache that uses GPTCache lib."""
|
||||
|
||||
|
||||
if isinstance(cache, str):
|
||||
_cache = Cache()
|
||||
init_similar_cache(
|
||||
data_dir=os.path.join(
|
||||
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
|
||||
),
|
||||
cache_obj=_cache
|
||||
cache_obj=_cache,
|
||||
)
|
||||
else:
|
||||
_cache = cache
|
||||
@ -41,4 +42,4 @@ class GPTCache(Cache):
|
||||
return get(key) is not None
|
||||
|
||||
def create(self, llm: str, **kwargs: Dict[str, Any]) -> str:
|
||||
pass
|
||||
pass
|
||||
|
11
pilot/model/cache/memory_cache.py
vendored
11
pilot/model/cache/memory_cache.py
vendored
@ -1,24 +1,23 @@
|
||||
from typing import Dict, Any
|
||||
from pilot.model.cache import Cache
|
||||
|
||||
|
||||
class InMemoryCache(Cache):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"Initialize that stores things in memory."
|
||||
self._cache: Dict[str, Any] = {}
|
||||
|
||||
def create(self, key: str) -> bool:
|
||||
pass
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
return self._cache.clear()
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
self._cache[key] = value
|
||||
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self._cache[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self._cache.get(key, None) is not None
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self._cache.get(key, None) is not None
|
||||
|
@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.server.api_v1.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import (LOGDIR)
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.base_message import (BaseMessage)
|
||||
from pilot.scene.base_message import BaseMessage
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
|
||||
|
||||
def __get_conv_user_message(conversations: dict):
|
||||
messages = conversations['messages']
|
||||
messages = conversations["messages"]
|
||||
for item in messages:
|
||||
if item['type'] == "human":
|
||||
return item['data']['content']
|
||||
if item["type"] == "human":
|
||||
return item["data"]["content"]
|
||||
return ""
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(response: Response, user_id: str = None):
|
||||
# 设置CORS头部信息
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET'
|
||||
response.headers['Access-Control-Request-Headers'] = 'content-type'
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET"
|
||||
response.headers["Access-Control-Request-Headers"] = "content-type"
|
||||
|
||||
dialogues: List = []
|
||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||
@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None):
|
||||
conversations = json.loads(messages)
|
||||
|
||||
first_conv: OnceConversation = conversations[0]
|
||||
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
|
||||
chat_mode=first_conv['chat_mode'])
|
||||
conv_vo: ConversationVo = ConversationVo(
|
||||
conv_uid=conv_uid,
|
||||
user_input=__get_conv_user_message(first_conv),
|
||||
chat_mode=first_conv["chat_mode"],
|
||||
)
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
return Result[ConversationVo].succ(dialogues)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||
new_modes: List[ChatScene] = [
|
||||
ChatScene.ChatDb,
|
||||
ChatScene.ChatData,
|
||||
ChatScene.ChatDashboard,
|
||||
ChatScene.ChatKnowledge,
|
||||
ChatScene.ChatExecution,
|
||||
]
|
||||
for scene in new_modes:
|
||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||
if not scene.value in [
|
||||
ChatScene.ChatNormal.value,
|
||||
ChatScene.InnerChatDBSummary.value,
|
||||
]:
|
||||
scene_vo = ChatSceneVo(
|
||||
chat_scene=scene.value,
|
||||
scene_name=scene.name,
|
||||
param_title="Selection Param",
|
||||
)
|
||||
scene_vos.append(scene_vo)
|
||||
return Result.succ(scene_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
|
||||
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
|
||||
@ -90,7 +113,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
|
||||
def get_db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
params:dict = {}
|
||||
params: dict = {}
|
||||
for name in dbs:
|
||||
params.update({name: name})
|
||||
return params
|
||||
@ -108,7 +131,7 @@ def knowledge_list():
|
||||
return knowledge_service.get_knowledge_space(request)
|
||||
|
||||
|
||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
if ChatScene.ChatDb.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/delete')
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
|
||||
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
||||
async def dialogue_history_messages(con_uid: str):
|
||||
print(f"dialogue_history_messages:{con_uid}")
|
||||
message_vos: List[MessageVo] = []
|
||||
@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str):
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
|
||||
once_message_vos = [
|
||||
message2Vo(element, once["chat_order"]) for element in once["messages"]
|
||||
]
|
||||
message_vos.extend(once_message_vos)
|
||||
return Result.succ(message_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/completions')
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||
raise StopAsyncIteration(
|
||||
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
|
||||
)
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
@ -188,7 +215,9 @@ def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
# chat.current_message.add_ai_message(msg)
|
||||
@ -206,7 +235,9 @@ def stream_generator(chat):
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
return MessageVo(
|
||||
role=message["type"], context=message["data"]["content"], order=order
|
||||
)
|
||||
|
||||
|
||||
def non_stream_response(chat):
|
||||
@ -214,38 +245,38 @@ def non_stream_response(chat):
|
||||
return chat.nostream_call()
|
||||
|
||||
|
||||
@router.get('/v1/db/types', response_model=Result[str])
|
||||
@router.get("/v1/db/types", response_model=Result[str])
|
||||
async def db_types():
|
||||
return Result.succ(["mysql", "duckdb"])
|
||||
|
||||
|
||||
@router.get('/v1/db/list', response_model=Result[str])
|
||||
@router.get("/v1/db/list", response_model=Result[str])
|
||||
async def db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
return Result.succ(dbs)
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/list')
|
||||
@router.get("/v1/knowledge/list")
|
||||
async def knowledge_list():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/add')
|
||||
@router.post("/v1/knowledge/add")
|
||||
async def knowledge_add():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/delete')
|
||||
@router.post("/v1/knowledge/delete")
|
||||
async def knowledge_delete():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/types')
|
||||
@router.get("/v1/knowledge/types")
|
||||
async def knowledge_types():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/detail')
|
||||
@router.get("/v1/knowledge/detail")
|
||||
async def knowledge_detail():
|
||||
return ["test1", "test2"]
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Result(Generic[T], BaseModel):
|
||||
@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel):
|
||||
|
||||
|
||||
class ChatSceneVo(BaseModel):
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||
|
||||
|
||||
class ConversationVo(BaseModel):
|
||||
"""
|
||||
dialogue_uid
|
||||
"""
|
||||
conv_uid: str = Field(..., description="dialogue uid")
|
||||
|
||||
conv_uid: str = Field(..., description="dialogue uid")
|
||||
"""
|
||||
user input
|
||||
"""
|
||||
@ -44,7 +46,7 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
the scene of chat
|
||||
"""
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
|
||||
"""
|
||||
chat scene select param
|
||||
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
|
||||
select_param: str = None
|
||||
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
"""
|
||||
role that sends out the current message
|
||||
role that sends out the current message
|
||||
"""
|
||||
|
||||
role: str
|
||||
"""
|
||||
current message
|
||||
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
|
||||
time the current message was sent
|
||||
"""
|
||||
time_stamp: Any = None
|
||||
|
||||
|
@ -10,8 +10,10 @@ from pilot.configs.config import Config
|
||||
CFG = Config()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
__tablename__ = 'document_chunk'
|
||||
__tablename__ = "document_chunk"
|
||||
id = Column(Integer, primary_key=True)
|
||||
document_id = Column(Integer)
|
||||
doc_name = Column(String(100))
|
||||
@ -29,43 +31,55 @@ class DocumentChunkDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(
|
||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
||||
echo=True)
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_documents_chunks(self, documents:List):
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.Session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
document_id=document.document_id,
|
||||
content=document.content or "",
|
||||
meta_info=document.meta_info or "",
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
document_id=document.document_id,
|
||||
content=document.content or "",
|
||||
meta_info=document.meta_info or "",
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for document in documents]
|
||||
for document in documents
|
||||
]
|
||||
session.add_all(docs)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
|
||||
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
|
||||
session = self.Session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
if query.document_id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.meta_info is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
|
||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = document_chunks.all()
|
||||
return result
|
||||
|
||||
|
@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
||||
KnowledgeQueryResponse,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentSyncRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_knowledge_documents(
|
||||
space_name,
|
||||
query_request
|
||||
))
|
||||
return Result.succ(
|
||||
knowledge_space_service.get_knowledge_documents(space_name, query_request)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
async def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
with NamedTemporaryFile(delete=False) as tmp:
|
||||
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
knowledge_space_service.sync_knowledge_document(
|
||||
space_name=space_name, doc_ids=request.doc_ids
|
||||
)
|
||||
Result.succ([])
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(
|
||||
query_request
|
||||
))
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
|
||||
|
||||
|
@ -12,7 +12,7 @@ Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
__tablename__ = 'knowledge_document'
|
||||
__tablename__ = "knowledge_document"
|
||||
id = Column(Integer, primary_key=True)
|
||||
doc_name = Column(String(100))
|
||||
doc_type = Column(String(100))
|
||||
@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
|
||||
status = Column(String(100))
|
||||
last_sync = Column(String(100))
|
||||
content = Column(Text)
|
||||
result = Column(Text)
|
||||
vector_ids = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeDocumentDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(
|
||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
||||
echo=True)
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
|
||||
status=document.status,
|
||||
last_sync=document.last_sync,
|
||||
content=document.content or "",
|
||||
result=document.result or "",
|
||||
vector_ids=document.vector_ids,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_document)
|
||||
session.commit()
|
||||
@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
|
||||
session = self.Session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
|
||||
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
|
||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
|
||||
knowledge_documents = knowledge_documents.order_by(
|
||||
KnowledgeDocumentEntity.id.desc()
|
||||
)
|
||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = knowledge_documents.all()
|
||||
return result
|
||||
|
||||
def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
def delete_knowledge_document(self, document_id:int):
|
||||
def delete_knowledge_document(self, document_id: int):
|
||||
cursor = self.conn.cursor()
|
||||
query = "DELETE FROM knowledge_document WHERE id = %s"
|
||||
cursor.execute(query, (document_id,))
|
||||
|
@ -4,17 +4,24 @@ from datetime import datetime
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.logs import logger
|
||||
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
||||
from pilot.openapi.knowledge.document_chunk_dao import (
|
||||
DocumentChunkEntity,
|
||||
DocumentChunkDao,
|
||||
)
|
||||
from pilot.openapi.knowledge.knowledge_document_dao import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from pilot.openapi.knowledge.knowledge_space_dao import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentQueryRequest,
|
||||
ChunkQueryRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
document_chunk_dao = DocumentChunkDao()
|
||||
|
||||
CFG=Config()
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
@ -53,10 +60,7 @@ class KnowledgeService:
|
||||
"""create knowledge document"""
|
||||
|
||||
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
space=space
|
||||
)
|
||||
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
|
||||
documents = knowledge_document_dao.get_knowledge_documents(query)
|
||||
if len(documents) > 0:
|
||||
raise Exception(f"document name:{request.doc_name} have already named")
|
||||
@ -74,26 +78,27 @@ class KnowledgeService:
|
||||
|
||||
"""get knowledge space"""
|
||||
|
||||
def get_knowledge_space(self, request:KnowledgeSpaceRequest):
|
||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name,
|
||||
vector_type=request.vector_type,
|
||||
owner=request.owner
|
||||
name=request.name, vector_type=request.vector_type, owner=request.owner
|
||||
)
|
||||
return knowledge_space_dao.get_knowledge_space(query)
|
||||
|
||||
"""get knowledge get_knowledge_documents"""
|
||||
|
||||
def get_knowledge_documents(self, space, request:DocumentQueryRequest):
|
||||
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
|
||||
return knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
|
||||
def sync_knowledge_document(self, space_name, doc_ids):
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
@ -101,12 +106,14 @@ class KnowledgeService:
|
||||
space=space_name,
|
||||
)
|
||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||
client = KnowledgeEmbedding(knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
})
|
||||
client = KnowledgeEmbedding(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
},
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
@ -114,9 +121,12 @@ class KnowledgeService:
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
# async doc embeddings
|
||||
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
|
||||
thread = threading.Thread(
|
||||
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||
)
|
||||
thread.start()
|
||||
#save chunk details
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
@ -125,9 +135,10 @@ class KnowledgeService:
|
||||
content=chunk_doc.page_content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs]
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
|
||||
return True
|
||||
@ -145,26 +156,30 @@ class KnowledgeService:
|
||||
return knowledge_space_dao.delete_knowledge_space(space_id)
|
||||
|
||||
"""get document chunks"""
|
||||
def get_document_chunks(self, request:ChunkQueryRequest):
|
||||
|
||||
def get_document_chunks(self, request: ChunkQueryRequest):
|
||||
query = DocumentChunkEntity(
|
||||
id=request.id,
|
||||
document_id=request.document_id,
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type
|
||||
doc_type=request.doc_type,
|
||||
)
|
||||
return document_chunk_dao.get_document_chunks(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
|
||||
logger.info(
|
||||
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.content = "embedding success"
|
||||
doc.result = "document embedding success"
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.content = str(e)
|
||||
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
|
||||
|
@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
__tablename__ = 'knowledge_space'
|
||||
__tablename__ = "knowledge_space"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100))
|
||||
vector_type = Column(String(100))
|
||||
@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
|
||||
class KnowledgeSpaceDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
|
||||
self.db_engine = create_engine(
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_knowledge_space(self, space:KnowledgeSpaceRequest):
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.Session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
|
||||
desc=space.desc,
|
||||
owner=space.owner,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_space)
|
||||
session.commit()
|
||||
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query:KnowledgeSpaceEntity):
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.id == query.id
|
||||
)
|
||||
if query.name is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.name == query.name
|
||||
)
|
||||
if query.vector_type is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||
)
|
||||
if query.desc is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.desc == query.desc
|
||||
)
|
||||
if query.owner is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.owner == query.owner
|
||||
)
|
||||
if query.gmt_created is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_created == query.gmt_created
|
||||
)
|
||||
if query.gmt_modified is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||
)
|
||||
|
||||
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
|
||||
knowledge_spaces = knowledge_spaces.order_by(
|
||||
KnowledgeSpaceEntity.gmt_created.desc()
|
||||
)
|
||||
result = knowledge_spaces.all()
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
|
||||
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
||||
cursor = self.conn.cursor()
|
||||
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
||||
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
|
||||
cursor.execute(
|
||||
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def delete_knowledge_space(self, space_id:int):
|
||||
def delete_knowledge_space(self, space_id: int):
|
||||
cursor = self.conn.cursor()
|
||||
query = "DELETE FROM knowledge_space WHERE id = %s"
|
||||
cursor.execute(query, (space_id,))
|
||||
|
@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str
|
||||
"""content: content"""
|
||||
content: str
|
||||
content: str = None
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
||||
|
||||
class DocumentQueryRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str= None
|
||||
doc_type: str = None
|
||||
"""status: status"""
|
||||
status: str= None
|
||||
status: str = None
|
||||
"""page: page"""
|
||||
page: int = 1
|
||||
"""page_size: page size"""
|
||||
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
|
||||
|
||||
class DocumentSyncRequest(BaseModel):
|
||||
"""doc_ids: doc ids"""
|
||||
|
||||
doc_ids: List
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
||||
id: int = None
|
||||
"""document_id: doc id"""
|
||||
document_id: int = None
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
from pilot.common.schema import ExampleType
|
||||
|
||||
|
||||
class ExampleSelector(BaseModel, ABC):
|
||||
examples: List[List]
|
||||
use_example: bool = False
|
||||
|
@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
try:
|
||||
|
@ -1,11 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Scene:
|
||||
def __init__(self, code, describe, is_inner):
|
||||
self.code = code
|
||||
self.describe = describe
|
||||
self.is_inner = is_inner
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDbExecute = "chat_with_db_execute"
|
||||
ChatWithDbQA = "chat_with_db_qa"
|
||||
@ -19,9 +21,8 @@ class ChatScene(Enum):
|
||||
ChatDashboard = "chat_dashboard"
|
||||
ChatKnowledge = "chat_knowledge"
|
||||
ChatDb = "chat_db"
|
||||
ChatData= "chat_data"
|
||||
ChatData = "chat_data"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
return any(mode == item.value for item in ChatScene)
|
||||
|
||||
|
@ -104,7 +104,9 @@ class BaseChat(ABC):
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
# TODO
|
||||
self.current_message.tokens = 0
|
||||
current_prompt = None
|
||||
@ -200,8 +202,11 @@ class BaseChat(ABC):
|
||||
# }"""
|
||||
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
|
@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
|
||||
report_name: str
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input, report_name
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input, report_name):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
|
||||
"input": self.current_user_input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect),
|
||||
"supported_chat_type": "" #TODO
|
||||
"supported_chat_type": "" # TODO
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
return input_values
|
||||
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
|
||||
chart_datas.append(chart_data)
|
||||
|
||||
report_data.conv_uid = self.chat_session_id
|
||||
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
|
||||
report_data.charts = chart_datas
|
||||
|
||||
return report_data
|
||||
|
||||
|
||||
|
@ -12,11 +12,7 @@ class ChartData(BaseModel):
|
||||
|
||||
|
||||
class ReportData(BaseModel):
|
||||
conv_uid:str
|
||||
template_name:str
|
||||
template_introduce:str
|
||||
conv_uid: str
|
||||
template_name: str
|
||||
template_introduce: str
|
||||
charts: List[ChartData]
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
class ChartItem(NamedTuple):
|
||||
sql: str
|
||||
title:str
|
||||
title: str
|
||||
thoughts: str
|
||||
showcase:str
|
||||
showcase: str
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
||||
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
||||
response = json.loads(clean_str)
|
||||
chart_items = List[ChartItem]
|
||||
for item in response:
|
||||
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
|
||||
chart_items.append(
|
||||
ChartItem(
|
||||
item["sql"], item["title"], item["thoughts"], item["showcase"]
|
||||
)
|
||||
)
|
||||
return chart_items
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
|
@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = [{
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis"
|
||||
}]
|
||||
RESPONSE_FORMAT = [
|
||||
{
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
}
|
||||
]
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
|
@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
|
@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
|
||||
"table_info": table_info,
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
||||
|
@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
||||
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
||||
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
|
||||
]
|
||||
|
||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
||||
|
@ -36,7 +36,6 @@ RESPONSE_FORMAT = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
@ -49,8 +48,10 @@ prompt = PromptTemplate(
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
example=example
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||
),
|
||||
example=example,
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, user_input, knowledge_name
|
||||
):
|
||||
def __init__(self, chat_session_id, user_input, knowledge_name):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNewKnowledge.value
|
||||
|
@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatDefaultKnowledge,
|
||||
@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatDefaultKnowledge.value
|
||||
|
@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.InnerChatDBSummary.value
|
||||
|
@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatUrlKnowledge.value
|
||||
|
@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, knowledge_space):
|
||||
def __init__(self, chat_session_id, user_input, knowledge_space):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value
|
||||
|
@ -85,7 +85,6 @@ class OnceConversation:
|
||||
self.messages.clear()
|
||||
self.session_id = None
|
||||
|
||||
|
||||
def get_user_message(self):
|
||||
for once in self.messages:
|
||||
if isinstance(once, HumanMessage):
|
||||
|
@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
||||
|
||||
|
||||
|
@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||
if model_path.endswith("/"):
|
||||
@ -106,6 +105,7 @@ worker = ModelWorker(
|
||||
|
||||
app = FastAPI()
|
||||
from pilot.openapi.knowledge.knowledge_controller import router
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
origins = [
|
||||
@ -119,9 +119,10 @@ app.add_middleware(
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float
|
||||
|
@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
@ -107,12 +107,16 @@ knowledge_qa_type_list = [
|
||||
add_knowledge_base_dialogue,
|
||||
]
|
||||
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
|
||||
)
|
||||
|
||||
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
@ -360,14 +364,18 @@ def http_bot(
|
||||
response = chat.stream_call()
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
state.messages[-1][-1] =msg
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
state.messages[-1][-1] = msg
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
chat.memory.append(chat.current_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
state.messages[-1][
|
||||
-1
|
||||
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"-new", "--new", action="store_true", help="enable new http mode"
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -702,27 +714,24 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
|
||||
if args.new:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
else:
|
||||
### Compatibility mode starts the old version server by default
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
concurrency_count=args.concurrency_count,
|
||||
status_update_rate=10,
|
||||
api_open=False,
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -38,7 +38,9 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
|
||||
return get_knowledge_embedding(
|
||||
self.file_type.upper(), self.file_path, self.vector_store_config
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
|
Loading…
Reference in New Issue
Block a user