fix: prevent to ingest local files (by default) (#2010)

* feat: prevent to local ingestion (by default) and add white-list

* docs: add local ingestion warning

* docs: add missing comment

* fix: update exception error

* fix: black
This commit is contained in:
Javier Martinez
2024-07-31 14:33:46 +02:00
committed by GitHub
parent 1020cd5328
commit e54a8fe043
5 changed files with 133 additions and 3 deletions

View File

@@ -7,12 +7,13 @@ from pathlib import Path
from private_gpt.di import global_injector
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.ingest.ingest_watcher import IngestWatcher
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
class LocalIngestWorker:
def __init__(self, ingest_service: IngestService) -> None:
def __init__(self, ingest_service: IngestService, setting: Settings) -> None:
self.ingest_service = ingest_service
self.total_documents = 0
@@ -20,6 +21,24 @@ class LocalIngestWorker:
self._files_under_root_folder: list[Path] = []
self.is_local_ingestion_enabled = setting.data.local_ingestion.enabled
self.allowed_local_folders = setting.data.local_ingestion.allow_ingest_from
def _validate_folder(self, folder_path: Path) -> None:
if not self.is_local_ingestion_enabled:
raise ValueError(
"Local ingestion is disabled."
"You can enable it in settings `ingestion.enabled`"
)
# Allow all folders if wildcard is present
if "*" in self.allowed_local_folders:
return
for allowed_folder in self.allowed_local_folders:
if not folder_path.is_relative_to(allowed_folder):
raise ValueError(f"Folder {folder_path} is not allowed for ingestion")
def _find_all_files_in_folder(self, root_path: Path, ignored: list[str]) -> None:
"""Search all files under the root folder recursively.
@@ -28,6 +47,7 @@ class LocalIngestWorker:
for file_path in root_path.iterdir():
if file_path.is_file() and file_path.name not in ignored:
self.total_documents += 1
self._validate_folder(file_path)
self._files_under_root_folder.append(file_path)
elif file_path.is_dir() and file_path.name not in ignored:
self._find_all_files_in_folder(file_path, ignored)
@@ -92,13 +112,13 @@ if args.log_file:
logger.addHandler(file_handler)
if __name__ == "__main__":
root_path = Path(args.folder)
if not root_path.exists():
raise ValueError(f"Path {args.folder} does not exist")
ingest_service = global_injector.get(IngestService)
worker = LocalIngestWorker(ingest_service)
settings = global_injector.get(Settings)
worker = LocalIngestWorker(ingest_service, settings)
worker.ingest_folder(root_path, args.ignored)
if args.ignored: