diff --git a/pilot/openapi/knowledge/__init__.py b/pilot/openapi/knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/openapi/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py index cb728e85c..d67f4a01a 100644 --- a/pilot/openapi/knowledge/document_chunk_dao.py +++ b/pilot/openapi/knowledge/document_chunk_dao.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -83,6 +83,30 @@ class DocumentChunkDao: result = document_chunks.all() return result + def get_document_chunks_count(self, query: DocumentChunkEntity): + session = self.Session() + document_chunks = session.query(func.count(DocumentChunkEntity.id)) + if query.id is not None: + document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) + if query.document_id is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.document_id == query.document_id + ) + if query.doc_type is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_type == query.doc_type + ) + if query.doc_name is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_name == query.doc_name + ) + if query.meta_info is not None: + document_chunks = document_chunks.filter( + DocumentChunkEntity.meta_info == query.meta_info + ) + count = document_chunks.scalar() + return count + # def update_knowledge_document(self, document:KnowledgeDocumentEntity): # session = self.Session() # updated_space = session.merge(document) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index bebbc8a3f..1b452119a 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -1,11 +1,13 @@ +import os +import shutil from tempfile import NamedTemporaryFile -from fastapi import APIRouter, File, UploadFile +from fastapi import APIRouter, File, UploadFile, Request, Form from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.openapi.api_v1.api_view_model import Result from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding @@ -74,18 +76,25 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/document/upload") -async def document_sync(space_name: str, file: UploadFile = File(...)): +async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)): print(f"/document/upload params: {space_name}") try: - with NamedTemporaryFile(delete=False) as tmp: - tmp.write(file.read()) - tmp_path = tmp.name - tmp_content = tmp.read() - - return {"file_path": tmp_path, "file_content": tmp_content} - Result.succ([]) + if doc_file: + with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp: + tmp.write(await doc_file.read()) + tmp_path = tmp.name + shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename)) + request = KnowledgeDocumentRequest() + request.doc_name = doc_name + request.doc_type = doc_type + request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename), + knowledge_space_service.create_knowledge_document( + space=space_name, request=request + ) + return Result.succ([]) + return Result.faild(code="E000X", msg=f"doc_file is None") except Exception as e: - return Result.faild(code="E000X", msg=f"document sync error {e}") + return Result.faild(code="E000X", msg=f"document add error {e}") @router.post("/knowledge/{space_name}/document/sync") diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index 1f7afc401..cad881e71 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine +from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy.orm import declarative_base, sessionmaker from pilot.configs.config import Config @@ -92,15 +92,41 @@ class KnowledgeDocumentDao: result = knowledge_documents.all() return result - def update_knowledge_document(self, document: KnowledgeDocumentEntity): + def get_knowledge_documents_count(self, query): session = self.Session() - updated_space = session.merge(document) - session.commit() - return updated_space.id + knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) + if query.id is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.id == query.id + ) + if query.doc_name is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_name == query.doc_name + ) + if query.doc_type is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_type == query.doc_type + ) + if query.space is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.space == query.space + ) + if query.status is not None: + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.status == query.status + ) + count = knowledge_documents.scalar() + return count - def delete_knowledge_document(self, document_id: int): - cursor = self.conn.cursor() - query = "DELETE FROM knowledge_document WHERE id = %s" - cursor.execute(query, (document_id,)) - self.conn.commit() - cursor.close() + # def update_knowledge_document(self, document: KnowledgeDocumentEntity): + # session = self.Session() + # updated_space = session.merge(document) + # session.commit() + # return updated_space.id + # + # def delete_knowledge_document(self, document_id: int): + # cursor = self.conn.cursor() + # query = "DELETE FROM knowledge_document WHERE id = %s" + # cursor.execute(query, (document_id,)) + # self.conn.commit() + # cursor.close() diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index c41630ede..9ee9f3c40 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -25,6 +25,7 @@ from pilot.openapi.knowledge.request.knowledge_request import ( ) from enum import Enum +from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -93,9 +94,13 @@ class KnowledgeService: space=space, status=request.status, ) - return knowledge_document_dao.get_knowledge_documents( + res = DocumentQueryResponse() + res.data = knowledge_document_dao.get_knowledge_documents( query, page=request.page, page_size=request.page_size ) + res.total = knowledge_document_dao.get_knowledge_documents_count(query) + res.page = request.page + return res """sync knowledge document chunk into vector store""" @@ -164,9 +169,13 @@ class KnowledgeService: doc_name=request.doc_name, doc_type=request.doc_type, ) - return document_chunk_dao.get_document_chunks( + res = ChunkQueryResponse() + res.data = document_chunk_dao.get_document_chunks( query, page=request.page, page_size=request.page_size ) + res.total = document_chunk_dao.get_document_chunks_count(query) + res.page = request.page + return res def async_doc_embedding(self, client, chunk_docs, doc): logger.info( diff --git a/pilot/openapi/knowledge/request/__init__.py b/pilot/openapi/knowledge/request/__init__.py new file mode 100644 index 000000000..e69de29bb