DirectoryLoader slicing (#8994)

DirectoryLoader can now return a random sample of files in a directory.
Parameters added are:
sample_size
randomize_sample
sample_seed


@rlancemartin, @eyurtsev

---------

Co-authored-by: Andrew Oseen <amovfx@protonmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kaizen 2023-08-09 16:05:16 -07:00 committed by GitHub
parent d248481f13
commit bbbd2b076f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
"""Load documents from a directory.""" """Load documents from a directory."""
import concurrent import concurrent
import logging import logging
import random
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader):
show_progress: bool = False, show_progress: bool = False,
use_multithreading: bool = False, use_multithreading: bool = False,
max_concurrency: int = 4, max_concurrency: int = 4,
*,
sample_size: int = 0,
randomize_sample: bool = False,
sample_seed: Union[int, None] = None,
): ):
"""Initialize with a path to directory and how to glob over it. """Initialize with a path to directory and how to glob over it.
@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader):
show_progress: Whether to show a progress bar. Defaults to False. show_progress: Whether to show a progress bar. Defaults to False.
use_multithreading: Whether to use multithreading. Defaults to False. use_multithreading: Whether to use multithreading. Defaults to False.
max_concurrency: The maximum number of threads to use. Defaults to 4. max_concurrency: The maximum number of threads to use. Defaults to 4.
sample_size: The maximum number of files you would like to load from the
directory.
randomize_sample: Suffle the files to get a random sample.
sample_seed: set the seed of the random shuffle for reporoducibility.
""" """
if loader_kwargs is None: if loader_kwargs is None:
loader_kwargs = {} loader_kwargs = {}
@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader):
self.show_progress = show_progress self.show_progress = show_progress
self.use_multithreading = use_multithreading self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.sample_seed = sample_seed
def load_file( def load_file(
self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any] self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader):
docs: List[Document] = [] docs: List[Document] = []
items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob)) items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
if self.sample_size > 0:
if self.randomize_sample:
randomizer = (
random.Random(self.sample_seed) if self.sample_seed else random
)
randomizer.shuffle(items) # type: ignore
items = items[: min(len(items), self.sample_size)]
pbar = None pbar = None
if self.show_progress: if self.show_progress:
try: try: