diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql new file mode 100644 index 000000000..2c50e42b4 --- /dev/null +++ b/assets/schema/knowledge_management.sql @@ -0,0 +1,41 @@ +CREATE TABLE `knowledge_space` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge space table'; + +CREATE TABLE `knowledge_document` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document table'; + +CREATE TABLE `document_chunk` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document chunk detail' \ No newline at end of file diff --git a/datacenter/app/agents/[agentId]/page.tsx b/datacenter/app/agents/[agentId]/page.tsx index 61a99ee10..848c51f40 100644 --- a/datacenter/app/agents/[agentId]/page.tsx +++ b/datacenter/app/agents/[agentId]/page.tsx @@ -19,7 +19,7 @@ const AgentPage = (props) => { }); const { history, handleChatSubmit } = useAgentChat({ - queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`, + queryAgentURL: `http://localhost:5000/v1/chat/completions`, queryBody: { conv_uid: props.params?.agentId, chat_mode: props.searchParams?.scene || 'chat_normal', diff --git a/datacenter/app/agents/page.tsx b/datacenter/app/agents/page.tsx index d808b340d..ad40a354e 100644 --- a/datacenter/app/agents/page.tsx +++ b/datacenter/app/agents/page.tsx @@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({ const Agents = () => { const { handleChatSubmit, history } = useAgentChat({ - queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`, + queryAgentURL: `http://localhost:5000/v1/chat/completions`, }); const data = [ diff --git a/pilot/embedding_engine/markdown_embedding.py b/pilot/embedding_engine/markdown_embedding.py index e9a97dce9..0d70ba34f 100644 --- a/pilot/embedding_engine/markdown_embedding.py +++ b/pilot/embedding_engine/markdown_embedding.py @@ -6,12 +6,11 @@ from typing import List import markdown from bs4 import BeautifulSoup from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -30,12 +29,20 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=100, - ) - return loader.load_and_split(textsplitter) + + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/pdf_embedding.py b/pilot/embedding_engine/pdf_embedding.py index ea4276460..2b8f244e3 100644 --- a/pilot/embedding_engine/pdf_embedding.py +++ b/pilot/embedding_engine/pdf_embedding.py @@ -4,7 +4,7 @@ from typing import List from langchain.document_loaders import PyPDFLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register @@ -28,12 +28,24 @@ class PDFEmbedding(SourceEmbedding): # textsplitter = CHNDocumentSplitter( # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=100, - ) - return loader.load_and_split(textsplitter) + # textsplitter = SpacyTextSplitter( + # pipeline="zh_core_web_sm", + # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + # chunk_overlap=100, + # ) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 485083d1c..da4390849 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -4,10 +4,11 @@ from typing import List from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -25,12 +26,24 @@ class PPTEmbedding(SourceEmbedding): def read(self): """Load from ppt path.""" loader = UnstructuredPowerPointLoader(self.file_path) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=200, - ) - return loader.load_and_split(textsplitter) + # textsplitter = SpacyTextSplitter( + # pipeline="zh_core_web_sm", + # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + # chunk_overlap=200, + # ) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 34fc48450..55988a240 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -2,12 +2,12 @@ # -*- coding: utf-8 -*- from typing import List -from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader +from langchain.document_loaders import UnstructuredWordDocumentLoader from langchain.schema import Document +from langchain.text_splitter import CharacterTextSplitter, SpacyTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -25,10 +25,19 @@ class WordEmbedding(SourceEmbedding): def read(self): """Load from word path.""" loader = UnstructuredWordDocumentLoader(self.file_path) - textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE - ) - return loader.load_and_split(textsplitter) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 6aa7c344c..6ccba5c4f 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = 'chat_history' +table_name = "chat_history" CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): - def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -29,10 +28,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): self.__init_chat_history_tables() def __init_chat_history_tables(self): - # 检查表是否存在 - result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", - [table_name]).fetchall() + result = self.connect.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] + ).fetchall() if not result: # 如果表不存在,则创建新表 @@ -74,8 +73,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: - cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) + cursor.execute( + "UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id], + ) else: cursor.execute( "INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)", @@ -85,13 +86,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): def clear(self) -> None: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() self.connect.commit() def delete(self) -> bool: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() return True @@ -100,7 +105,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() if user_name: - cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) + cursor.execute( + "SELECT * FROM chat_history where user_name=? limit 20", [user_name] + ) else: cursor.execute("SELECT * FROM chat_history limit 20") # 获取查询结果字段名 @@ -118,7 +125,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() - cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) context = cursor.fetchone() if context: if context[0]: diff --git a/pilot/openapi/knowledge/__init__.py b/pilot/openapi/knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/openapi/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py index cb728e85c..d67f4a01a 100644 --- a/pilot/openapi/knowledge/document_chunk_dao.py +++ b/pilot/openapi/knowledge/document_chunk_dao.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -83,6 +83,30 @@ class DocumentChunkDao: result = document_chunks.all() return result + def get_document_chunks_count(self, query: DocumentChunkEntity): + session = self.Session() + document_chunks = session.query(func.count(DocumentChunkEntity.id)) + if query.id is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) + if query.document_id is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.document_id == query.document_id + ) + if query.doc_type is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_type == query.doc_type + ) + if query.doc_name is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_name == query.doc_name + ) + if query.meta_info is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.meta_info == query.meta_info + ) + count = document_chunks.scalar() + return count + # def update_knowledge_document(self, document:KnowledgeDocumentEntity): # session = self.Session() # updated_space = session.merge(document) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index bebbc8a3f..26323e098 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -1,11 +1,13 @@ +import os +import shutil from tempfile import NamedTemporaryFile -from fastapi import APIRouter, File, UploadFile +from fastapi import APIRouter, File, UploadFile, Request, Form from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.openapi.api_v1.api_view_model import Result from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding @@ -74,18 +76,43 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/document/upload") -async def document_sync(space_name: str, file: UploadFile = File(...)): +async def document_upload( + space_name: str, + doc_name: str = Form(...), + doc_type: str = Form(...), + doc_file: UploadFile = File(...), +): print(f"/document/upload params: {space_name}") try: - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(file.read()) - tmp_path = tmp.name - tmp_content = tmp.read() - - return {"file_path": tmp_path, "file_content": tmp_content} - Result.succ([]) + if doc_file: + if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)): + os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)) + with NamedTemporaryFile( + dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False + ) as tmp: + tmp.write(await doc_file.read()) + tmp_path = tmp.name + shutil.move( + tmp_path, + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) + request = KnowledgeDocumentRequest() + request.doc_name = doc_name + request.doc_type = doc_type + request.content = ( + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) + knowledge_space_service.create_knowledge_document( + space=space_name, request=request + ) + return Result.succ([]) + return Result.faild(code="E000X", msg=f"doc_file is None") except Exception as e: - return Result.faild(code="E000X", msg=f"document sync error {e}") + return Result.faild(code="E000X", msg=f"document add error {e}") @router.post("/knowledge/{space_name}/document/sync") diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index 1f7afc401..f99b81a72 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -92,15 +92,41 @@ class KnowledgeDocumentDao: result = knowledge_documents.all() return result + def get_knowledge_documents_count(self, query): + session = self.Session() + knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) + if query.id is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.id == query.id + ) + if query.doc_name is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_name == query.doc_name + ) + if query.doc_type is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_type == query.doc_type + ) + if query.space is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.space == query.space + ) + if query.status is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.status == query.status + ) + count = knowledge_documents.scalar() + return count + def update_knowledge_document(self, document: KnowledgeDocumentEntity): session = self.Session() updated_space = session.merge(document) session.commit() return updated_space.id - - def delete_knowledge_document(self, document_id: int): - cursor = self.conn.cursor() - query = "DELETE FROM knowledge_document WHERE id = %s" - cursor.execute(query, (document_id,)) - self.conn.commit() - cursor.close() + # + # def delete_knowledge_document(self, document_id: int): + # cursor = self.conn.cursor() + # query = "DELETE FROM knowledge_document WHERE id = %s" + # cursor.execute(query, (document_id,)) + # self.conn.commit() + # cursor.close() diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index c41630ede..49e6a0fa3 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -25,6 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import ( ) from enum import Enum +from pilot.openapi.knowledge.request.knowledge_response import ( + ChunkQueryResponse, + DocumentQueryResponse, +) knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -72,6 +76,7 @@ class KnowledgeService: status=SyncStatus.TODO.name, last_sync=datetime.now(), content=request.content, + result="", ) knowledge_document_dao.create_knowledge_document(document) return True @@ -93,9 +98,13 @@ class KnowledgeService: space=space, status=request.status, ) - return knowledge_document_dao.get_knowledge_documents( + res = DocumentQueryResponse() + res.data = knowledge_document_dao.get_knowledge_documents( query, page=request.page, page_size=request.page_size ) + res.total = knowledge_document_dao.get_knowledge_documents_count(query) + res.page = request.page + return res """sync knowledge document chunk into vector store""" @@ -106,6 +115,8 @@ class KnowledgeService: space=space_name, ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] + if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: + raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync") client = KnowledgeEmbedding( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), @@ -164,9 +175,13 @@ class KnowledgeService: doc_name=request.doc_name, doc_type=request.doc_type, ) - return document_chunk_dao.get_document_chunks( + res = ChunkQueryResponse() + res.data = document_chunk_dao.get_document_chunks( query, page=request.page, page_size=request.page_size ) + res.total = document_chunk_dao.get_document_chunks_count(query) + res.page = request.page + return res def async_doc_embedding(self, client, chunk_docs, doc): logger.info( diff --git a/pilot/openapi/knowledge/request/__init__.py b/pilot/openapi/knowledge/request/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/openapi/knowledge/request/knowledge_request.py b/pilot/openapi/knowledge/request/knowledge_request.py index 1c5916f7c..d393ca9b7 100644 --- a/pilot/openapi/knowledge/request/knowledge_request.py +++ b/pilot/openapi/knowledge/request/knowledge_request.py @@ -1,6 +1,7 @@ from typing import List from pydantic import BaseModel +from fastapi import UploadFile class KnowledgeQueryRequest(BaseModel): @@ -26,11 +27,14 @@ class KnowledgeSpaceRequest(BaseModel): class KnowledgeDocumentRequest(BaseModel): """doc_name: doc path""" - doc_name: str + doc_name: str = None """doc_type: doc type""" - doc_type: str + doc_type: str = None """content: content""" content: str = None + """content: content""" + source: str = None + """text_chunk_size: text_chunk_size""" # text_chunk_size: int diff --git a/pilot/openapi/knowledge/request/knowledge_response.py b/pilot/openapi/knowledge/request/knowledge_response.py new file mode 100644 index 000000000..7fbf36155 --- /dev/null +++ b/pilot/openapi/knowledge/request/knowledge_response.py @@ -0,0 +1,23 @@ +from typing import List + +from pydantic import BaseModel + + +class ChunkQueryResponse(BaseModel): + """data: data""" + + data: List = None + """total: total size""" + total: int = None + """page: current page""" + page: int = None + + +class DocumentQueryResponse(BaseModel): + """data: data""" + + data: List = None + """total: total size""" + total: int = None + """page: current page""" + page: int = None diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 6f0da16fa..7202d3019 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -122,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): + for j, c in enumerate(s[i + 1 :], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -130,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i: j + 1] + return s[i : j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -147,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] + cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] + cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -158,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 6c005dbf3..dff8528ac 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -60,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -172,11 +172,18 @@ class BaseChat(ABC): print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, - self.prompt_template.sep) + ai_response_text = ( + self.prompt_template.output_parser.parse_model_nostream_resp( + rsp_str, self.prompt_template.sep + ) + ) ### model result deal self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -236,7 +243,9 @@ class BaseChat(ABC): system_convs = self.current_message.get_system_conv() system_text = "" for system_conv in system_convs: - system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + system_text += ( + system_conv.type + ":" + system_conv.content + self.prompt_template.sep + ) return system_text def __load_user_message(self): @@ -250,13 +259,16 @@ class BaseChat(ABC): example_text = "" if self.prompt_template.example_selector: for round_conv in self.prompt_template.example_selector.examples(): - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: example_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) return example_text @@ -268,37 +280,46 @@ class BaseChat(ABC): f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" ) if len(self.history_message) > self.chat_retention_rounds: - for first_message in self.history_message[0]['messages']: - if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + for first_message in self.history_message[0]["messages"]: + if not first_message["type"] in [ + ViewMessage.type, + SystemMessage.type, + ]: history_text += ( - first_message['type'] - + ":" - + first_message['data']['content'] - + self.prompt_template.sep + first_message["type"] + + ":" + + first_message["data"]["content"] + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for round_conv in self.history_message[-index:]: - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) else: ### user all history for conversation in self.history_message: - for message in conversation['messages']: + for message in conversation["messages"]: ### histroy message not have promot and view info - if not message['type'] in [SystemMessage.type, ViewMessage.type]: + if not message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - message['type'] - + ":" - + message['data']['content'] - + self.prompt_template.sep + message["type"] + + ":" + + message["data"]["content"] + + self.prompt_template.sep ) return history_text diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index 63d00b36c..436860d7f 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge +from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 404a9347b..1c4786725 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -18,7 +18,7 @@ from pilot.configs.model_config import ( LOGDIR, ) -from pilot.scene.chat_knowledge.default.prompt import prompt +from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding CFG = Config() diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 972331bbb..51ec2643e 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -98,9 +98,10 @@ class OnceConversation: system_convs.append(message) return system_convs + def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if hasattr(once, 'start_date') and once.start_date: + if hasattr(once, "start_date") and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index b69ca3ab3..fbae996c1 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,15 +23,23 @@ from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router + from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() logger = build_logger("webserver", LOGDIR + "webserver.log") +def signal_handler(sig, frame): + print("in order to avoid chroma db atexit problem") + os._exit(0) + + def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( *args, **kwargs, @@ -55,23 +63,27 @@ app.add_middleware( ) app.mount("/static", StaticFiles(directory=static_file_path), name="static") - +app.include_router(knowledge_router) app.include_router(api_v1) app.add_exception_handler(RequestValidationError, validation_exception_handler) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument( + "--model_list_mode", type=str, default="once", choices=["once", "reload"] + ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) + parser.add_argument("--port", type=int, default=5000) parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--share", default=False, action="store_true") + signal.signal(signal.SIGINT, signal_handler) # init server config args = parser.parse_args() server_init(args) CFG.NEW_SERVER_MODE = True import uvicorn - uvicorn.run(app, host="0.0.0.0", port=5000) + + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 67e6183b2..d87540a8e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -9,7 +9,8 @@ import sys import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware + +# from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel global_counter = 0 @@ -41,11 +42,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -60,22 +61,22 @@ class ModelWorker: def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, @@ -107,23 +108,23 @@ worker = ModelWorker( ) app = FastAPI() -from pilot.openapi.knowledge.knowledge_controller import router - -app.include_router(router) - -origins = [ - "http://localhost", - "http://localhost:8000", - "http://localhost:3000", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# from pilot.openapi.knowledge.knowledge_controller import router +# +# app.include_router(router) +# +# origins = [ +# "http://localhost", +# "http://localhost:8000", +# "http://localhost:3000", +# ] +# +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) class PromptRequest(BaseModel): diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index c76984c37..33486c439 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -40,6 +40,7 @@ def server_init(args): cfg = Config() from pilot.server.llmserver import worker + worker.start_check() load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler)