diff --git a/libs/langchain/langchain/storage/file_system.py b/libs/langchain/langchain/storage/file_system.py index 720acf085a1..e737309a182 100644 --- a/libs/langchain/langchain/storage/file_system.py +++ b/libs/langchain/langchain/storage/file_system.py @@ -1,3 +1,4 @@ +import os import re from pathlib import Path from typing import Iterator, List, Optional, Sequence, Tuple, Union @@ -42,7 +43,7 @@ class LocalFileStore(ByteStore): root_path (Union[str, Path]): The root path of the file store. All keys are interpreted as paths relative to this root. """ - self.root_path = Path(root_path) + self.root_path = Path(root_path).absolute() def _get_full_path(self, key: str) -> Path: """Get the full path for a given key relative to the root path. @@ -55,7 +56,15 @@ class LocalFileStore(ByteStore): """ if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key): raise InvalidKeyException(f"Invalid characters in key: {key}") - return self.root_path / key + full_path = os.path.abspath(self.root_path / key) + common_path = os.path.commonpath([str(self.root_path), full_path]) + if common_path != str(self.root_path): + raise InvalidKeyException( + f"Invalid key: {key}. Key should be relative to the full path." + f"{self.root_path} vs. {common_path} and full path of {full_path}" + ) + + return Path(full_path) def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: """Get the values associated with the given keys. diff --git a/libs/langchain/tests/unit_tests/storage/test_filesystem.py b/libs/langchain/tests/unit_tests/storage/test_filesystem.py index 76991eb2948..4213006bf25 100644 --- a/libs/langchain/tests/unit_tests/storage/test_filesystem.py +++ b/libs/langchain/tests/unit_tests/storage/test_filesystem.py @@ -77,3 +77,26 @@ def test_yield_keys(file_store: LocalFileStore) -> None: # Assert that the yielded keys match the expected keys expected_keys = ["key1", os.path.join("subdir", "key2")] assert keys == expected_keys + + +def test_catches_forbidden_keys(file_store: LocalFileStore) -> None: + """Make sure we raise exception on keys that are not allowed; e.g., absolute path""" + with pytest.raises(InvalidKeyException): + file_store.mset([("/etc", b"value1")]) + with pytest.raises(InvalidKeyException): + list(file_store.yield_keys("/etc/passwd")) + with pytest.raises(InvalidKeyException): + file_store.mget(["/etc/passwd"]) + + # check relative paths + with pytest.raises(InvalidKeyException): + list(file_store.yield_keys("..")) + + with pytest.raises(InvalidKeyException): + file_store.mget(["../etc/passwd"]) + + with pytest.raises(InvalidKeyException): + file_store.mset([("../etc", b"value1")]) + + with pytest.raises(InvalidKeyException): + list(file_store.yield_keys("../etc/passwd"))