style:format code

format code
This commit is contained in:
aries_ckt 2023-06-27 22:20:21 +08:00
parent 2475ffe282
commit 7f979c0880
47 changed files with 438 additions and 290 deletions

View File

@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
elif hasattr(obj, '__dict__'):
elif hasattr(obj, "__dict__"):
return obj.__dict__
else:
return json.JSONEncoder.default(self, obj)
return json.JSONEncoder.default(self, obj)

View File

@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
headers={'Authorization': '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'})
response = session.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S')
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.

View File

@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
THREE = auto()
FOUR = auto()
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"

View File

@ -90,7 +90,7 @@ class Config(metaclass=Singleton):
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")

View File

@ -6,7 +6,7 @@ import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
from test_cls_1 import TestBase,Test1
from test_cls_1 import TestBase, Test1
from test_cls_2 import Test2
CFG = Config()
@ -60,21 +60,21 @@ CFG = Config()
# if __name__ == "__main__":
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
if __name__ == "__main__":
test1 = Test1()
@ -83,4 +83,4 @@ if __name__ == "__main__":
test1.test()
test2.write()
test1.test()
test2.test()
test2.test()

View File

@ -4,9 +4,9 @@ from test_cls_base import TestBase
class Test1(TestBase):
mode:str = "456"
mode: str = "456"
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")

View File

@ -3,13 +3,15 @@ from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase):
test_2_values:List = []
mode:str = "789"
test_2_values: List = []
mode: str = "789"
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
self.test_2_values.append("z")
self.test_2_values.append("z")

View File

@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
mode:str = "123"
mode: str = "123"
def test(self):
print(self.__class__.__name__ + ":" )
print(self.__class__.__name__ + ":")
print(self.test_values)
print(self.mode)
print(self.mode)

View File

@ -39,7 +39,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config
)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
vector_client.delete_by_ids(ids=ids)

View File

@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
def clear(self) -> None:
"""Clear session memory from the local file"""
def conv_list(self, user_name:str=None) -> None:
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass

View File

@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history'
table_name = "chat_history"
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
# 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
[table_name]).fetchall()
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)")
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor()
@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(once_message)
cursor = self.connect.cursor()
if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id])
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
self.chat_seesion_id,
],
)
else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)])
cursor.execute(
"INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[
self.chat_seesion_id,
"",
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
],
)
cursor.commit()
self.connect.commit()
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
return True
@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else:
cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名
@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return []
def get_messages(self)-> List[OnceConversation]:
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone()
if context:
return json.loads(context[0])

View File

@ -11,7 +11,7 @@ from pilot.scene.message import (
conversation_from_dict,
conversations_to_dict,
)
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
CFG = Config()
@ -19,7 +19,6 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []})

View File

@ -1,4 +1,4 @@
from .base import Cache
from .disk_cache import DiskCache
from .memory_cache import InMemoryCache
from .gpt_cache import GPTCache
from .gpt_cache import GPTCache

View File

@ -3,8 +3,8 @@ import hashlib
from typing import Any, Dict
from abc import ABC, abstractmethod
class Cache(ABC):
def create(self, key: str) -> bool:
pass
@ -24,4 +24,4 @@ class Cache(ABC):
@abstractmethod
def __contains__(self, key: str) -> bool:
"""see if we can return a cached value for the passed key"""
pass
pass

View File

@ -3,15 +3,15 @@ import diskcache
import platformdirs
from pilot.model.cache import Cache
class DiskCache(Cache):
"""DiskCache is a cache that uses diskcache lib.
https://github.com/grantjenks/python-diskcache
https://github.com/grantjenks/python-diskcache
"""
def __init__(self, llm_name: str):
self._diskcache = diskcache.Cache(
os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
)
os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
)
def __getitem__(self, key: str) -> str:
@ -22,6 +22,6 @@ class DiskCache(Cache):
def __contains__(self, key: str) -> bool:
return key in self._diskcache
def clear(self):
self._diskcache.clear()
self._diskcache.clear()

View File

