Merge branch 'llm_framework' into dev_ty_06_end

# Conflicts:
#	pilot/openapi/api_v1/api_v1.py
#	pilot/server/dbgpt_server.py
This commit is contained in:
tuyang.yhj 2023-06-30 09:58:32 +08:00
commit 6f8f182d1d
24 changed files with 384 additions and 137 deletions

View File

@ -0,0 +1,41 @@
CREATE TABLE `knowledge_space` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
`desc` varchar(500) NOT NULL COMMENT 'description',
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge space table';
CREATE TABLE `knowledge_document` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
`chunk_size` int NOT NULL COMMENT 'chunk size',
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document table';
CREATE TABLE `document_chunk` (
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`document_id` int NOT NULL COMMENT 'document parent id',
`content` longtext NOT NULL COMMENT 'chunk content',
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document chunk detail'

View File

@ -19,7 +19,7 @@ const AgentPage = (props) => {
}); });
const { history, handleChatSubmit } = useAgentChat({ const { history, handleChatSubmit } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`, queryAgentURL: `http://localhost:5000/v1/chat/completions`,
queryBody: { queryBody: {
conv_uid: props.params?.agentId, conv_uid: props.params?.agentId,
chat_mode: props.searchParams?.scene || 'chat_normal', chat_mode: props.searchParams?.scene || 'chat_normal',

View File

@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({
const Agents = () => { const Agents = () => {
const { handleChatSubmit, history } = useAgentChat({ const { handleChatSubmit, history } = useAgentChat({
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`, queryAgentURL: `http://localhost:5000/v1/chat/completions`,
}); });
const data = [ const data = [

View File

@ -6,12 +6,11 @@ from typing import List
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()
@ -30,12 +29,20 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", if CFG.LANGUAGE == "en":
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, text_splitter = CharacterTextSplitter(
chunk_overlap=100, chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
) chunk_overlap=20,
return loader.load_and_split(textsplitter) length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -4,7 +4,7 @@ from typing import List
from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
@ -28,12 +28,24 @@ class PDFEmbedding(SourceEmbedding):
# textsplitter = CHNDocumentSplitter( # textsplitter = CHNDocumentSplitter(
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter( # textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", # pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100, # chunk_overlap=100,
) # )
return loader.load_and_split(textsplitter) if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -4,10 +4,11 @@ from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()
@ -25,12 +26,24 @@ class PPTEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from ppt path.""" """Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path) loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter( # textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", # pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=200, # chunk_overlap=200,
) # )
return loader.load_and_split(textsplitter) if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -2,12 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import List from typing import List
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter, SpacyTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()
@ -25,10 +25,19 @@ class WordEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from word path.""" """Load from word path."""
loader = UnstructuredWordDocumentLoader(self.file_path) loader = UnstructuredWordDocumentLoader(self.file_path)
textsplitter = CHNDocumentSplitter( if CFG.LANGUAGE == "en":
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE text_splitter = CharacterTextSplitter(
) chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
return loader.load_and_split(textsplitter) chunk_overlap=20,
length_function=len,
)
else:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
return loader.load_and_split(text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message") default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history' table_name = "chat_history"
CFG = Config() CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory): class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True) os.makedirs(default_db_path, exist_ok=True)
@ -29,10 +28,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables() self.__init_chat_history_tables()
def __init_chat_history_tables(self): def __init_chat_history_tables(self):
# 检查表是否存在 # 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", result = self.connect.execute(
[table_name]).fetchall() "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
@ -74,8 +73,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(_conversation_to_dic(once_message)) conversations.append(_conversation_to_dic(once_message))
cursor = self.connect.cursor() cursor = self.connect.cursor()
if context: if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", cursor.execute(
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) "UPDATE chat_history set messages=? where conv_uid=?",
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
)
else: else:
cursor.execute( cursor.execute(
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)", "INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
@ -85,13 +86,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None: def clear(self) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def delete(self) -> bool: def delete(self) -> bool:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
return True return True
@ -100,7 +105,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
if user_name: if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else: else:
cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名 # 获取查询结果字段名
@ -118,7 +125,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
def get_messages(self) -> List[OnceConversation]: def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone() context = cursor.fetchone()
if context: if context:
if context[0]: if context[0]:

View File

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
@ -83,6 +83,30 @@ class DocumentChunkDao:
result = document_chunks.all() result = document_chunks.all()
return result return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
count = document_chunks.scalar()
return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity): # def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session() # session = self.Session()
# updated_space = session.merge(document) # updated_space = session.merge(document)

View File

@ -1,11 +1,13 @@
import os
import shutil
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile, Request, Form
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.openapi.api_v1.api_view_model import Result from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
@ -74,18 +76,43 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
async def document_sync(space_name: str, file: UploadFile = File(...)): async def document_upload(
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
with NamedTemporaryFile(delete=False) as tmp: if doc_file:
tmp.write(file.read()) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)):
tmp_path = tmp.name os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name))
tmp_content = tmp.read() with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
return {"file_path": tmp_path, "file_content": tmp_content} ) as tmp:
Result.succ([]) tmp.write(await doc_file.read())
tmp_path = tmp.name
shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
request = KnowledgeDocumentRequest()
request.doc_name = doc_name
request.doc_type = doc_type
request.content = (
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
return Result.succ([])
return Result.faild(code="E000X", msg=f"doc_file is None")
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}") return Result.faild(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
@ -92,15 +92,41 @@ class KnowledgeDocumentDao:
result = knowledge_documents.all() result = knowledge_documents.all()
return result return result
def get_knowledge_documents_count(self, query):
session = self.Session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
count = knowledge_documents.scalar()
return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity): def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.Session()
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
#
def delete_knowledge_document(self, document_id: int): # def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor() # cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s" # query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,)) # cursor.execute(query, (document_id,))
self.conn.commit() # self.conn.commit()
cursor.close() # cursor.close()

View File

@ -25,6 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
) )
from enum import Enum from enum import Enum
from pilot.openapi.knowledge.request.knowledge_response import (
ChunkQueryResponse,
DocumentQueryResponse,
)
knowledge_space_dao = KnowledgeSpaceDao() knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao() knowledge_document_dao = KnowledgeDocumentDao()
@ -72,6 +76,7 @@ class KnowledgeService:
status=SyncStatus.TODO.name, status=SyncStatus.TODO.name,
last_sync=datetime.now(), last_sync=datetime.now(),
content=request.content, content=request.content,
result="",
) )
knowledge_document_dao.create_knowledge_document(document) knowledge_document_dao.create_knowledge_document(document)
return True return True
@ -93,9 +98,13 @@ class KnowledgeService:
space=space, space=space,
status=request.status, status=request.status,
) )
return knowledge_document_dao.get_knowledge_documents( res = DocumentQueryResponse()
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size query, page=request.page, page_size=request.page_size
) )
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
return res
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
@ -106,6 +115,8 @@ class KnowledgeService:
space=space_name, space=space_name,
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = knowledge_document_dao.get_knowledge_documents(query)[0]
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name:
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync")
client = KnowledgeEmbedding( client = KnowledgeEmbedding(
knowledge_source=doc.content, knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),
@ -164,9 +175,13 @@ class KnowledgeService:
doc_name=request.doc_name, doc_name=request.doc_name,
doc_type=request.doc_type, doc_type=request.doc_type,
) )
return document_chunk_dao.get_document_chunks( res = ChunkQueryResponse()
res.data = document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size query, page=request.page, page_size=request.page_size
) )
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
logger.info( logger.info(

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from fastapi import UploadFile
class KnowledgeQueryRequest(BaseModel): class KnowledgeQueryRequest(BaseModel):
@ -26,11 +27,14 @@ class KnowledgeSpaceRequest(BaseModel):
class KnowledgeDocumentRequest(BaseModel): class KnowledgeDocumentRequest(BaseModel):
"""doc_name: doc path""" """doc_name: doc path"""
doc_name: str doc_name: str = None
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str doc_type: str = None
"""content: content""" """content: content"""
content: str = None content: str = None
"""content: content"""
source: str = None
"""text_chunk_size: text_chunk_size""" """text_chunk_size: text_chunk_size"""
# text_chunk_size: int # text_chunk_size: int

View File

@ -0,0 +1,23 @@
from typing import List
from pydantic import BaseModel
class ChunkQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
"""page: current page"""
page: int = None
class DocumentQueryResponse(BaseModel):
"""data: data"""
data: List = None
"""total: total size"""
total: int = None
"""page: current page"""
page: int = None

View File

@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s): def __extract_json(slef, s):
i = s.index("{") i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1:], start=i + 1): for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}": if c == "}":
count -= 1 count -= 1
elif c == "{": elif c == "{":
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0: if count == 0:
break break
assert count == 0 # 检查是否找到最后一个'}' assert count == 0 # 检查是否找到最后一个'}'
return s[i: j + 1] return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):] cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):] cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\n", " ") .replace("\n", " ")
.replace("\\n", " ") .replace("\\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
return cleaned_output return cleaned_output

View File

@ -60,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self,
chat_mode, chat_mode,
chat_session_id, chat_session_id,
current_user_input, current_user_input,
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -172,11 +172,18 @@ class BaseChat(ABC):
print("[TEST: output]:", rsp_str) print("[TEST: output]:", rsp_str)
### output parse ### output parse
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, ai_response_text = (
self.prompt_template.sep) self.prompt_template.output_parser.parse_model_nostream_resp(
rsp_str, self.prompt_template.sep
)
)
### model result deal ### model result deal
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response) result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
@ -236,7 +243,9 @@ class BaseChat(ABC):
system_convs = self.current_message.get_system_conv() system_convs = self.current_message.get_system_conv()
system_text = "" system_text = ""
for system_conv in system_convs: for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep system_text += (
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
)
return system_text return system_text
def __load_user_message(self): def __load_user_message(self):
@ -250,13 +259,16 @@ class BaseChat(ABC):
example_text = "" example_text = ""
if self.prompt_template.example_selector: if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples(): for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']: for round_message in round_conv["messages"]:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
example_text += ( example_text += (
round_message['type'] round_message["type"]
+ ":" + ":"
+ round_message['data']['content'] + round_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
return example_text return example_text
@ -268,37 +280,46 @@ class BaseChat(ABC):
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
) )
if len(self.history_message) > self.chat_retention_rounds: if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]['messages']: for first_message in self.history_message[0]["messages"]:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: if not first_message["type"] in [
ViewMessage.type,
SystemMessage.type,
]:
history_text += ( history_text += (
first_message['type'] first_message["type"]
+ ":" + ":"
+ first_message['data']['content'] + first_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
index = self.chat_retention_rounds - 1 index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]: for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']: for round_message in round_conv["messages"]:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: if not round_message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += ( history_text += (
round_message['type'] round_message["type"]
+ ":" + ":"
+ round_message['data']['content'] + round_message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
else: else:
### user all history ### user all history
for conversation in self.history_message: for conversation in self.history_message:
for message in conversation['messages']: for message in conversation["messages"]:
### histroy message not have promot and view info ### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]: if not message["type"] in [
SystemMessage.type,
ViewMessage.type,
]:
history_text += ( history_text += (
message['type'] message["type"]
+ ":" + ":"
+ message['data']['content'] + message["data"]["content"]
+ self.prompt_template.sep + self.prompt_template.sep
) )
return history_text return history_text

View File

@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary

View File

@ -18,7 +18,7 @@ from pilot.configs.model_config import (
LOGDIR, LOGDIR,
) )
from pilot.scene.chat_knowledge.default.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
CFG = Config() CFG = Config()

View File

@ -98,9 +98,10 @@ class OnceConversation:
system_convs.append(message) system_convs.append(message)
return system_convs return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict: def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = "" start_str: str = ""
if hasattr(once, 'start_date') and once.start_date: if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime): if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else: else:

View File

@ -23,15 +23,23 @@ from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static") static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config() CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log") logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, **kwargs, *args, **kwargs,
@ -55,23 +63,27 @@ app.add_middleware(
) )
app.mount("/static", StaticFiles(directory=static_file_path), name="static") app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.include_router(knowledge_router)
app.include_router(api_v1) app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config # old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
signal.signal(signal.SIGINT, signal_handler)
# init server config # init server config
args = parser.parse_args() args = parser.parse_args()
server_init(args) server_init(args)
CFG.NEW_SERVER_MODE = True CFG.NEW_SERVER_MODE = True
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -9,7 +9,8 @@ import sys
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
global_counter = 0 global_counter = 0
@ -41,11 +42,11 @@ class ModelWorker:
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr( if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length" self.model.config, "max_sequence_length"
): ):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr( elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings" self.model.config, "max_position_embeddings"
): ):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
@ -60,22 +61,22 @@ class ModelWorker:
def get_queue_length(self): def get_queue_length(self):
if ( if (
model_semaphore is None model_semaphore is None
or model_semaphore._value is None or model_semaphore._value is None
or model_semaphore._waiters is None or model_semaphore._waiters is None
): ):
return 0 return 0
else: else:
( (
CFG.LIMIT_MODEL_CONCURRENCY CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value - model_semaphore._value
+ len(model_semaphore._waiters) + len(model_semaphore._waiters)
) )
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,
@ -107,23 +108,23 @@ worker = ModelWorker(
) )
app = FastAPI() app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router # from pilot.openapi.knowledge.knowledge_controller import router
#
app.include_router(router) # app.include_router(router)
#
origins = [ # origins = [
"http://localhost", # "http://localhost",
"http://localhost:8000", # "http://localhost:8000",
"http://localhost:3000", # "http://localhost:3000",
] # ]
#
app.add_middleware( # app.add_middleware(
CORSMiddleware, # CORSMiddleware,
allow_origins=origins, # allow_origins=origins,
allow_credentials=True, # allow_credentials=True,
allow_methods=["*"], # allow_methods=["*"],
allow_headers=["*"], # allow_headers=["*"],
) # )
class PromptRequest(BaseModel): class PromptRequest(BaseModel):

View File

@ -40,6 +40,7 @@ def server_init(args):
cfg = Config() cfg = Config()
from pilot.server.llmserver import worker from pilot.server.llmserver import worker
worker.start_check() worker.start_check()
load_native_plugins(cfg) load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)