feat: knowledge backend management

1.knowledge_type.py
2.knowledge backend api
This commit is contained in:
aries_ckt 2023-06-27 15:29:13 +08:00
parent 0dd2e5e12c
commit a06342425b
16 changed files with 153 additions and 100 deletions

View File

@ -5,6 +5,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine.csv_embedding import CSVEmbedding from pilot.embedding_engine.csv_embedding import CSVEmbedding
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
from pilot.embedding_engine.pdf_embedding import PDFEmbedding from pilot.embedding_engine.pdf_embedding import PDFEmbedding
from pilot.embedding_engine.ppt_embedding import PPTEmbedding from pilot.embedding_engine.ppt_embedding import PPTEmbedding
@ -14,16 +15,16 @@ from pilot.vector_store.connector import VectorStoreConnector
CFG = Config() CFG = Config()
KnowledgeEmbeddingType = { # KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}), # ".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}), # ".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}), # ".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}), # ".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}), # ".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}), # ".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}), # ".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}), # ".pptx": (PPTEmbedding, {}),
} # }
class KnowledgeEmbedding: class KnowledgeEmbedding:
@ -31,14 +32,14 @@ class KnowledgeEmbedding:
self, self,
model_name, model_name,
vector_store_config, vector_store_config,
file_type: Optional[str] = "default", knowledge_type: Optional[str],
file_path: Optional[str] = None, knowledge_source: Optional[str] = None,
): ):
"""Initialize with Loader url, model_name, vector_store_config""" """Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.file_path = file_path self.knowledge_source = knowledge_source
self.model_name = model_name self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.file_type = file_type self.knowledge_type = knowledge_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings self.vector_store_config["embeddings"] = self.embeddings
@ -55,23 +56,24 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
if self.file_type == "url": return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
embedding = URLEmbedding( # if self.file_type == "url":
file_path=self.file_path, # embedding = URLEmbedding(
vector_store_config=self.vector_store_config, # file_path=self.file_path,
) # vector_store_config=self.vector_store_config,
return embedding # )
extension = "." + self.file_path.rsplit(".", 1)[-1] # return embedding
if extension in KnowledgeEmbeddingType: # extension = "." + self.file_path.rsplit(".", 1)[-1]
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] # if extension in KnowledgeEmbeddingType:
embedding = knowledge_class( # knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
self.file_path, # embedding = knowledge_class(
vector_store_config=self.vector_store_config, # self.file_path,
**knowledge_args # vector_store_config=self.vector_store_config,
) # **knowledge_args
return embedding # )
raise ValueError(f"Unsupported knowledge file type '{extension}'") # return embedding
return embedding # raise ValueError(f"Unsupported knowledge file type '{extension}'")
# return embedding
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(

View File

@ -0,0 +1,62 @@
from enum import Enum
from pilot.embedding_engine.csv_embedding import CSVEmbedding
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.url_embedding import URLEmbedding
from pilot.embedding_engine.word_embedding import WordEmbedding
DocumentEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeType(Enum):
DOCUMENT = "DOCUMENT"
URL = "URL"
TEXT = "TEXT"
OSS = "OSS"
NOTION = "NOTION"
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config):
match knowledge_type:
case KnowledgeType.DOCUMENT.value:
extension = "." + knowledge_source.rsplit(".", 1)[-1]
if extension in DocumentEmbeddingType:
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
embedding = knowledge_class(
knowledge_source,
vector_store_config=vector_store_config,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge document type '{extension}'")
case KnowledgeType.URL.value:
embedding = URLEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
)
return embedding
case KnowledgeType.TEXT.value:
embedding = StringEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
)
return embedding
case KnowledgeType.OSS.value:
raise Exception("OSS have not integrate")
case KnowledgeType.NOTION.value:
raise Exception("NOTION have not integrate")
case _:
raise Exception("unknown knowledge type")

View File

@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene

View File

@ -1,26 +1,22 @@
import json from tempfile import NamedTemporaryFile
import os
import sys from fastapi import APIRouter, File, UploadFile
from typing import List
from fastapi import APIRouter
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.server.api_v1.api_view_model import Result from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.server.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.server.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest, KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
) )
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
CFG = Config() CFG = Config()
router = APIRouter() router = APIRouter()
@ -74,6 +70,21 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
return Result.faild(code="E000X", msg=f"document list error {e}") return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
tmp.write(file.read())
tmp_path = tmp.name
tmp_content = tmp.read()
return {"file_path": tmp_path, "file_content": tmp_content}
Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest): def document_sync(space_name: str, request: DocumentSyncRequest):
print(f"Received params: {space_name}, {request}") print(f"Received params: {space_name}, {request}")
@ -90,7 +101,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest): def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}") print(f"/document/list params: {space_name}, {query_request}")
try: try:
Result.succ(knowledge_space_service.get_document_chunks( return Result.succ(knowledge_space_service.get_document_chunks(
query_request query_request
)) ))
except Exception as e: except Exception as e:

View File

@ -9,6 +9,8 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document' __tablename__ = 'knowledge_document'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)

View File

@ -4,14 +4,15 @@ from datetime import datetime
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger from pilot.logs import logger
from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.server.knowledge.knowledge_document_dao import ( from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao, KnowledgeDocumentDao,
KnowledgeDocumentEntity, KnowledgeDocumentEntity,
) )
from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.server.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest, KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
) )
@ -24,6 +25,7 @@ document_chunk_dao = DocumentChunkDao()
CFG=Config() CFG=Config()
class SyncStatus(Enum): class SyncStatus(Enum):
TODO = "TODO" TODO = "TODO"
FAILED = "FAILED" FAILED = "FAILED"
@ -65,7 +67,7 @@ class KnowledgeService:
chunk_size=0, chunk_size=0,
status=SyncStatus.TODO.name, status=SyncStatus.TODO.name,
last_sync=datetime.now(), last_sync=datetime.now(),
content="", content=request.content,
) )
knowledge_document_dao.create_knowledge_document(document) knowledge_document_dao.create_knowledge_document(document)
return True return True
@ -99,8 +101,8 @@ class KnowledgeService:
space=space_name, space=space_name,
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(file_path=doc.doc_name, client = KnowledgeEmbedding(knowledge_source=doc.content,
file_type="url", knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={ vector_store_config={
"vector_store_name": space_name, "vector_store_name": space_name,
@ -127,11 +129,6 @@ class KnowledgeService:
) )
for chunk_doc in chunk_docs] for chunk_doc in chunk_docs]
document_chunk_dao.create_documents_chunks(chunk_entities) document_chunk_dao.create_documents_chunks(chunk_entities)
#update document status
# doc.status = SyncStatus.RUNNING.name
# doc.chunk_size = len(chunk_docs)
# doc.gmt_modified = datetime.now()
# knowledge_document_dao.update_knowledge_document(doc)
return True return True

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
CFG = Config() CFG = Config()

View File

@ -29,6 +29,8 @@ class KnowledgeDocumentRequest(BaseModel):
doc_name: str doc_name: str
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str doc_type: str
"""content: content"""
content: str
"""text_chunk_size: text_chunk_size""" """text_chunk_size: text_chunk_size"""
# text_chunk_size: int # text_chunk_size: int

View File

@ -1,3 +1,4 @@
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -44,8 +45,8 @@ class ChatUrlKnowledge(BaseChat):
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
file_type="url", knowledge_type=KnowledgeType.URL.value,
file_path=url, knowledge_source=url,
) )
# url soruce in vector # url soruce in vector

View File

@ -13,3 +13,5 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True) load_dotenv(verbose=True, override=True)
del load_dotenv del load_dotenv

View File

@ -23,7 +23,6 @@ from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.server.chat_adapter import get_llm_chat_adapter
from knowledge.knowledge_controller import router
CFG = Config() CFG = Config()
@ -106,6 +105,7 @@ worker = ModelWorker(
) )
app = FastAPI() app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router) app.include_router(router)
origins = [ origins = [
@ -119,7 +119,7 @@ app.add_middleware(
allow_origins=origins, allow_origins=origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"]
) )
class PromptRequest(BaseModel): class PromptRequest(BaseModel):

View File

@ -11,6 +11,8 @@ import uuid
import gradio as gr import gradio as gr
from pilot.embedding_engine.knowledge_type import KnowledgeType
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
@ -57,7 +59,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件 # 加载插件
CFG = Config() CFG = Config()
@ -652,8 +654,9 @@ def knowledge_embedding_store(vs_id, files):
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
) )
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"], knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={ vector_store_config={
"vector_store_name": vector_store_name["vs_name"], "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
CFG = Config() CFG = Config()
KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeEmbedding: class KnowledgeEmbedding:
def __init__( def __init__(
@ -54,23 +38,7 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
if self.file_type == "url": return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
embedding = URLEmbedding(
file_path=self.file_path,
vector_store_config=self.vector_store_config,
)
return embedding
extension = "." + self.file_path.rsplit(".", 1)[-1]
if extension in KnowledgeEmbeddingType:
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
embedding = knowledge_class(
self.file_path,
vector_store_config=self.vector_store_config,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge file type '{extension}'")
return embedding
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(

View File

@ -4,6 +4,8 @@ import argparse
import os import os
import sys import sys
from pilot.embedding_engine.knowledge_type import KnowledgeType
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from pilot.configs.config import Config from pilot.configs.config import Config
@ -30,7 +32,8 @@ class LocalKnowledgeInit:
filename = os.path.join(root, file) filename = os.path.join(root, file)
# docs = self._load_file(filename) # docs = self._load_file(filename)
ke = KnowledgeEmbedding( ke = KnowledgeEmbedding(
file_path=filename, knowledge_source=filename,
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=self.model_name, model_name=self.model_name,
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
) )