feature:ppt embedding

This commit is contained in:
aries-ckt 2023-06-12 20:57:00 +08:00
parent 46af29dd36
commit 6a7c4aa5f6
7 changed files with 70 additions and 43 deletions

View File

@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat):
)
self.url = url
vector_store_config = {
"vector_store_name": url,
"vector_store_name": url.replace(":", ""),
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(

View File

@ -1,11 +1,13 @@
from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector
@ -19,6 +21,8 @@ KnowledgeEmbeddingType = {
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
@ -42,8 +46,12 @@ class KnowledgeEmbedding:
self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self):
self.knowledge_embedding_client.batch_embedding()
def knowledge_embedding_batch(self, docs):
# docs = self.knowledge_embedding_client.read_batch()
self.knowledge_embedding_client.index_to_store(docs)
def read(self):
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
if self.file_type == "url":
@ -68,7 +76,11 @@ class KnowledgeEmbedding:
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
return vector_client.similar_search(text, topk)
try:
ans = vector_client.similar_search(text, topk)
except NotEnoughElementsException:
ans = vector_client.similar_search(text, 1)
return ans
def vector_exist(self):
vector_client = VectorStoreConnector(

View File

@ -5,8 +5,8 @@ from typing import List
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self):
"""Load from markdown path."""
loader = EncodeTextLoader(self.file_path)
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(text_splitter)
@register
def read_batch(self):
"""Load from markdown path."""
docments = []
for root, _, files in os.walk(self.file_path, topdown=False):
for file in files:
filename = os.path.join(root, file)
loader = TextLoader(filename)
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
# docs = loader.load_and_split()
docs = loader.load()
# 更新metadata数据
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(self.file_path, "")
}
print("doc is embedding ... ", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):

View File

@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
)
return loader.load_and_split(textsplitter)

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
CFG = Config()
class PPTEmbedding(SourceEmbedding):
"""ppt embedding for read ppt document."""
def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path."""
super().__init__(file_path, vector_store_config)
self.file_path = file_path
self.vector_store_config = vector_store_config
@register
def read(self):
"""Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException
from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector
@ -62,7 +64,11 @@ class SourceEmbedding(ABC):
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_client.similar_search(doc, topk)
try:
ans = self.vector_client.similar_search(doc, topk)
except NotEnoughElementsException:
ans = self.vector_client.similar_search(doc, 1)
return ans
def vector_name_exist(self):
return self.vector_client.vector_name_exists()
@ -79,14 +85,11 @@ class SourceEmbedding(ABC):
if "index_to_store" in registered_methods:
self.index_to_store(text)
def batch_embedding(self):
if "read_batch" in registered_methods:
text = self.read_batch()
def read_batch(self):
if "read" in registered_methods:
text = self.read()
if "data_process" in registered_methods:
text = self.data_process(text)
if "text_split" in registered_methods:
self.text_split(text)
if "text_to_vector" in registered_methods:
self.text_to_vector(text)
if "index_to_store" in registered_methods:
self.index_to_store(text)
return text

View File

@ -23,7 +23,7 @@ class LocalKnowledgeInit:
self.vector_store_config = vector_store_config
self.model_name = LLM_MODEL_CONFIG["text2vec"]
def knowledge_persist(self, file_path, append_mode):
def knowledge_persist(self, file_path):
"""knowledge persist"""
for root, _, files in os.walk(file_path, topdown=False):
for file in files:
@ -41,7 +41,6 @@ class LocalKnowledgeInit:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default")
parser.add_argument("--append", type=bool, default=False)
args = parser.parse_args()
vector_name = args.vector_name
append_mode = args.append
@ -49,5 +48,5 @@ if __name__ == "__main__":
vector_store_config = {"vector_store_name": vector_name}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
kv.knowledge_persist(file_path=DATASETS_DIR)
print("your knowledge embedding success...")