mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 21:02:08 +00:00
feature:ppt embedding
This commit is contained in:
parent
46af29dd36
commit
6a7c4aa5f6
@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat):
|
||||
)
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
"vector_store_name": url.replace(":", ""),
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
|
@ -1,11 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.source_embedding.ppt_embedding import PPTEmbedding
|
||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
@ -19,6 +21,8 @@ KnowledgeEmbeddingType = {
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
".ppt": (PPTEmbedding, {}),
|
||||
".pptx": (PPTEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
@ -42,8 +46,12 @@ class KnowledgeEmbedding:
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self):
|
||||
self.knowledge_embedding_client.batch_embedding()
|
||||
def knowledge_embedding_batch(self, docs):
|
||||
# docs = self.knowledge_embedding_client.read_batch()
|
||||
self.knowledge_embedding_client.index_to_store(docs)
|
||||
|
||||
def read(self):
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
if self.file_type == "url":
|
||||
@ -68,7 +76,11 @@ class KnowledgeEmbedding:
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.similar_search(text, topk)
|
||||
try:
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = vector_client.similar_search(text, 1)
|
||||
return ans
|
||||
|
||||
def vector_exist(self):
|
||||
vector_client = VectorStoreConnector(
|
||||
|
@ -5,8 +5,8 @@ from typing import List
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def read_batch(self):
|
||||
"""Load from markdown path."""
|
||||
docments = []
|
||||
for root, _, files in os.walk(self.file_path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
loader = TextLoader(filename)
|
||||
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
|
||||
# docs = loader.load_and_split()
|
||||
docs = loader.load()
|
||||
# 更新metadata数据
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"source": doc.metadata["source"].replace(self.file_path, "")
|
||||
}
|
||||
print("doc is embedding ... ", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += new_docs
|
||||
return docments
|
||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
# )
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
|
||||
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
|
37
pilot/source_embedding/ppt_embedding.py
Normal file
37
pilot/source_embedding/ppt_embedding.py
Normal file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders import UnstructuredPowerPointLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PPTEmbedding(SourceEmbedding):
|
||||
"""ppt embedding for read ppt document."""
|
||||
|
||||
def __init__(self, file_path, vector_store_config):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from ppt path."""
|
||||
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -2,6 +2,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from pilot.configs.config import Config
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@ -62,7 +64,11 @@ class SourceEmbedding(ABC):
|
||||
@register
|
||||
def similar_search(self, doc, topk):
|
||||
"""vector store similarity_search"""
|
||||
return self.vector_client.similar_search(doc, topk)
|
||||
try:
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = self.vector_client.similar_search(doc, 1)
|
||||
return ans
|
||||
|
||||
def vector_name_exist(self):
|
||||
return self.vector_client.vector_name_exists()
|
||||
@ -79,14 +85,11 @@ class SourceEmbedding(ABC):
|
||||
if "index_to_store" in registered_methods:
|
||||
self.index_to_store(text)
|
||||
|
||||
def batch_embedding(self):
|
||||
if "read_batch" in registered_methods:
|
||||
text = self.read_batch()
|
||||
def read_batch(self):
|
||||
if "read" in registered_methods:
|
||||
text = self.read()
|
||||
if "data_process" in registered_methods:
|
||||
text = self.data_process(text)
|
||||
if "text_split" in registered_methods:
|
||||
self.text_split(text)
|
||||
if "text_to_vector" in registered_methods:
|
||||
self.text_to_vector(text)
|
||||
if "index_to_store" in registered_methods:
|
||||
self.index_to_store(text)
|
||||
return text
|
||||
|
@ -23,7 +23,7 @@ class LocalKnowledgeInit:
|
||||
self.vector_store_config = vector_store_config
|
||||
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
|
||||
def knowledge_persist(self, file_path, append_mode):
|
||||
def knowledge_persist(self, file_path):
|
||||
"""knowledge persist"""
|
||||
for root, _, files in os.walk(file_path, topdown=False):
|
||||
for file in files:
|
||||
@ -41,7 +41,6 @@ class LocalKnowledgeInit:
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
parser.add_argument("--append", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
append_mode = args.append
|
||||
@ -49,5 +48,5 @@ if __name__ == "__main__":
|
||||
vector_store_config = {"vector_store_name": vector_name}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR)
|
||||
print("your knowledge embedding success...")
|
||||
|
Loading…
Reference in New Issue
Block a user