diff --git a/libs/community/langchain_community/document_loaders/csv_loader.py b/libs/community/langchain_community/document_loaders/csv_loader.py index efc8507aa6f..8b0e6b3ab9c 100644 --- a/libs/community/langchain_community/document_loaders/csv_loader.py +++ b/libs/community/langchain_community/document_loaders/csv_loader.py @@ -104,6 +104,8 @@ class CSVLoader(BaseLoader): csv_args: Optional[Dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = False, + *, + content_columns: Sequence[str] = (), ): """ @@ -116,6 +118,8 @@ class CSVLoader(BaseLoader): Optional. Defaults to None. encoding: The encoding of the CSV file. Optional. Defaults to None. autodetect_encoding: Whether to try to autodetect the file encoding. + content_columns: A sequence of column names to use for the document content. + If not present, use all columns that are not part of the metadata. """ self.file_path = file_path self.source_column = source_column @@ -123,6 +127,7 @@ class CSVLoader(BaseLoader): self.encoding = encoding self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding + self.content_columns = content_columns def lazy_load(self) -> Iterator[Document]: try: @@ -163,7 +168,11 @@ class CSVLoader(BaseLoader): if isinstance(v, str) else ','.join(map(str.strip, v)) if isinstance(v, list) else v}""" for k, v in row.items() - if k not in self.metadata_columns + if ( + k in self.content_columns + if self.content_columns + else k not in self.metadata_columns + ) ) metadata = {"source": source, "row": i} for col in self.metadata_columns: diff --git a/libs/community/tests/unit_tests/document_loaders/test_csv_loader.py b/libs/community/tests/unit_tests/document_loaders/test_csv_loader.py index a7ab65e35a4..75a3b08bf79 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_csv_loader.py +++ b/libs/community/tests/unit_tests/document_loaders/test_csv_loader.py @@ -108,6 +108,27 @@ class TestCSVLoader: # Assert assert result == expected_docs + def test_csv_loader_content_columns(self) -> None: + # Setup + file_path = self._get_csv_file_path("test_none_col.csv") + expected_docs = [ + Document( + page_content="column1: value1\n" "column3: value3", + metadata={"source": file_path, "row": 0}, + ), + Document( + page_content="column1: value6\n" "column3: value8", + metadata={"source": file_path, "row": 1}, + ), + ] + + # Exercise + loader = CSVLoader(file_path=file_path, content_columns=("column1", "column3")) + result = loader.load() + + # Assert + assert result == expected_docs + # utility functions def _get_csv_file_path(self, file_name: str) -> str: return str(Path(__file__).resolve().parent / "test_docs" / "csv" / file_name)