mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Add file server for DB-GPT
This commit is contained in:
@@ -316,6 +316,13 @@ class Config(metaclass=Singleton):
|
||||
# experimental financial report model configuration
|
||||
self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None)
|
||||
|
||||
# file server configuration
|
||||
# The host of the current file server, if None, get the host automatically
|
||||
self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST")
|
||||
self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv(
|
||||
"FILE_SERVER_LOCAL_STORAGE_PATH"
|
||||
)
|
||||
|
||||
@property
|
||||
def local_db_manager(self) -> "ConnectorManager":
|
||||
from dbgpt.datasource.manages import ConnectorManager
|
||||
|
@@ -59,7 +59,7 @@ def initialize_components(
|
||||
_initialize_agent(system_app)
|
||||
_initialize_openapi(system_app)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app, CFG)
|
||||
register_serve_apps(system_app, CFG, param.port)
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp):
|
||||
|
@@ -8,6 +8,7 @@ from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity
|
||||
from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity
|
||||
from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity
|
||||
from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
@@ -19,6 +20,7 @@ from dbgpt.storage.chat_history.chat_history_db import (
|
||||
|
||||
_MODELS = [
|
||||
PluginHubEntity,
|
||||
FileServeEntity,
|
||||
MyPluginEntity,
|
||||
PromptManageEntity,
|
||||
KnowledgeSpaceEntity,
|
||||
|
@@ -2,7 +2,7 @@ from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
|
||||
def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int):
|
||||
"""Register serve apps"""
|
||||
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
|
||||
if cfg.API_KEYS:
|
||||
@@ -47,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
# Register serve app
|
||||
system_app.register(FlowServe)
|
||||
|
||||
# ################################ AWEL Flow Serve Register End ########################################
|
||||
|
||||
# ################################ Rag Serve Register Begin ######################################
|
||||
|
||||
from dbgpt.serve.rag.serve import (
|
||||
@@ -57,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
# Register serve app
|
||||
system_app.register(RagServe)
|
||||
|
||||
# ################################ Rag Serve Register End ########################################
|
||||
|
||||
# ################################ Datasource Serve Register Begin ######################################
|
||||
|
||||
from dbgpt.serve.datasource.serve import (
|
||||
@@ -66,4 +70,32 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
|
||||
# Register serve app
|
||||
system_app.register(DatasourceServe)
|
||||
# ################################ AWEL Flow Serve Register End ########################################
|
||||
|
||||
# ################################ Datasource Serve Register End ########################################
|
||||
|
||||
# ################################ File Serve Register Begin ######################################
|
||||
|
||||
from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH
|
||||
from dbgpt.serve.file.serve import (
|
||||
SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX,
|
||||
)
|
||||
from dbgpt.serve.file.serve import Serve as FileServe
|
||||
|
||||
local_storage_path = (
|
||||
cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH
|
||||
)
|
||||
# Set config
|
||||
system_app.config.set(
|
||||
f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path
|
||||
)
|
||||
system_app.config.set(
|
||||
f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port
|
||||
)
|
||||
if cfg.FILE_SERVER_HOST:
|
||||
system_app.config.set(
|
||||
f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST
|
||||
)
|
||||
# Register serve app
|
||||
system_app.register(FileServe)
|
||||
|
||||
# ################################ File Serve Register End ########################################
|
||||
|
@@ -90,6 +90,7 @@ class ComponentType(str, Enum):
|
||||
AGENT_MANAGER = "dbgpt_agent_manager"
|
||||
RESOURCE_MANAGER = "dbgpt_resource_manager"
|
||||
VARIABLES_PROVIDER = "dbgpt_variables_provider"
|
||||
FILE_STORAGE_CLIENT = "dbgpt_file_storage_client"
|
||||
|
||||
|
||||
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
|
||||
|
@@ -14,6 +14,7 @@ DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
|
||||
FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server")
|
||||
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
|
||||
# Global language setting
|
||||
LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales")
|
||||
|
791
dbgpt/core/interface/file.py
Normal file
791
dbgpt/core/interface/file.py
Normal file
@@ -0,0 +1,791 @@
|
||||
"""File storage interface."""
|
||||
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
from .storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
_SCHEMA = "dbgpt-fs"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileMetadataIdentifier(ResourceIdentifier):
|
||||
"""File metadata identifier."""
|
||||
|
||||
file_id: str
|
||||
bucket: str
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dictionary."""
|
||||
return {"file_id": self.file_id, "bucket": self.bucket}
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Get the string identifier.
|
||||
|
||||
Returns:
|
||||
str: The string identifier
|
||||
"""
|
||||
return f"{self.bucket}/{self.file_id}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileMetadata(StorageItem):
|
||||
"""File metadata for storage."""
|
||||
|
||||
file_id: str
|
||||
bucket: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
storage_type: str
|
||||
storage_path: str
|
||||
uri: str
|
||||
custom_metadata: Dict[str, Any]
|
||||
file_hash: str
|
||||
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = FileMetadataIdentifier(
|
||||
file_id=self.file_id, bucket=self.bucket
|
||||
)
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Get the resource identifier."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the metadata with another item."""
|
||||
if not isinstance(other, FileMetadata):
|
||||
raise StorageError("Cannot merge different types of items")
|
||||
self._from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the metadata to a dictionary."""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"bucket": self.bucket,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"storage_type": self.storage_type,
|
||||
"storage_path": self.storage_path,
|
||||
"uri": self.uri,
|
||||
"custom_metadata": self.custom_metadata,
|
||||
"file_hash": self.file_hash,
|
||||
}
|
||||
|
||||
def _from_object(self, obj: "FileMetadata") -> None:
|
||||
self.file_id = obj.file_id
|
||||
self.bucket = obj.bucket
|
||||
self.file_name = obj.file_name
|
||||
self.file_size = obj.file_size
|
||||
self.storage_type = obj.storage_type
|
||||
self.storage_path = obj.storage_path
|
||||
self.uri = obj.uri
|
||||
self.custom_metadata = obj.custom_metadata
|
||||
self.file_hash = obj.file_hash
|
||||
self._identifier = obj._identifier
|
||||
|
||||
|
||||
class FileStorageURI:
|
||||
"""File storage URI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_type: str,
|
||||
bucket: str,
|
||||
file_id: str,
|
||||
version: Optional[str] = None,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize the file storage URI."""
|
||||
self.scheme = _SCHEMA
|
||||
self.storage_type = storage_type
|
||||
self.bucket = bucket
|
||||
self.file_id = file_id
|
||||
self.version = version
|
||||
self.custom_params = custom_params or {}
|
||||
|
||||
@classmethod
|
||||
def parse(cls, uri: str) -> "FileStorageURI":
|
||||
"""Parse the URI string."""
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != _SCHEMA:
|
||||
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
|
||||
path_parts = parsed.path.strip("/").split("/")
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError("Invalid URI path. Must contain bucket and file ID")
|
||||
storage_type = parsed.netloc
|
||||
bucket = path_parts[0]
|
||||
file_id = path_parts[1]
|
||||
version = path_parts[2] if len(path_parts) > 2 else None
|
||||
custom_params = parse_qs(parsed.query)
|
||||
return cls(storage_type, bucket, file_id, version, custom_params)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get the string representation of the URI."""
|
||||
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
|
||||
if self.version:
|
||||
base_uri += f"/{self.version}"
|
||||
if self.custom_params:
|
||||
query_string = urlencode(self.custom_params, doseq=True)
|
||||
base_uri += f"?{query_string}"
|
||||
return base_uri
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Storage backend interface."""
|
||||
|
||||
storage_type: str = "__base__"
|
||||
|
||||
@abstractmethod
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the storage backend.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
|
||||
Returns:
|
||||
str: The storage path
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
BinaryIO: The file data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size.
|
||||
|
||||
Returns:
|
||||
int: The save chunk size
|
||||
"""
|
||||
|
||||
|
||||
class LocalFileStorage(StorageBackend):
|
||||
"""Local file storage backend."""
|
||||
|
||||
storage_type: str = "local"
|
||||
|
||||
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
|
||||
"""Initialize the local file storage backend."""
|
||||
self.base_path = base_path
|
||||
self._save_chunk_size = save_chunk_size
|
||||
os.makedirs(self.base_path, exist_ok=True)
|
||||
|
||||
@property
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size."""
|
||||
return self._save_chunk_size
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, bucket)
|
||||
os.makedirs(bucket_path, exist_ok=True)
|
||||
file_path = os.path.join(bucket_path, file_id)
|
||||
with open(file_path, "wb") as f:
|
||||
while True:
|
||||
chunk = file_data.read(self.save_chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
return file_path
|
||||
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, fm.bucket)
|
||||
file_path = os.path.join(bucket_path, fm.file_id)
|
||||
return open(file_path, "rb") # noqa: SIM115
|
||||
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, fm.bucket)
|
||||
file_path = os.path.join(bucket_path, fm.file_id)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FileStorageSystem:
|
||||
"""File storage system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_backends: Dict[str, StorageBackend],
|
||||
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
|
||||
check_hash: bool = True,
|
||||
):
|
||||
"""Initialize the file storage system."""
|
||||
metadata_storage = metadata_storage or InMemoryStorage()
|
||||
self.storage_backends = storage_backends
|
||||
self.metadata_storage = metadata_storage
|
||||
self.check_hash = check_hash
|
||||
self._save_chunk_size = min(
|
||||
backend.save_chunk_size for backend in storage_backends.values()
|
||||
)
|
||||
|
||||
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
|
||||
"""Calculate the MD5 hash of the file data."""
|
||||
if not self.check_hash:
|
||||
return "-1"
|
||||
hasher = hashlib.md5()
|
||||
file_data.seek(0)
|
||||
while chunk := file_data.read(self._save_chunk_size):
|
||||
hasher.update(chunk)
|
||||
file_data.seek(0)
|
||||
return hasher.hexdigest()
|
||||
|
||||
@trace("file_storage_system.save_file")
|
||||
def save_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_name: str,
|
||||
file_data: BinaryIO,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save the file data to the storage backend."""
|
||||
file_id = str(uuid.uuid4())
|
||||
backend = self.storage_backends.get(storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {storage_type}")
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.save_file.backend_save",
|
||||
metadata={
|
||||
"bucket": bucket,
|
||||
"file_id": file_id,
|
||||
"file_name": file_name,
|
||||
"storage_type": storage_type,
|
||||
},
|
||||
):
|
||||
storage_path = backend.save(bucket, file_id, file_data)
|
||||
file_data.seek(0, 2) # Move to the end of the file
|
||||
file_size = file_data.tell() # Get the file size
|
||||
file_data.seek(0) # Reset file pointer
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.save_file.calculate_hash",
|
||||
):
|
||||
file_hash = self._calculate_file_hash(file_data)
|
||||
uri = FileStorageURI(
|
||||
storage_type, bucket, file_id, custom_params=custom_metadata
|
||||
)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
storage_type=storage_type,
|
||||
storage_path=storage_path,
|
||||
uri=str(uri),
|
||||
custom_metadata=custom_metadata or {},
|
||||
file_hash=file_hash,
|
||||
)
|
||||
|
||||
self.metadata_storage.save(metadata)
|
||||
return str(uri)
|
||||
|
||||
@trace("file_storage_system.get_file")
|
||||
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage backend."""
|
||||
parsed_uri = FileStorageURI.parse(uri)
|
||||
metadata = self.metadata_storage.load(
|
||||
FileMetadataIdentifier(
|
||||
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
||||
),
|
||||
FileMetadata,
|
||||
)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"No metadata found for URI: {uri}")
|
||||
|
||||
backend = self.storage_backends.get(metadata.storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.get_file.backend_load",
|
||||
metadata={
|
||||
"bucket": metadata.bucket,
|
||||
"file_id": metadata.file_id,
|
||||
"file_name": metadata.file_name,
|
||||
"storage_type": metadata.storage_type,
|
||||
},
|
||||
):
|
||||
file_data = backend.load(metadata)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.get_file.verify_hash",
|
||||
):
|
||||
calculated_hash = self._calculate_file_hash(file_data)
|
||||
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
|
||||
raise ValueError("File integrity check failed. Hash mismatch.")
|
||||
|
||||
return file_data, metadata
|
||||
|
||||
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
|
||||
"""Get the file metadata.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
Optional[FileMetadata]: The file metadata
|
||||
"""
|
||||
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
|
||||
return self.metadata_storage.load(fid, FileMetadata)
|
||||
|
||||
def delete_file(self, uri: str) -> bool:
|
||||
"""Delete the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
parsed_uri = FileStorageURI.parse(uri)
|
||||
fid = FileMetadataIdentifier(
|
||||
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
||||
)
|
||||
metadata = self.metadata_storage.load(fid, FileMetadata)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
backend = self.storage_backends.get(metadata.storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
||||
|
||||
if backend.delete(metadata):
|
||||
try:
|
||||
self.metadata_storage.delete(fid)
|
||||
return True
|
||||
except Exception:
|
||||
# If the metadata deletion fails, log the error and return False
|
||||
return False
|
||||
return False
|
||||
|
||||
def list_files(
|
||||
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[FileMetadata]:
|
||||
"""List the files in the bucket."""
|
||||
filters = filters or {}
|
||||
filters["bucket"] = bucket
|
||||
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
|
||||
|
||||
|
||||
class FileStorageClient(BaseComponent):
|
||||
"""File storage client component."""
|
||||
|
||||
name = ComponentType.FILE_STORAGE_CLIENT.value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
storage_system: Optional[FileStorageSystem] = None,
|
||||
):
|
||||
"""Initialize the file storage client."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not storage_system:
|
||||
from pathlib import Path
|
||||
|
||||
base_path = Path.home() / ".cache" / "dbgpt" / "files"
|
||||
storage_system = FileStorageSystem(
|
||||
{
|
||||
LocalFileStorage.storage_type: LocalFileStorage(
|
||||
base_path=str(base_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.system_app = system_app
|
||||
self._storage_system = storage_system
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the application."""
|
||||
self.system_app = system_app
|
||||
|
||||
@property
|
||||
def storage_system(self) -> FileStorageSystem:
|
||||
"""Get the file storage system."""
|
||||
if not self._storage_system:
|
||||
raise ValueError("File storage system not initialized")
|
||||
return self._storage_system
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_path: str,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Upload a file to the storage system.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_path (str): The file path
|
||||
storage_type (str): The storage type
|
||||
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
str: The file URI
|
||||
"""
|
||||
with open(file_path, "rb") as file:
|
||||
return self.save_file(
|
||||
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
|
||||
)
|
||||
|
||||
def save_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_name: str,
|
||||
file_data: BinaryIO,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save the file data to the storage system.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_name (str): The file name
|
||||
file_data (BinaryIO): The file data
|
||||
storage_type (str): The storage type
|
||||
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
str: The file URI
|
||||
"""
|
||||
return self.storage_system.save_file(
|
||||
bucket, file_name, file_data, storage_type, custom_metadata
|
||||
)
|
||||
|
||||
def download_file(self, uri: str, destination_path: str) -> None:
|
||||
"""Download a file from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
destination_path (str): The destination
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file is not found
|
||||
"""
|
||||
file_data, _ = self.storage_system.get_file(uri)
|
||||
with open(destination_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
|
||||
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
||||
"""
|
||||
return self.storage_system.get_file(uri)
|
||||
|
||||
def get_file_by_id(
|
||||
self, bucket: str, file_id: str
|
||||
) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage system by ID.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
||||
"""
|
||||
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
||||
return self.get_file(metadata.uri)
|
||||
|
||||
def delete_file(self, uri: str) -> bool:
|
||||
"""Delete the file data from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
return self.storage_system.delete_file(uri)
|
||||
|
||||
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
|
||||
"""Delete the file data from the storage system by ID.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
||||
return self.delete_file(metadata.uri)
|
||||
|
||||
def list_files(
|
||||
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[FileMetadata]:
|
||||
"""List the files in the bucket.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
filters (Dict[str, Any], optional): Filters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[FileMetadata]: The list of file metadata
|
||||
"""
|
||||
return self.storage_system.list_files(bucket, filters)
|
||||
|
||||
|
||||
class SimpleDistributedStorage(StorageBackend):
|
||||
"""Simple distributed storage backend."""
|
||||
|
||||
storage_type: str = "distributed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_address: str,
|
||||
local_storage_path: str,
|
||||
save_chunk_size: int = 1024 * 1024,
|
||||
transfer_chunk_size: int = 1024 * 1024,
|
||||
transfer_timeout: int = 360,
|
||||
api_prefix: str = "/api/v2/serve/file/files",
|
||||
):
|
||||
"""Initialize the simple distributed storage backend."""
|
||||
self.node_address = node_address
|
||||
self.local_storage_path = local_storage_path
|
||||
os.makedirs(self.local_storage_path, exist_ok=True)
|
||||
self._save_chunk_size = save_chunk_size
|
||||
self._transfer_chunk_size = transfer_chunk_size
|
||||
self._transfer_timeout = transfer_timeout
|
||||
self._api_prefix = api_prefix
|
||||
|
||||
@property
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size."""
|
||||
return self._save_chunk_size
|
||||
|
||||
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
|
||||
node_id = hashlib.md5(node_address.encode()).hexdigest()
|
||||
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
|
||||
|
||||
def _parse_node_address(self, fm: FileMetadata) -> str:
|
||||
storage_path = fm.storage_path
|
||||
if not storage_path.startswith("distributed://"):
|
||||
raise ValueError("Invalid storage path")
|
||||
return storage_path.split("//")[1].split("/")[0]
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the distributed storage backend.
|
||||
|
||||
Just save the file locally.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
|
||||
Returns:
|
||||
str: The storage path
|
||||
"""
|
||||
file_path = self._get_file_path(bucket, file_id, self.node_address)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "wb") as f:
|
||||
while True:
|
||||
chunk = file_data.read(self.save_chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
return f"distributed://{self.node_address}/{bucket}/{file_id}"
|
||||
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the distributed storage backend.
|
||||
|
||||
If the file is stored on the local node, load it from the local storage.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
BinaryIO: The file data
|
||||
"""
|
||||
file_id = fm.file_id
|
||||
bucket = fm.bucket
|
||||
node_address = self._parse_node_address(fm)
|
||||
file_path = self._get_file_path(bucket, file_id, node_address)
|
||||
|
||||
# TODO: check if the file is cached in local storage
|
||||
if node_address == self.node_address:
|
||||
if os.path.exists(file_path):
|
||||
return open(file_path, "rb") # noqa: SIM115
|
||||
else:
|
||||
raise FileNotFoundError(f"File {file_id} not found on the local node")
|
||||
else:
|
||||
response = requests.get(
|
||||
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
||||
timeout=self._transfer_timeout,
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# TODO: cache the file in local storage
|
||||
return StreamedBytesIO(
|
||||
response.iter_content(chunk_size=self._transfer_chunk_size)
|
||||
)
|
||||
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the distributed storage backend.
|
||||
|
||||
If the file is stored on the local node, delete it from the local storage.
|
||||
If the file is stored on a remote node, send a delete request to the remote
|
||||
node.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
file_id = fm.file_id
|
||||
bucket = fm.bucket
|
||||
node_address = self._parse_node_address(fm)
|
||||
file_path = self._get_file_path(bucket, file_id, node_address)
|
||||
if node_address == self.node_address:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
response = requests.delete(
|
||||
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
||||
timeout=self._transfer_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class StreamedBytesIO(io.BytesIO):
|
||||
"""A BytesIO subclass that can be used with streaming responses.
|
||||
|
||||
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
|
||||
"""
|
||||
|
||||
def __init__(self, request_iterator):
|
||||
"""Initialize the StreamedBytesIO instance."""
|
||||
super().__init__()
|
||||
self._bytes = BytesIO()
|
||||
self._iterator = request_iterator
|
||||
|
||||
def _load_all(self):
|
||||
self._bytes.seek(0, io.SEEK_END)
|
||||
for chunk in self._iterator:
|
||||
self._bytes.write(chunk)
|
||||
|
||||
def _load_until(self, goal_position):
|
||||
current_position = self._bytes.seek(0, io.SEEK_END)
|
||||
while current_position < goal_position:
|
||||
try:
|
||||
current_position += self._bytes.write(next(self._iterator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Get the current position."""
|
||||
return self._bytes.tell()
|
||||
|
||||
def read(self, size: Optional[int] = None) -> bytes:
|
||||
"""Read the data from the stream.
|
||||
|
||||
Args:
|
||||
size (Optional[int], optional): The number of bytes to read. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
bytes: The read data
|
||||
"""
|
||||
left_off_at = self._bytes.tell()
|
||||
if size is None:
|
||||
self._load_all()
|
||||
else:
|
||||
goal_position = left_off_at + size
|
||||
self._load_until(goal_position)
|
||||
|
||||
self._bytes.seek(left_off_at)
|
||||
return self._bytes.read(size)
|
||||
|
||||
def seek(self, position: int, whence: int = io.SEEK_SET):
|
||||
"""Seek to a position in the stream.
|
||||
|
||||
Args:
|
||||
position (int): The position
|
||||
whence (int, optional): The reference point. Defaults to io.SEEK
|
||||
|
||||
Raises:
|
||||
ValueError: If the reference point is invalid
|
||||
"""
|
||||
if whence == io.SEEK_END:
|
||||
self._load_all()
|
||||
else:
|
||||
self._bytes.seek(position, whence)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, ext_type, value, tb):
|
||||
"""Exit the context manager."""
|
||||
self._bytes.close()
|
506
dbgpt/core/interface/tests/test_file.py
Normal file
506
dbgpt/core/interface/tests/test_file.py
Normal file
@@ -0,0 +1,506 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from ..file import (
|
||||
FileMetadata,
|
||||
FileMetadataIdentifier,
|
||||
FileStorageClient,
|
||||
FileStorageSystem,
|
||||
InMemoryStorage,
|
||||
LocalFileStorage,
|
||||
SimpleDistributedStorage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_test_file_dir(tmpdir):
|
||||
return str(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_path(tmpdir):
|
||||
return str(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_storage_backend(temp_storage_path):
|
||||
return LocalFileStorage(temp_storage_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_storage_backend(temp_storage_path):
|
||||
node_address = "127.0.0.1:8000"
|
||||
return SimpleDistributedStorage(node_address, temp_storage_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_system(local_storage_backend):
|
||||
backends = {"local": local_storage_backend}
|
||||
metadata_storage = InMemoryStorage()
|
||||
return FileStorageSystem(backends, metadata_storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_client(file_storage_system):
|
||||
return FileStorageClient(storage_system=file_storage_system)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_path(temp_test_file_dir):
|
||||
file_path = os.path.join(temp_test_file_dir, "sample.txt")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(b"Sample file content")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_data():
|
||||
return io.BytesIO(b"Sample file content for distributed storage")
|
||||
|
||||
|
||||
def test_save_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
assert uri.startswith("dbgpt-fs://local/test-bucket/")
|
||||
assert os.path.exists(sample_file_path)
|
||||
|
||||
|
||||
def test_get_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
assert file_data.read() == b"Sample file content"
|
||||
assert metadata.file_name == "sample.txt"
|
||||
assert metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_delete_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 1
|
||||
result = file_storage_client.delete_file(uri)
|
||||
assert result is True
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 0
|
||||
|
||||
|
||||
def test_list_files(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
files = file_storage_client.list_files(bucket=bucket)
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
def test_save_file_unsupported_storage(file_storage_system, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
with pytest.raises(ValueError):
|
||||
file_storage_system.save_file(
|
||||
bucket=bucket,
|
||||
file_name="unsupported.txt",
|
||||
file_data=io.BytesIO(b"Unsupported storage"),
|
||||
storage_type="unsupported",
|
||||
)
|
||||
|
||||
|
||||
def test_get_file_not_found(file_storage_system):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent")
|
||||
|
||||
|
||||
def test_delete_file_not_found(file_storage_system):
|
||||
result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_metadata_management(file_storage_system):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
metadata = file_storage_system.metadata_storage.save(
|
||||
FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=100,
|
||||
storage_type="local",
|
||||
storage_path="/path/to/test.txt",
|
||||
uri="dbgpt-fs://local/test-bucket/test_file",
|
||||
custom_metadata={"key": "value"},
|
||||
file_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
loaded_metadata = file_storage_system.metadata_storage.load(
|
||||
FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata
|
||||
)
|
||||
assert loaded_metadata.file_name == "test.txt"
|
||||
assert loaded_metadata.custom_metadata["key"] == "value"
|
||||
assert loaded_metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_concurrent_save_and_delete(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
|
||||
# Simulate concurrent file save and delete operations
|
||||
def save_file():
|
||||
return file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
def delete_file(uri):
|
||||
return file_storage_client.delete_file(uri)
|
||||
|
||||
uri = save_file()
|
||||
|
||||
# Simulate concurrent operations
|
||||
save_file()
|
||||
delete_file(uri)
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 1
|
||||
|
||||
|
||||
def test_large_file_handling(file_storage_client, temp_storage_path):
|
||||
bucket = "test-bucket"
|
||||
large_file_path = os.path.join(temp_storage_path, "large_sample.bin")
|
||||
with open(large_file_path, "wb") as f:
|
||||
f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file
|
||||
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket,
|
||||
file_path=large_file_path,
|
||||
storage_type="local",
|
||||
custom_metadata={"description": "Large file test"},
|
||||
)
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
assert file_data.read() == open(large_file_path, "rb").read()
|
||||
assert metadata.file_name == "large_sample.bin"
|
||||
assert metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_file_hash_verification_success(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
# Upload file and
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
file_hash = metadata.file_hash
|
||||
calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data)
|
||||
|
||||
assert (
|
||||
file_hash == calculated_hash
|
||||
), "File hash should match after saving and loading"
|
||||
|
||||
|
||||
def test_file_hash_verification_failure(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
# Upload file and
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Modify the file content manually to simulate file tampering
|
||||
storage_system = file_storage_client.storage_system
|
||||
metadata = storage_system.metadata_storage.load(
|
||||
FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata
|
||||
)
|
||||
with open(metadata.storage_path, "wb") as f:
|
||||
f.write(b"Tampered content")
|
||||
|
||||
# Get file should raise an exception due to hash mismatch
|
||||
with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."):
|
||||
storage_system.get_file(uri)
|
||||
|
||||
|
||||
def test_file_isolation_across_buckets(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the same file to two different buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Verify both URIs are different and point to different files
|
||||
assert uri1 != uri2
|
||||
|
||||
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
|
||||
assert file_data1.read() == b"Sample file content"
|
||||
assert file_data2.read() == b"Sample file content"
|
||||
assert metadata1.bucket == bucket1
|
||||
assert metadata2.bucket == bucket2
|
||||
|
||||
|
||||
def test_list_files_in_specific_bucket(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload a file to both buckets
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# List files in bucket1 and bucket2
|
||||
files_in_bucket1 = file_storage_client.list_files(bucket=bucket1)
|
||||
files_in_bucket2 = file_storage_client.list_files(bucket=bucket2)
|
||||
|
||||
assert len(files_in_bucket1) == 1
|
||||
assert len(files_in_bucket2) == 1
|
||||
assert files_in_bucket1[0].bucket == bucket1
|
||||
assert files_in_bucket2[0].bucket == bucket2
|
||||
|
||||
|
||||
def test_delete_file_in_one_bucket_does_not_affect_other_bucket(
|
||||
file_storage_client, sample_file_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the same file to two different buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Delete the file in bucket1
|
||||
file_storage_client.delete_file(uri1)
|
||||
|
||||
# Check that the file in bucket1 is deleted
|
||||
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
|
||||
|
||||
# Check that the file in bucket2 is still there
|
||||
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
assert file_data2.read() == b"Sample file content"
|
||||
|
||||
|
||||
def test_file_hash_verification_in_different_buckets(
|
||||
file_storage_client, sample_file_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the file to both buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
|
||||
# Verify that file hashes are the same for the same content
|
||||
file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1)
|
||||
file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2)
|
||||
|
||||
assert file_hash1 == metadata1.file_hash
|
||||
assert file_hash2 == metadata2.file_hash
|
||||
assert file_hash1 == file_hash2
|
||||
|
||||
|
||||
def test_file_download_from_different_buckets(
|
||||
file_storage_client, sample_file_path, temp_storage_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the file to both buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Download files to different locations
|
||||
download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt")
|
||||
download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt")
|
||||
|
||||
file_storage_client.download_file(uri1, download_path1)
|
||||
file_storage_client.download_file(uri2, download_path2)
|
||||
|
||||
# Verify contents of downloaded files
|
||||
assert open(download_path1, "rb").read() == b"Sample file content"
|
||||
assert open(download_path2, "rb").read() == b"Sample file content"
|
||||
|
||||
|
||||
def test_delete_all_files_in_bucket(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload files to both buckets
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Delete all files in bucket1
|
||||
for file in file_storage_client.list_files(bucket=bucket1):
|
||||
file_storage_client.delete_file(file.uri)
|
||||
|
||||
# Verify bucket1 is empty
|
||||
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
|
||||
|
||||
# Verify bucket2 still has files
|
||||
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
|
||||
|
||||
|
||||
def test_simple_distributed_storage_save_file(
|
||||
distributed_storage_backend, sample_file_data, temp_storage_path
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
expected_path = os.path.join(
|
||||
temp_storage_path,
|
||||
bucket,
|
||||
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
|
||||
)
|
||||
assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}"
|
||||
assert os.path.exists(expected_path)
|
||||
|
||||
|
||||
def test_simple_distributed_storage_load_file_local(
|
||||
distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
file_data = distributed_storage_backend.load(metadata)
|
||||
assert file_data.read() == b"Sample file content for distributed storage"
|
||||
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_simple_distributed_storage_load_file_remote(
|
||||
mock_get, distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
remote_node_address = "127.0.0.2:8000"
|
||||
|
||||
# Mock the response from remote node
|
||||
mock_response = mock.Mock()
|
||||
mock_response.iter_content = mock.Mock(
|
||||
return_value=iter([b"Sample file content for distributed storage"])
|
||||
)
|
||||
mock_response.raise_for_status = mock.Mock(return_value=None)
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
file_data = distributed_storage_backend.load(metadata)
|
||||
assert file_data.read() == b"Sample file content for distributed storage"
|
||||
mock_get.assert_called_once_with(
|
||||
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
|
||||
stream=True,
|
||||
timeout=360,
|
||||
)
|
||||
|
||||
|
||||
def test_simple_distributed_storage_delete_file_local(
|
||||
distributed_storage_backend, sample_file_data, temp_storage_path
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
result = distributed_storage_backend.delete(metadata)
|
||||
file_path = os.path.join(
|
||||
temp_storage_path,
|
||||
bucket,
|
||||
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
|
||||
)
|
||||
assert result is True
|
||||
assert not os.path.exists(file_path)
|
||||
|
||||
|
||||
@mock.patch("requests.delete")
|
||||
def test_simple_distributed_storage_delete_file_remote(
|
||||
mock_delete, distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
remote_node_address = "127.0.0.2:8000"
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.raise_for_status = mock.Mock(return_value=None)
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
result = distributed_storage_backend.delete(metadata)
|
||||
assert result is True
|
||||
mock_delete.assert_called_once_with(
|
||||
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
|
||||
timeout=360,
|
||||
)
|
2
dbgpt/serve/file/__init__.py
Normal file
2
dbgpt/serve/file/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve file`
|
2
dbgpt/serve/file/api/__init__.py
Normal file
2
dbgpt/serve/file/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve file`
|
159
dbgpt/serve/file/api/endpoints.py
Normal file
159
dbgpt/serve/file/api/endpoints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse, UploadFileResponse
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add your API endpoints here
|
||||
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key}
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/{bucket}",
|
||||
response_model=Result[List[UploadFileResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def upload_files(
|
||||
bucket: str, files: List[UploadFile], service: Service = Depends(get_service)
|
||||
) -> Result[List[UploadFileResponse]]:
|
||||
"""Upload files by a list of UploadFile."""
|
||||
logger.info(f"upload_files: bucket={bucket}, files={files}")
|
||||
results = await blocking_func_to_async(
|
||||
global_system_app, service.upload_files, bucket, "distributed", files
|
||||
)
|
||||
return Result.succ(results)
|
||||
|
||||
|
||||
@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)])
|
||||
async def download_file(
|
||||
bucket: str, file_id: str, service: Service = Depends(get_service)
|
||||
):
|
||||
"""Download a file by file_id."""
|
||||
logger.info(f"download_file: bucket={bucket}, file_id={file_id}")
|
||||
file_data, file_metadata = await blocking_func_to_async(
|
||||
global_system_app, service.download_file, bucket, file_id
|
||||
)
|
||||
file_name_encoded = quote(file_metadata.file_name)
|
||||
|
||||
def file_iterator(raw_iter):
|
||||
with raw_iter:
|
||||
while chunk := raw_iter.read(
|
||||
service.config.file_server_download_chunk_size
|
||||
):
|
||||
yield chunk
|
||||
|
||||
response = StreamingResponse(
|
||||
file_iterator(file_data), media_type="application/octet-stream"
|
||||
)
|
||||
response.headers[
|
||||
"Content-Disposition"
|
||||
] = f"attachment; filename={file_name_encoded}"
|
||||
return response
|
||||
|
||||
|
||||
@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)])
|
||||
async def delete_file(
|
||||
bucket: str, file_id: str, service: Service = Depends(get_service)
|
||||
):
|
||||
"""Delete a file by file_id."""
|
||||
await blocking_func_to_async(
|
||||
global_system_app, service.delete_file, bucket, file_id
|
||||
)
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
global_system_app = system_app
|
43
dbgpt/serve/file/api/schemas.py
Normal file
43
dbgpt/serve/file/api/schemas.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
"""File request model"""
|
||||
|
||||
# TODO define your own fields here
|
||||
|
||||
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class ServerResponse(BaseModel):
|
||||
"""File response model"""
|
||||
|
||||
# TODO define your own fields here
|
||||
|
||||
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
"""Upload file response model"""
|
||||
|
||||
file_name: str = Field(..., title="The name of the uploaded file")
|
||||
file_id: str = Field(..., title="The ID of the uploaded file")
|
||||
bucket: str = Field(..., title="The bucket of the uploaded file")
|
||||
uri: str = Field(..., title="The URI of the uploaded file")
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
return model_to_dict(self, **kwargs)
|
68
dbgpt/serve/file/config.py
Normal file
68
dbgpt/serve/file/config.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
APP_NAME = "file"
|
||||
SERVE_APP_NAME = "dbgpt_serve_file"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_File"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file."
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
# Database table name
|
||||
SERVER_APP_TABLE_NAME = "dbgpt_serve_file"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
check_hash: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Check the hash of the file when downloading"}
|
||||
)
|
||||
file_server_host: Optional[str] = field(
|
||||
default=None, metadata={"help": "The host of the file server"}
|
||||
)
|
||||
file_server_port: Optional[int] = field(
|
||||
default=5670, metadata={"help": "The port of the file server"}
|
||||
)
|
||||
file_server_download_chunk_size: Optional[int] = field(
|
||||
default=1024 * 1024,
|
||||
metadata={"help": "The chunk size when downloading the file"},
|
||||
)
|
||||
file_server_save_chunk_size: Optional[int] = field(
|
||||
default=1024 * 1024, metadata={"help": "The chunk size when saving the file"}
|
||||
)
|
||||
file_server_transfer_chunk_size: Optional[int] = field(
|
||||
default=1024 * 1024,
|
||||
metadata={"help": "The chunk size when transferring the file"},
|
||||
)
|
||||
file_server_transfer_timeout: Optional[int] = field(
|
||||
default=360, metadata={"help": "The timeout when transferring the file"}
|
||||
)
|
||||
local_storage_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "The local storage path"}
|
||||
)
|
||||
|
||||
def get_node_address(self) -> str:
|
||||
"""Get the node address"""
|
||||
file_server_host = self.file_server_host
|
||||
if not file_server_host:
|
||||
from dbgpt.util.net_utils import _get_ip_address
|
||||
|
||||
file_server_host = _get_ip_address()
|
||||
file_server_port = self.file_server_port or 5670
|
||||
return f"{file_server_host}:{file_server_port}"
|
||||
|
||||
def get_local_storage_path(self) -> str:
|
||||
"""Get the local storage path"""
|
||||
local_storage_path = self.local_storage_path
|
||||
if not local_storage_path:
|
||||
from pathlib import Path
|
||||
|
||||
base_path = Path.home() / ".cache" / "dbgpt" / "files"
|
||||
local_storage_path = str(base_path)
|
||||
return local_storage_path
|
1
dbgpt/serve/file/dependencies.py
Normal file
1
dbgpt/serve/file/dependencies.py
Normal file
@@ -0,0 +1 @@
|
||||
# Define your dependencies here
|
2
dbgpt/serve/file/models/__init__.py
Normal file
2
dbgpt/serve/file/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve file`
|
66
dbgpt/serve/file/models/file_adapter.py
Normal file
66
dbgpt/serve/file/models/file_adapter.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
|
||||
from .models import ServeEntity
|
||||
|
||||
|
||||
class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]):
|
||||
"""File metadata adapter.
|
||||
|
||||
Convert between storage format and database model.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: FileMetadata) -> ServeEntity:
|
||||
"""Convert to storage format."""
|
||||
custom_metadata = (
|
||||
json.dumps(item.custom_metadata, ensure_ascii=False)
|
||||
if item.custom_metadata
|
||||
else None
|
||||
)
|
||||
return ServeEntity(
|
||||
bucket=item.bucket,
|
||||
file_id=item.file_id,
|
||||
file_name=item.file_name,
|
||||
file_size=item.file_size,
|
||||
storage_type=item.storage_type,
|
||||
storage_path=item.storage_path,
|
||||
uri=item.uri,
|
||||
custom_metadata=custom_metadata,
|
||||
file_hash=item.file_hash,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: ServeEntity) -> FileMetadata:
|
||||
"""Convert from storage format."""
|
||||
custom_metadata = (
|
||||
json.loads(model.custom_metadata) if model.custom_metadata else None
|
||||
)
|
||||
return FileMetadata(
|
||||
bucket=model.bucket,
|
||||
file_id=model.file_id,
|
||||
file_name=model.file_name,
|
||||
file_size=model.file_size,
|
||||
storage_type=model.storage_type,
|
||||
storage_path=model.storage_path,
|
||||
uri=model.uri,
|
||||
custom_metadata=custom_metadata,
|
||||
file_hash=model.file_hash,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[ServeEntity],
|
||||
resource_id: FileMetadataIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get query for identifier."""
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
return session.query(storage_format).filter(
|
||||
storage_format.file_id == resource_id.file_id
|
||||
)
|
87
dbgpt/serve/file/models/models.py
Normal file
87
dbgpt/serve/file/models/models.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
__tablename__ = SERVER_APP_TABLE_NAME
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
|
||||
bucket = Column(String(255), nullable=False, comment="Bucket name")
|
||||
file_id = Column(String(255), nullable=False, comment="File id")
|
||||
file_name = Column(String(256), nullable=False, comment="File name")
|
||||
file_size = Column(Integer, nullable=True, comment="File size")
|
||||
storage_type = Column(String(32), nullable=False, comment="Storage type")
|
||||
storage_path = Column(String(512), nullable=False, comment="Storage path")
|
||||
uri = Column(String(512), nullable=False, comment="File URI")
|
||||
custom_metadata = Column(
|
||||
Text, nullable=True, comment="Custom metadata, JSON format"
|
||||
)
|
||||
file_hash = Column(String(128), nullable=True, comment="File hash")
|
||||
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', "
|
||||
f"gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for File"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[ServeRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
request.to_dict() if isinstance(request, ServeRequest) else request
|
||||
)
|
||||
entity = ServeEntity(**request_dict)
|
||||
# TODO implement your own logic here, transfer the request_dict to an entity
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: ServeEntity) -> ServeRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
# TODO implement your own logic here, transfer the entity to a request
|
||||
return ServeRequest()
|
||||
|
||||
def to_response(self, entity: ServeEntity) -> ServerResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
# TODO implement your own logic here, transfer the entity to a response
|
||||
return ServerResponse()
|
113
dbgpt/serve/file/serve.py
Normal file
113
dbgpt/serve/file/serve.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.file import FileStorageClient
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseServe):
|
||||
"""Serve component for DB-GPT"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}",
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
self._file_storage_client: Optional[FileStorageClient] = None
|
||||
self._serve_config: Optional[ServeConfig] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
def on_init(self):
|
||||
"""Called when init the application.
|
||||
|
||||
You can do some initialization here. You can't get other components here because they may be not initialized yet
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from .models.models import ServeEntity
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
from dbgpt.core.interface.file import (
|
||||
FileStorageSystem,
|
||||
SimpleDistributedStorage,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.file_adapter import FileMetadataAdapter
|
||||
from .models.models import ServeEntity
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
self._db_manager,
|
||||
ServeEntity,
|
||||
FileMetadataAdapter(),
|
||||
serializer,
|
||||
)
|
||||
simple_distributed_storage = SimpleDistributedStorage(
|
||||
node_address=self._serve_config.get_node_address(),
|
||||
local_storage_path=self._serve_config.get_local_storage_path(),
|
||||
save_chunk_size=self._serve_config.file_server_save_chunk_size,
|
||||
transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size,
|
||||
transfer_timeout=self._serve_config.file_server_transfer_timeout,
|
||||
)
|
||||
storage_backends = {
|
||||
simple_distributed_storage.storage_type: simple_distributed_storage,
|
||||
}
|
||||
fs = FileStorageSystem(
|
||||
storage_backends,
|
||||
metadata_storage=storage,
|
||||
check_hash=self._serve_config.check_hash,
|
||||
)
|
||||
self._file_storage_client = FileStorageClient(
|
||||
system_app=self._system_app, storage_system=fs
|
||||
)
|
||||
|
||||
@property
|
||||
def file_storage_client(self) -> FileStorageClient:
|
||||
"""Returns the file storage client."""
|
||||
if not self._file_storage_client:
|
||||
raise ValueError("File storage client is not initialized")
|
||||
return self._file_storage_client
|
0
dbgpt/serve/file/service/__init__.py
Normal file
0
dbgpt/serve/file/service/__init__.py
Normal file
106
dbgpt/serve/file/service/service.py
Normal file
106
dbgpt/serve/file/service/service.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import logging
|
||||
from typing import BinaryIO, List, Optional, Tuple
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The service class for File"""
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
@property
|
||||
def file_storage_client(self) -> FileStorageClient:
|
||||
"""Returns the internal FileStorageClient.
|
||||
|
||||
Returns:
|
||||
FileStorageClient: The internal FileStorageClient
|
||||
"""
|
||||
file_storage_client = FileStorageClient.get_instance(
|
||||
self._system_app, default_component=None
|
||||
)
|
||||
if file_storage_client:
|
||||
return file_storage_client
|
||||
else:
|
||||
from ..serve import Serve
|
||||
|
||||
file_storage_client = Serve.get_instance(
|
||||
self._system_app
|
||||
).file_storage_client
|
||||
self._system_app.register_instance(file_storage_client)
|
||||
return file_storage_client
|
||||
|
||||
@trace("upload_files")
|
||||
def upload_files(
|
||||
self, bucket: str, storage_type: str, files: List[UploadFile]
|
||||
) -> List[UploadFileResponse]:
|
||||
"""Upload files by a list of UploadFile."""
|
||||
results = []
|
||||
for file in files:
|
||||
file_name = file.filename
|
||||
logger.info(f"Uploading file {file_name} to bucket {bucket}")
|
||||
uri = self.file_storage_client.save_file(
|
||||
bucket, file_name, file_data=file.file, storage_type=storage_type
|
||||
)
|
||||
parsed_uri = FileStorageURI.parse(uri)
|
||||
logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}")
|
||||
results.append(
|
||||
UploadFileResponse(
|
||||
file_name=file_name,
|
||||
file_id=parsed_uri.file_id,
|
||||
bucket=bucket,
|
||||
uri=uri,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@trace("download_file")
|
||||
def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Download a file by file_id."""
|
||||
return self.file_storage_client.get_file_by_id(bucket, file_id)
|
||||
|
||||
def delete_file(self, bucket: str, file_id: str) -> None:
|
||||
"""Delete a file by file_id."""
|
||||
self.file_storage_client.delete_file_by_id(bucket, file_id)
|
0
dbgpt/serve/file/tests/__init__.py
Normal file
0
dbgpt/serve/file/tests/__init__.py
Normal file
124
dbgpt/serve/file/tests/test_endpoints.py
Normal file
124
dbgpt/serve/file/tests/test_endpoints.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import asystem_app, client
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_health(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
# TODO: add your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
99
dbgpt/serve/file/tests/test_models.py
Normal file
99
dbgpt/serve/file/tests/test_models.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
# TODO : build your server config
|
||||
return ServeConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dao(server_config):
|
||||
return ServeDao(server_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
def test_table_exist():
|
||||
assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_get_one(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_get_dao_get_list(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_update(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_delete(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_get_list_page(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
78
dbgpt/serve/file/tests/test_service.py
Normal file
78
dbgpt/serve/file/tests/test_service.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
# ...
|
||||
pass
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
Reference in New Issue
Block a user