mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 05:23:37 +00:00
835 lines
26 KiB
Python
835 lines
26 KiB
Python
"""File storage interface."""
|
|
|
|
import dataclasses
|
|
import hashlib
|
|
import io
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from io import BytesIO
|
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
|
|
from urllib.parse import parse_qs, urlencode, urlparse
|
|
|
|
import requests
|
|
|
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
|
from dbgpt.util.tracer import root_tracer, trace
|
|
|
|
from .storage import (
|
|
InMemoryStorage,
|
|
QuerySpec,
|
|
ResourceIdentifier,
|
|
StorageError,
|
|
StorageInterface,
|
|
StorageItem,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_SCHEMA = "dbgpt-fs"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FileMetadataIdentifier(ResourceIdentifier):
|
|
"""File metadata identifier."""
|
|
|
|
file_id: str
|
|
bucket: str
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert the identifier to a dictionary."""
|
|
return {"file_id": self.file_id, "bucket": self.bucket}
|
|
|
|
@property
|
|
def str_identifier(self) -> str:
|
|
"""Get the string identifier.
|
|
|
|
Returns:
|
|
str: The string identifier
|
|
"""
|
|
return f"{self.bucket}/{self.file_id}"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FileMetadata(StorageItem):
|
|
"""File metadata for storage."""
|
|
|
|
file_id: str
|
|
bucket: str
|
|
file_name: str
|
|
file_size: int
|
|
storage_type: str
|
|
storage_path: str
|
|
uri: str
|
|
custom_metadata: Dict[str, Any]
|
|
file_hash: str
|
|
user_name: Optional[str] = None
|
|
sys_code: Optional[str] = None
|
|
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
|
|
|
|
def __post_init__(self):
|
|
"""Post init method."""
|
|
self._identifier = FileMetadataIdentifier(
|
|
file_id=self.file_id, bucket=self.bucket
|
|
)
|
|
custom_metadata = self.custom_metadata or {}
|
|
if not self.user_name:
|
|
self.user_name = custom_metadata.get("user_name")
|
|
if not self.sys_code:
|
|
self.sys_code = custom_metadata.get("sys_code")
|
|
|
|
@property
|
|
def identifier(self) -> ResourceIdentifier:
|
|
"""Get the resource identifier."""
|
|
return self._identifier
|
|
|
|
def merge(self, other: "StorageItem") -> None:
|
|
"""Merge the metadata with another item."""
|
|
if not isinstance(other, FileMetadata):
|
|
raise StorageError("Cannot merge different types of items")
|
|
self._from_object(other)
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert the metadata to a dictionary."""
|
|
return {
|
|
"file_id": self.file_id,
|
|
"bucket": self.bucket,
|
|
"file_name": self.file_name,
|
|
"file_size": self.file_size,
|
|
"storage_type": self.storage_type,
|
|
"storage_path": self.storage_path,
|
|
"uri": self.uri,
|
|
"custom_metadata": self.custom_metadata,
|
|
"file_hash": self.file_hash,
|
|
}
|
|
|
|
def _from_object(self, obj: "FileMetadata") -> None:
|
|
self.file_id = obj.file_id
|
|
self.bucket = obj.bucket
|
|
self.file_name = obj.file_name
|
|
self.file_size = obj.file_size
|
|
self.storage_type = obj.storage_type
|
|
self.storage_path = obj.storage_path
|
|
self.uri = obj.uri
|
|
self.custom_metadata = obj.custom_metadata
|
|
self.file_hash = obj.file_hash
|
|
self._identifier = obj._identifier
|
|
|
|
|
|
class FileStorageURI:
|
|
"""File storage URI."""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_type: str,
|
|
bucket: str,
|
|
file_id: str,
|
|
version: Optional[str] = None,
|
|
custom_params: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""Initialize the file storage URI."""
|
|
self.scheme = _SCHEMA
|
|
self.storage_type = storage_type
|
|
self.bucket = bucket
|
|
self.file_id = file_id
|
|
self.version = version
|
|
self.custom_params = custom_params or {}
|
|
|
|
@classmethod
|
|
def is_local_file(cls, uri: str) -> bool:
|
|
"""Check if the URI is local."""
|
|
parsed = urlparse(uri)
|
|
if not parsed.scheme or parsed.scheme == "file":
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def parse(cls, uri: str) -> "FileStorageURI":
|
|
"""Parse the URI string."""
|
|
parsed = urlparse(uri)
|
|
if parsed.scheme != _SCHEMA:
|
|
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
|
|
path_parts = parsed.path.strip("/").split("/")
|
|
if len(path_parts) < 2:
|
|
raise ValueError("Invalid URI path. Must contain bucket and file ID")
|
|
storage_type = parsed.netloc
|
|
bucket = path_parts[0]
|
|
file_id = path_parts[1]
|
|
version = path_parts[2] if len(path_parts) > 2 else None
|
|
custom_params = parse_qs(parsed.query)
|
|
return cls(storage_type, bucket, file_id, version, custom_params)
|
|
|
|
def __str__(self) -> str:
|
|
"""Get the string representation of the URI."""
|
|
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
|
|
if self.version:
|
|
base_uri += f"/{self.version}"
|
|
if self.custom_params:
|
|
query_string = urlencode(self.custom_params, doseq=True)
|
|
base_uri += f"?{query_string}"
|
|
return base_uri
|
|
|
|
|
|
class StorageBackend(ABC):
|
|
"""Storage backend interface."""
|
|
|
|
storage_type: str = "__base__"
|
|
|
|
@abstractmethod
|
|
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
|
"""Save the file data to the storage backend.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_id (str): The file ID
|
|
file_data (BinaryIO): The file data
|
|
|
|
Returns:
|
|
str: The storage path
|
|
"""
|
|
|
|
@abstractmethod
|
|
def load(self, fm: FileMetadata) -> BinaryIO:
|
|
"""Load the file data from the storage backend.
|
|
|
|
Args:
|
|
fm (FileMetadata): The file metadata
|
|
|
|
Returns:
|
|
BinaryIO: The file data
|
|
"""
|
|
|
|
@abstractmethod
|
|
def delete(self, fm: FileMetadata) -> bool:
|
|
"""Delete the file data from the storage backend.
|
|
|
|
Args:
|
|
fm (FileMetadata): The file metadata
|
|
|
|
Returns:
|
|
bool: True if the file was deleted, False otherwise
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def save_chunk_size(self) -> int:
|
|
"""Get the save chunk size.
|
|
|
|
Returns:
|
|
int: The save chunk size
|
|
"""
|
|
|
|
|
|
class LocalFileStorage(StorageBackend):
|
|
"""Local file storage backend."""
|
|
|
|
storage_type: str = "local"
|
|
|
|
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
|
|
"""Initialize the local file storage backend."""
|
|
self.base_path = base_path
|
|
self._save_chunk_size = save_chunk_size
|
|
os.makedirs(self.base_path, exist_ok=True)
|
|
|
|
@property
|
|
def save_chunk_size(self) -> int:
|
|
"""Get the save chunk size."""
|
|
return self._save_chunk_size
|
|
|
|
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
|
"""Save the file data to the local storage backend."""
|
|
bucket_path = os.path.join(self.base_path, bucket)
|
|
os.makedirs(bucket_path, exist_ok=True)
|
|
file_path = os.path.join(bucket_path, file_id)
|
|
with open(file_path, "wb") as f:
|
|
while True:
|
|
chunk = file_data.read(self.save_chunk_size)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
return file_path
|
|
|
|
def load(self, fm: FileMetadata) -> BinaryIO:
|
|
"""Load the file data from the local storage backend."""
|
|
bucket_path = os.path.join(self.base_path, fm.bucket)
|
|
file_path = os.path.join(bucket_path, fm.file_id)
|
|
return open(file_path, "rb") # noqa: SIM115
|
|
|
|
def delete(self, fm: FileMetadata) -> bool:
|
|
"""Delete the file data from the local storage backend."""
|
|
bucket_path = os.path.join(self.base_path, fm.bucket)
|
|
file_path = os.path.join(bucket_path, fm.file_id)
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
return True
|
|
return False
|
|
|
|
|
|
class FileStorageSystem:
|
|
"""File storage system."""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_backends: Dict[str, StorageBackend],
|
|
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
|
|
check_hash: bool = True,
|
|
):
|
|
"""Initialize the file storage system."""
|
|
metadata_storage = metadata_storage or InMemoryStorage()
|
|
self.storage_backends = storage_backends
|
|
self.metadata_storage = metadata_storage
|
|
self.check_hash = check_hash
|
|
self._save_chunk_size = min(
|
|
backend.save_chunk_size for backend in storage_backends.values()
|
|
)
|
|
|
|
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
|
|
"""Calculate the MD5 hash of the file data."""
|
|
if not self.check_hash:
|
|
return "-1"
|
|
hasher = hashlib.md5()
|
|
file_data.seek(0)
|
|
while chunk := file_data.read(self._save_chunk_size):
|
|
hasher.update(chunk)
|
|
file_data.seek(0)
|
|
return hasher.hexdigest()
|
|
|
|
@trace("file_storage_system.save_file")
|
|
def save_file(
|
|
self,
|
|
bucket: str,
|
|
file_name: str,
|
|
file_data: BinaryIO,
|
|
storage_type: str,
|
|
custom_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""Save the file data to the storage backend."""
|
|
file_id = str(uuid.uuid4())
|
|
backend = self.storage_backends.get(storage_type)
|
|
if not backend:
|
|
raise ValueError(f"Unsupported storage type: {storage_type}")
|
|
|
|
with root_tracer.start_span(
|
|
"file_storage_system.save_file.backend_save",
|
|
metadata={
|
|
"bucket": bucket,
|
|
"file_id": file_id,
|
|
"file_name": file_name,
|
|
"storage_type": storage_type,
|
|
},
|
|
):
|
|
storage_path = backend.save(bucket, file_id, file_data)
|
|
file_data.seek(0, 2) # Move to the end of the file
|
|
file_size = file_data.tell() # Get the file size
|
|
file_data.seek(0) # Reset file pointer
|
|
|
|
# filter None value
|
|
custom_metadata = (
|
|
{k: v for k, v in custom_metadata.items() if v is not None}
|
|
if custom_metadata
|
|
else {}
|
|
)
|
|
|
|
with root_tracer.start_span(
|
|
"file_storage_system.save_file.calculate_hash",
|
|
):
|
|
file_hash = self._calculate_file_hash(file_data)
|
|
uri = FileStorageURI(
|
|
storage_type, bucket, file_id, custom_params=custom_metadata
|
|
)
|
|
|
|
metadata = FileMetadata(
|
|
file_id=file_id,
|
|
bucket=bucket,
|
|
file_name=file_name,
|
|
file_size=file_size,
|
|
storage_type=storage_type,
|
|
storage_path=storage_path,
|
|
uri=str(uri),
|
|
custom_metadata=custom_metadata,
|
|
file_hash=file_hash,
|
|
)
|
|
|
|
self.metadata_storage.save(metadata)
|
|
return str(uri)
|
|
|
|
@trace("file_storage_system.get_file")
|
|
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
|
"""Get the file data from the storage backend."""
|
|
if FileStorageURI.is_local_file(uri):
|
|
local_file_name = uri.split("/")[-1]
|
|
if not os.path.exists(uri):
|
|
raise FileNotFoundError(f"File not found: {uri}")
|
|
|
|
dummy_metadata = FileMetadata(
|
|
file_id=local_file_name,
|
|
bucket="dummy_bucket",
|
|
file_name=local_file_name,
|
|
file_size=-1,
|
|
storage_type="local",
|
|
storage_path=uri,
|
|
uri=uri,
|
|
custom_metadata={},
|
|
file_hash="",
|
|
)
|
|
logger.info(f"Reading local file: {uri}")
|
|
return open(uri, "rb"), dummy_metadata # noqa: SIM115
|
|
|
|
parsed_uri = FileStorageURI.parse(uri)
|
|
metadata = self.metadata_storage.load(
|
|
FileMetadataIdentifier(
|
|
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
|
),
|
|
FileMetadata,
|
|
)
|
|
if not metadata:
|
|
raise FileNotFoundError(f"No metadata found for URI: {uri}")
|
|
|
|
backend = self.storage_backends.get(metadata.storage_type)
|
|
if not backend:
|
|
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
|
|
|
with root_tracer.start_span(
|
|
"file_storage_system.get_file.backend_load",
|
|
metadata={
|
|
"bucket": metadata.bucket,
|
|
"file_id": metadata.file_id,
|
|
"file_name": metadata.file_name,
|
|
"storage_type": metadata.storage_type,
|
|
},
|
|
):
|
|
file_data = backend.load(metadata)
|
|
|
|
with root_tracer.start_span(
|
|
"file_storage_system.get_file.verify_hash",
|
|
):
|
|
calculated_hash = self._calculate_file_hash(file_data)
|
|
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
|
|
raise ValueError("File integrity check failed. Hash mismatch.")
|
|
|
|
return file_data, metadata
|
|
|
|
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
|
|
"""Get the file metadata.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_id (str): The file ID
|
|
|
|
Returns:
|
|
Optional[FileMetadata]: The file metadata
|
|
"""
|
|
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
|
|
return self.metadata_storage.load(fid, FileMetadata)
|
|
|
|
def delete_file(self, uri: str) -> bool:
|
|
"""Delete the file data from the storage backend.
|
|
|
|
Args:
|
|
uri (str): The file URI
|
|
|
|
Returns:
|
|
bool: True if the file was deleted, False otherwise
|
|
"""
|
|
parsed_uri = FileStorageURI.parse(uri)
|
|
fid = FileMetadataIdentifier(
|
|
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
|
)
|
|
metadata = self.metadata_storage.load(fid, FileMetadata)
|
|
if not metadata:
|
|
return False
|
|
|
|
backend = self.storage_backends.get(metadata.storage_type)
|
|
if not backend:
|
|
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
|
|
|
if backend.delete(metadata):
|
|
try:
|
|
self.metadata_storage.delete(fid)
|
|
return True
|
|
except Exception:
|
|
# If the metadata deletion fails, log the error and return False
|
|
return False
|
|
return False
|
|
|
|
def list_files(
|
|
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
|
) -> List[FileMetadata]:
|
|
"""List the files in the bucket."""
|
|
filters = filters or {}
|
|
filters["bucket"] = bucket
|
|
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
|
|
|
|
|
|
class FileStorageClient(BaseComponent):
|
|
"""File storage client component."""
|
|
|
|
name = ComponentType.FILE_STORAGE_CLIENT.value
|
|
|
|
def __init__(
|
|
self,
|
|
system_app: Optional[SystemApp] = None,
|
|
storage_system: Optional[FileStorageSystem] = None,
|
|
):
|
|
"""Initialize the file storage client."""
|
|
super().__init__(system_app=system_app)
|
|
if not storage_system:
|
|
from pathlib import Path
|
|
|
|
base_path = Path.home() / ".cache" / "dbgpt" / "files"
|
|
storage_system = FileStorageSystem(
|
|
{
|
|
LocalFileStorage.storage_type: LocalFileStorage(
|
|
base_path=str(base_path)
|
|
)
|
|
}
|
|
)
|
|
|
|
self.system_app = system_app
|
|
self._storage_system = storage_system
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
"""Initialize the application."""
|
|
self.system_app = system_app
|
|
|
|
@property
|
|
def storage_system(self) -> FileStorageSystem:
|
|
"""Get the file storage system."""
|
|
if not self._storage_system:
|
|
raise ValueError("File storage system not initialized")
|
|
return self._storage_system
|
|
|
|
def upload_file(
|
|
self,
|
|
bucket: str,
|
|
file_path: str,
|
|
storage_type: str,
|
|
custom_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""Upload a file to the storage system.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_path (str): The file path
|
|
storage_type (str): The storage type
|
|
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
|
None.
|
|
|
|
Returns:
|
|
str: The file URI
|
|
"""
|
|
with open(file_path, "rb") as file:
|
|
return self.save_file(
|
|
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
|
|
)
|
|
|
|
def save_file(
|
|
self,
|
|
bucket: str,
|
|
file_name: str,
|
|
file_data: BinaryIO,
|
|
storage_type: str,
|
|
custom_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""Save the file data to the storage system.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_name (str): The file name
|
|
file_data (BinaryIO): The file data
|
|
storage_type (str): The storage type
|
|
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
|
None.
|
|
|
|
Returns:
|
|
str: The file URI
|
|
"""
|
|
return self.storage_system.save_file(
|
|
bucket, file_name, file_data, storage_type, custom_metadata
|
|
)
|
|
|
|
def download_file(self, uri: str, destination_path: str) -> None:
|
|
"""Download a file from the storage system.
|
|
|
|
Args:
|
|
uri (str): The file URI
|
|
destination_path (str): The destination
|
|
|
|
Raises:
|
|
FileNotFoundError: If the file is not found
|
|
"""
|
|
file_data, _ = self.storage_system.get_file(uri)
|
|
with open(destination_path, "wb") as f:
|
|
f.write(file_data.read())
|
|
|
|
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
|
"""Get the file data from the storage system.
|
|
|
|
Args:
|
|
uri (str): The file URI
|
|
|
|
Returns:
|
|
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
|
"""
|
|
return self.storage_system.get_file(uri)
|
|
|
|
def get_file_by_id(
|
|
self, bucket: str, file_id: str
|
|
) -> Tuple[BinaryIO, FileMetadata]:
|
|
"""Get the file data from the storage system by ID.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_id (str): The file ID
|
|
|
|
Returns:
|
|
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
|
"""
|
|
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
|
if not metadata:
|
|
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
|
return self.get_file(metadata.uri)
|
|
|
|
def delete_file(self, uri: str) -> bool:
|
|
"""Delete the file data from the storage system.
|
|
|
|
Args:
|
|
uri (str): The file URI
|
|
|
|
Returns:
|
|
bool: True if the file was deleted, False otherwise
|
|
"""
|
|
return self.storage_system.delete_file(uri)
|
|
|
|
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
|
|
"""Delete the file data from the storage system by ID.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_id (str): The file ID
|
|
|
|
Returns:
|
|
bool: True if the file was deleted, False otherwise
|
|
"""
|
|
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
|
if not metadata:
|
|
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
|
return self.delete_file(metadata.uri)
|
|
|
|
def list_files(
|
|
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
|
) -> List[FileMetadata]:
|
|
"""List the files in the bucket.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
filters (Dict[str, Any], optional): Filters. Defaults to None.
|
|
|
|
Returns:
|
|
List[FileMetadata]: The list of file metadata
|
|
"""
|
|
return self.storage_system.list_files(bucket, filters)
|
|
|
|
|
|
class SimpleDistributedStorage(StorageBackend):
|
|
"""Simple distributed storage backend."""
|
|
|
|
storage_type: str = "distributed"
|
|
|
|
def __init__(
|
|
self,
|
|
node_address: str,
|
|
local_storage_path: str,
|
|
save_chunk_size: int = 1024 * 1024,
|
|
transfer_chunk_size: int = 1024 * 1024,
|
|
transfer_timeout: int = 360,
|
|
api_prefix: str = "/api/v2/serve/file/files",
|
|
):
|
|
"""Initialize the simple distributed storage backend."""
|
|
self.node_address = node_address
|
|
self.local_storage_path = local_storage_path
|
|
os.makedirs(self.local_storage_path, exist_ok=True)
|
|
self._save_chunk_size = save_chunk_size
|
|
self._transfer_chunk_size = transfer_chunk_size
|
|
self._transfer_timeout = transfer_timeout
|
|
self._api_prefix = api_prefix
|
|
|
|
@property
|
|
def save_chunk_size(self) -> int:
|
|
"""Get the save chunk size."""
|
|
return self._save_chunk_size
|
|
|
|
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
|
|
node_id = hashlib.md5(node_address.encode()).hexdigest()
|
|
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
|
|
|
|
def _parse_node_address(self, fm: FileMetadata) -> str:
|
|
storage_path = fm.storage_path
|
|
if not storage_path.startswith("distributed://"):
|
|
raise ValueError("Invalid storage path")
|
|
return storage_path.split("//")[1].split("/")[0]
|
|
|
|
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
|
"""Save the file data to the distributed storage backend.
|
|
|
|
Just save the file locally.
|
|
|
|
Args:
|
|
bucket (str): The bucket name
|
|
file_id (str): The file ID
|
|
file_data (BinaryIO): The file data
|
|
|
|
Returns:
|
|
str: The storage path
|
|
"""
|
|
file_path = self._get_file_path(bucket, file_id, self.node_address)
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
with open(file_path, "wb") as f:
|
|
while True:
|
|
chunk = file_data.read(self.save_chunk_size)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
|
|
return f"distributed://{self.node_address}/{bucket}/{file_id}"
|
|
|
|
def load(self, fm: FileMetadata) -> BinaryIO:
|
|
"""Load the file data from the distributed storage backend.
|
|
|
|
If the file is stored on the local node, load it from the local storage.
|
|
|
|
Args:
|
|
fm (FileMetadata): The file metadata
|
|
|
|
Returns:
|
|
BinaryIO: The file data
|
|
"""
|
|
file_id = fm.file_id
|
|
bucket = fm.bucket
|
|
node_address = self._parse_node_address(fm)
|
|
file_path = self._get_file_path(bucket, file_id, node_address)
|
|
|
|
# TODO: check if the file is cached in local storage
|
|
if node_address == self.node_address:
|
|
if os.path.exists(file_path):
|
|
return open(file_path, "rb") # noqa: SIM115
|
|
else:
|
|
raise FileNotFoundError(f"File {file_id} not found on the local node")
|
|
else:
|
|
response = requests.get(
|
|
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
|
timeout=self._transfer_timeout,
|
|
stream=True,
|
|
)
|
|
response.raise_for_status()
|
|
# TODO: cache the file in local storage
|
|
return StreamedBytesIO(
|
|
response.iter_content(chunk_size=self._transfer_chunk_size)
|
|
)
|
|
|
|
def delete(self, fm: FileMetadata) -> bool:
|
|
"""Delete the file data from the distributed storage backend.
|
|
|
|
If the file is stored on the local node, delete it from the local storage.
|
|
If the file is stored on a remote node, send a delete request to the remote
|
|
node.
|
|
|
|
Args:
|
|
fm (FileMetadata): The file metadata
|
|
|
|
Returns:
|
|
bool: True if the file was deleted, False otherwise
|
|
"""
|
|
file_id = fm.file_id
|
|
bucket = fm.bucket
|
|
node_address = self._parse_node_address(fm)
|
|
file_path = self._get_file_path(bucket, file_id, node_address)
|
|
if node_address == self.node_address:
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
return True
|
|
return False
|
|
else:
|
|
try:
|
|
response = requests.delete(
|
|
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
|
timeout=self._transfer_timeout,
|
|
)
|
|
response.raise_for_status()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
class StreamedBytesIO(io.BytesIO):
|
|
"""A BytesIO subclass that can be used with streaming responses.
|
|
|
|
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
|
|
"""
|
|
|
|
def __init__(self, request_iterator):
|
|
"""Initialize the StreamedBytesIO instance."""
|
|
super().__init__()
|
|
self._bytes = BytesIO()
|
|
self._iterator = request_iterator
|
|
|
|
def _load_all(self):
|
|
self._bytes.seek(0, io.SEEK_END)
|
|
for chunk in self._iterator:
|
|
self._bytes.write(chunk)
|
|
|
|
def _load_until(self, goal_position):
|
|
current_position = self._bytes.seek(0, io.SEEK_END)
|
|
while current_position < goal_position:
|
|
try:
|
|
current_position += self._bytes.write(next(self._iterator))
|
|
except StopIteration:
|
|
break
|
|
|
|
def tell(self) -> int:
|
|
"""Get the current position."""
|
|
return self._bytes.tell()
|
|
|
|
def read(self, size: Optional[int] = None) -> bytes:
|
|
"""Read the data from the stream.
|
|
|
|
Args:
|
|
size (Optional[int], optional): The number of bytes to read. Defaults to
|
|
None.
|
|
|
|
Returns:
|
|
bytes: The read data
|
|
"""
|
|
left_off_at = self._bytes.tell()
|
|
if size is None:
|
|
self._load_all()
|
|
else:
|
|
goal_position = left_off_at + size
|
|
self._load_until(goal_position)
|
|
|
|
self._bytes.seek(left_off_at)
|
|
return self._bytes.read(size)
|
|
|
|
def seek(self, position: int, whence: int = io.SEEK_SET):
|
|
"""Seek to a position in the stream.
|
|
|
|
Args:
|
|
position (int): The position
|
|
whence (int, optional): The reference point. Defaults to io.SEEK
|
|
|
|
Raises:
|
|
ValueError: If the reference point is invalid
|
|
"""
|
|
if whence == io.SEEK_END:
|
|
self._load_all()
|
|
else:
|
|
self._bytes.seek(position, whence)
|
|
|
|
def __enter__(self):
|
|
"""Enter the context manager."""
|
|
return self
|
|
|
|
def __exit__(self, ext_type, value, tb):
|
|
"""Exit the context manager."""
|
|
self._bytes.close()
|