feat:chunks page list

add chunk page list
This commit is contained in:
aries_ckt 2023-06-29 11:21:48 +08:00
parent 219b61caae
commit 1283b24a5f
6 changed files with 93 additions and 25 deletions

View File

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
@ -83,6 +83,30 @@ class DocumentChunkDao:
result = document_chunks.all() result = document_chunks.all()
return result return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None:
document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
count = document_chunks.scalar()
return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity): # def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session() # session = self.Session()
# updated_space = session.merge(document) # updated_space = session.merge(document)

View File

@ -1,11 +1,13 @@
import os
import shutil
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile, Request, Form
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.openapi.api_v1.api_view_model import Result from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
@ -74,18 +76,25 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
async def document_sync(space_name: str, file: UploadFile = File(...)): async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
with NamedTemporaryFile(delete=False) as tmp: if doc_file:
tmp.write(file.read()) with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp:
tmp.write(await doc_file.read())
tmp_path = tmp.name tmp_path = tmp.name
tmp_content = tmp.read() shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename))
request = KnowledgeDocumentRequest()
return {"file_path": tmp_path, "file_content": tmp_content} request.doc_name = doc_name
Result.succ([]) request.doc_type = doc_type
request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
return Result.succ([])
return Result.faild(code="E000X", msg=f"doc_file is None")
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}") return Result.faild(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
@ -92,15 +92,41 @@ class KnowledgeDocumentDao:
result = knowledge_documents.all() result = knowledge_documents.all()
return result return result
def update_knowledge_document(self, document: KnowledgeDocumentEntity): def get_knowledge_documents_count(self, query):
session = self.Session() session = self.Session()
updated_space = session.merge(document) knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
session.commit() if query.id is not None:
return updated_space.id knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
count = knowledge_documents.scalar()
return count
def delete_knowledge_document(self, document_id: int): # def update_knowledge_document(self, document: KnowledgeDocumentEntity):
cursor = self.conn.cursor() # session = self.Session()
query = "DELETE FROM knowledge_document WHERE id = %s" # updated_space = session.merge(document)
cursor.execute(query, (document_id,)) # session.commit()
self.conn.commit() # return updated_space.id
cursor.close() #
# def delete_knowledge_document(self, document_id: int):
# cursor = self.conn.cursor()
# query = "DELETE FROM knowledge_document WHERE id = %s"
# cursor.execute(query, (document_id,))
# self.conn.commit()
# cursor.close()

View File

@ -25,6 +25,7 @@ from pilot.openapi.knowledge.request.knowledge_request import (
) )
from enum import Enum from enum import Enum
from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse
knowledge_space_dao = KnowledgeSpaceDao() knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao() knowledge_document_dao = KnowledgeDocumentDao()
@ -93,9 +94,13 @@ class KnowledgeService:
space=space, space=space,
status=request.status, status=request.status,
) )
return knowledge_document_dao.get_knowledge_documents( res = DocumentQueryResponse()
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size query, page=request.page, page_size=request.page_size
) )
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
return res
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
@ -164,9 +169,13 @@ class KnowledgeService:
doc_name=request.doc_name, doc_name=request.doc_name,
doc_type=request.doc_type, doc_type=request.doc_type,
) )
return document_chunk_dao.get_document_chunks( res = ChunkQueryResponse()
res.data = document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size query, page=request.page, page_size=request.page_size
) )
res.total = document_chunk_dao.get_document_chunks_count(query)
res.page = request.page
return res
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
logger.info( logger.info(