mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 07:27:45 +00:00
Add progress bar to filesystemblob loader, update pytest config for unit tests (#4212)
This PR adds: * Option to show a tqdm progress bar when using the file system blob loader * Update pytest run configuration to be stricter * Adding a new marker that checks that required pkgs exist
This commit is contained in:
parent
f4c8502e61
commit
aa11f7c89b
@ -1,9 +1,38 @@
|
||||
"""Use to load blobs from the local file system."""
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Sequence, Union
|
||||
from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _make_iterator(
|
||||
length_func: Callable[[], int], show_progress: bool = False
|
||||
) -> Callable[[Iterable[T]], Iterator[T]]:
|
||||
"""Create a function that optionally wraps an iterable in tqdm."""
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must install tqdm to use show_progress=True."
|
||||
"You can install tqdm with `pip install tqdm`."
|
||||
)
|
||||
|
||||
# Make sure to provide `total` here so that tqdm can show
|
||||
# a progress bar that takes into account the total number of files.
|
||||
def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
|
||||
"""Wrap an iterable in a tqdm progress bar."""
|
||||
return tqdm(iterable, total=length_func())
|
||||
|
||||
iterator = _with_tqdm
|
||||
else:
|
||||
iterator = iter # type: ignore
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
@ -26,6 +55,7 @@ class FileSystemBlobLoader(BlobLoader):
|
||||
*,
|
||||
glob: str = "**/[!.]*",
|
||||
suffixes: Optional[Sequence[str]] = None,
|
||||
show_progress: bool = False,
|
||||
) -> None:
|
||||
"""Initialize with path to directory and how to glob over it.
|
||||
|
||||
@ -36,6 +66,9 @@ class FileSystemBlobLoader(BlobLoader):
|
||||
suffixes: Provide to keep only files with these suffixes
|
||||
Useful when wanting to keep files with different suffixes
|
||||
Suffixes must include the dot, e.g. ".txt"
|
||||
show_progress: If true, will show a progress bar as the files are loaded.
|
||||
This forces an iteration through all matching files
|
||||
to count them prior to loading them.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -60,14 +93,33 @@ class FileSystemBlobLoader(BlobLoader):
|
||||
self.path = _path
|
||||
self.glob = glob
|
||||
self.suffixes = set(suffixes or [])
|
||||
self.show_progress = show_progress
|
||||
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""Yield blobs that match the requested pattern."""
|
||||
iterator = _make_iterator(
|
||||
length_func=self.count_matching_files, show_progress=self.show_progress
|
||||
)
|
||||
|
||||
for path in iterator(self._yield_paths()):
|
||||
yield Blob.from_path(path)
|
||||
|
||||
def _yield_paths(self) -> Iterable[Path]:
|
||||
"""Yield paths that match the requested pattern."""
|
||||
paths = self.path.glob(self.glob)
|
||||
for path in paths:
|
||||
if path.is_file():
|
||||
if self.suffixes and path.suffix not in self.suffixes:
|
||||
continue
|
||||
yield Blob.from_path(str(path))
|
||||
yield path
|
||||
|
||||
def count_matching_files(self) -> int:
|
||||
"""Count files that match the pattern without loading them."""
|
||||
# Carry out a full iteration to count the files without
|
||||
# materializing anything expensive in memory.
|
||||
num = 0
|
||||
for _ in self._yield_paths():
|
||||
num += 1
|
||||
return num
|
||||
|
@ -184,3 +184,17 @@ omit = [
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library"
|
||||
]
|
||||
|
44
tests/unit_tests/conftest.py
Normal file
44
tests/unit_tests/conftest.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
required_pkgs_info[pkg] = util.find_spec(pkg) is not None
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(pytest.mark.skip(reason=f"requires pkg: `{pkg}`"))
|
||||
break
|
@ -91,6 +91,8 @@ def test_file_names_exist(
|
||||
loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes)
|
||||
blobs = list(loader.yield_blobs())
|
||||
|
||||
assert loader.count_matching_files() == len(relative_filenames)
|
||||
|
||||
file_names = sorted(str(blob.path) for blob in blobs)
|
||||
|
||||
expected_filenames = sorted(
|
||||
@ -99,3 +101,11 @@ def test_file_names_exist(
|
||||
)
|
||||
|
||||
assert file_names == expected_filenames
|
||||
|
||||
|
||||
@pytest.mark.requires("tqdm")
|
||||
def test_show_progress(toy_dir: str) -> None:
|
||||
"""Verify that file system loader works with a progress bar."""
|
||||
loader = FileSystemBlobLoader(toy_dir)
|
||||
blobs = list(loader.yield_blobs())
|
||||
assert len(blobs) == loader.count_matching_files()
|
||||
|
Loading…
Reference in New Issue
Block a user