From 6a7c4aa5f6596069e43767d71c614bbde6b63e0e Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 12 Jun 2023 20:57:00 +0800 Subject: [PATCH] feature:ppt embedding --- pilot/scene/chat_knowledge/url/chat.py | 2 +- pilot/source_embedding/knowledge_embedding.py | 18 +++++++-- pilot/source_embedding/markdown_embedding.py | 30 ++------------- pilot/source_embedding/pdf_embedding.py | 2 +- pilot/source_embedding/ppt_embedding.py | 37 +++++++++++++++++++ pilot/source_embedding/source_embedding.py | 19 ++++++---- tools/knowlege_init.py | 5 +-- 7 files changed, 70 insertions(+), 43 deletions(-) create mode 100644 pilot/source_embedding/ppt_embedding.py diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 88dc7ad0b..ce45602a2 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat): ) self.url = url vector_store_config = { - "vector_store_name": url, + "vector_store_name": url.replace(":", ""), "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 7ec0de76c..97b515897 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,11 +1,13 @@ from typing import Optional +from chromadb.errors import NotEnoughElementsException from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding +from pilot.source_embedding.ppt_embedding import PPTEmbedding from pilot.source_embedding.url_embedding import URLEmbedding from pilot.source_embedding.word_embedding import WordEmbedding from pilot.vector_store.connector import VectorStoreConnector @@ -19,6 +21,8 @@ KnowledgeEmbeddingType = { ".doc": (WordEmbedding, {}), ".docx": (WordEmbedding, {}), ".csv": (CSVEmbedding, {}), + ".ppt": (PPTEmbedding, {}), + ".pptx": (PPTEmbedding, {}), } @@ -42,8 +46,12 @@ class KnowledgeEmbedding: self.knowledge_embedding_client = self.init_knowledge_embedding() self.knowledge_embedding_client.source_embedding() - def knowledge_embedding_batch(self): - self.knowledge_embedding_client.batch_embedding() + def knowledge_embedding_batch(self, docs): + # docs = self.knowledge_embedding_client.read_batch() + self.knowledge_embedding_client.index_to_store(docs) + + def read(self): + return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): if self.file_type == "url": @@ -68,7 +76,11 @@ class KnowledgeEmbedding: vector_client = VectorStoreConnector( CFG.VECTOR_STORE_TYPE, self.vector_store_config ) - return vector_client.similar_search(text, topk) + try: + ans = vector_client.similar_search(text, topk) + except NotEnoughElementsException: + ans = vector_client.similar_search(text, 1) + return ans def vector_exist(self): vector_client = VectorStoreConnector( diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index e2851d122..5f6d9526d 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -5,8 +5,8 @@ from typing import List import markdown from bs4 import BeautifulSoup -from langchain.document_loaders import TextLoader from langchain.schema import Document +from langchain.text_splitter import SpacyTextSplitter from pilot.configs.config import Config from pilot.source_embedding import SourceEmbedding, register @@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) - text_splitter = CHNDocumentSplitter( - pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE - ) - return loader.load_and_split(text_splitter) - - @register - def read_batch(self): - """Load from markdown path.""" - docments = [] - for root, _, files in os.walk(self.file_path, topdown=False): - for file in files: - filename = os.path.join(root, file) - loader = TextLoader(filename) - # text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len) - # docs = loader.load_and_split() - docs = loader.load() - # 更新metadata数据 - new_docs = [] - for doc in docs: - doc.metadata = { - "source": doc.metadata["source"].replace(self.file_path, "") - } - print("doc is embedding ... ", doc.metadata) - new_docs.append(doc) - docments += new_docs - return docments + textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + return loader.load_and_split(textsplitter) @register def data_process(self, documents: List[Document]): diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index ae8dde974..66b0963d9 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding): # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # ) textsplitter = SpacyTextSplitter( - pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200 + pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200 ) return loader.load_and_split(textsplitter) diff --git a/pilot/source_embedding/ppt_embedding.py b/pilot/source_embedding/ppt_embedding.py new file mode 100644 index 000000000..869e92395 --- /dev/null +++ b/pilot/source_embedding/ppt_embedding.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import UnstructuredPowerPointLoader +from langchain.schema import Document +from langchain.text_splitter import SpacyTextSplitter + +from pilot.configs.config import Config +from pilot.source_embedding import SourceEmbedding, register + +CFG = Config() + + +class PPTEmbedding(SourceEmbedding): + """ppt embedding for read ppt document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with pdf path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from ppt path.""" + loader = UnstructuredPowerPointLoader(self.file_path) + textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 50c7044f9..3d881fcdf 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- from abc import ABC, abstractmethod from typing import Dict, List, Optional + +from chromadb.errors import NotEnoughElementsException from pilot.configs.config import Config from pilot.vector_store.connector import VectorStoreConnector @@ -62,7 +64,11 @@ class SourceEmbedding(ABC): @register def similar_search(self, doc, topk): """vector store similarity_search""" - return self.vector_client.similar_search(doc, topk) + try: + ans = self.vector_client.similar_search(doc, topk) + except NotEnoughElementsException: + ans = self.vector_client.similar_search(doc, 1) + return ans def vector_name_exist(self): return self.vector_client.vector_name_exists() @@ -79,14 +85,11 @@ class SourceEmbedding(ABC): if "index_to_store" in registered_methods: self.index_to_store(text) - def batch_embedding(self): - if "read_batch" in registered_methods: - text = self.read_batch() + def read_batch(self): + if "read" in registered_methods: + text = self.read() if "data_process" in registered_methods: text = self.data_process(text) if "text_split" in registered_methods: self.text_split(text) - if "text_to_vector" in registered_methods: - self.text_to_vector(text) - if "index_to_store" in registered_methods: - self.index_to_store(text) + return text diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index ff13865b4..26338df1c 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -23,7 +23,7 @@ class LocalKnowledgeInit: self.vector_store_config = vector_store_config self.model_name = LLM_MODEL_CONFIG["text2vec"] - def knowledge_persist(self, file_path, append_mode): + def knowledge_persist(self, file_path): """knowledge persist""" for root, _, files in os.walk(file_path, topdown=False): for file in files: @@ -41,7 +41,6 @@ class LocalKnowledgeInit: if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vector_name", type=str, default="default") - parser.add_argument("--append", type=bool, default=False) args = parser.parse_args() vector_name = args.vector_name append_mode = args.append @@ -49,5 +48,5 @@ if __name__ == "__main__": vector_store_config = {"vector_store_name": vector_name} print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) - kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) + kv.knowledge_persist(file_path=DATASETS_DIR) print("your knowledge embedding success...")