feat: knowledge backend management

1.knowledge_type.py
2.knowledge backend api
This commit is contained in:
aries_ckt 2023-06-27 15:29:13 +08:00
parent 0dd2e5e12c
commit a06342425b
16 changed files with 153 additions and 100 deletions

View File

@ -5,6 +5,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.embedding_engine.csv_embedding import CSVEmbedding
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
@ -14,16 +15,16 @@ from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
# KnowledgeEmbeddingType = {
# ".txt": (MarkdownEmbedding, {}),
# ".md": (MarkdownEmbedding, {}),
# ".pdf": (PDFEmbedding, {}),
# ".doc": (WordEmbedding, {}),
# ".docx": (WordEmbedding, {}),
# ".csv": (CSVEmbedding, {}),
# ".ppt": (PPTEmbedding, {}),
# ".pptx": (PPTEmbedding, {}),
# }
class KnowledgeEmbedding:
@ -31,14 +32,14 @@ class KnowledgeEmbedding:
self,
model_name,
vector_store_config,
file_type: Optional[str] = "default",
file_path: Optional[str] = None,
knowledge_type: Optional[str],
knowledge_source: Optional[str] = None,
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.knowledge_source = knowledge_source
self.model_name = model_name
self.vector_store_config = vector_store_config
self.file_type = file_type
self.knowledge_type = knowledge_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
@ -55,23 +56,24 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
if self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
vector_store_config=self.vector_store_config,
)
return embedding
extension = "." + self.file_path.rsplit(".", 1)[-1]
if extension in KnowledgeEmbeddingType:
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
embedding = knowledge_class(
self.file_path,
vector_store_config=self.vector_store_config,
**knowledge_args
)
return embedding
raise ValueError(f"Unsupported knowledge file type '{extension}'")
return embedding
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
# if self.file_type == "url":
# embedding = URLEmbedding(
# file_path=self.file_path,
# vector_store_config=self.vector_store_config,
# )
# return embedding
# extension = "." + self.file_path.rsplit(".", 1)[-1]
# if extension in KnowledgeEmbeddingType:
# knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
# embedding = knowledge_class(
# self.file_path,
# vector_store_config=self.vector_store_config,
# **knowledge_args
# )
# return embedding
# raise ValueError(f"Unsupported knowledge file type '{extension}'")
# return embedding
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(

View File

@ -0,0 +1,62 @@
from enum import Enum
from pilot.embedding_engine.csv_embedding import CSVEmbedding
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.url_embedding import URLEmbedding
from pilot.embedding_engine.word_embedding import WordEmbedding
DocumentEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeType(Enum):
DOCUMENT = "DOCUMENT"
URL = "URL"
TEXT = "TEXT"
OSS = "OSS"
NOTION = "NOTION"
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config):
match knowledge_type:
case KnowledgeType.DOCUMENT.value:
extension = "." + knowledge_source.rsplit(".", 1)[-1]
if extension in DocumentEmbeddingType:
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
embedding = knowledge_class(
knowledge_source,
vector_store_config=vector_store_config,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge document type '{extension}'")
case KnowledgeType.URL.value:
embedding = URLEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
)
return embedding
case KnowledgeType.TEXT.value:
embedding = StringEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
)
return embedding
case KnowledgeType.OSS.value:
raise Exception("OSS have not integrate")
case KnowledgeType.NOTION.value:
raise Exception("NOTION have not integrate")
case _:
raise Exception("unknown knowledge type")

View File

@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo
from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene

View File

@ -1,26 +1,22 @@
import json
import os
import sys
from typing import List
from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile
from fastapi import APIRouter
from langchain.embeddings import HuggingFaceEmbeddings
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.server.api_v1.api_view_model import Result
from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.server.knowledge.knowledge_service import KnowledgeService
from pilot.server.knowledge.request.knowledge_request import (
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
)
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
CFG = Config()
router = APIRouter()
@ -74,6 +70,21 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
tmp.write(file.read())
tmp_path = tmp.name
tmp_content = tmp.read()
return {"file_path": tmp_path, "file_content": tmp_content}
Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest):
print(f"Received params: {space_name}, {request}")
@ -90,7 +101,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
Result.succ(knowledge_space_service.get_document_chunks(
return Result.succ(knowledge_space_service.get_document_chunks(
query_request
))
except Exception as e:

View File

@ -9,6 +9,8 @@ from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document'
id = Column(Integer, primary_key=True)

View File

@ -4,14 +4,15 @@ from datetime import datetime
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger
from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.server.knowledge.knowledge_document_dao import (
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.server.knowledge.request.knowledge_request import (
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
)
@ -24,6 +25,7 @@ document_chunk_dao = DocumentChunkDao()
CFG=Config()
class SyncStatus(Enum):
TODO = "TODO"
FAILED = "FAILED"
@ -65,7 +67,7 @@ class KnowledgeService:
chunk_size=0,
status=SyncStatus.TODO.name,
last_sync=datetime.now(),
content="",
content=request.content,
)
knowledge_document_dao.create_knowledge_document(document)
return True
@ -99,8 +101,8 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(file_path=doc.doc_name,
file_type="url",
client = KnowledgeEmbedding(knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
@ -127,11 +129,6 @@ class KnowledgeService:
)
for chunk_doc in chunk_docs]
document_chunk_dao.create_documents_chunks(chunk_entities)
#update document status
# doc.status = SyncStatus.RUNNING.name
# doc.chunk_size = len(chunk_docs)
# doc.gmt_modified = datetime.now()
# knowledge_document_dao.update_knowledge_document(doc)
return True

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from sqlalchemy.orm import sessionmaker
CFG = Config()

View File

@ -29,6 +29,8 @@ class KnowledgeDocumentRequest(BaseModel):
doc_name: str
"""doc_type: doc type"""
doc_type: str
"""content: content"""
content: str
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int

View File

@ -1,3 +1,4 @@
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -44,8 +45,8 @@ class ChatUrlKnowledge(BaseChat):
self.knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
file_type="url",
file_path=url,
knowledge_type=KnowledgeType.URL.value,
knowledge_source=url,
)
# url soruce in vector

View File

@ -13,3 +13,5 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@ -23,7 +23,6 @@ from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter
from knowledge.knowledge_controller import router
CFG = Config()
@ -106,6 +105,7 @@ worker = ModelWorker(
)
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
@ -119,7 +119,7 @@ app.add_middleware(
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_headers=["*"]
)
class PromptRequest(BaseModel):

View File

@ -11,6 +11,8 @@ import uuid
import gradio as gr
from pilot.embedding_engine.knowledge_type import KnowledgeType
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -57,7 +59,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件
CFG = Config()
@ -652,8 +654,9 @@ def knowledge_embedding_store(vs_id, files):
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"],
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeEmbedding:
def __init__(
@ -54,23 +38,7 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
if self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
vector_store_config=self.vector_store_config,
)
return embedding
extension = "." + self.file_path.rsplit(".", 1)[-1]
if extension in KnowledgeEmbeddingType:
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
embedding = knowledge_class(
self.file_path,
vector_store_config=self.vector_store_config,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge file type '{extension}'")
return embedding
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(

View File

@ -4,6 +4,8 @@ import argparse
import os
import sys
from pilot.embedding_engine.knowledge_type import KnowledgeType
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from pilot.configs.config import Config
@ -30,7 +32,8 @@ class LocalKnowledgeInit:
filename = os.path.join(root, file)
# docs = self._load_file(filename)
ke = KnowledgeEmbedding(
file_path=filename,
knowledge_source=filename,
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)