feat: knowledge backend management

1.knowledge_type.py
2.knowledge backend api
This commit is contained in:
aries_ckt
2023-06-27 15:29:13 +08:00
parent 0dd2e5e12c
commit a06342425b
16 changed files with 153 additions and 100 deletions

View File

@@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
KnowledgeEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeEmbedding:
def __init__(
@@ -54,23 +38,7 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
if self.file_type == "url":
embedding = URLEmbedding(
file_path=self.file_path,
vector_store_config=self.vector_store_config,
)
return embedding
extension = "." + self.file_path.rsplit(".", 1)[-1]
if extension in KnowledgeEmbeddingType:
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
embedding = knowledge_class(
self.file_path,
vector_store_config=self.vector_store_config,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge file type '{extension}'")
return embedding
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(