mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 00:37:34 +00:00
feat: knowledge backend management
1.knowledge_type.py 2.knowledge backend api
This commit is contained in:
parent
0dd2e5e12c
commit
a06342425b
@ -5,6 +5,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
||||||
|
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||||
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
||||||
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
||||||
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
||||||
@ -14,16 +15,16 @@ from pilot.vector_store.connector import VectorStoreConnector
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
KnowledgeEmbeddingType = {
|
# KnowledgeEmbeddingType = {
|
||||||
".txt": (MarkdownEmbedding, {}),
|
# ".txt": (MarkdownEmbedding, {}),
|
||||||
".md": (MarkdownEmbedding, {}),
|
# ".md": (MarkdownEmbedding, {}),
|
||||||
".pdf": (PDFEmbedding, {}),
|
# ".pdf": (PDFEmbedding, {}),
|
||||||
".doc": (WordEmbedding, {}),
|
# ".doc": (WordEmbedding, {}),
|
||||||
".docx": (WordEmbedding, {}),
|
# ".docx": (WordEmbedding, {}),
|
||||||
".csv": (CSVEmbedding, {}),
|
# ".csv": (CSVEmbedding, {}),
|
||||||
".ppt": (PPTEmbedding, {}),
|
# ".ppt": (PPTEmbedding, {}),
|
||||||
".pptx": (PPTEmbedding, {}),
|
# ".pptx": (PPTEmbedding, {}),
|
||||||
}
|
# }
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
@ -31,14 +32,14 @@ class KnowledgeEmbedding:
|
|||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
vector_store_config,
|
vector_store_config,
|
||||||
file_type: Optional[str] = "default",
|
knowledge_type: Optional[str],
|
||||||
file_path: Optional[str] = None,
|
knowledge_source: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
|
||||||
self.file_path = file_path
|
self.knowledge_source = knowledge_source
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.vector_store_config = vector_store_config
|
self.vector_store_config = vector_store_config
|
||||||
self.file_type = file_type
|
self.knowledge_type = knowledge_type
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
self.vector_store_config["embeddings"] = self.embeddings
|
self.vector_store_config["embeddings"] = self.embeddings
|
||||||
|
|
||||||
@ -55,23 +56,24 @@ class KnowledgeEmbedding:
|
|||||||
return self.knowledge_embedding_client.read_batch()
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
if self.file_type == "url":
|
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
||||||
embedding = URLEmbedding(
|
# if self.file_type == "url":
|
||||||
file_path=self.file_path,
|
# embedding = URLEmbedding(
|
||||||
vector_store_config=self.vector_store_config,
|
# file_path=self.file_path,
|
||||||
)
|
# vector_store_config=self.vector_store_config,
|
||||||
return embedding
|
# )
|
||||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
# return embedding
|
||||||
if extension in KnowledgeEmbeddingType:
|
# extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
# if extension in KnowledgeEmbeddingType:
|
||||||
embedding = knowledge_class(
|
# knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||||
self.file_path,
|
# embedding = knowledge_class(
|
||||||
vector_store_config=self.vector_store_config,
|
# self.file_path,
|
||||||
**knowledge_args
|
# vector_store_config=self.vector_store_config,
|
||||||
)
|
# **knowledge_args
|
||||||
return embedding
|
# )
|
||||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
# return embedding
|
||||||
return embedding
|
# raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||||
|
# return embedding
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
|
62
pilot/embedding_engine/knowledge_type.py
Normal file
62
pilot/embedding_engine/knowledge_type.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
||||||
|
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
||||||
|
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
||||||
|
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
||||||
|
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||||
|
from pilot.embedding_engine.url_embedding import URLEmbedding
|
||||||
|
from pilot.embedding_engine.word_embedding import WordEmbedding
|
||||||
|
|
||||||
|
DocumentEmbeddingType = {
|
||||||
|
".txt": (MarkdownEmbedding, {}),
|
||||||
|
".md": (MarkdownEmbedding, {}),
|
||||||
|
".pdf": (PDFEmbedding, {}),
|
||||||
|
".doc": (WordEmbedding, {}),
|
||||||
|
".docx": (WordEmbedding, {}),
|
||||||
|
".csv": (CSVEmbedding, {}),
|
||||||
|
".ppt": (PPTEmbedding, {}),
|
||||||
|
".pptx": (PPTEmbedding, {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeType(Enum):
|
||||||
|
DOCUMENT = "DOCUMENT"
|
||||||
|
URL = "URL"
|
||||||
|
TEXT = "TEXT"
|
||||||
|
OSS = "OSS"
|
||||||
|
NOTION = "NOTION"
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config):
|
||||||
|
match knowledge_type:
|
||||||
|
case KnowledgeType.DOCUMENT.value:
|
||||||
|
extension = "." + knowledge_source.rsplit(".", 1)[-1]
|
||||||
|
if extension in DocumentEmbeddingType:
|
||||||
|
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
|
||||||
|
embedding = knowledge_class(
|
||||||
|
knowledge_source,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
**knowledge_args,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
raise ValueError(f"Unsupported knowledge document type '{extension}'")
|
||||||
|
case KnowledgeType.URL.value:
|
||||||
|
embedding = URLEmbedding(
|
||||||
|
file_path=knowledge_source,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
case KnowledgeType.TEXT.value:
|
||||||
|
embedding = StringEmbedding(
|
||||||
|
file_path=knowledge_source,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
case KnowledgeType.OSS.value:
|
||||||
|
raise Exception("OSS have not integrate")
|
||||||
|
case KnowledgeType.NOTION.value:
|
||||||
|
raise Exception("NOTION have not integrate")
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise Exception("unknown knowledge type")
|
@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
@ -1,26 +1,22 @@
|
|||||||
import json
|
from tempfile import NamedTemporaryFile
|
||||||
import os
|
|
||||||
import sys
|
from fastapi import APIRouter, File, UploadFile
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
sys.path.append(ROOT_PATH)
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result
|
from pilot.openapi.api_v1.api_view_model import Result
|
||||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
|
||||||
from pilot.server.knowledge.knowledge_service import KnowledgeService
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
from pilot.server.knowledge.request.knowledge_request import (
|
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||||
KnowledgeQueryRequest,
|
KnowledgeQueryRequest,
|
||||||
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -74,6 +70,21 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
|
|||||||
return Result.faild(code="E000X", msg=f"document list error {e}")
|
return Result.faild(code="E000X", msg=f"document list error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge/{space_name}/document/upload")
|
||||||
|
def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||||
|
print(f"/document/upload params: {space_name}")
|
||||||
|
try:
|
||||||
|
with NamedTemporaryFile(delete=False) as tmp:
|
||||||
|
tmp.write(file.read())
|
||||||
|
tmp_path = tmp.name
|
||||||
|
tmp_content = tmp.read()
|
||||||
|
|
||||||
|
return {"file_path": tmp_path, "file_content": tmp_content}
|
||||||
|
Result.succ([])
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/sync")
|
@router.post("/knowledge/{space_name}/document/sync")
|
||||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||||
print(f"Received params: {space_name}, {request}")
|
print(f"Received params: {space_name}, {request}")
|
||||||
@ -90,7 +101,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
|||||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||||
print(f"/document/list params: {space_name}, {query_request}")
|
print(f"/document/list params: {space_name}, {query_request}")
|
||||||
try:
|
try:
|
||||||
Result.succ(knowledge_space_service.get_document_chunks(
|
return Result.succ(knowledge_space_service.get_document_chunks(
|
||||||
query_request
|
query_request
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
@ -9,6 +9,8 @@ from pilot.configs.config import Config
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentEntity(Base):
|
class KnowledgeDocumentEntity(Base):
|
||||||
__tablename__ = 'knowledge_document'
|
__tablename__ = 'knowledge_document'
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
@ -4,14 +4,15 @@ from datetime import datetime
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
||||||
from pilot.server.knowledge.knowledge_document_dao import (
|
from pilot.openapi.knowledge.knowledge_document_dao import (
|
||||||
KnowledgeDocumentDao,
|
KnowledgeDocumentDao,
|
||||||
KnowledgeDocumentEntity,
|
KnowledgeDocumentEntity,
|
||||||
)
|
)
|
||||||
from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||||
from pilot.server.knowledge.request.knowledge_request import (
|
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||||
KnowledgeSpaceRequest,
|
KnowledgeSpaceRequest,
|
||||||
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
||||||
)
|
)
|
||||||
@ -24,6 +25,7 @@ document_chunk_dao = DocumentChunkDao()
|
|||||||
|
|
||||||
CFG=Config()
|
CFG=Config()
|
||||||
|
|
||||||
|
|
||||||
class SyncStatus(Enum):
|
class SyncStatus(Enum):
|
||||||
TODO = "TODO"
|
TODO = "TODO"
|
||||||
FAILED = "FAILED"
|
FAILED = "FAILED"
|
||||||
@ -65,7 +67,7 @@ class KnowledgeService:
|
|||||||
chunk_size=0,
|
chunk_size=0,
|
||||||
status=SyncStatus.TODO.name,
|
status=SyncStatus.TODO.name,
|
||||||
last_sync=datetime.now(),
|
last_sync=datetime.now(),
|
||||||
content="",
|
content=request.content,
|
||||||
)
|
)
|
||||||
knowledge_document_dao.create_knowledge_document(document)
|
knowledge_document_dao.create_knowledge_document(document)
|
||||||
return True
|
return True
|
||||||
@ -99,8 +101,8 @@ class KnowledgeService:
|
|||||||
space=space_name,
|
space=space_name,
|
||||||
)
|
)
|
||||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||||
client = KnowledgeEmbedding(file_path=doc.doc_name,
|
client = KnowledgeEmbedding(knowledge_source=doc.content,
|
||||||
file_type="url",
|
knowledge_type=doc.doc_type.upper(),
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": space_name,
|
"vector_store_name": space_name,
|
||||||
@ -127,11 +129,6 @@ class KnowledgeService:
|
|||||||
)
|
)
|
||||||
for chunk_doc in chunk_docs]
|
for chunk_doc in chunk_docs]
|
||||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||||
#update document status
|
|
||||||
# doc.status = SyncStatus.RUNNING.name
|
|
||||||
# doc.chunk_size = len(chunk_docs)
|
|
||||||
# doc.gmt_modified = datetime.now()
|
|
||||||
# knowledge_document_dao.update_knowledge_document(doc)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
@ -29,6 +29,8 @@ class KnowledgeDocumentRequest(BaseModel):
|
|||||||
doc_name: str
|
doc_name: str
|
||||||
"""doc_type: doc type"""
|
"""doc_type: doc type"""
|
||||||
doc_type: str
|
doc_type: str
|
||||||
|
"""content: content"""
|
||||||
|
content: str
|
||||||
"""text_chunk_size: text_chunk_size"""
|
"""text_chunk_size: text_chunk_size"""
|
||||||
# text_chunk_size: int
|
# text_chunk_size: int
|
||||||
|
|
@ -1,3 +1,4 @@
|
|||||||
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -44,8 +45,8 @@ class ChatUrlKnowledge(BaseChat):
|
|||||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
file_type="url",
|
knowledge_type=KnowledgeType.URL.value,
|
||||||
file_path=url,
|
knowledge_source=url,
|
||||||
)
|
)
|
||||||
|
|
||||||
# url soruce in vector
|
# url soruce in vector
|
||||||
|
@ -13,3 +13,5 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
|||||||
load_dotenv(verbose=True, override=True)
|
load_dotenv(verbose=True, override=True)
|
||||||
|
|
||||||
del load_dotenv
|
del load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ from pilot.configs.model_config import *
|
|||||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||||
from knowledge.knowledge_controller import router
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -106,6 +105,7 @@ worker = ModelWorker(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
from pilot.openapi.knowledge.knowledge_controller import router
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
@ -119,7 +119,7 @@ app.add_middleware(
|
|||||||
allow_origins=origins,
|
allow_origins=origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
|
@ -11,6 +11,8 @@ import uuid
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -652,8 +654,9 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
||||||
)
|
)
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||||
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException
|
|||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
|
||||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
|
||||||
from pilot.source_embedding.ppt_embedding import PPTEmbedding
|
|
||||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
|
||||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
KnowledgeEmbeddingType = {
|
|
||||||
".txt": (MarkdownEmbedding, {}),
|
|
||||||
".md": (MarkdownEmbedding, {}),
|
|
||||||
".pdf": (PDFEmbedding, {}),
|
|
||||||
".doc": (WordEmbedding, {}),
|
|
||||||
".docx": (WordEmbedding, {}),
|
|
||||||
".csv": (CSVEmbedding, {}),
|
|
||||||
".ppt": (PPTEmbedding, {}),
|
|
||||||
".pptx": (PPTEmbedding, {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -54,23 +38,7 @@ class KnowledgeEmbedding:
|
|||||||
return self.knowledge_embedding_client.read_batch()
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
if self.file_type == "url":
|
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
|
||||||
embedding = URLEmbedding(
|
|
||||||
file_path=self.file_path,
|
|
||||||
vector_store_config=self.vector_store_config,
|
|
||||||
)
|
|
||||||
return embedding
|
|
||||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
|
||||||
if extension in KnowledgeEmbeddingType:
|
|
||||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
|
||||||
embedding = knowledge_class(
|
|
||||||
self.file_path,
|
|
||||||
vector_store_config=self.vector_store_config,
|
|
||||||
**knowledge_args,
|
|
||||||
)
|
|
||||||
return embedding
|
|
||||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
|
@ -4,6 +4,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -30,7 +32,8 @@ class LocalKnowledgeInit:
|
|||||||
filename = os.path.join(root, file)
|
filename = os.path.join(root, file)
|
||||||
# docs = self._load_file(filename)
|
# docs = self._load_file(filename)
|
||||||
ke = KnowledgeEmbedding(
|
ke = KnowledgeEmbedding(
|
||||||
file_path=filename,
|
knowledge_source=filename,
|
||||||
|
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
vector_store_config=self.vector_store_config,
|
vector_store_config=self.vector_store_config,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user