From db28894443ebb6b7ff02fcabcc5abada84a68ef0 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 26 Jun 2023 15:24:25 +0800 Subject: [PATCH] feat: knowledge management backend api 1.create knowledge space 2.list knowledge space 3.create knowledge document 4.list knowledge document 5.save document chunks 6.sync embedding document --- pilot/embedding_engine/knowledge_embedding.py | 5 +- pilot/embedding_engine/source_embedding.py | 2 +- pilot/scene/chat_knowledge/url/chat.py | 2 +- pilot/server/api_v1/api_view_model.py | 12 +- pilot/server/knowledge/document_chunk_dao.py | 83 +++++++++ .../server/knowledge/knowledge_controller.py | 111 +++++++++++ .../knowledge/knowledge_document_dao.py | 87 +++++++++ pilot/server/knowledge/knowledge_service.py | 173 ++++++++++++++++++ pilot/server/knowledge/knowledge_space_dao.py | 82 +++++++++ .../knowledge/request/knowledge_request.py | 74 ++++++++ pilot/server/llmserver.py | 17 ++ pilot/vector_store/chroma_store.py | 7 +- pilot/vector_store/connector.py | 5 +- 13 files changed, 648 insertions(+), 12 deletions(-) create mode 100644 pilot/server/knowledge/document_chunk_dao.py create mode 100644 pilot/server/knowledge/knowledge_controller.py create mode 100644 pilot/server/knowledge/knowledge_document_dao.py create mode 100644 pilot/server/knowledge/knowledge_service.py create mode 100644 pilot/server/knowledge/knowledge_space_dao.py create mode 100644 pilot/server/knowledge/request/knowledge_request.py diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/knowledge_embedding.py index 2d7780510..132b360fc 100644 --- a/pilot/embedding_engine/knowledge_embedding.py +++ b/pilot/embedding_engine/knowledge_embedding.py @@ -48,9 +48,10 @@ class KnowledgeEmbedding: def knowledge_embedding_batch(self, docs): # docs = self.knowledge_embedding_client.read_batch() - self.knowledge_embedding_client.index_to_store(docs) + return self.knowledge_embedding_client.index_to_store(docs) def read(self): + self.knowledge_embedding_client = self.init_knowledge_embedding() return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): @@ -66,7 +67,7 @@ class KnowledgeEmbedding: embedding = knowledge_class( self.file_path, vector_store_config=self.vector_store_config, - **knowledge_args, + **knowledge_args ) return embedding raise ValueError(f"Unsupported knowledge file type '{extension}'") diff --git a/pilot/embedding_engine/source_embedding.py b/pilot/embedding_engine/source_embedding.py index 3d881fcdf..b99529cf9 100644 --- a/pilot/embedding_engine/source_embedding.py +++ b/pilot/embedding_engine/source_embedding.py @@ -59,7 +59,7 @@ class SourceEmbedding(ABC): @register def index_to_store(self, docs): """index to vector store""" - self.vector_client.load_document(docs) + return self.vector_client.load_document(docs) @register def similar_search(self, doc, topk): diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 57fb8b618..acf7bbaeb 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -42,7 +42,7 @@ class ChatUrlKnowledge(BaseChat): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - model_name=LLM_MODEL_CONFIG["text2vec"], + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, file_type="url", file_path=url, diff --git a/pilot/server/api_v1/api_view_model.py b/pilot/server/api_v1/api_view_model.py index 938ce22ec..e58e581a9 100644 --- a/pilot/server/api_v1/api_view_model.py +++ b/pilot/server/api_v1/api_view_model.py @@ -6,21 +6,21 @@ T = TypeVar('T') class Result(Generic[T], BaseModel): success: bool - err_code: str - err_msg: str - data: List[T] + err_code: str = None + err_msg: str = None + data: List[T] = None @classmethod def succ(cls, data: List[T]): - return Result(True, None, None, data) + return Result(success=True, err_code=None, err_msg=None, data=data) @classmethod def faild(cls, msg): - return Result(True, "E000X", msg, None) + return Result(success=False, err_code="E000X", err_msg=msg, data=None) @classmethod def faild(cls, code, msg): - return Result(True, code, msg, None) + return Result(success=False, err_code=code, err_msg=msg, data=None) class ConversationVo(BaseModel): diff --git a/pilot/server/knowledge/document_chunk_dao.py b/pilot/server/knowledge/document_chunk_dao.py new file mode 100644 index 000000000..e9a994e66 --- /dev/null +++ b/pilot/server/knowledge/document_chunk_dao.py @@ -0,0 +1,83 @@ +from datetime import datetime +from typing import List + +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy.orm import declarative_base, sessionmaker + +from pilot.configs.config import Config + + +CFG = Config() + +Base = declarative_base() +class DocumentChunkEntity(Base): + __tablename__ = 'document_chunk' + id = Column(Integer, primary_key=True) + document_id = Column(Integer) + doc_name = Column(String(100)) + doc_type = Column(String(100)) + content = Column(Text) + meta_info = Column(String(500)) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class DocumentChunkDao: + def __init__(self): + database = "knowledge_management" + self.db_engine = create_engine( + f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', + echo=True) + self.Session = sessionmaker(bind=self.db_engine) + + def create_documents_chunks(self, documents:List): + session = self.Session() + docs = [ + DocumentChunkEntity( + doc_name=document.doc_name, + doc_type=document.doc_type, + document_id=document.document_id, + content=document.content or "", + meta_info=document.meta_info or "", + gmt_created=datetime.now(), + gmt_modified=datetime.now() + ) + for document in documents] + session.add_all(docs) + session.commit() + session.close() + + def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20): + session = self.Session() + document_chunks = session.query(DocumentChunkEntity) + if query.id is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) + if query.document_id is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id) + if query.doc_type is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type) + if query.doc_name is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name) + if query.meta_info is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info) + + document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc()) + document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size) + result = document_chunks.all() + return result + + # def update_knowledge_document(self, document:KnowledgeDocumentEntity): + # session = self.Session() + # updated_space = session.merge(document) + # session.commit() + # return updated_space.id + + # def delete_knowledge_document(self, document_id:int): + # cursor = self.conn.cursor() + # query = "DELETE FROM knowledge_document WHERE id = %s" + # cursor.execute(query, (document_id,)) + # self.conn.commit() + # cursor.close() diff --git a/pilot/server/knowledge/knowledge_controller.py b/pilot/server/knowledge/knowledge_controller.py new file mode 100644 index 000000000..123e9238b --- /dev/null +++ b/pilot/server/knowledge/knowledge_controller.py @@ -0,0 +1,111 @@ +import json +import os +import sys +from typing import List + +from fastapi import APIRouter +from langchain.embeddings import HuggingFaceEmbeddings +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from pilot.configs.config import Config +from pilot.configs.model_config import LLM_MODEL_CONFIG + +from pilot.server.api_v1.api_view_model import Result +from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding + +from pilot.server.knowledge.knowledge_service import KnowledgeService +from pilot.server.knowledge.request.knowledge_request import ( + KnowledgeQueryRequest, + KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, +) + +from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest + +CFG = Config() +router = APIRouter() + + +embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]) + +knowledge_space_service = KnowledgeService() + + +@router.post("/knowledge/space/add") +def space_add(request: KnowledgeSpaceRequest): + print(f"/space/add params: {request}") + try: + knowledge_space_service.create_knowledge_space(request) + return Result.succ([]) + except Exception as e: + return Result.faild(code="E000X", msg=f"space add error {e}") + + +@router.post("/knowledge/space/list") +def space_list(request: KnowledgeSpaceRequest): + print(f"/space/list params:") + try: + return Result.succ(knowledge_space_service.get_knowledge_space(request)) + except Exception as e: + return Result.faild(code="E000X", msg=f"space list error {e}") + + +@router.post("/knowledge/{space_name}/document/add") +def document_add(space_name: str, request: KnowledgeDocumentRequest): + print(f"/document/add params: {space_name}, {request}") + try: + knowledge_space_service.create_knowledge_document( + space=space_name, request=request + ) + return Result.succ([]) + except Exception as e: + return Result.faild(code="E000X", msg=f"document add error {e}") + + +@router.post("/knowledge/{space_name}/document/list") +def document_list(space_name: str, query_request: DocumentQueryRequest): + print(f"/document/list params: {space_name}, {query_request}") + try: + return Result.succ(knowledge_space_service.get_knowledge_documents( + space_name, + query_request + )) + except Exception as e: + return Result.faild(code="E000X", msg=f"document list error {e}") + + +@router.post("/knowledge/{space_name}/document/sync") +def document_sync(space_name: str, request: DocumentSyncRequest): + print(f"Received params: {space_name}, {request}") + try: + knowledge_space_service.sync_knowledge_document( + space_name=space_name, doc_ids=request.doc_ids + ) + Result.succ([]) + except Exception as e: + return Result.faild(code="E000X", msg=f"document sync error {e}") + + +@router.post("/knowledge/{space_name}/chunk/list") +def document_list(space_name: str, query_request: ChunkQueryRequest): + print(f"/document/list params: {space_name}, {query_request}") + try: + Result.succ(knowledge_space_service.get_document_chunks( + query_request + )) + except Exception as e: + return Result.faild(code="E000X", msg=f"document chunk list error {e}") + + +@router.post("/knowledge/{vector_name}/query") +def similar_query(space_name: str, query_request: KnowledgeQueryRequest): + print(f"Received params: {space_name}, {query_request}") + client = KnowledgeEmbedding( + model_name=embeddings, vector_store_config={"vector_store_name": space_name} + ) + docs = client.similar_search(query_request.query, query_request.top_k) + res = [ + KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"]) + for d in docs + ] + return {"response": res} diff --git a/pilot/server/knowledge/knowledge_document_dao.py b/pilot/server/knowledge/knowledge_document_dao.py new file mode 100644 index 000000000..497ad6b21 --- /dev/null +++ b/pilot/server/knowledge/knowledge_document_dao.py @@ -0,0 +1,87 @@ +from datetime import datetime + +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy.orm import declarative_base, sessionmaker + +from pilot.configs.config import Config + + +CFG = Config() + +Base = declarative_base() +class KnowledgeDocumentEntity(Base): + __tablename__ = 'knowledge_document' + id = Column(Integer, primary_key=True) + doc_name = Column(String(100)) + doc_type = Column(String(100)) + space = Column(String(100)) + chunk_size = Column(Integer) + status = Column(String(100)) + last_sync = Column(String(100)) + content = Column(Text) + vector_ids = Column(Text) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class KnowledgeDocumentDao: + def __init__(self): + database = "knowledge_management" + self.db_engine = create_engine( + f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', + echo=True) + self.Session = sessionmaker(bind=self.db_engine) + + def create_knowledge_document(self, document:KnowledgeDocumentEntity): + session = self.Session() + knowledge_document = KnowledgeDocumentEntity( + doc_name=document.doc_name, + doc_type=document.doc_type, + space=document.space, + chunk_size=0.0, + status=document.status, + last_sync=document.last_sync, + content=document.content or "", + vector_ids=document.vector_ids, + gmt_created=datetime.now(), + gmt_modified=datetime.now() + ) + session.add(knowledge_document) + session.commit() + + session.close() + + def get_knowledge_documents(self, query, page=1, page_size=20): + session = self.Session() + knowledge_documents = session.query(KnowledgeDocumentEntity) + if query.id is not None: + knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id) + if query.doc_name is not None: + knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name) + if query.doc_type is not None: + knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type) + if query.space is not None: + knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space) + if query.status is not None: + knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status) + + knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc()) + knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size) + result = knowledge_documents.all() + return result + + def update_knowledge_document(self, document:KnowledgeDocumentEntity): + session = self.Session() + updated_space = session.merge(document) + session.commit() + return updated_space.id + + def delete_knowledge_document(self, document_id:int): + cursor = self.conn.cursor() + query = "DELETE FROM knowledge_document WHERE id = %s" + cursor.execute(query, (document_id,)) + self.conn.commit() + cursor.close() diff --git a/pilot/server/knowledge/knowledge_service.py b/pilot/server/knowledge/knowledge_service.py new file mode 100644 index 000000000..7372276bd --- /dev/null +++ b/pilot/server/knowledge/knowledge_service.py @@ -0,0 +1,173 @@ +import threading +from datetime import datetime + +from pilot.configs.config import Config +from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.logs import logger +from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao +from pilot.server.knowledge.knowledge_document_dao import ( + KnowledgeDocumentDao, + KnowledgeDocumentEntity, +) +from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity +from pilot.server.knowledge.request.knowledge_request import ( + KnowledgeSpaceRequest, + KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, +) +from enum import Enum + + +knowledge_space_dao = KnowledgeSpaceDao() +knowledge_document_dao = KnowledgeDocumentDao() +document_chunk_dao = DocumentChunkDao() + +CFG=Config() + +class SyncStatus(Enum): + TODO = "TODO" + FAILED = "FAILED" + RUNNING = "RUNNING" + FINISHED = "FINISHED" + + +# @singleton +class KnowledgeService: + def __init__(self): + pass + + """create knowledge space""" + + def create_knowledge_space(self, request: KnowledgeSpaceRequest): + query = KnowledgeSpaceEntity( + name=request.name, + ) + spaces = knowledge_space_dao.get_knowledge_space(query) + if len(spaces) > 0: + raise Exception(f"space name:{request.name} have already named") + knowledge_space_dao.create_knowledge_space(request) + return True + + """create knowledge document""" + + def create_knowledge_document(self, space, request: KnowledgeDocumentRequest): + query = KnowledgeDocumentEntity( + doc_name=request.doc_name, + space=space + ) + documents = knowledge_document_dao.get_knowledge_documents(query) + if len(documents) > 0: + raise Exception(f"document name:{request.doc_name} have already named") + document = KnowledgeDocumentEntity( + doc_name=request.doc_name, + doc_type=request.doc_type, + space=space, + chunk_size=0, + status=SyncStatus.TODO.name, + last_sync=datetime.now(), + content="", + ) + knowledge_document_dao.create_knowledge_document(document) + return True + + """get knowledge space""" + + def get_knowledge_space(self, request:KnowledgeSpaceRequest): + query = KnowledgeSpaceEntity( + name=request.name, + vector_type=request.vector_type, + owner=request.owner + ) + return knowledge_space_dao.get_knowledge_space(query) + + """get knowledge get_knowledge_documents""" + + def get_knowledge_documents(self, space, request:DocumentQueryRequest): + query = KnowledgeDocumentEntity( + doc_name=request.doc_name, + doc_type=request.doc_type, + space=space, + status=request.status, + ) + return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size) + + """sync knowledge document chunk into vector store""" + def sync_knowledge_document(self, space_name, doc_ids): + for doc_id in doc_ids: + query = KnowledgeDocumentEntity( + id=doc_id, + space=space_name, + ) + doc = knowledge_document_dao.get_knowledge_documents(query)[0] + client = KnowledgeEmbedding(file_path=doc.doc_name, + file_type="url", + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config={ + "vector_store_name": space_name, + }) + chunk_docs = client.read() + # update document status + doc.status = SyncStatus.RUNNING.name + doc.chunk_size = len(chunk_docs) + doc.gmt_modified = datetime.now() + knowledge_document_dao.update_knowledge_document(doc) + # async doc embeddings + thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc)) + thread.start() + #save chunk details + chunk_entities = [ + DocumentChunkEntity( + doc_name=doc.doc_name, + doc_type=doc.doc_type, + document_id=doc.id, + content=chunk_doc.page_content, + meta_info=str(chunk_doc.metadata), + gmt_created=datetime.now(), + gmt_modified=datetime.now() + ) + for chunk_doc in chunk_docs] + document_chunk_dao.create_documents_chunks(chunk_entities) + #update document status + # doc.status = SyncStatus.RUNNING.name + # doc.chunk_size = len(chunk_docs) + # doc.gmt_modified = datetime.now() + # knowledge_document_dao.update_knowledge_document(doc) + + return True + + """update knowledge space""" + + def update_knowledge_space( + self, space_id: int, space_request: KnowledgeSpaceRequest + ): + knowledge_space_dao.update_knowledge_space(space_id, space_request) + + """delete knowledge space""" + + def delete_knowledge_space(self, space_id: int): + return knowledge_space_dao.delete_knowledge_space(space_id) + + """get document chunks""" + def get_document_chunks(self, request:ChunkQueryRequest): + query = DocumentChunkEntity( + id=request.id, + document_id=request.document_id, + doc_name=request.doc_name, + doc_type=request.doc_type + ) + return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size) + + def async_doc_embedding(self, client, chunk_docs, doc): + logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}") + try: + vector_ids = client.knowledge_embedding_batch(chunk_docs) + doc.status = SyncStatus.FINISHED.name + doc.content = "embedding success" + doc.vector_ids = ",".join(vector_ids) + except Exception as e: + doc.status = SyncStatus.FAILED.name + doc.content = str(e) + + return knowledge_document_dao.update_knowledge_document(doc) + + diff --git a/pilot/server/knowledge/knowledge_space_dao.py b/pilot/server/knowledge/knowledge_space_dao.py new file mode 100644 index 000000000..20bb578a8 --- /dev/null +++ b/pilot/server/knowledge/knowledge_space_dao.py @@ -0,0 +1,82 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer, String, DateTime, create_engine +from sqlalchemy.ext.declarative import declarative_base + +from pilot.configs.config import Config + +from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest +from sqlalchemy.orm import sessionmaker + +CFG = Config() +Base = declarative_base() +class KnowledgeSpaceEntity(Base): + __tablename__ = 'knowledge_space' + id = Column(Integer, primary_key=True) + name = Column(String(100)) + vector_type = Column(String(100)) + desc = Column(String(100)) + owner = Column(String(100)) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class KnowledgeSpaceDao: + def __init__(self): + database = "knowledge_management" + self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True) + self.Session = sessionmaker(bind=self.db_engine) + + def create_knowledge_space(self, space:KnowledgeSpaceRequest): + session = self.Session() + knowledge_space = KnowledgeSpaceEntity( + name=space.name, + vector_type=space.vector_type, + desc=space.desc, + owner=space.owner, + gmt_created=datetime.now(), + gmt_modified=datetime.now() + ) + session.add(knowledge_space) + session.commit() + + session.close() + + def get_knowledge_space(self, query:KnowledgeSpaceEntity): + session = self.Session() + knowledge_spaces = session.query(KnowledgeSpaceEntity) + if query.id is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id) + if query.name is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name) + if query.vector_type is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type) + if query.desc is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc) + if query.owner is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner) + if query.gmt_created is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created) + if query.gmt_modified is not None: + knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified) + + knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc()) + result = knowledge_spaces.all() + return result + + def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity): + cursor = self.conn.cursor() + query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s" + cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id)) + self.conn.commit() + cursor.close() + + def delete_knowledge_space(self, space_id:int): + cursor = self.conn.cursor() + query = "DELETE FROM knowledge_space WHERE id = %s" + cursor.execute(query, (space_id,)) + self.conn.commit() + cursor.close() diff --git a/pilot/server/knowledge/request/knowledge_request.py b/pilot/server/knowledge/request/knowledge_request.py new file mode 100644 index 000000000..039d84727 --- /dev/null +++ b/pilot/server/knowledge/request/knowledge_request.py @@ -0,0 +1,74 @@ +from typing import List + +from pydantic import BaseModel + + +class KnowledgeQueryRequest(BaseModel): + """query: knowledge query""" + + query: str + """top_k: return topK documents""" + top_k: int + + +class KnowledgeSpaceRequest(BaseModel): + """name: knowledge space name""" + + name: str = None + """vector_type: vector type""" + vector_type: str = None + """desc: description""" + desc: str = None + """owner: owner""" + owner: str = None + + +class KnowledgeDocumentRequest(BaseModel): + """doc_name: doc path""" + + doc_name: str + """doc_type: doc type""" + doc_type: str + """text_chunk_size: text_chunk_size""" + # text_chunk_size: int + +class DocumentQueryRequest(BaseModel): + """doc_name: doc path""" + doc_name: str = None + """doc_type: doc type""" + doc_type: str= None + """status: status""" + status: str= None + """page: page""" + page: int = 1 + """page_size: page size""" + page_size: int = 20 + + +class DocumentSyncRequest(BaseModel): + """doc_ids: doc ids""" + doc_ids: List + +class ChunkQueryRequest(BaseModel): + """id: id""" + id: int = None + """document_id: doc id""" + document_id: int = None + """doc_name: doc path""" + doc_name: str = None + """doc_type: doc type""" + doc_type: str = None + """page: page""" + page: int = 1 + """page_size: page size""" + page_size: int = 20 + + +class KnowledgeQueryResponse: + """source: knowledge reference source""" + + source: str + """score: knowledge vector query similarity score""" + score: float = 0.0 + """text: raw text info""" + text: str diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 1e3a4dcb3..1e6278e84 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -9,6 +9,7 @@ import sys import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel global_counter = 0 @@ -22,10 +23,12 @@ from pilot.configs.model_config import * from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.loader import ModelLoader from pilot.server.chat_adapter import get_llm_chat_adapter +from knowledge.knowledge_controller import router CFG = Config() + class ModelWorker: def __init__(self, model_path, model_name, device, num_gpus=1): if model_path.endswith("/"): @@ -103,7 +106,21 @@ worker = ModelWorker( ) app = FastAPI() +app.include_router(router) +origins = [ + "http://localhost", + "http://localhost:8000", + "http://localhost:3000", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) class PromptRequest(BaseModel): prompt: str diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index 4949924d4..35016aa09 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -32,5 +32,10 @@ class ChromaStore(VectorStoreBase): logger.info("ChromaStore load document") texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] - self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) + ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) self.vector_store_client.persist() + return ids + + def delete_by_ids(self, ids): + collection = self.vector_store_client._collection + collection.delete(ids=ids) diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 482f43007..8eecb74e0 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -16,7 +16,7 @@ class VectorStoreConnector: def load_document(self, docs): """load document in vector database.""" - self.client.load_document(docs) + return self.client.load_document(docs) def similar_search(self, docs, topk): """similar search in vector database.""" @@ -25,3 +25,6 @@ class VectorStoreConnector: def vector_name_exists(self): """is vector store name exist.""" return self.client.vector_name_exists() + + def delete_by_ids(self, ids): + self.client.delete_by_ids(ids=ids)