mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
Add progress bar to filesystemblob loader, update pytest config for unit tests (#4212)
This PR adds: * Option to show a tqdm progress bar when using the file system blob loader * Update pytest run configuration to be stricter * Adding a new marker that checks that required pkgs exist
This commit is contained in:
parent
f4c8502e61
commit
aa11f7c89b
@ -1,9 +1,38 @@
|
|||||||
"""Use to load blobs from the local file system."""
|
"""Use to load blobs from the local file system."""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Optional, Sequence, Union
|
from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union
|
||||||
|
|
||||||
from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader
|
from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_iterator(
|
||||||
|
length_func: Callable[[], int], show_progress: bool = False
|
||||||
|
) -> Callable[[Iterable[T]], Iterator[T]]:
|
||||||
|
"""Create a function that optionally wraps an iterable in tqdm."""
|
||||||
|
if show_progress:
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"You must install tqdm to use show_progress=True."
|
||||||
|
"You can install tqdm with `pip install tqdm`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure to provide `total` here so that tqdm can show
|
||||||
|
# a progress bar that takes into account the total number of files.
|
||||||
|
def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
|
||||||
|
"""Wrap an iterable in a tqdm progress bar."""
|
||||||
|
return tqdm(iterable, total=length_func())
|
||||||
|
|
||||||
|
iterator = _with_tqdm
|
||||||
|
else:
|
||||||
|
iterator = iter # type: ignore
|
||||||
|
|
||||||
|
return iterator
|
||||||
|
|
||||||
|
|
||||||
# PUBLIC API
|
# PUBLIC API
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +55,7 @@ class FileSystemBlobLoader(BlobLoader):
|
|||||||
*,
|
*,
|
||||||
glob: str = "**/[!.]*",
|
glob: str = "**/[!.]*",
|
||||||
suffixes: Optional[Sequence[str]] = None,
|
suffixes: Optional[Sequence[str]] = None,
|
||||||
|
show_progress: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize with path to directory and how to glob over it.
|
"""Initialize with path to directory and how to glob over it.
|
||||||
|
|
||||||
@ -36,6 +66,9 @@ class FileSystemBlobLoader(BlobLoader):
|
|||||||
suffixes: Provide to keep only files with these suffixes
|
suffixes: Provide to keep only files with these suffixes
|
||||||
Useful when wanting to keep files with different suffixes
|
Useful when wanting to keep files with different suffixes
|
||||||
Suffixes must include the dot, e.g. ".txt"
|
Suffixes must include the dot, e.g. ".txt"
|
||||||
|
show_progress: If true, will show a progress bar as the files are loaded.
|
||||||
|
This forces an iteration through all matching files
|
||||||
|
to count them prior to loading them.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -60,14 +93,33 @@ class FileSystemBlobLoader(BlobLoader):
|
|||||||
self.path = _path
|
self.path = _path
|
||||||
self.glob = glob
|
self.glob = glob
|
||||||
self.suffixes = set(suffixes or [])
|
self.suffixes = set(suffixes or [])
|
||||||
|
self.show_progress = show_progress
|
||||||
|
|
||||||
def yield_blobs(
|
def yield_blobs(
|
||||||
self,
|
self,
|
||||||
) -> Iterable[Blob]:
|
) -> Iterable[Blob]:
|
||||||
"""Yield blobs that match the requested pattern."""
|
"""Yield blobs that match the requested pattern."""
|
||||||
|
iterator = _make_iterator(
|
||||||
|
length_func=self.count_matching_files, show_progress=self.show_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
for path in iterator(self._yield_paths()):
|
||||||
|
yield Blob.from_path(path)
|
||||||
|
|
||||||
|
def _yield_paths(self) -> Iterable[Path]:
|
||||||
|
"""Yield paths that match the requested pattern."""
|
||||||
paths = self.path.glob(self.glob)
|
paths = self.path.glob(self.glob)
|
||||||
for path in paths:
|
for path in paths:
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
if self.suffixes and path.suffix not in self.suffixes:
|
if self.suffixes and path.suffix not in self.suffixes:
|
||||||
continue
|
continue
|
||||||
yield Blob.from_path(str(path))
|
yield path
|
||||||
|
|
||||||
|
def count_matching_files(self) -> int:
|
||||||
|
"""Count files that match the pattern without loading them."""
|
||||||
|
# Carry out a full iteration to count the files without
|
||||||
|
# materializing anything expensive in memory.
|
||||||
|
num = 0
|
||||||
|
for _ in self._yield_paths():
|
||||||
|
num += 1
|
||||||
|
return num
|
||||||
|
@ -184,3 +184,17 @@ omit = [
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# --strict-markers will raise errors on unknown marks.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||||
|
#
|
||||||
|
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||||
|
# --strict-config any warnings encountered while parsing the `pytest`
|
||||||
|
# section of the configuration file raise errors.
|
||||||
|
addopts = "--strict-markers --strict-config --durations=5"
|
||||||
|
# Registering custom markers.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library"
|
||||||
|
]
|
||||||
|
44
tests/unit_tests/conftest.py
Normal file
44
tests/unit_tests/conftest.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Configuration for unit tests."""
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import Config, Function
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||||
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
|
|
||||||
|
The `requires` marker is used to denote tests that require one or more packages
|
||||||
|
to be installed to run. If the package is not installed, the test is skipped.
|
||||||
|
|
||||||
|
The `requires` marker syntax is:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@pytest.mark.requires("package1", "package2")
|
||||||
|
def test_something():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
# Mapping from the name of a package to whether it is installed or not.
|
||||||
|
# Used to avoid repeated calls to `util.find_spec`
|
||||||
|
required_pkgs_info: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
requires_marker = item.get_closest_marker("requires")
|
||||||
|
if requires_marker is not None:
|
||||||
|
# Iterate through the list of required packages
|
||||||
|
required_pkgs = requires_marker.args
|
||||||
|
for pkg in required_pkgs:
|
||||||
|
# If we haven't yet checked whether the pkg is installed
|
||||||
|
# let's check it and store the result.
|
||||||
|
if pkg not in required_pkgs_info:
|
||||||
|
required_pkgs_info[pkg] = util.find_spec(pkg) is not None
|
||||||
|
|
||||||
|
if not required_pkgs_info[pkg]:
|
||||||
|
# If the package is not installed, we immediately break
|
||||||
|
# and mark the test as skipped.
|
||||||
|
item.add_marker(pytest.mark.skip(reason=f"requires pkg: `{pkg}`"))
|
||||||
|
break
|
@ -91,6 +91,8 @@ def test_file_names_exist(
|
|||||||
loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes)
|
loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes)
|
||||||
blobs = list(loader.yield_blobs())
|
blobs = list(loader.yield_blobs())
|
||||||
|
|
||||||
|
assert loader.count_matching_files() == len(relative_filenames)
|
||||||
|
|
||||||
file_names = sorted(str(blob.path) for blob in blobs)
|
file_names = sorted(str(blob.path) for blob in blobs)
|
||||||
|
|
||||||
expected_filenames = sorted(
|
expected_filenames = sorted(
|
||||||
@ -99,3 +101,11 @@ def test_file_names_exist(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert file_names == expected_filenames
|
assert file_names == expected_filenames
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("tqdm")
|
||||||
|
def test_show_progress(toy_dir: str) -> None:
|
||||||
|
"""Verify that file system loader works with a progress bar."""
|
||||||
|
loader = FileSystemBlobLoader(toy_dir)
|
||||||
|
blobs = list(loader.yield_blobs())
|
||||||
|
assert len(blobs) == loader.count_matching_files()
|
||||||
|
Loading…
Reference in New Issue
Block a user