mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 21:02:08 +00:00
feature:ppt embedding
This commit is contained in:
parent
46af29dd36
commit
6a7c4aa5f6
@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": url,
|
"vector_store_name": url.replace(":", ""),
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from chromadb.errors import NotEnoughElementsException
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||||
|
from pilot.source_embedding.ppt_embedding import PPTEmbedding
|
||||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
@ -19,6 +21,8 @@ KnowledgeEmbeddingType = {
|
|||||||
".doc": (WordEmbedding, {}),
|
".doc": (WordEmbedding, {}),
|
||||||
".docx": (WordEmbedding, {}),
|
".docx": (WordEmbedding, {}),
|
||||||
".csv": (CSVEmbedding, {}),
|
".csv": (CSVEmbedding, {}),
|
||||||
|
".ppt": (PPTEmbedding, {}),
|
||||||
|
".pptx": (PPTEmbedding, {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -42,8 +46,12 @@ class KnowledgeEmbedding:
|
|||||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||||
self.knowledge_embedding_client.source_embedding()
|
self.knowledge_embedding_client.source_embedding()
|
||||||
|
|
||||||
def knowledge_embedding_batch(self):
|
def knowledge_embedding_batch(self, docs):
|
||||||
self.knowledge_embedding_client.batch_embedding()
|
# docs = self.knowledge_embedding_client.read_batch()
|
||||||
|
self.knowledge_embedding_client.index_to_store(docs)
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
if self.file_type == "url":
|
if self.file_type == "url":
|
||||||
@ -68,7 +76,11 @@ class KnowledgeEmbedding:
|
|||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
)
|
)
|
||||||
return vector_client.similar_search(text, topk)
|
try:
|
||||||
|
ans = vector_client.similar_search(text, topk)
|
||||||
|
except NotEnoughElementsException:
|
||||||
|
ans = vector_client.similar_search(text, 1)
|
||||||
|
return ans
|
||||||
|
|
||||||
def vector_exist(self):
|
def vector_exist(self):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
|
@ -5,8 +5,8 @@ from typing import List
|
|||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders import TextLoader
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import SpacyTextSplitter
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.source_embedding import SourceEmbedding, register
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding):
|
|||||||
def read(self):
|
def read(self):
|
||||||
"""Load from markdown path."""
|
"""Load from markdown path."""
|
||||||
loader = EncodeTextLoader(self.file_path)
|
loader = EncodeTextLoader(self.file_path)
|
||||||
text_splitter = CHNDocumentSplitter(
|
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
return loader.load_and_split(textsplitter)
|
||||||
)
|
|
||||||
return loader.load_and_split(text_splitter)
|
|
||||||
|
|
||||||
@register
|
|
||||||
def read_batch(self):
|
|
||||||
"""Load from markdown path."""
|
|
||||||
docments = []
|
|
||||||
for root, _, files in os.walk(self.file_path, topdown=False):
|
|
||||||
for file in files:
|
|
||||||
filename = os.path.join(root, file)
|
|
||||||
loader = TextLoader(filename)
|
|
||||||
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
|
|
||||||
# docs = loader.load_and_split()
|
|
||||||
docs = loader.load()
|
|
||||||
# 更新metadata数据
|
|
||||||
new_docs = []
|
|
||||||
for doc in docs:
|
|
||||||
doc.metadata = {
|
|
||||||
"source": doc.metadata["source"].replace(self.file_path, "")
|
|
||||||
}
|
|
||||||
print("doc is embedding ... ", doc.metadata)
|
|
||||||
new_docs.append(doc)
|
|
||||||
docments += new_docs
|
|
||||||
return docments
|
|
||||||
|
|
||||||
@register
|
@register
|
||||||
def data_process(self, documents: List[Document]):
|
def data_process(self, documents: List[Document]):
|
||||||
|
@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding):
|
|||||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
# )
|
# )
|
||||||
textsplitter = SpacyTextSplitter(
|
textsplitter = SpacyTextSplitter(
|
||||||
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
|
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
|
||||||
)
|
)
|
||||||
return loader.load_and_split(textsplitter)
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
37
pilot/source_embedding/ppt_embedding.py
Normal file
37
pilot/source_embedding/ppt_embedding.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders import UnstructuredPowerPointLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import SpacyTextSplitter
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class PPTEmbedding(SourceEmbedding):
|
||||||
|
"""ppt embedding for read ppt document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with pdf path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from ppt path."""
|
||||||
|
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||||
|
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
|
||||||
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
@ -2,6 +2,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from chromadb.errors import NotEnoughElementsException
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
@ -62,7 +64,11 @@ class SourceEmbedding(ABC):
|
|||||||
@register
|
@register
|
||||||
def similar_search(self, doc, topk):
|
def similar_search(self, doc, topk):
|
||||||
"""vector store similarity_search"""
|
"""vector store similarity_search"""
|
||||||
return self.vector_client.similar_search(doc, topk)
|
try:
|
||||||
|
ans = self.vector_client.similar_search(doc, topk)
|
||||||
|
except NotEnoughElementsException:
|
||||||
|
ans = self.vector_client.similar_search(doc, 1)
|
||||||
|
return ans
|
||||||
|
|
||||||
def vector_name_exist(self):
|
def vector_name_exist(self):
|
||||||
return self.vector_client.vector_name_exists()
|
return self.vector_client.vector_name_exists()
|
||||||
@ -79,14 +85,11 @@ class SourceEmbedding(ABC):
|
|||||||
if "index_to_store" in registered_methods:
|
if "index_to_store" in registered_methods:
|
||||||
self.index_to_store(text)
|
self.index_to_store(text)
|
||||||
|
|
||||||
def batch_embedding(self):
|
def read_batch(self):
|
||||||
if "read_batch" in registered_methods:
|
if "read" in registered_methods:
|
||||||
text = self.read_batch()
|
text = self.read()
|
||||||
if "data_process" in registered_methods:
|
if "data_process" in registered_methods:
|
||||||
text = self.data_process(text)
|
text = self.data_process(text)
|
||||||
if "text_split" in registered_methods:
|
if "text_split" in registered_methods:
|
||||||
self.text_split(text)
|
self.text_split(text)
|
||||||
if "text_to_vector" in registered_methods:
|
return text
|
||||||
self.text_to_vector(text)
|
|
||||||
if "index_to_store" in registered_methods:
|
|
||||||
self.index_to_store(text)
|
|
||||||
|
@ -23,7 +23,7 @@ class LocalKnowledgeInit:
|
|||||||
self.vector_store_config = vector_store_config
|
self.vector_store_config = vector_store_config
|
||||||
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||||
|
|
||||||
def knowledge_persist(self, file_path, append_mode):
|
def knowledge_persist(self, file_path):
|
||||||
"""knowledge persist"""
|
"""knowledge persist"""
|
||||||
for root, _, files in os.walk(file_path, topdown=False):
|
for root, _, files in os.walk(file_path, topdown=False):
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -41,7 +41,6 @@ class LocalKnowledgeInit:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--vector_name", type=str, default="default")
|
parser.add_argument("--vector_name", type=str, default="default")
|
||||||
parser.add_argument("--append", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vector_name = args.vector_name
|
vector_name = args.vector_name
|
||||||
append_mode = args.append
|
append_mode = args.append
|
||||||
@ -49,5 +48,5 @@ if __name__ == "__main__":
|
|||||||
vector_store_config = {"vector_store_name": vector_name}
|
vector_store_config = {"vector_store_name": vector_name}
|
||||||
print(vector_store_config)
|
print(vector_store_config)
|
||||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||||
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
kv.knowledge_persist(file_path=DATASETS_DIR)
|
||||||
print("your knowledge embedding success...")
|
print("your knowledge embedding success...")
|
||||||
|
Loading…
Reference in New Issue
Block a user