From 4a57192879f6062be7775691f1b91a8bcbfed61e Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 11:20:16 +0800 Subject: [PATCH 01/12] feat:chunks page list 1.add document chunk list --- .../knowledge/request/knowledge_request.py | 8 +++++-- .../knowledge/request/knowledge_response.py | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 pilot/openapi/knowledge/request/knowledge_response.py diff --git a/pilot/openapi/knowledge/request/knowledge_request.py b/pilot/openapi/knowledge/request/knowledge_request.py index 1c5916f7c..d393ca9b7 100644 --- a/pilot/openapi/knowledge/request/knowledge_request.py +++ b/pilot/openapi/knowledge/request/knowledge_request.py @@ -1,6 +1,7 @@ from typing import List from pydantic import BaseModel +from fastapi import UploadFile class KnowledgeQueryRequest(BaseModel): @@ -26,11 +27,14 @@ class KnowledgeSpaceRequest(BaseModel): class KnowledgeDocumentRequest(BaseModel): """doc_name: doc path""" - doc_name: str + doc_name: str = None """doc_type: doc type""" - doc_type: str + doc_type: str = None """content: content""" content: str = None + """content: content""" + source: str = None + """text_chunk_size: text_chunk_size""" # text_chunk_size: int diff --git a/pilot/openapi/knowledge/request/knowledge_response.py b/pilot/openapi/knowledge/request/knowledge_response.py new file mode 100644 index 000000000..71d426643 --- /dev/null +++ b/pilot/openapi/knowledge/request/knowledge_response.py @@ -0,0 +1,22 @@ +from typing import List + +from pydantic import BaseModel + + +class ChunkQueryResponse(BaseModel): + """data: data""" + data: List = None + """total: total size""" + total: int = None + """page: current page""" + page: int = None + + +class DocumentQueryResponse(BaseModel): + """data: data""" + data: List = None + """total: total size""" + total: int = None + """page: current page""" + page: int = None + From 06688b1c367b57aa9fbfb8995daf87ae5c2416b1 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 11:21:48 +0800 Subject: [PATCH 02/12] feat:chunks page list add chunk page list --- pilot/openapi/knowledge/__init__.py | 0 pilot/openapi/knowledge/document_chunk_dao.py | 26 +++++++++- .../openapi/knowledge/knowledge_controller.py | 31 +++++++----- .../knowledge/knowledge_document_dao.py | 48 ++++++++++++++----- pilot/openapi/knowledge/knowledge_service.py | 13 ++++- pilot/openapi/knowledge/request/__init__.py | 0 6 files changed, 93 insertions(+), 25 deletions(-) create mode 100644 pilot/openapi/knowledge/__init__.py create mode 100644 pilot/openapi/knowledge/request/__init__.py diff --git a/pilot/openapi/knowledge/__init__.py b/pilot/openapi/knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/openapi/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py index cb728e85c..d67f4a01a 100644 --- a/pilot/openapi/knowledge/document_chunk_dao.py +++ b/pilot/openapi/knowledge/document_chunk_dao.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -83,6 +83,30 @@ class DocumentChunkDao: result = document_chunks.all() return result + def get_document_chunks_count(self, query: DocumentChunkEntity): + session = self.Session() + document_chunks = session.query(func.count(DocumentChunkEntity.id)) + if query.id is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) + if query.document_id is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.document_id == query.document_id + ) + if query.doc_type is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_type == query.doc_type + ) + if query.doc_name is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_name == query.doc_name + ) + if query.meta_info is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.meta_info == query.meta_info + ) + count = document_chunks.scalar() + return count + # def update_knowledge_document(self, document:KnowledgeDocumentEntity): # session = self.Session() # updated_space = session.merge(document) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index bebbc8a3f..1b452119a 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -1,11 +1,13 @@ +import os +import shutil from tempfile import NamedTemporaryFile -from fastapi import APIRouter, File, UploadFile +from fastapi import APIRouter, File, UploadFile, Request, Form from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.openapi.api_v1.api_view_model import Result from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding @@ -74,18 +76,25 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/document/upload") -async def document_sync(space_name: str, file: UploadFile = File(...)): +async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)): print(f"/document/upload params: {space_name}") try: - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(file.read()) - tmp_path = tmp.name - tmp_content = tmp.read() - - return {"file_path": tmp_path, "file_content": tmp_content} - Result.succ([]) + if doc_file: + with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp: + tmp.write(await doc_file.read()) + tmp_path = tmp.name + shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename)) + request = KnowledgeDocumentRequest() + request.doc_name = doc_name + request.doc_type = doc_type + request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename), + knowledge_space_service.create_knowledge_document( + space=space_name, request=request + ) + return Result.succ([]) + return Result.faild(code="E000X", msg=f"doc_file is None") except Exception as e: - return Result.faild(code="E000X", msg=f"document sync error {e}") + return Result.faild(code="E000X", msg=f"document add error {e}") @router.post("/knowledge/{space_name}/document/sync") diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index 1f7afc401..cad881e71 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -92,15 +92,41 @@ class KnowledgeDocumentDao: result = knowledge_documents.all() return result - def update_knowledge_document(self, document: KnowledgeDocumentEntity): + def get_knowledge_documents_count(self, query): session = self.Session() - updated_space = session.merge(document) - session.commit() - return updated_space.id + knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) + if query.id is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.id == query.id + ) + if query.doc_name is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_name == query.doc_name + ) + if query.doc_type is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_type == query.doc_type + ) + if query.space is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.space == query.space + ) + if query.status is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.status == query.status + ) + count = knowledge_documents.scalar() + return count - def delete_knowledge_document(self, document_id: int): - cursor = self.conn.cursor() - query = "DELETE FROM knowledge_document WHERE id = %s" - cursor.execute(query, (document_id,)) - self.conn.commit() - cursor.close() + # def update_knowledge_document(self, document: KnowledgeDocumentEntity): + # session = self.Session() + # updated_space = session.merge(document) + # session.commit() + # return updated_space.id + # + # def delete_knowledge_document(self, document_id: int): + # cursor = self.conn.cursor() + # query = "DELETE FROM knowledge_document WHERE id = %s" + # cursor.execute(query, (document_id,)) + # self.conn.commit() + # cursor.close() diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index c41630ede..9ee9f3c40 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -25,6 +25,7 @@ from pilot.openapi.knowledge.request.knowledge_request import ( ) from enum import Enum +from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -93,9 +94,13 @@ class KnowledgeService: space=space, status=request.status, ) - return knowledge_document_dao.get_knowledge_documents( + res = DocumentQueryResponse() + res.data = knowledge_document_dao.get_knowledge_documents( query, page=request.page, page_size=request.page_size ) + res.total = knowledge_document_dao.get_knowledge_documents_count(query) + res.page = request.page + return res """sync knowledge document chunk into vector store""" @@ -164,9 +169,13 @@ class KnowledgeService: doc_name=request.doc_name, doc_type=request.doc_type, ) - return document_chunk_dao.get_document_chunks( + res = ChunkQueryResponse() + res.data = document_chunk_dao.get_document_chunks( query, page=request.page, page_size=request.page_size ) + res.total = document_chunk_dao.get_document_chunks_count(query) + res.page = request.page + return res def async_doc_embedding(self, client, chunk_docs, doc): logger.info( diff --git a/pilot/openapi/knowledge/request/__init__.py b/pilot/openapi/knowledge/request/__init__.py new file mode 100644 index 000000000..e69de29bb From cdb616cc210a35959881ac8dd8eeca59dbcf77e1 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 13:52:53 +0800 Subject: [PATCH 03/12] style:format code style format code style --- pilot/memory/chat_history/duckdb_history.py | 48 ++++++---- pilot/openapi/api_v1/api_v1.py | 49 +++++++--- .../openapi/knowledge/knowledge_controller.py | 24 ++++- pilot/openapi/knowledge/knowledge_service.py | 5 +- .../knowledge/request/knowledge_response.py | 3 +- pilot/out_parser/base.py | 14 +-- pilot/scene/base_chat.py | 89 ++++++++++++------- pilot/scene/chat_execution/example.py | 4 +- pilot/scene/message.py | 3 +- pilot/server/dbgpt_server.py | 19 ++-- pilot/server/llmserver.py | 55 ++++++------ pilot/server/webserver_base.py | 1 + 12 files changed, 205 insertions(+), 109 deletions(-) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index de80a5bc2..659690120 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = 'chat_history' +table_name = "chat_history" CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): - def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -29,15 +28,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): self.__init_chat_history_tables() def __init_chat_history_tables(self): - # 检查表是否存在 - result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", - [table_name]).fetchall() + result = self.connect.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] + ).fetchall() if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)" + ) def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() @@ -47,6 +47,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return content[0] else: return None + def messages(self) -> List[OnceConversation]: context = self.__get_messages_by_conv_uid(self.chat_seesion_id) if context: @@ -62,23 +63,35 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: - cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) + cursor.execute( + "UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id], + ) else: - cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) + cursor.execute( + "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", + [ + self.chat_seesion_id, + "", + json.dumps(conversations, ensure_ascii=False), + ], + ) cursor.commit() self.connect.commit() def clear(self) -> None: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() self.connect.commit() def delete(self) -> bool: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() return True @@ -87,7 +100,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() if user_name: - cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) + cursor.execute( + "SELECT * FROM chat_history where user_name=? limit 20", [user_name] + ) else: cursor.execute("SELECT * FROM chat_history limit 20") # 获取查询结果字段名 @@ -103,10 +118,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return [] - - def get_messages(self)-> List[OnceConversation]: + def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() - cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) context = cursor.fetchone() if context: return json.loads(context[0]) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 4f2e23946..d391c6e41 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,17 +2,29 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks +from fastapi import ( + APIRouter, + Request, + Body, + status, + HTTPException, + Response, + BackgroundTasks, +) from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.openapi.api_v1.api_view_model import ( + Result, + ConversationVo, + MessageVo, + ChatSceneVo, +) from pilot.configs.config import Config from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest @@ -103,7 +115,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()): } if not chat.prompt_template.stream_out: - return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', - background=background_tasks) + return StreamingResponse( + no_stream_generator(chat), + headers=headers, + media_type="text/event-stream", + background=background_tasks, + ) else: - return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', - background=background_tasks) + return StreamingResponse( + stream_generator(chat), + headers=headers, + media_type="text/plain", + background=background_tasks, + ) def release_model_semaphore(): @@ -236,12 +256,15 @@ async def no_stream_generator(chat): msg = msg.replace("\n", "\\n") yield f"data: {msg}\n\n" + async def stream_generator(chat): model_response = chat.stream_call() if not CFG.NEW_SERVER_MODE: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" @@ -249,7 +272,9 @@ async def stream_generator(chat): else: for chunk in model_response: if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") @@ -259,4 +284,6 @@ async def stream_generator(chat): def message2Vo(message: dict, order) -> MessageVo: - return MessageVo(role=message['type'], context=message['data']['content'], order=order) + return MessageVo( + role=message["type"], context=message["data"]["content"], order=order + ) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index 1b452119a..aec612e9c 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/document/upload") -async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)): +async def document_upload( + space_name: str, + doc_name: str = Form(...), + doc_type: str = Form(...), + doc_file: UploadFile = File(...), +): print(f"/document/upload params: {space_name}") try: if doc_file: - with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp: + with NamedTemporaryFile( + dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False + ) as tmp: tmp.write(await doc_file.read()) tmp_path = tmp.name - shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename)) + shutil.move( + tmp_path, + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) request = KnowledgeDocumentRequest() request.doc_name = doc_name request.doc_type = doc_type - request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename), + request.content = ( + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) knowledge_space_service.create_knowledge_document( space=space_name, request=request ) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 9ee9f3c40..2f035fcb5 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import ( ) from enum import Enum -from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse +from pilot.openapi.knowledge.request.knowledge_response import ( + ChunkQueryResponse, + DocumentQueryResponse, +) knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() diff --git a/pilot/openapi/knowledge/request/knowledge_response.py b/pilot/openapi/knowledge/request/knowledge_response.py index 71d426643..7fbf36155 100644 --- a/pilot/openapi/knowledge/request/knowledge_response.py +++ b/pilot/openapi/knowledge/request/knowledge_response.py @@ -5,6 +5,7 @@ from pydantic import BaseModel class ChunkQueryResponse(BaseModel): """data: data""" + data: List = None """total: total size""" total: int = None @@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel): class DocumentQueryResponse(BaseModel): """data: data""" + data: List = None """total: total size""" total: int = None """page: current page""" page: int = None - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index ca308e92f..cd75c950c 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -122,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): + for j, c in enumerate(s[i + 1 :], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -130,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i: j + 1] + return s[i : j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -147,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] + cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] + cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -158,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 8e7b3dfe7..245851062 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -60,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -102,7 +102,9 @@ class BaseChat(ABC): ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.current_message.start_date = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) # TODO self.current_message.tokens = 0 @@ -168,11 +170,18 @@ class BaseChat(ABC): print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, - self.prompt_template.sep) + ai_response_text = ( + self.prompt_template.output_parser.parse_model_nostream_resp( + rsp_str, self.prompt_template.sep + ) + ) ### model result deal self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -232,7 +241,9 @@ class BaseChat(ABC): system_convs = self.current_message.get_system_conv() system_text = "" for system_conv in system_convs: - system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + system_text += ( + system_conv.type + ":" + system_conv.content + self.prompt_template.sep + ) return system_text def __load_user_message(self): @@ -246,13 +257,16 @@ class BaseChat(ABC): example_text = "" if self.prompt_template.example_selector: for round_conv in self.prompt_template.example_selector.examples(): - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: example_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) return example_text @@ -264,37 +278,46 @@ class BaseChat(ABC): f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" ) if len(self.history_message) > self.chat_retention_rounds: - for first_message in self.history_message[0]['messages']: - if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + for first_message in self.history_message[0]["messages"]: + if not first_message["type"] in [ + ViewMessage.type, + SystemMessage.type, + ]: history_text += ( - first_message['type'] - + ":" - + first_message['data']['content'] - + self.prompt_template.sep + first_message["type"] + + ":" + + first_message["data"]["content"] + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for round_conv in self.history_message[-index:]: - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) else: ### user all history for conversation in self.history_message: - for message in conversation['messages']: + for message in conversation["messages"]: ### histroy message not have promot and view info - if not message['type'] in [SystemMessage.type, ViewMessage.type]: + if not message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - message['type'] - + ":" - + message['data']['content'] - + self.prompt_template.sep + message["type"] + + ":" + + message["data"]["content"] + + self.prompt_template.sep ) return history_text diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index f50a7f546..9e2aee6a2 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], - [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] + [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}], + [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}], ] plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 972331bbb..51ec2643e 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -98,9 +98,10 @@ class OnceConversation: system_convs.append(message) return system_convs + def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if hasattr(once, 'start_date') and once.start_date: + if hasattr(once, "start_date") and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 6c22105cd..c4f9ad87e 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,9 +23,12 @@ from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router + from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() @@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log") def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( - *args, **kwargs, - swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', - swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + *args, + **kwargs, + swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", + swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css" ) @@ -55,14 +59,16 @@ app.add_middleware( ) app.mount("/static", StaticFiles(directory=static_file_path), name="static") -app.add_route("/test", "static/test.html") - +app.add_route("/test", "static/test.html") +app.include_router(knowledge_router) app.include_router(api_v1) app.add_exception_handler(RequestValidationError, validation_exception_handler) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument( + "--model_list_mode", type=str, default="once", choices=["once", "reload"] + ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") @@ -75,4 +81,5 @@ if __name__ == "__main__": server_init(args) CFG.NEW_SERVER_MODE = True import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 67e6183b2..d87540a8e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -9,7 +9,8 @@ import sys import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware + +# from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel global_counter = 0 @@ -41,11 +42,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -60,22 +61,22 @@ class ModelWorker: def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, @@ -107,23 +108,23 @@ worker = ModelWorker( ) app = FastAPI() -from pilot.openapi.knowledge.knowledge_controller import router - -app.include_router(router) - -origins = [ - "http://localhost", - "http://localhost:8000", - "http://localhost:3000", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# from pilot.openapi.knowledge.knowledge_controller import router +# +# app.include_router(router) +# +# origins = [ +# "http://localhost", +# "http://localhost:8000", +# "http://localhost:3000", +# ] +# +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) class PromptRequest(BaseModel): diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index c76984c37..33486c439 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -40,6 +40,7 @@ def server_init(args): cfg = Config() from pilot.server.llmserver import worker + worker.start_check() load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) From 295b759ba42b5527ccb1fd04c5819b6c5e7b677a Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 15:27:48 +0800 Subject: [PATCH 04/12] doc:add knowledge backend sql script prepare knowledge backend sql script --- assets/schema/knowledge_management.sql | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 assets/schema/knowledge_management.sql diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql new file mode 100644 index 000000000..df0b09982 --- /dev/null +++ b/assets/schema/knowledge_management.sql @@ -0,0 +1,40 @@ +CREATE TABLE `knowledge_space` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge space table'; + +CREATE TABLE `knowledge_document` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NOT NULL COMMENT 'vector_ids', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document table'; + +CREATE TABLE `document_chunk` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document chunk detail' \ No newline at end of file From 5951215c0b0cc3b75daaa769435a1ee9138f58bb Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 16:07:39 +0800 Subject: [PATCH 05/12] fix:port define can re-define port --- pilot/server/dbgpt_server.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index c4f9ad87e..bdfb48737 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -1,3 +1,4 @@ +import signal import traceback import os import shutil @@ -9,10 +10,7 @@ sys.path.append(ROOT_PATH) from pilot.configs.config import Config from pilot.configs.model_config import ( - DATASETS_DIR, - KNOWLEDGE_UPLOAD_ROOT_PATH, - LLM_MODEL_CONFIG, - LOGDIR, + LOGDIR ) from pilot.utils import build_logger @@ -35,6 +33,11 @@ CFG = Config() logger = build_logger("webserver", LOGDIR + "webserver.log") +def signal_handler(sig, frame): + print("in order to avoid chroma db atexit problem") + os._exit(0) + + def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( *args, @@ -72,9 +75,10 @@ if __name__ == "__main__": # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) + parser.add_argument("--port", type=int, default=5000) parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--share", default=False, action="store_true") + signal.signal(signal.SIGINT, signal_handler) # init server config args = parser.parse_args() @@ -82,4 +86,4 @@ if __name__ == "__main__": CFG.NEW_SERVER_MODE = True import uvicorn - uvicorn.run(app, host="0.0.0.0", port=5000) + uvicorn.run(app, host="0.0.0.0", port=args.port) From b2f1b1319f7a94f95a894cfb32b334ff4989ac6a Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 16:34:12 +0800 Subject: [PATCH 06/12] fix:update knowledge schema sql update knowledge schema sql --- assets/schema/knowledge_management.sql | 5 +++-- pilot/openapi/knowledge/knowledge_document_dao.py | 10 +++++----- pilot/openapi/knowledge/knowledge_service.py | 1 + 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index df0b09982..2c50e42b4 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -18,8 +18,9 @@ CREATE TABLE `knowledge_document` ( `chunk_size` int NOT NULL COMMENT 'chunk size', `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', - `content` LONGTEXT NOT NULL COMMENT 'knowledge content', - `vector_ids` LONGTEXT NOT NULL COMMENT 'vector_ids', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index cad881e71..f99b81a72 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -118,11 +118,11 @@ class KnowledgeDocumentDao: count = knowledge_documents.scalar() return count - # def update_knowledge_document(self, document: KnowledgeDocumentEntity): - # session = self.Session() - # updated_space = session.merge(document) - # session.commit() - # return updated_space.id + def update_knowledge_document(self, document: KnowledgeDocumentEntity): + session = self.Session() + updated_space = session.merge(document) + session.commit() + return updated_space.id # # def delete_knowledge_document(self, document_id: int): # cursor = self.conn.cursor() diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 2f035fcb5..4ad9dcec3 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -76,6 +76,7 @@ class KnowledgeService: status=SyncStatus.TODO.name, last_sync=datetime.now(), content=request.content, + result="", ) knowledge_document_dao.create_knowledge_document(document) return True From 92c5e209cf2c07f4f599c24c7bd551820532aced Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 17:06:15 +0800 Subject: [PATCH 07/12] fix:solve upload file bug solve upload file bug --- pilot/openapi/knowledge/knowledge_controller.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index aec612e9c..26323e098 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -85,6 +85,8 @@ async def document_upload( print(f"/document/upload params: {space_name}") try: if doc_file: + if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)): + os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)) with NamedTemporaryFile( dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False ) as tmp: From 7ad629eb3d5070d6c7d83020bac0ccc7c3bda3f9 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 17:26:05 +0800 Subject: [PATCH 08/12] feat:chunk split method replace chunk split method replace --- pilot/embedding_engine/markdown_embedding.py | 22 +++++++++++++------ pilot/embedding_engine/pdf_embedding.py | 23 ++++++++++++++------ pilot/embedding_engine/ppt_embedding.py | 23 ++++++++++++++------ pilot/embedding_engine/word_embedding.py | 14 ++++++++---- 4 files changed, 57 insertions(+), 25 deletions(-) diff --git a/pilot/embedding_engine/markdown_embedding.py b/pilot/embedding_engine/markdown_embedding.py index e9a97dce9..2bbd20878 100644 --- a/pilot/embedding_engine/markdown_embedding.py +++ b/pilot/embedding_engine/markdown_embedding.py @@ -6,7 +6,7 @@ from typing import List import markdown from bs4 import BeautifulSoup from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register @@ -30,12 +30,20 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=100, - ) - return loader.load_and_split(textsplitter) + # text_splitter = SpacyTextSplitter( + # pipeline="zh_core_web_sm", + # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + # chunk_overlap=100, + # ) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/pdf_embedding.py b/pilot/embedding_engine/pdf_embedding.py index ea4276460..a51eccbda 100644 --- a/pilot/embedding_engine/pdf_embedding.py +++ b/pilot/embedding_engine/pdf_embedding.py @@ -4,10 +4,11 @@ from typing import List from langchain.document_loaders import PyPDFLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -28,12 +29,20 @@ class PDFEmbedding(SourceEmbedding): # textsplitter = CHNDocumentSplitter( # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=100, - ) - return loader.load_and_split(textsplitter) + # textsplitter = SpacyTextSplitter( + # pipeline="zh_core_web_sm", + # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + # chunk_overlap=100, + # ) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 485083d1c..4ff06c6b7 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -4,10 +4,11 @@ from typing import List from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter +from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -25,12 +26,20 @@ class PPTEmbedding(SourceEmbedding): def read(self): """Load from ppt path.""" loader = UnstructuredPowerPointLoader(self.file_path) - textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=200, - ) - return loader.load_and_split(textsplitter) + # textsplitter = SpacyTextSplitter( + # pipeline="zh_core_web_sm", + # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + # chunk_overlap=200, + # ) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 34fc48450..9668700a1 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -4,6 +4,7 @@ from typing import List from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader from langchain.schema import Document +from langchain.text_splitter import CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register @@ -25,10 +26,15 @@ class WordEmbedding(SourceEmbedding): def read(self): """Load from word path.""" loader = UnstructuredWordDocumentLoader(self.file_path) - textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE - ) - return loader.load_and_split(textsplitter) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + return loader.load_and_split(text_splitter) @register def data_process(self, documents: List[Document]): From f29a0a247252e4114ba3a9a7d4f8578bee94edac Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 17:46:53 +0800 Subject: [PATCH 09/12] fix:sync status idempotent when RUNNING AND FINISHED cannot sync --- pilot/openapi/knowledge/knowledge_service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 4ad9dcec3..8f0f184d6 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -115,6 +115,8 @@ class KnowledgeService: space=space_name, ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] + if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: + raise Exception(f"RUNNING AND FINISHED doc:{doc.name} can not sync") client = KnowledgeEmbedding( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), From 96f3dcf095205f143b74b7192179f3c8a893009e Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 18:00:55 +0800 Subject: [PATCH 10/12] fix:sync status idempotent when RUNNING AND FINISHED cannot sync --- pilot/openapi/knowledge/knowledge_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 8f0f184d6..7456e8813 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -116,7 +116,7 @@ class KnowledgeService: ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: - raise Exception(f"RUNNING AND FINISHED doc:{doc.name} can not sync") + raise Exception(f"RUNNING AND FINISHED doc:{doc.doc_name} can not sync") client = KnowledgeEmbedding( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), From ff31d9a22b8aedfb2e17b6e750eb36659264340b Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 18:08:26 +0800 Subject: [PATCH 11/12] fix:sync status idempotent when RUNNING AND FINISHED cannot sync --- pilot/openapi/knowledge/knowledge_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 7456e8813..49e6a0fa3 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -116,7 +116,7 @@ class KnowledgeService: ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: - raise Exception(f"RUNNING AND FINISHED doc:{doc.doc_name} can not sync") + raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync") client = KnowledgeEmbedding( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), From 24130a60973897d41cb8dedbb2f833c2fce4896f Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 29 Jun 2023 18:32:36 +0800 Subject: [PATCH 12/12] fix:use spacy replace chunk method use spacy replace chunk method --- pilot/embedding_engine/markdown_embedding.py | 13 ++++++------- pilot/embedding_engine/pdf_embedding.py | 7 +++++-- pilot/embedding_engine/ppt_embedding.py | 6 +++++- pilot/embedding_engine/word_embedding.py | 11 +++++++---- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pilot/embedding_engine/markdown_embedding.py b/pilot/embedding_engine/markdown_embedding.py index 2bbd20878..0d70ba34f 100644 --- a/pilot/embedding_engine/markdown_embedding.py +++ b/pilot/embedding_engine/markdown_embedding.py @@ -11,7 +11,6 @@ from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -30,11 +29,7 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - # text_splitter = SpacyTextSplitter( - # pipeline="zh_core_web_sm", - # chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - # chunk_overlap=100, - # ) + if CFG.LANGUAGE == "en": text_splitter = CharacterTextSplitter( chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, @@ -42,7 +37,11 @@ class MarkdownEmbedding(SourceEmbedding): length_function=len, ) else: - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) return loader.load_and_split(text_splitter) @register diff --git a/pilot/embedding_engine/pdf_embedding.py b/pilot/embedding_engine/pdf_embedding.py index a51eccbda..2b8f244e3 100644 --- a/pilot/embedding_engine/pdf_embedding.py +++ b/pilot/embedding_engine/pdf_embedding.py @@ -8,7 +8,6 @@ from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -41,7 +40,11 @@ class PDFEmbedding(SourceEmbedding): length_function=len, ) else: - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) return loader.load_and_split(text_splitter) @register diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 4ff06c6b7..da4390849 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -38,7 +38,11 @@ class PPTEmbedding(SourceEmbedding): length_function=len, ) else: - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) return loader.load_and_split(text_splitter) @register diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 9668700a1..55988a240 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -2,13 +2,12 @@ # -*- coding: utf-8 -*- from typing import List -from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader +from langchain.document_loaders import UnstructuredWordDocumentLoader from langchain.schema import Document -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import CharacterTextSplitter, SpacyTextSplitter from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() @@ -33,7 +32,11 @@ class WordEmbedding(SourceEmbedding): length_function=len, ) else: - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) return loader.load_and_split(text_splitter) @register