diff --git a/pilot/embedding_engine/docx_loader.py b/pilot/embedding_engine/docx_loader.py new file mode 100644 index 000000000..50d1a6b45 --- /dev/null +++ b/pilot/embedding_engine/docx_loader.py @@ -0,0 +1,26 @@ +from typing import List, Optional + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader +import docx + + +class DocxLoader(BaseLoader): + """Load docx files.""" + + def __init__(self, file_path: str, encoding: Optional[str] = None): + """Initialize with file path.""" + self.file_path = file_path + self.encoding = encoding + + def load(self) -> List[Document]: + """Load from file path.""" + docs = [] + doc = docx.Document(self.file_path) + content = [] + for i in range(len(doc.paragraphs)): + para = doc.paragraphs[i] + text = para.text + content.append(text) + docs.append(Document(page_content=''.join(content), metadata={"source": self.file_path})) + return docs diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 09370d496..7c691662d 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- from typing import List, Optional -from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.schema import Document from langchain.text_splitter import ( SpacyTextSplitter, @@ -11,6 +10,7 @@ from langchain.text_splitter import ( ) from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.ppt_loader import PPTLoader class PPTEmbedding(SourceEmbedding): @@ -36,7 +36,7 @@ class PPTEmbedding(SourceEmbedding): def read(self): """Load from ppt path.""" if self.source_reader is None: - self.source_reader = UnstructuredPowerPointLoader(self.file_path) + self.source_reader = PPTLoader(self.file_path) if self.text_splitter is None: try: self.text_splitter = SpacyTextSplitter( diff --git a/pilot/embedding_engine/ppt_loader.py b/pilot/embedding_engine/ppt_loader.py new file mode 100644 index 000000000..b3d970291 --- /dev/null +++ b/pilot/embedding_engine/ppt_loader.py @@ -0,0 +1,24 @@ +from typing import List, Optional + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader +from pptx import Presentation + + +class PPTLoader(BaseLoader): + """Load PPT files.""" + + def __init__(self, file_path: str, encoding: Optional[str] = None): + """Initialize with file path.""" + self.file_path = file_path + self.encoding = encoding + + def load(self) -> List[Document]: + """Load from file path.""" + pr = Presentation(self.file_path) + docs = [] + for slide in pr.slides: + for shape in slide.shapes: + if hasattr(shape, "text") and shape.text is not "": + docs.append(Document(page_content=shape.text, metadata={"source": slide.slide_id})) + return docs diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 55a2b5079..833be4012 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -3,7 +3,6 @@ from typing import List, Optional from langchain.schema import Document -from langchain.document_loaders import Docx2txtLoader from langchain.text_splitter import ( SpacyTextSplitter, RecursiveCharacterTextSplitter, @@ -11,6 +10,7 @@ from langchain.text_splitter import ( ) from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.docx_loader import DocxLoader class WordEmbedding(SourceEmbedding): @@ -36,7 +36,7 @@ class WordEmbedding(SourceEmbedding): def read(self): """Load from word path.""" if self.source_reader is None: - self.source_reader = Docx2txtLoader(self.file_path) + self.source_reader = DocxLoader(self.file_path) if self.text_splitter is None: try: self.text_splitter = SpacyTextSplitter(