community[patch]: Implement lazy_load() for CSVLoader (#18391)

Covered by `test_csv_loader.py`
This commit is contained in:
Christophe Bornet 2024-03-01 20:17:08 +01:00 committed by GitHub
parent c54d6eb5da
commit 69be82c86d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import csv import csv
from io import TextIOWrapper from io import TextIOWrapper
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, Iterator, List, Optional, Sequence
from langchain_core.documents import Document from langchain_core.documents import Document
@ -61,13 +61,10 @@ class CSVLoader(BaseLoader):
self.csv_args = csv_args or {} self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]: def lazy_load(self) -> Iterator[Document]:
"""Load data into document objects."""
docs = []
try: try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile: with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self.__read_file(csvfile) yield from self.__read_file(csvfile)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self.autodetect_encoding: if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path) detected_encodings = detect_file_encodings(self.file_path)
@ -76,7 +73,7 @@ class CSVLoader(BaseLoader):
with open( with open(
self.file_path, newline="", encoding=encoding.encoding self.file_path, newline="", encoding=encoding.encoding
) as csvfile: ) as csvfile:
docs = self.__read_file(csvfile) yield from self.__read_file(csvfile)
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
@ -85,11 +82,7 @@ class CSVLoader(BaseLoader):
except Exception as e: except Exception as e:
raise RuntimeError(f"Error loading {self.file_path}") from e raise RuntimeError(f"Error loading {self.file_path}") from e
return docs def __read_file(self, csvfile: TextIOWrapper) -> Iterator[Document]:
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) csv_reader = csv.DictReader(csvfile, **self.csv_args)
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
try: try:
@ -113,10 +106,7 @@ class CSVLoader(BaseLoader):
metadata[col] = row[col] metadata[col] = row[col]
except KeyError: except KeyError:
raise ValueError(f"Metadata column '{col}' not found in CSV file.") raise ValueError(f"Metadata column '{col}' not found in CSV file.")
doc = Document(page_content=content, metadata=metadata) yield Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
class UnstructuredCSVLoader(UnstructuredFileLoader): class UnstructuredCSVLoader(UnstructuredFileLoader):