mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 00:37:34 +00:00
feat: knowledge backend management
1.knowledge_type.py 2.knowledge backend api
This commit is contained in:
parent
0dd2e5e12c
commit
a06342425b
@ -5,6 +5,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
||||
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
||||
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
||||
@ -14,16 +15,16 @@ from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
".ppt": (PPTEmbedding, {}),
|
||||
".pptx": (PPTEmbedding, {}),
|
||||
}
|
||||
# KnowledgeEmbeddingType = {
|
||||
# ".txt": (MarkdownEmbedding, {}),
|
||||
# ".md": (MarkdownEmbedding, {}),
|
||||
# ".pdf": (PDFEmbedding, {}),
|
||||
# ".doc": (WordEmbedding, {}),
|
||||
# ".docx": (WordEmbedding, {}),
|
||||
# ".csv": (CSVEmbedding, {}),
|
||||
# ".ppt": (PPTEmbedding, {}),
|
||||
# ".pptx": (PPTEmbedding, {}),
|
||||
# }
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
@ -31,14 +32,14 @@ class KnowledgeEmbedding:
|
||||
self,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
file_type: Optional[str] = "default",
|
||||
file_path: Optional[str] = None,
|
||||
knowledge_type: Optional[str],
|
||||
knowledge_source: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
|
||||
self.knowledge_source = knowledge_source
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.file_type = file_type
|
||||
self.knowledge_type = knowledge_type
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
|
||||
@ -55,23 +56,24 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
if self.file_type == "url":
|
||||
embedding = URLEmbedding(
|
||||
file_path=self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
return embedding
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
**knowledge_args
|
||||
)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
||||
# if self.file_type == "url":
|
||||
# embedding = URLEmbedding(
|
||||
# file_path=self.file_path,
|
||||
# vector_store_config=self.vector_store_config,
|
||||
# )
|
||||
# return embedding
|
||||
# extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
# if extension in KnowledgeEmbeddingType:
|
||||
# knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
# embedding = knowledge_class(
|
||||
# self.file_path,
|
||||
# vector_store_config=self.vector_store_config,
|
||||
# **knowledge_args
|
||||
# )
|
||||
# return embedding
|
||||
# raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
# return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
|
62
pilot/embedding_engine/knowledge_type.py
Normal file
62
pilot/embedding_engine/knowledge_type.py
Normal file
@ -0,0 +1,62 @@
|
||||
from enum import Enum
|
||||
|
||||
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
||||
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
||||
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.embedding_engine.url_embedding import URLEmbedding
|
||||
from pilot.embedding_engine.word_embedding import WordEmbedding
|
||||
|
||||
DocumentEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
".ppt": (PPTEmbedding, {}),
|
||||
".pptx": (PPTEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeType(Enum):
|
||||
DOCUMENT = "DOCUMENT"
|
||||
URL = "URL"
|
||||
TEXT = "TEXT"
|
||||
OSS = "OSS"
|
||||
NOTION = "NOTION"
|
||||
|
||||
|
||||
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config):
|
||||
match knowledge_type:
|
||||
case KnowledgeType.DOCUMENT.value:
|
||||
extension = "." + knowledge_source.rsplit(".", 1)[-1]
|
||||
if extension in DocumentEmbeddingType:
|
||||
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
**knowledge_args,
|
||||
)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge document type '{extension}'")
|
||||
case KnowledgeType.URL.value:
|
||||
embedding = URLEmbedding(
|
||||
file_path=knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
return embedding
|
||||
case KnowledgeType.TEXT.value:
|
||||
embedding = StringEmbedding(
|
||||
file_path=knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
return embedding
|
||||
case KnowledgeType.OSS.value:
|
||||
raise Exception("OSS have not integrate")
|
||||
case KnowledgeType.NOTION.value:
|
||||
raise Exception("NOTION have not integrate")
|
||||
|
||||
case _:
|
||||
raise Exception("unknown knowledge type")
|
@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
@ -1,26 +1,22 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result
|
||||
from pilot.openapi.api_v1.api_view_model import Result
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
from pilot.server.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.server.knowledge.request.knowledge_request import (
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
|
||||
CFG = Config()
|
||||
router = APIRouter()
|
||||
@ -74,6 +70,21 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
return Result.faild(code="E000X", msg=f"document list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
with NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(file.read())
|
||||
tmp_path = tmp.name
|
||||
tmp_content = tmp.read()
|
||||
|
||||
return {"file_path": tmp_path, "file_content": tmp_content}
|
||||
Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
print(f"Received params: {space_name}, {request}")
|
||||
@ -90,7 +101,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
Result.succ(knowledge_space_service.get_document_chunks(
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(
|
||||
query_request
|
||||
))
|
||||
except Exception as e:
|
@ -9,6 +9,8 @@ from pilot.configs.config import Config
|
||||
CFG = Config()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
__tablename__ = 'knowledge_document'
|
||||
id = Column(Integer, primary_key=True)
|
@ -4,14 +4,15 @@ from datetime import datetime
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.logs import logger
|
||||
from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
||||
from pilot.server.knowledge.knowledge_document_dao import (
|
||||
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
||||
from pilot.openapi.knowledge.knowledge_document_dao import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from pilot.server.knowledge.request.knowledge_request import (
|
||||
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
||||
)
|
||||
@ -24,6 +25,7 @@ document_chunk_dao = DocumentChunkDao()
|
||||
|
||||
CFG=Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
TODO = "TODO"
|
||||
FAILED = "FAILED"
|
||||
@ -65,7 +67,7 @@ class KnowledgeService:
|
||||
chunk_size=0,
|
||||
status=SyncStatus.TODO.name,
|
||||
last_sync=datetime.now(),
|
||||
content="",
|
||||
content=request.content,
|
||||
)
|
||||
knowledge_document_dao.create_knowledge_document(document)
|
||||
return True
|
||||
@ -99,8 +101,8 @@ class KnowledgeService:
|
||||
space=space_name,
|
||||
)
|
||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||
client = KnowledgeEmbedding(file_path=doc.doc_name,
|
||||
file_type="url",
|
||||
client = KnowledgeEmbedding(knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
@ -127,11 +129,6 @@ class KnowledgeService:
|
||||
)
|
||||
for chunk_doc in chunk_docs]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
#update document status
|
||||
# doc.status = SyncStatus.RUNNING.name
|
||||
# doc.chunk_size = len(chunk_docs)
|
||||
# doc.gmt_modified = datetime.now()
|
||||
# knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
return True
|
||||
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
CFG = Config()
|
@ -29,6 +29,8 @@ class KnowledgeDocumentRequest(BaseModel):
|
||||
doc_name: str
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str
|
||||
"""content: content"""
|
||||
content: str
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
@ -1,3 +1,4 @@
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -44,8 +45,8 @@ class ChatUrlKnowledge(BaseChat):
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
file_type="url",
|
||||
file_path=url,
|
||||
knowledge_type=KnowledgeType.URL.value,
|
||||
knowledge_source=url,
|
||||
)
|
||||
|
||||
# url soruce in vector
|
||||
|
@ -13,3 +13,5 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
||||
|
||||
|
||||
|
@ -23,7 +23,6 @@ from pilot.configs.model_config import *
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from knowledge.knowledge_controller import router
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -106,6 +105,7 @@ worker = ModelWorker(
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
from pilot.openapi.knowledge.knowledge_controller import router
|
||||
app.include_router(router)
|
||||
|
||||
origins = [
|
||||
@ -119,7 +119,7 @@ app.add_middleware(
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_headers=["*"]
|
||||
)
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
|
@ -11,6 +11,8 @@ import uuid
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
@ -57,7 +59,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
@ -652,8 +654,9 @@ def knowledge_embedding_store(vs_id, files):
|
||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
||||
)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
|
@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.source_embedding.ppt_embedding import PPTEmbedding
|
||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
".ppt": (PPTEmbedding, {}),
|
||||
".pptx": (PPTEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(
|
||||
@ -54,23 +38,7 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
if self.file_type == "url":
|
||||
embedding = URLEmbedding(
|
||||
file_path=self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
return embedding
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
self.file_path,
|
||||
vector_store_config=self.vector_store_config,
|
||||
**knowledge_args,
|
||||
)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
|
@ -4,6 +4,8 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
from pilot.configs.config import Config
|
||||
@ -30,7 +32,8 @@ class LocalKnowledgeInit:
|
||||
filename = os.path.join(root, file)
|
||||
# docs = self._load_file(filename)
|
||||
ke = KnowledgeEmbedding(
|
||||
file_path=filename,
|
||||
knowledge_source=filename,
|
||||
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user