DirectoryLoader slicing (#8994)

DirectoryLoader can now return a random sample of files in a directory.
Parameters added are:
sample_size
randomize_sample
sample_seed


@rlancemartin, @eyurtsev

---------

Co-authored-by: Andrew Oseen <amovfx@protonmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kaizen 2023-08-09 16:05:16 -07:00 committed by GitHub
parent d248481f13
commit bbbd2b076f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
"""Load documents from a directory."""
import concurrent
import logging
import random
from pathlib import Path
from typing import Any, List, Optional, Type, Union
@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader):
show_progress: bool = False,
use_multithreading: bool = False,
max_concurrency: int = 4,
*,
sample_size: int = 0,
randomize_sample: bool = False,
sample_seed: Union[int, None] = None,
):
"""Initialize with a path to directory and how to glob over it.
@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader):
show_progress: Whether to show a progress bar. Defaults to False.
use_multithreading: Whether to use multithreading. Defaults to False.
max_concurrency: The maximum number of threads to use. Defaults to 4.
sample_size: The maximum number of files you would like to load from the
directory.
randomize_sample: Suffle the files to get a random sample.
sample_seed: set the seed of the random shuffle for reporoducibility.
"""
if loader_kwargs is None:
loader_kwargs = {}
@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader):
self.show_progress = show_progress
self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.sample_seed = sample_seed
def load_file(
self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader):
docs: List[Document] = []
items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
if self.sample_size > 0:
if self.randomize_sample:
randomizer = (
random.Random(self.sample_seed) if self.sample_seed else random
)
randomizer.shuffle(items) # type: ignore
items = items[: min(len(items), self.sample_size)]
pbar = None
if self.show_progress:
try: