feature:ppt embedding

This commit is contained in:
aries-ckt 2023-06-12 20:57:00 +08:00
parent 46af29dd36
commit 6a7c4aa5f6
7 changed files with 70 additions and 43 deletions

View File

@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat):
) )
self.url = url self.url = url
vector_store_config = { vector_store_config = {
"vector_store_name": url, "vector_store_name": url.replace(":", ""),
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(

View File

@ -1,11 +1,13 @@
from typing import Optional from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@ -19,6 +21,8 @@ KnowledgeEmbeddingType = {
".doc": (WordEmbedding, {}), ".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}), ".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}), ".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
} }
@ -42,8 +46,12 @@ class KnowledgeEmbedding:
self.knowledge_embedding_client = self.init_knowledge_embedding() self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding() self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self): def knowledge_embedding_batch(self, docs):
self.knowledge_embedding_client.batch_embedding() # docs = self.knowledge_embedding_client.read_batch()
self.knowledge_embedding_client.index_to_store(docs)
def read(self):
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
if self.file_type == "url": if self.file_type == "url":
@ -68,7 +76,11 @@ class KnowledgeEmbedding:
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config CFG.VECTOR_STORE_TYPE, self.vector_store_config
) )
return vector_client.similar_search(text, topk) try:
ans = vector_client.similar_search(text, topk)
except NotEnoughElementsException:
ans = vector_client.similar_search(text, 1)
return ans
def vector_exist(self): def vector_exist(self):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(

View File

@ -5,8 +5,8 @@ from typing import List
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding import SourceEmbedding, register
@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
text_splitter = CHNDocumentSplitter( textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE return loader.load_and_split(textsplitter)
)
return loader.load_and_split(text_splitter)
@register
def read_batch(self):
"""Load from markdown path."""
docments = []
for root, _, files in os.walk(self.file_path, topdown=False):
for file in files:
filename = os.path.join(root, file)
loader = TextLoader(filename)
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
# docs = loader.load_and_split()
docs = loader.load()
# 更新metadata数据
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(self.file_path, "")
}
print("doc is embedding ... ", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# ) # )
textsplitter = SpacyTextSplitter( textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200 pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
) )
return loader.load_and_split(textsplitter) return loader.load_and_split(textsplitter)

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
CFG = Config()
class PPTEmbedding(SourceEmbedding):
"""ppt embedding for read ppt document."""
def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path."""
super().__init__(file_path, vector_store_config)
self.file_path = file_path
self.vector_store_config = vector_store_config
@register
def read(self):
"""Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@ -62,7 +64,11 @@ class SourceEmbedding(ABC):
@register @register
def similar_search(self, doc, topk): def similar_search(self, doc, topk):
"""vector store similarity_search""" """vector store similarity_search"""
return self.vector_client.similar_search(doc, topk) try:
ans = self.vector_client.similar_search(doc, topk)
except NotEnoughElementsException:
ans = self.vector_client.similar_search(doc, 1)
return ans
def vector_name_exist(self): def vector_name_exist(self):
return self.vector_client.vector_name_exists() return self.vector_client.vector_name_exists()
@ -79,14 +85,11 @@ class SourceEmbedding(ABC):
if "index_to_store" in registered_methods: if "index_to_store" in registered_methods:
self.index_to_store(text) self.index_to_store(text)
def batch_embedding(self): def read_batch(self):
if "read_batch" in registered_methods: if "read" in registered_methods:
text = self.read_batch() text = self.read()
if "data_process" in registered_methods: if "data_process" in registered_methods:
text = self.data_process(text) text = self.data_process(text)
if "text_split" in registered_methods: if "text_split" in registered_methods:
self.text_split(text) self.text_split(text)
if "text_to_vector" in registered_methods: return text
self.text_to_vector(text)
if "index_to_store" in registered_methods:
self.index_to_store(text)

View File

@ -23,7 +23,7 @@ class LocalKnowledgeInit:
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.model_name = LLM_MODEL_CONFIG["text2vec"] self.model_name = LLM_MODEL_CONFIG["text2vec"]
def knowledge_persist(self, file_path, append_mode): def knowledge_persist(self, file_path):
"""knowledge persist""" """knowledge persist"""
for root, _, files in os.walk(file_path, topdown=False): for root, _, files in os.walk(file_path, topdown=False):
for file in files: for file in files:
@ -41,7 +41,6 @@ class LocalKnowledgeInit:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default") parser.add_argument("--vector_name", type=str, default="default")
parser.add_argument("--append", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
vector_name = args.vector_name vector_name = args.vector_name
append_mode = args.append append_mode = args.append
@ -49,5 +48,5 @@ if __name__ == "__main__":
vector_store_config = {"vector_store_name": vector_name} vector_store_config = {"vector_store_name": vector_name}
print(vector_store_config) print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) kv.knowledge_persist(file_path=DATASETS_DIR)
print("your knowledge embedding success...") print("your knowledge embedding success...")