From d6a23ead3b0155a24feea65d44dbf2b9cdd1eff7 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 3 Aug 2023 13:02:26 +0800 Subject: [PATCH] fix:csv_loader bug 1.add new_csv_loader,override load() 2.add loader dir Close #396 --- pilot/embedding_engine/csv_embedding.py | 6 +- .../{ => loader}/chn_document_splitter.py | 0 pilot/embedding_engine/loader/csv_loader.py | 76 +++++++++++++++++++ .../{ => loader}/docx_loader.py | 0 .../{ => loader}/pdf_loader.py | 0 .../{ => loader}/ppt_loader.py | 0 pilot/embedding_engine/ppt_embedding.py | 2 +- pilot/embedding_engine/word_embedding.py | 2 +- 8 files changed, 81 insertions(+), 5 deletions(-) rename pilot/embedding_engine/{ => loader}/chn_document_splitter.py (100%) create mode 100644 pilot/embedding_engine/loader/csv_loader.py rename pilot/embedding_engine/{ => loader}/docx_loader.py (100%) rename pilot/embedding_engine/{ => loader}/pdf_loader.py (100%) rename pilot/embedding_engine/{ => loader}/ppt_loader.py (100%) diff --git a/pilot/embedding_engine/csv_embedding.py b/pilot/embedding_engine/csv_embedding.py index 216670b94..5f58fac9d 100644 --- a/pilot/embedding_engine/csv_embedding.py +++ b/pilot/embedding_engine/csv_embedding.py @@ -1,6 +1,5 @@ -from typing import Dict, List, Optional +from typing import List, Optional -from langchain.document_loaders import CSVLoader from langchain.schema import Document from langchain.text_splitter import ( TextSplitter, @@ -9,6 +8,7 @@ from langchain.text_splitter import ( ) from pilot.embedding_engine import SourceEmbedding, register +from pilot.embedding_engine.loader.csv_loader import NewCSVLoader class CSVEmbedding(SourceEmbedding): @@ -34,7 +34,7 @@ class CSVEmbedding(SourceEmbedding): def read(self): """Load from csv path.""" if self.source_reader is None: - self.source_reader = CSVLoader(self.file_path) + self.source_reader = NewCSVLoader(self.file_path) if self.text_splitter is None: try: self.text_splitter = SpacyTextSplitter( diff --git a/pilot/embedding_engine/chn_document_splitter.py b/pilot/embedding_engine/loader/chn_document_splitter.py similarity index 100% rename from pilot/embedding_engine/chn_document_splitter.py rename to pilot/embedding_engine/loader/chn_document_splitter.py diff --git a/pilot/embedding_engine/loader/csv_loader.py b/pilot/embedding_engine/loader/csv_loader.py new file mode 100644 index 000000000..bdb58112e --- /dev/null +++ b/pilot/embedding_engine/loader/csv_loader.py @@ -0,0 +1,76 @@ +"""Loads a CSV file into a list of documents. + +Each document represents one row of the CSV file. Every row is converted into a +key/value pair and outputted to a new line in the document's page_content. + +The source for each document loaded from csv is set to the value of the +`file_path` argument for all doucments by default. +You can override this by setting the `source_column` argument to the +name of a column in the CSV file. +The source of each document will then be set to the value of the column +with the name specified in `source_column`. + +Output Example: + .. code-block:: txt + + column1: value1 + column2: value2 + column3: value3 +""" +from typing import Optional, Dict, List +import csv +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document + + +class NewCSVLoader(BaseLoader): + def __init__( + self, + file_path: str, + source_column: Optional[str] = None, + csv_args: Optional[Dict] = None, + encoding: Optional[str] = None, + ): + """ + + Args: + file_path: The path to the CSV file. + source_column: The name of the column in the CSV file to use as the source. + Optional. Defaults to None. + csv_args: A dictionary of arguments to pass to the csv.DictReader. + Optional. Defaults to None. + encoding: The encoding of the CSV file. Optional. Defaults to None. + """ + self.file_path = file_path + self.source_column = source_column + self.encoding = encoding + self.csv_args = csv_args or {} + + def load(self) -> List[Document]: + """Load data into document objects.""" + + docs = [] + with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + strs = [] + for k, v in row.items(): + if k is None or v is None: + continue + strs.append(f"{k.strip()}: {v.strip()}") + content = "\n".join(strs) + try: + source = ( + row[self.source_column] + if self.source_column is not None + else self.file_path + ) + except KeyError: + raise ValueError( + f"Source column '{self.source_column}' not found in CSV file." + ) + metadata = {"source": source, "row": i} + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/pilot/embedding_engine/docx_loader.py b/pilot/embedding_engine/loader/docx_loader.py similarity index 100% rename from pilot/embedding_engine/docx_loader.py rename to pilot/embedding_engine/loader/docx_loader.py diff --git a/pilot/embedding_engine/pdf_loader.py b/pilot/embedding_engine/loader/pdf_loader.py similarity index 100% rename from pilot/embedding_engine/pdf_loader.py rename to pilot/embedding_engine/loader/pdf_loader.py diff --git a/pilot/embedding_engine/ppt_loader.py b/pilot/embedding_engine/loader/ppt_loader.py similarity index 100% rename from pilot/embedding_engine/ppt_loader.py rename to pilot/embedding_engine/loader/ppt_loader.py diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 7c691662d..9058c092f 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -10,7 +10,7 @@ from langchain.text_splitter import ( ) from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.ppt_loader import PPTLoader +from pilot.embedding_engine.loader.ppt_loader import PPTLoader class PPTEmbedding(SourceEmbedding): diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index 833be4012..aba50fe24 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -10,7 +10,7 @@ from langchain.text_splitter import ( ) from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.docx_loader import DocxLoader +from pilot.embedding_engine.loader.docx_loader import DocxLoader class WordEmbedding(SourceEmbedding):