langchain(patch): Restrict paths in LocalFileStore cache (#15065)

This PR restricts the paths that can be resolve using the local file system cache so that all paths must be contained within the root path.
This commit is contained in:
Eugene Yurtsev 2023-12-22 11:20:17 -05:00 committed by GitHub
parent 501cc8311d
commit aad3d8bd47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 2 deletions

View File

@ -1,3 +1,4 @@
import os
import re
from pathlib import Path
from typing import Iterator, List, Optional, Sequence, Tuple, Union
@ -42,7 +43,7 @@ class LocalFileStore(ByteStore):
root_path (Union[str, Path]): The root path of the file store. All keys are
interpreted as paths relative to this root.
"""
self.root_path = Path(root_path)
self.root_path = Path(root_path).absolute()
def _get_full_path(self, key: str) -> Path:
"""Get the full path for a given key relative to the root path.
@ -55,7 +56,15 @@ class LocalFileStore(ByteStore):
"""
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
raise InvalidKeyException(f"Invalid characters in key: {key}")
return self.root_path / key
full_path = os.path.abspath(self.root_path / key)
common_path = os.path.commonpath([str(self.root_path), full_path])
if common_path != str(self.root_path):
raise InvalidKeyException(
f"Invalid key: {key}. Key should be relative to the full path."
f"{self.root_path} vs. {common_path} and full path of {full_path}"
)
return Path(full_path)
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the values associated with the given keys.

View File

@ -77,3 +77,26 @@ def test_yield_keys(file_store: LocalFileStore) -> None:
# Assert that the yielded keys match the expected keys
expected_keys = ["key1", os.path.join("subdir", "key2")]
assert keys == expected_keys
def test_catches_forbidden_keys(file_store: LocalFileStore) -> None:
"""Make sure we raise exception on keys that are not allowed; e.g., absolute path"""
with pytest.raises(InvalidKeyException):
file_store.mset([("/etc", b"value1")])
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys("/etc/passwd"))
with pytest.raises(InvalidKeyException):
file_store.mget(["/etc/passwd"])
# check relative paths
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys(".."))
with pytest.raises(InvalidKeyException):
file_store.mget(["../etc/passwd"])
with pytest.raises(InvalidKeyException):
file_store.mset([("../etc", b"value1")])
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys("../etc/passwd"))