@ -9,22 +9,23 @@ try:
except ImportError:
pass
class GPTCache(Cache):
"""
GPTCache is a semantic cache that uses
"""
GPTCache is a semantic cache that uses
"""
def __init__(self, cache) -> None:
"""GPT Cache is a semantic cache that uses GPTCache lib."""
if isinstance(cache, str):
_cache = Cache()
init_similar_cache(
data_dir=os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
),
cache_obj=_cache
cache_obj=_cache,
)
else:
_cache = cache
@ -41,4 +42,4 @@ class GPTCache(Cache):
return get(key) is not None
def create(self, llm: str, **kwargs: Dict[str, Any]) -> str:
pass
pass

View File

@ -1,24 +1,23 @@
from typing import Dict, Any
from pilot.model.cache import Cache
class InMemoryCache(Cache):
def __init__(self) -> None:
"Initialize that stores things in memory."
self._cache: Dict[str, Any] = {}
def create(self, key: str) -> bool:
pass
pass
def clear(self):
return self._cache.clear()
def __setitem__(self, key: str, value: str) -> None:
self._cache[key] = value
def __getitem__(self, key: str) -> str:
return self._cache[key]
def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None
def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None

View File

@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
from pilot.server.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR)
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage)
from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
def __get_conv_user_message(conversations: dict):
messages = conversations['messages']
messages = conversations["messages"]
for item in messages:
if item['type'] == "human":
return item['data']['content']
if item["type"] == "human":
return item["data"]["content"]
return ""
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET'
response.headers['Access-Control-Request-Headers'] = 'content-type'
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET"
response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None):
conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv['chat_mode'])
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
)
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues)
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
new_modes: List[ChatScene] = [
ChatScene.ChatDb,
ChatScene.ChatData,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
]
for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
if not scene.value in [
ChatScene.ChatNormal.value,
ChatScene.InnerChatDBSummary.value,
]:
scene_vo = ChatSceneVo(
chat_scene=scene.value,
scene_name=scene.name,
param_title="Selection Param",
)
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -90,7 +113,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params:dict = {}
params: dict = {}
for name in dbs:
params.update({name: name})
return params
@ -108,7 +131,7 @@ def knowledge_list():
return knowledge_service.get_knowledge_space(request)
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode:
return Result.succ(get_db_list())
@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
return Result.succ(None)
@router.post('/v1/chat/dialogue/delete')
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete()
return Result.succ(None)
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = []
@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
]
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
@router.post('/v1/chat/completions')
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
raise StopAsyncIteration(
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
)
chat_param = {
"chat_session_id": dialogue.conv_uid,
@ -188,7 +215,9 @@ def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg)
yield msg
# chat.current_message.add_ai_message(msg)
@ -206,7 +235,9 @@ def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)
def non_stream_response(chat):
@ -214,38 +245,38 @@ def non_stream_response(chat):
return chat.nostream_call()
@router.get('/v1/db/types', response_model=Result[str])
@router.get("/v1/db/types", response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str])
@router.get("/v1/db/list", response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get('/v1/knowledge/list')
@router.get("/v1/knowledge/list")
async def knowledge_list():
return ["test1", "test2"]
@router.post('/v1/knowledge/add')
@router.post("/v1/knowledge/add")
async def knowledge_add():
return ["test1", "test2"]
@router.post('/v1/knowledge/delete')
@router.post("/v1/knowledge/delete")
async def knowledge_delete():
return ["test1", "test2"]
@router.get('/v1/knowledge/types')
@router.get("/v1/knowledge/types")
async def knowledge_types():
return ["test1", "test2"]
@router.get('/v1/knowledge/detail')
@router.get("/v1/knowledge/detail")
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
T = TypeVar('T')
T = TypeVar("T")
class Result(Generic[T], BaseModel):
@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel):
class ChatSceneVo(BaseModel):
chat_scene: str = Field(..., description="chat_scene")
scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title")
chat_scene: str = Field(..., description="chat_scene")
scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title")
class ConversationVo(BaseModel):
"""
dialogue_uid
"""
conv_uid: str = Field(..., description="dialogue uid")
conv_uid: str = Field(..., description="dialogue uid")
"""
user input
"""
@ -44,7 +46,7 @@ class ConversationVo(BaseModel):
"""
the scene of chat
"""
chat_mode: str = Field(..., description="the scene of chat ")
chat_mode: str = Field(..., description="the scene of chat ")
"""
chat scene select param
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
select_param: str = None
class MessageVo(BaseModel):
"""
role that sends out the current message
role that sends out the current message
"""
role: str
"""
current message
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
time the current message was sent
"""
time_stamp: Any = None

View File

@ -10,8 +10,10 @@ from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base):
__tablename__ = 'document_chunk'
__tablename__ = "document_chunk"
id = Column(Integer, primary_key=True)
document_id = Column(Integer)
doc_name = Column(String(100))
@ -29,43 +31,55 @@ class DocumentChunkDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents:List):
def create_documents_chunks(self, documents: List):
session = self.Session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
document_id=document.document_id,
content=document.content or "",
meta_info=document.meta_info or "",
gmt_created=datetime.now(),
gmt_modified=datetime.now()
doc_name=document.doc_name,
doc_type=document.doc_type,
document_id=document.document_id,
content=document.content or "",
meta_info=document.meta_info or "",
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for document in documents]
for document in documents
]
session.add_all(docs)
session.commit()
session.close()
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
page_size
)
result = document_chunks.all()
return result

View File

@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
KnowledgeQueryResponse,
KnowledgeDocumentRequest,
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
)
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(knowledge_space_service.get_knowledge_documents(
space_name,
query_request
))
return Result.succ(
knowledge_space_service.get_knowledge_documents(space_name, query_request)
)
except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)):
async def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids
)
Result.succ([])
return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(knowledge_space_service.get_document_chunks(
query_request
))
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}")

View File

@ -12,7 +12,7 @@ Base = declarative_base()
class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document'
__tablename__ = "knowledge_document"
id = Column(Integer, primary_key=True)
doc_name = Column(String(100))
doc_type = Column(String(100))
@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
status = Column(String(100))
last_sync = Column(String(100))
content = Column(Text)
result = Column(Text)
vector_ids = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document:KnowledgeDocumentEntity):
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
status=document.status,
last_sync=document.last_sync,
content=document.content or "",
result=document.result or "",
vector_ids=document.vector_ids,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
session.add(knowledge_document)
session.commit()
@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
knowledge_documents = knowledge_documents.order_by(
KnowledgeDocumentEntity.id.desc()
)
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
page_size
)
result = knowledge_documents.all()
return result
def update_knowledge_document(self, document:KnowledgeDocumentEntity):
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
def delete_knowledge_document(self, document_id:int):
def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,))

View File

@ -4,17 +4,24 @@ from datetime import datetime
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.openapi.knowledge.document_chunk_dao import (
DocumentChunkEntity,
DocumentChunkDao,
)
from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.openapi.knowledge.knowledge_space_dao import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
KnowledgeDocumentRequest,
DocumentQueryRequest,
ChunkQueryRequest,
)
from enum import Enum
@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao()
CFG=Config()
CFG = Config()
class SyncStatus(Enum):
@ -53,10 +60,7 @@ class KnowledgeService:
"""create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
space=space
)
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named")
@ -74,26 +78,27 @@ class KnowledgeService:
"""get knowledge space"""
def get_knowledge_space(self, request:KnowledgeSpaceRequest):
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity(
name=request.name,
vector_type=request.vector_type,
owner=request.owner
name=request.name, vector_type=request.vector_type, owner=request.owner
)
return knowledge_space_dao.get_knowledge_space(query)
"""get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request:DocumentQueryRequest):
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
return knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
@ -101,12 +106,14 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
})
client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
},
)
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
@ -114,9 +121,12 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start()
#save chunk details
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
@ -125,9 +135,10 @@ class KnowledgeService:
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs]
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
return True
@ -145,26 +156,30 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks"""
def get_document_chunks(self, request:ChunkQueryRequest):
def get_document_chunks(self, request: ChunkQueryRequest):
query = DocumentChunkEntity(
id=request.id,
document_id=request.document_id,
doc_name=request.doc_name,
doc_type=request.doc_type
doc_type=request.doc_type,
)
return document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
)
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.content = "embedding success"
doc.result = "document embedding success"
doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.content = str(e)
doc.result = "document embedding failed" + str(e)
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)

View File

@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base):
__tablename__ = 'knowledge_space'
__tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))
@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
self.db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space:KnowledgeSpaceRequest):
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
session.add(knowledge_space)
session.commit()
session.close()
def get_knowledge_space(self, query:KnowledgeSpaceEntity):
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.id == query.id
)
if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.name == query.name
)
if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
)
if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.owner == query.owner
)
if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_created == query.gmt_created
)
if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
result = knowledge_spaces.all()
return result
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
cursor.execute(
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
)
self.conn.commit()
cursor.close()
def delete_knowledge_space(self, space_id:int):
def delete_knowledge_space(self, space_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_space WHERE id = %s"
cursor.execute(query, (space_id,))

View File

@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
"""doc_type: doc type"""
doc_type: str
"""content: content"""
content: str
content: str = None
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int
class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str = None
"""doc_type: doc type"""
doc_type: str= None
doc_type: str = None
"""status: status"""
status: str= None
status: str = None
"""page: page"""
page: int = 1
"""page_size: page size"""
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
doc_ids: List
class ChunkQueryRequest(BaseModel):
"""id: id"""
id: int = None
"""document_id: doc id"""
document_id: int = None

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples: List[List]
use_example: bool = False

View File

@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:

View File

@ -1,11 +1,13 @@
from enum import Enum
class Scene:
def __init__(self, code, describe, is_inner):
self.code = code
self.describe = describe
self.is_inner = is_inner
class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa"
@ -19,9 +21,8 @@ class ChatScene(Enum):
ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db"
ChatData= "chat_data"
ChatData = "chat_data"
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)

View File

@ -104,7 +104,9 @@ class BaseChat(ABC):
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
# TODO
self.current_message.tokens = 0
current_prompt = None
@ -200,8 +202,11 @@ class BaseChat(ABC):
# }"""
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response)

View File

@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
)
CFG = Config()
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input, report_name
):
def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" #TODO
"supported_chat_type": "" # TODO
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
# TODO 修复流程
print(str(e))
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
report_data.charts = chart_datas
return report_data

View File

@ -12,11 +12,7 @@ class ChartData(BaseModel):
class ReportData(BaseModel):
conv_uid:str
template_name:str
template_introduce:str
conv_uid: str
template_name: str
template_introduce: str
charts: List[ChartData]

View File

@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
class ChartItem(NamedTuple):
sql: str
title:str
title: str
thoughts: str
showcase:str
showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
response = json.loads(clean_str)
chart_items = List[ChartItem]
for item in response:
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
chart_items.append(
ChartItem(
item["sql"], item["title"], item["thoughts"], item["showcase"]
)
)
return chart_items
def parse_view_response(self, speak, data) -> str:

View File

@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = [{
"sql": "data analysis SQL",
"title": "Data Analysis Title",
"showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis"
}]
RESPONSE_FORMAT = [
{
"sql": "data analysis SQL",
"title": "Data Analysis Title",
"showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis",
}
]
PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input
):
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,

View File

@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input
):
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
"table_info": table_info,
}
return input_values

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
[{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -36,7 +36,6 @@ RESPONSE_FORMAT = {
}
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
@ -49,8 +48,10 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
example=example
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
example=example,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, user_input, knowledge_name
):
def __init__(self, chat_session_id, user_input, knowledge_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatNewKnowledge,
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value

View File

@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input):
def __init__(self, chat_session_id, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatDefaultKnowledge,
@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
)
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatDefaultKnowledge.value

View File

@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value

View File

@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, knowledge_space):
def __init__(self, chat_session_id, user_input, knowledge_space):
""" """
super().__init__(
chat_mode=ChatScene.ChatKnowledge,
@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
)
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value

View File

@ -85,7 +85,6 @@ class OnceConversation:
self.messages.clear()
self.session_id = None
def get_user_message(self):
for once in self.messages:
if isinstance(once, HumanMessage):

View File

@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
@ -106,6 +105,7 @@ worker = ModelWorker(
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
@ -119,9 +119,10 @@ app.add_middleware(
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
allow_headers=["*"],
)
class PromptRequest(BaseModel):
prompt: str
temperature: float

View File

@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
@ -107,12 +107,16 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue,
]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
@ -360,14 +364,18 @@ def http_bot(
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
state.messages[-1][-1] =msg
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
state.messages[-1][-1] = msg
chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
state.messages[-1][
-1
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -702,27 +714,24 @@ if __name__ == "__main__":
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)

View File

@ -38,7 +38,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
return get_knowledge_embedding(
self.file_type.upper(), self.file_path, self.vector_store_config
)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(