feat:docxLoader and ppt_loader

This commit is contained in:
aries_ckt 2023-07-25 15:21:53 +08:00
parent 48fc8c47ac
commit 487f91a1ec
4 changed files with 54 additions and 4 deletions

View File

@ -0,0 +1,26 @@
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
import docx
class DocxLoader(BaseLoader):
"""Load docx files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
docs = []
doc = docx.Document(self.file_path)
content = []
for i in range(len(doc.paragraphs)):
para = doc.paragraphs[i]
text = para.text
content.append(text)
docs.append(Document(page_content=''.join(content), metadata={"source": self.file_path}))
return docs

View File

@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
from typing import List, Optional
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import (
SpacyTextSplitter,
@ -11,6 +10,7 @@ from langchain.text_splitter import (
)
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.ppt_loader import PPTLoader
class PPTEmbedding(SourceEmbedding):
@ -36,7 +36,7 @@ class PPTEmbedding(SourceEmbedding):
def read(self):
"""Load from ppt path."""
if self.source_reader is None:
self.source_reader = UnstructuredPowerPointLoader(self.file_path)
self.source_reader = PPTLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(

View File

@ -0,0 +1,24 @@
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from pptx import Presentation
class PPTLoader(BaseLoader):
"""Load PPT files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
pr = Presentation(self.file_path)
docs = []
for slide in pr.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text is not "":
docs.append(Document(page_content=shape.text, metadata={"source": slide.slide_id}))
return docs

View File

@ -3,7 +3,6 @@
from typing import List, Optional
from langchain.schema import Document
from langchain.document_loaders import Docx2txtLoader
from langchain.text_splitter import (
SpacyTextSplitter,
RecursiveCharacterTextSplitter,
@ -11,6 +10,7 @@ from langchain.text_splitter import (
)
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.docx_loader import DocxLoader
class WordEmbedding(SourceEmbedding):
@ -36,7 +36,7 @@ class WordEmbedding(SourceEmbedding):
def read(self):
"""Load from word path."""
if self.source_reader is None:
self.source_reader = Docx2txtLoader(self.file_path)
self.source_reader = DocxLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(