diff --git a/libs/community/langchain_community/document_loaders/directory.py b/libs/community/langchain_community/document_loaders/directory.py index b20eff88759..d46d545656e 100644 --- a/libs/community/langchain_community/document_loaders/directory.py +++ b/libs/community/langchain_community/document_loaders/directory.py @@ -2,7 +2,7 @@ import concurrent import logging import random from pathlib import Path -from typing import Any, Callable, Iterator, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Type, Union from langchain_core.documents import Document @@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader): def __init__( self, path: str, - glob: str = "**/[!.]*", + glob: Union[List[str], Tuple[str], str] = "**/[!.]*", silent_errors: bool = False, load_hidden: bool = False, loader_cls: FILE_LOADER_TYPE = UnstructuredFileLoader, @@ -51,8 +51,8 @@ class DirectoryLoader(BaseLoader): Args: path: Path to directory. - glob: Glob pattern to use to find files. Defaults to "**/[!.]*" - (all files except hidden). + glob: A glob pattern or list of glob patterns to use to find files. + Defaults to "**/[!.]*" (all files except hidden). exclude: A pattern or list of patterns to exclude from results. Use glob syntax. silent_errors: Whether to silently ignore errors. Defaults to False. @@ -124,7 +124,20 @@ class DirectoryLoader(BaseLoader): if not p.is_dir(): raise ValueError(f"Expected directory, got file: '{self.path}'") - paths = p.rglob(self.glob) if self.recursive else p.glob(self.glob) + # glob multiple patterns if a list is provided, e.g., multiple file extensions + if isinstance(self.glob, (list, tuple)): + paths = [] + for pattern in self.glob: + paths.extend( + list(p.rglob(pattern) if self.recursive else p.glob(pattern)) + ) + elif isinstance(self.glob, str): + paths = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob)) + else: + raise TypeError( + f"Expected glob to be str or sequence of str, but got {type(self.glob)}" + ) + items = [ path for path in paths diff --git a/libs/community/tests/unit_tests/document_loaders/test_directory.py b/libs/community/tests/unit_tests/document_loaders/test_directory.py index 9523ebfa26a..2b2440e504f 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_directory.py +++ b/libs/community/tests/unit_tests/document_loaders/test_directory.py @@ -5,6 +5,7 @@ import pytest from langchain_core.documents import Document from langchain_community.document_loaders import DirectoryLoader +from langchain_community.document_loaders.text import TextLoader def test_raise_error_if_path_not_exist() -> None: @@ -23,7 +24,7 @@ def test_raise_error_if_path_is_not_directory() -> None: assert str(e.value) == f"Expected directory, got file: '{__file__}'" -class CustomLoader: +class CustomLoader(TextLoader): """Test loader. Mimics interface of existing file loader.""" def __init__(self, path: Path, **kwargs: Any) -> None: @@ -56,3 +57,44 @@ def test_exclude_ignores_matching_files(tmp_path: Path) -> None: def test_exclude_as_string_converts_to_sequence() -> None: loader = DirectoryLoader("./some_directory", exclude="*.py") assert loader.exclude == ("*.py",) + + +class CustomLoaderMetadataOnly(CustomLoader): + """Test loader that just returns the file path in metadata. For test_directory_loader_glob_multiple.""" # noqa: E501 + + def load(self) -> List[Document]: + metadata = {"source": self.path} + return [Document(page_content="", metadata=metadata)] + + def lazy_load(self) -> Iterator[Document]: + return iter(self.load()) + + +def test_directory_loader_glob_multiple() -> None: + """Verify that globbing multiple patterns in a list works correctly.""" + + path_to_examples = "tests/examples/" + list_extensions = [".rst", ".txt"] + list_globs = [f"**/*{ext}" for ext in list_extensions] + is_file_type_loaded = {ext: False for ext in list_extensions} + + loader = DirectoryLoader( + path=path_to_examples, glob=list_globs, loader_cls=CustomLoaderMetadataOnly + ) + + list_documents = loader.load() + + for doc in list_documents: + path_doc = Path(doc.metadata.get("source", "")) + ext_doc = path_doc.suffix + + if is_file_type_loaded.get(ext_doc, False): + continue + elif ext_doc in list_extensions: + is_file_type_loaded[ext_doc] = True + else: + # Loaded a filetype that was not specified in extensions list + assert False + + for ext in list_extensions: + assert is_file_type_loaded.get(ext, False)