DB-GPT/dbgpt/core/interface/file.py
Fangyin Cheng 9502251c08
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
2024-08-23 14:57:54 +08:00

835 lines
26 KiB
Python

"""File storage interface."""
import dataclasses
import hashlib
import io
import logging
import os
import uuid
from abc import ABC, abstractmethod
from io import BytesIO
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import requests
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.tracer import root_tracer, trace
from .storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageError,
StorageInterface,
StorageItem,
)
logger = logging.getLogger(__name__)
_SCHEMA = "dbgpt-fs"
@dataclasses.dataclass
class FileMetadataIdentifier(ResourceIdentifier):
"""File metadata identifier."""
file_id: str
bucket: str
def to_dict(self) -> Dict:
"""Convert the identifier to a dictionary."""
return {"file_id": self.file_id, "bucket": self.bucket}
@property
def str_identifier(self) -> str:
"""Get the string identifier.
Returns:
str: The string identifier
"""
return f"{self.bucket}/{self.file_id}"
@dataclasses.dataclass
class FileMetadata(StorageItem):
"""File metadata for storage."""
file_id: str
bucket: str
file_name: str
file_size: int
storage_type: str
storage_path: str
uri: str
custom_metadata: Dict[str, Any]
file_hash: str
user_name: Optional[str] = None
sys_code: Optional[str] = None
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = FileMetadataIdentifier(
file_id=self.file_id, bucket=self.bucket
)
custom_metadata = self.custom_metadata or {}
if not self.user_name:
self.user_name = custom_metadata.get("user_name")
if not self.sys_code:
self.sys_code = custom_metadata.get("sys_code")
@property
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the metadata with another item."""
if not isinstance(other, FileMetadata):
raise StorageError("Cannot merge different types of items")
self._from_object(other)
def to_dict(self) -> Dict:
"""Convert the metadata to a dictionary."""
return {
"file_id": self.file_id,
"bucket": self.bucket,
"file_name": self.file_name,
"file_size": self.file_size,
"storage_type": self.storage_type,
"storage_path": self.storage_path,
"uri": self.uri,
"custom_metadata": self.custom_metadata,
"file_hash": self.file_hash,
}
def _from_object(self, obj: "FileMetadata") -> None:
self.file_id = obj.file_id
self.bucket = obj.bucket
self.file_name = obj.file_name
self.file_size = obj.file_size
self.storage_type = obj.storage_type
self.storage_path = obj.storage_path
self.uri = obj.uri
self.custom_metadata = obj.custom_metadata
self.file_hash = obj.file_hash
self._identifier = obj._identifier
class FileStorageURI:
"""File storage URI."""
def __init__(
self,
storage_type: str,
bucket: str,
file_id: str,
version: Optional[str] = None,
custom_params: Optional[Dict[str, Any]] = None,
):
"""Initialize the file storage URI."""
self.scheme = _SCHEMA
self.storage_type = storage_type
self.bucket = bucket
self.file_id = file_id
self.version = version
self.custom_params = custom_params or {}
@classmethod
def is_local_file(cls, uri: str) -> bool:
"""Check if the URI is local."""
parsed = urlparse(uri)
if not parsed.scheme or parsed.scheme == "file":
return True
return False
@classmethod
def parse(cls, uri: str) -> "FileStorageURI":
"""Parse the URI string."""
parsed = urlparse(uri)
if parsed.scheme != _SCHEMA:
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
path_parts = parsed.path.strip("/").split("/")
if len(path_parts) < 2:
raise ValueError("Invalid URI path. Must contain bucket and file ID")
storage_type = parsed.netloc
bucket = path_parts[0]
file_id = path_parts[1]
version = path_parts[2] if len(path_parts) > 2 else None
custom_params = parse_qs(parsed.query)
return cls(storage_type, bucket, file_id, version, custom_params)
def __str__(self) -> str:
"""Get the string representation of the URI."""
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
if self.version:
base_uri += f"/{self.version}"
if self.custom_params:
query_string = urlencode(self.custom_params, doseq=True)
base_uri += f"?{query_string}"
return base_uri
class StorageBackend(ABC):
"""Storage backend interface."""
storage_type: str = "__base__"
@abstractmethod
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the storage backend.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
@abstractmethod
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
@abstractmethod
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
@property
@abstractmethod
def save_chunk_size(self) -> int:
"""Get the save chunk size.
Returns:
int: The save chunk size
"""
class LocalFileStorage(StorageBackend):
"""Local file storage backend."""
storage_type: str = "local"
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
"""Initialize the local file storage backend."""
self.base_path = base_path
self._save_chunk_size = save_chunk_size
os.makedirs(self.base_path, exist_ok=True)
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the local storage backend."""
bucket_path = os.path.join(self.base_path, bucket)
os.makedirs(bucket_path, exist_ok=True)
file_path = os.path.join(bucket_path, file_id)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return file_path
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
return open(file_path, "rb") # noqa: SIM115
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
class FileStorageSystem:
"""File storage system."""
def __init__(
self,
storage_backends: Dict[str, StorageBackend],
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
check_hash: bool = True,
):
"""Initialize the file storage system."""
metadata_storage = metadata_storage or InMemoryStorage()
self.storage_backends = storage_backends
self.metadata_storage = metadata_storage
self.check_hash = check_hash
self._save_chunk_size = min(
backend.save_chunk_size for backend in storage_backends.values()
)
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
"""Calculate the MD5 hash of the file data."""
if not self.check_hash:
return "-1"
hasher = hashlib.md5()
file_data.seek(0)
while chunk := file_data.read(self._save_chunk_size):
hasher.update(chunk)
file_data.seek(0)
return hasher.hexdigest()
@trace("file_storage_system.save_file")
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage backend."""
file_id = str(uuid.uuid4())
backend = self.storage_backends.get(storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {storage_type}")
with root_tracer.start_span(
"file_storage_system.save_file.backend_save",
metadata={
"bucket": bucket,
"file_id": file_id,
"file_name": file_name,
"storage_type": storage_type,
},
):
storage_path = backend.save(bucket, file_id, file_data)
file_data.seek(0, 2) # Move to the end of the file
file_size = file_data.tell() # Get the file size
file_data.seek(0) # Reset file pointer
# filter None value
custom_metadata = (
{k: v for k, v in custom_metadata.items() if v is not None}
if custom_metadata
else {}
)
with root_tracer.start_span(
"file_storage_system.save_file.calculate_hash",
):
file_hash = self._calculate_file_hash(file_data)
uri = FileStorageURI(
storage_type, bucket, file_id, custom_params=custom_metadata
)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name=file_name,
file_size=file_size,
storage_type=storage_type,
storage_path=storage_path,
uri=str(uri),
custom_metadata=custom_metadata,
file_hash=file_hash,
)
self.metadata_storage.save(metadata)
return str(uri)
@trace("file_storage_system.get_file")
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage backend."""
if FileStorageURI.is_local_file(uri):
local_file_name = uri.split("/")[-1]
if not os.path.exists(uri):
raise FileNotFoundError(f"File not found: {uri}")
dummy_metadata = FileMetadata(
file_id=local_file_name,
bucket="dummy_bucket",
file_name=local_file_name,
file_size=-1,
storage_type="local",
storage_path=uri,
uri=uri,
custom_metadata={},
file_hash="",
)
logger.info(f"Reading local file: {uri}")
return open(uri, "rb"), dummy_metadata # noqa: SIM115
parsed_uri = FileStorageURI.parse(uri)
metadata = self.metadata_storage.load(
FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
),
FileMetadata,
)
if not metadata:
raise FileNotFoundError(f"No metadata found for URI: {uri}")
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
with root_tracer.start_span(
"file_storage_system.get_file.backend_load",
metadata={
"bucket": metadata.bucket,
"file_id": metadata.file_id,
"file_name": metadata.file_name,
"storage_type": metadata.storage_type,
},
):
file_data = backend.load(metadata)
with root_tracer.start_span(
"file_storage_system.get_file.verify_hash",
):
calculated_hash = self._calculate_file_hash(file_data)
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
raise ValueError("File integrity check failed. Hash mismatch.")
return file_data, metadata
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
"""Get the file metadata.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Optional[FileMetadata]: The file metadata
"""
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
return self.metadata_storage.load(fid, FileMetadata)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage backend.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
parsed_uri = FileStorageURI.parse(uri)
fid = FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
)
metadata = self.metadata_storage.load(fid, FileMetadata)
if not metadata:
return False
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
if backend.delete(metadata):
try:
self.metadata_storage.delete(fid)
return True
except Exception:
# If the metadata deletion fails, log the error and return False
return False
return False
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket."""
filters = filters or {}
filters["bucket"] = bucket
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
class FileStorageClient(BaseComponent):
"""File storage client component."""
name = ComponentType.FILE_STORAGE_CLIENT.value
def __init__(
self,
system_app: Optional[SystemApp] = None,
storage_system: Optional[FileStorageSystem] = None,
):
"""Initialize the file storage client."""
super().__init__(system_app=system_app)
if not storage_system:
from pathlib import Path
base_path = Path.home() / ".cache" / "dbgpt" / "files"
storage_system = FileStorageSystem(
{
LocalFileStorage.storage_type: LocalFileStorage(
base_path=str(base_path)
)
}
)
self.system_app = system_app
self._storage_system = storage_system
def init_app(self, system_app: SystemApp):
"""Initialize the application."""
self.system_app = system_app
@property
def storage_system(self) -> FileStorageSystem:
"""Get the file storage system."""
if not self._storage_system:
raise ValueError("File storage system not initialized")
return self._storage_system
def upload_file(
self,
bucket: str,
file_path: str,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Upload a file to the storage system.
Args:
bucket (str): The bucket name
file_path (str): The file path
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
with open(file_path, "rb") as file:
return self.save_file(
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
)
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage system.
Args:
bucket (str): The bucket name
file_name (str): The file name
file_data (BinaryIO): The file data
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
return self.storage_system.save_file(
bucket, file_name, file_data, storage_type, custom_metadata
)
def download_file(self, uri: str, destination_path: str) -> None:
"""Download a file from the storage system.
Args:
uri (str): The file URI
destination_path (str): The destination
Raises:
FileNotFoundError: If the file is not found
"""
file_data, _ = self.storage_system.get_file(uri)
with open(destination_path, "wb") as f:
f.write(file_data.read())
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system.
Args:
uri (str): The file URI
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
return self.storage_system.get_file(uri)
def get_file_by_id(
self, bucket: str, file_id: str
) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.get_file(metadata.uri)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage system.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
return self.storage_system.delete_file(uri)
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
"""Delete the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
bool: True if the file was deleted, False otherwise
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.delete_file(metadata.uri)
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket.
Args:
bucket (str): The bucket name
filters (Dict[str, Any], optional): Filters. Defaults to None.
Returns:
List[FileMetadata]: The list of file metadata
"""
return self.storage_system.list_files(bucket, filters)
class SimpleDistributedStorage(StorageBackend):
"""Simple distributed storage backend."""
storage_type: str = "distributed"
def __init__(
self,
node_address: str,
local_storage_path: str,
save_chunk_size: int = 1024 * 1024,
transfer_chunk_size: int = 1024 * 1024,
transfer_timeout: int = 360,
api_prefix: str = "/api/v2/serve/file/files",
):
"""Initialize the simple distributed storage backend."""
self.node_address = node_address
self.local_storage_path = local_storage_path
os.makedirs(self.local_storage_path, exist_ok=True)
self._save_chunk_size = save_chunk_size
self._transfer_chunk_size = transfer_chunk_size
self._transfer_timeout = transfer_timeout
self._api_prefix = api_prefix
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
node_id = hashlib.md5(node_address.encode()).hexdigest()
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
def _parse_node_address(self, fm: FileMetadata) -> str:
storage_path = fm.storage_path
if not storage_path.startswith("distributed://"):
raise ValueError("Invalid storage path")
return storage_path.split("//")[1].split("/")[0]
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the distributed storage backend.
Just save the file locally.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
file_path = self._get_file_path(bucket, file_id, self.node_address)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return f"distributed://{self.node_address}/{bucket}/{file_id}"
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the distributed storage backend.
If the file is stored on the local node, load it from the local storage.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
# TODO: check if the file is cached in local storage
if node_address == self.node_address:
if os.path.exists(file_path):
return open(file_path, "rb") # noqa: SIM115
else:
raise FileNotFoundError(f"File {file_id} not found on the local node")
else:
response = requests.get(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
stream=True,
)
response.raise_for_status()
# TODO: cache the file in local storage
return StreamedBytesIO(
response.iter_content(chunk_size=self._transfer_chunk_size)
)
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the distributed storage backend.
If the file is stored on the local node, delete it from the local storage.
If the file is stored on a remote node, send a delete request to the remote
node.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
if node_address == self.node_address:
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
else:
try:
response = requests.delete(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
)
response.raise_for_status()
return True
except Exception:
return False
class StreamedBytesIO(io.BytesIO):
"""A BytesIO subclass that can be used with streaming responses.
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
"""
def __init__(self, request_iterator):
"""Initialize the StreamedBytesIO instance."""
super().__init__()
self._bytes = BytesIO()
self._iterator = request_iterator
def _load_all(self):
self._bytes.seek(0, io.SEEK_END)
for chunk in self._iterator:
self._bytes.write(chunk)
def _load_until(self, goal_position):
current_position = self._bytes.seek(0, io.SEEK_END)
while current_position < goal_position:
try:
current_position += self._bytes.write(next(self._iterator))
except StopIteration:
break
def tell(self) -> int:
"""Get the current position."""
return self._bytes.tell()
def read(self, size: Optional[int] = None) -> bytes:
"""Read the data from the stream.
Args:
size (Optional[int], optional): The number of bytes to read. Defaults to
None.
Returns:
bytes: The read data
"""
left_off_at = self._bytes.tell()
if size is None:
self._load_all()
else:
goal_position = left_off_at + size
self._load_until(goal_position)
self._bytes.seek(left_off_at)
return self._bytes.read(size)
def seek(self, position: int, whence: int = io.SEEK_SET):
"""Seek to a position in the stream.
Args:
position (int): The position
whence (int, optional): The reference point. Defaults to io.SEEK
Raises:
ValueError: If the reference point is invalid
"""
if whence == io.SEEK_END:
self._load_all()
else:
self._bytes.seek(position, whence)
def __enter__(self):
"""Enter the context manager."""
return self
def __exit__(self, ext_type, value, tb):
"""Exit the context manager."""
self._bytes.close()