mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
DirectoryLoader slicing (#8994)
DirectoryLoader can now return a random sample of files in a directory. Parameters added are: sample_size randomize_sample sample_seed @rlancemartin, @eyurtsev --------- Co-authored-by: Andrew Oseen <amovfx@protonmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d248481f13
commit
bbbd2b076f
@ -1,6 +1,7 @@
|
||||
"""Load documents from a directory."""
|
||||
import concurrent
|
||||
import logging
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader):
|
||||
show_progress: bool = False,
|
||||
use_multithreading: bool = False,
|
||||
max_concurrency: int = 4,
|
||||
*,
|
||||
sample_size: int = 0,
|
||||
randomize_sample: bool = False,
|
||||
sample_seed: Union[int, None] = None,
|
||||
):
|
||||
"""Initialize with a path to directory and how to glob over it.
|
||||
|
||||
@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader):
|
||||
show_progress: Whether to show a progress bar. Defaults to False.
|
||||
use_multithreading: Whether to use multithreading. Defaults to False.
|
||||
max_concurrency: The maximum number of threads to use. Defaults to 4.
|
||||
sample_size: The maximum number of files you would like to load from the
|
||||
directory.
|
||||
randomize_sample: Suffle the files to get a random sample.
|
||||
sample_seed: set the seed of the random shuffle for reporoducibility.
|
||||
"""
|
||||
if loader_kwargs is None:
|
||||
loader_kwargs = {}
|
||||
@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader):
|
||||
self.show_progress = show_progress
|
||||
self.use_multithreading = use_multithreading
|
||||
self.max_concurrency = max_concurrency
|
||||
self.sample_size = sample_size
|
||||
self.randomize_sample = randomize_sample
|
||||
self.sample_seed = sample_seed
|
||||
|
||||
def load_file(
|
||||
self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
|
||||
@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader):
|
||||
docs: List[Document] = []
|
||||
items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
|
||||
|
||||
if self.sample_size > 0:
|
||||
if self.randomize_sample:
|
||||
randomizer = (
|
||||
random.Random(self.sample_seed) if self.sample_seed else random
|
||||
)
|
||||
randomizer.shuffle(items) # type: ignore
|
||||
items = items[: min(len(items), self.sample_size)]
|
||||
|
||||
pbar = None
|
||||
if self.show_progress:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user