Merge branch 'llm_framework' into dev_ty_06_end

# Conflicts:
#	pilot/openapi/api_v1/api_v1.py
#	pilot/server/dbgpt_server.py
This commit is contained in:
tuyang.yhj 2023-06-30 09:58:32 +08:00
commit 6f8f182d1d
24 changed files with 384 additions and 137 deletions

View File

@ -0,0 +1,41 @@
CREATE TABLE `knowledge_space` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
`desc` varchar(500) NOT NULL COMMENT 'description',
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge space table';
CREATE TABLE `knowledge_document` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
`chunk_size` int NOT NULL COMMENT 'chunk size',
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document table';
CREATE TABLE `document_chunk` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`document_id` int NOT NULL COMMENT 'document parent id',
`content` longtext NOT NULL COMMENT 'chunk content',
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document chunk detail'

View File

@ -19,7 +19,7 @@ const AgentPage = (props) => {
});
const { history, handleChatSubmit } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
queryBody: {
conv_uid: props.params?.agentId,
chat_mode: props.searchParams?.scene || 'chat_normal',

View File

@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({
const Agents = () => {
const { handleChatSubmit, history } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
});
const data = [

View File

@ -6,12 +6,11 @@ from typing import List
import markdown
from bs4 import BeautifulSoup
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
@ -30,12 +29,20 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self):
"""Load from markdown path."""
loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(textsplitter)
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register
def data_process(self, documents: List[Document]):

View File

@ -4,7 +4,7 @@ from typing import List
from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
@ -28,12 +28,24 @@ class PDFEmbedding(SourceEmbedding):
# textsplitter = CHNDocumentSplitter(
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(textsplitter)
# textsplitter = SpacyTextSplitter(
# pipeline="zh_core_web_sm",
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
# chunk_overlap=100,
# )
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register
def data_process(self, documents: List[Document]):

View File

@ -4,10 +4,11 @@ from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
@ -25,12 +26,24 @@ class PPTEmbedding(SourceEmbedding):
def read(self):
"""Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200,
)
return loader.load_and_split(textsplitter)
# textsplitter = SpacyTextSplitter(
# pipeline="zh_core_web_sm",
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
# chunk_overlap=200,
# )
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register
def data_process(self, documents: List[Document]):

View File

@ -2,12 +2,12 @@
# -*- coding: utf-8 -*-
from typing import List
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter, SpacyTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
@ -25,10 +25,19 @@ class WordEmbedding(SourceEmbedding):
def read(self):
"""Load from word path."""
loader = UnstructuredWordDocumentLoader(self.file_path)
textsplitter = CHNDocumentSplitter(
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(textsplitter)
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register
def data_process(self, documents: List[Document]):

View File

@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history'
table_name = "chat_history"
CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True)
@ -29,10 +28,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables()
def __init_chat_history_tables(self):
# 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
[table_name]).fetchall()
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# 如果表不存在,则创建新表
@ -74,8 +73,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor()
if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
)
else:
cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
@ -85,13 +86,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
self.connect.commit()
def delete(self) -> bool:
cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit()
return True
@ -100,7 +105,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else:
cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名
@ -118,7 +125,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone()
if context:
if context[0]:

View File

View File

@ -1,7 +1,7 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config
@ -83,6 +83,30 @@ class DocumentChunkDao:
result = document_chunks.all()
return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
count = document_chunks.scalar()
return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session()
# updated_space = session.merge(document)

View File

@ -1,11 +1,13 @@
import os
import shutil
from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile
from fastapi import APIRouter, File, UploadFile, Request, Form
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
@ -74,18 +76,43 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload")
async def document_sync(space_name: str, file: UploadFile = File(...)):
async def document_upload(
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
tmp.write(file.read())
tmp_path = tmp.name
tmp_content = tmp.read()
return {"file_path": tmp_path, "file_content": tmp_content}
Result.succ([])
if doc_file:
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name))
with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
) as tmp:
tmp.write(await doc_file.read())
tmp_path = tmp.name
shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
request = KnowledgeDocumentRequest()
request.doc_name = doc_name
request.doc_type = doc_type
request.content = (
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
return Result.succ([])
return Result.faild(code="E000X", msg=f"doc_file is None")
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
return Result.faild(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/sync")

View File

@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config
@ -92,15 +92,41 @@ class KnowledgeDocumentDao:
result = knowledge_documents.all()
return result
def get_knowledge_documents_count(self, query):
session = self.Session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
count = knowledge_documents.scalar()
return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,))
self.conn.commit()
cursor.close()
#
# def delete_knowledge_document(self, document_id: int):
# cursor = self.conn.cursor()
# query = "DELETE FROM knowledge_document WHERE id = %s"
# cursor.execute(query, (document_id,))
# self.conn.commit()
# cursor.close()

View File

@ -25,6 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
)
from enum import Enum
from pilot.openapi.knowledge.request.knowledge_response import (
ChunkQueryResponse,
DocumentQueryResponse,
)
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
@ -72,6 +76,7 @@ class KnowledgeService:
status=SyncStatus.TODO.name,
last_sync=datetime.now(),
content=request.content,
result="",
)
knowledge_document_dao.create_knowledge_document(document)
return True
@ -93,9 +98,13 @@ class KnowledgeService:
space=space,
status=request.status,
)
return knowledge_document_dao.get_knowledge_documents(
res = DocumentQueryResponse()
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
return res
"""sync knowledge document chunk into vector store"""
@ -106,6 +115,8 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name:
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync")
client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
@ -164,9 +175,13 @@ class KnowledgeService:
doc_name=request.doc_name,
doc_type=request.doc_type,
)
return document_chunk_dao.get_document_chunks(
res = ChunkQueryResponse()
res.data = document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
)
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(

View File

@ -1,6 +1,7 @@
from typing import List
from pydantic import BaseModel
from fastapi import UploadFile
class KnowledgeQueryRequest(BaseModel):
@ -26,11 +27,14 @@ class KnowledgeSpaceRequest(BaseModel):
class KnowledgeDocumentRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str
doc_name: str = None
"""doc_type: doc type"""
doc_type: str
doc_type: str = None
"""content: content"""
content: str = None
"""content: content"""
source: str = None
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int

View File

@ -0,0 +1,23 @@
from typing import List
from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
"""page: current page"""
page: int = None
class DocumentQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
"""page: current page"""
page: int = None

View File

@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1):
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i: j + 1]
return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
return cleaned_output

View File

@ -60,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self,
chat_mode,
chat_session_id,
current_user_input,
self,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
@ -172,11 +172,18 @@ class BaseChat(ABC):
print("[TEST: output]:", rsp_str)
### output parse
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
self.prompt_template.sep)
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
rsp_str, self.prompt_template.sep
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
@ -236,7 +243,9 @@ class BaseChat(ABC):
system_convs = self.current_message.get_system_conv()
system_text = ""
for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
return system_text
def __load_user_message(self):
@ -250,13 +259,16 @@ class BaseChat(ABC):
example_text = ""
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
example_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
round_message["type"]
+ ":"
+ round_message["data"]["content"]
+ self.prompt_template.sep
)
return example_text
@ -268,37 +280,46 @@ class BaseChat(ABC):
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]['messages']:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ViewMessage.type,
SystemMessage.type,
]:
history_text += (
first_message['type']
+ ":"
+ first_message['data']['content']
+ self.prompt_template.sep
first_message["type"]
+ ":"
+ first_message["data"]["content"]
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
for round_message in round_conv["messages"]:
if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
round_message["type"]
+ ":"
+ round_message["data"]["content"]
+ self.prompt_template.sep
)
else:
### user all history
for conversation in self.history_message:
for message in conversation['messages']:
for message in conversation["messages"]:
### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
if not message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += (
message['type']
+ ":"
+ message['data']['content']
+ self.prompt_template.sep
message["type"]
+ ":"
+ message["data"]["content"]
+ self.prompt_template.sep
)
return history_text

View File

@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary

View File

@ -18,7 +18,7 @@ from pilot.configs.model_config import (
LOGDIR,
)
from pilot.scene.chat_knowledge.default.prompt import prompt
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
CFG = Config()

View File

@ -98,9 +98,10 @@ class OnceConversation:
system_convs.append(message)
return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = ""
if hasattr(once, 'start_date') and once.start_date:
if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else:

View File

@ -23,15 +23,23 @@ from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
@ -55,23 +63,27 @@ app.add_middleware(
)
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.include_router(knowledge_router)
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
signal.signal(signal.SIGINT, signal_handler)
# init server config
args = parser.parse_args()
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -9,7 +9,8 @@ import sys
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
global_counter = 0
@ -41,11 +42,11 @@ class ModelWorker:
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
@ -60,22 +61,22 @@ class ModelWorker:
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
@ -107,23 +108,23 @@ worker = ModelWorker(
)
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# from pilot.openapi.knowledge.knowledge_controller import router
#
# app.include_router(router)
#
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "http://localhost:3000",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class PromptRequest(BaseModel):

View File

@ -40,6 +40,7 @@ def server_init(args):
cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)