chore: Merge latest code

This commit is contained in:
Fangyin Cheng
2024-08-22 11:37:38 +08:00
parent 4f2c56d821
commit 6442920c63
26 changed files with 2301 additions and 3 deletions

View File

@@ -278,6 +278,11 @@ DBGPT_LOG_LEVEL=INFO
# ENCRYPT KEY - The key used to encrypt and decrypt the data
# ENCRYPT_KEY=your_secret_key
#*******************************************************************#
#** File Server **#
#*******************************************************************#
## The local storage path of the file server, the default is pilot/data/file_server
# FILE_SERVER_LOCAL_STORAGE_PATH =
#*******************************************************************#
#** Application Config **#

View File

@@ -119,3 +119,6 @@ ignore_missing_imports = True
[mypy-networkx.*]
ignore_missing_imports = True
[mypy-pypdf.*]
ignore_missing_imports = True

View File

@@ -321,6 +321,13 @@ class Config(metaclass=Singleton):
os.getenv("USE_NEW_WEB_UI", "True").lower() == "true"
)
# file server configuration
# The host of the current file server, if None, get the host automatically
self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST")
self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv(
"FILE_SERVER_LOCAL_STORAGE_PATH"
)
@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager

View File

@@ -59,7 +59,7 @@ def initialize_components(
_initialize_agent(system_app)
_initialize_openapi(system_app)
# Register serve apps
register_serve_apps(system_app, CFG)
register_serve_apps(system_app, CFG, param.port)
def _initialize_model_cache(system_app: SystemApp):

View File

@@ -11,6 +11,7 @@ from dbgpt.serve.agent.app.recommend_question.recommend_question import (
)
from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginEntity
from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubEntity
from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity
from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity
from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
@@ -22,6 +23,7 @@ from dbgpt.storage.chat_history.chat_history_db import (
_MODELS = [
PluginHubEntity,
FileServeEntity,
MyPluginEntity,
PromptManageEntity,
KnowledgeSpaceEntity,

View File

@@ -2,7 +2,7 @@ from dbgpt._private.config import Config
from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp, cfg: Config):
def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int):
"""Register serve apps"""
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
if cfg.API_KEYS:
@@ -47,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(FlowServe)
# ################################ AWEL Flow Serve Register End ########################################
# ################################ Rag Serve Register Begin ######################################
from dbgpt.serve.rag.serve import (
@@ -57,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(RagServe)
# ################################ Rag Serve Register End ########################################
# ################################ Datasource Serve Register Begin ######################################
from dbgpt.serve.datasource.serve import (
@@ -66,7 +70,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(DatasourceServe)
# ################################ AWEL Flow Serve Register End ########################################
# ################################ Datasource Serve Register End ########################################
# ################################ Chat Feedback Serve Register End ########################################
from dbgpt.serve.feedback.serve import (
@@ -77,3 +82,30 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve feedback
system_app.register(FeedbackServe)
# ################################ Chat Feedback Register End ########################################
# ################################ File Serve Register Begin ######################################
from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH
from dbgpt.serve.file.serve import (
SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX,
)
from dbgpt.serve.file.serve import Serve as FileServe
local_storage_path = (
cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH
)
# Set config
system_app.config.set(
f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path
)
system_app.config.set(
f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port
)
if cfg.FILE_SERVER_HOST:
system_app.config.set(
f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST
)
# Register serve app
system_app.register(FileServe)
# ################################ File Serve Register End ########################################

View File

@@ -90,6 +90,7 @@ class ComponentType(str, Enum):
AGENT_MANAGER = "dbgpt_agent_manager"
RESOURCE_MANAGER = "dbgpt_resource_manager"
VARIABLES_PROVIDER = "dbgpt_variables_provider"
FILE_STORAGE_CLIENT = "dbgpt_file_storage_client"
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"

View File

@@ -14,6 +14,7 @@ DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server")
_DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel")
# Global language setting
LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales")

View File

@@ -0,0 +1,791 @@
"""File storage interface."""
import dataclasses
import hashlib
import io
import os
import uuid
from abc import ABC, abstractmethod
from io import BytesIO
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import requests
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.tracer import root_tracer, trace
from .storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageError,
StorageInterface,
StorageItem,
)
_SCHEMA = "dbgpt-fs"
@dataclasses.dataclass
class FileMetadataIdentifier(ResourceIdentifier):
"""File metadata identifier."""
file_id: str
bucket: str
def to_dict(self) -> Dict:
"""Convert the identifier to a dictionary."""
return {"file_id": self.file_id, "bucket": self.bucket}
@property
def str_identifier(self) -> str:
"""Get the string identifier.
Returns:
str: The string identifier
"""
return f"{self.bucket}/{self.file_id}"
@dataclasses.dataclass
class FileMetadata(StorageItem):
"""File metadata for storage."""
file_id: str
bucket: str
file_name: str
file_size: int
storage_type: str
storage_path: str
uri: str
custom_metadata: Dict[str, Any]
file_hash: str
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = FileMetadataIdentifier(
file_id=self.file_id, bucket=self.bucket
)
@property
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the metadata with another item."""
if not isinstance(other, FileMetadata):
raise StorageError("Cannot merge different types of items")
self._from_object(other)
def to_dict(self) -> Dict:
"""Convert the metadata to a dictionary."""
return {
"file_id": self.file_id,
"bucket": self.bucket,
"file_name": self.file_name,
"file_size": self.file_size,
"storage_type": self.storage_type,
"storage_path": self.storage_path,
"uri": self.uri,
"custom_metadata": self.custom_metadata,
"file_hash": self.file_hash,
}
def _from_object(self, obj: "FileMetadata") -> None:
self.file_id = obj.file_id
self.bucket = obj.bucket
self.file_name = obj.file_name
self.file_size = obj.file_size
self.storage_type = obj.storage_type
self.storage_path = obj.storage_path
self.uri = obj.uri
self.custom_metadata = obj.custom_metadata
self.file_hash = obj.file_hash
self._identifier = obj._identifier
class FileStorageURI:
"""File storage URI."""
def __init__(
self,
storage_type: str,
bucket: str,
file_id: str,
version: Optional[str] = None,
custom_params: Optional[Dict[str, Any]] = None,
):
"""Initialize the file storage URI."""
self.scheme = _SCHEMA
self.storage_type = storage_type
self.bucket = bucket
self.file_id = file_id
self.version = version
self.custom_params = custom_params or {}
@classmethod
def parse(cls, uri: str) -> "FileStorageURI":
"""Parse the URI string."""
parsed = urlparse(uri)
if parsed.scheme != _SCHEMA:
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
path_parts = parsed.path.strip("/").split("/")
if len(path_parts) < 2:
raise ValueError("Invalid URI path. Must contain bucket and file ID")
storage_type = parsed.netloc
bucket = path_parts[0]
file_id = path_parts[1]
version = path_parts[2] if len(path_parts) > 2 else None
custom_params = parse_qs(parsed.query)
return cls(storage_type, bucket, file_id, version, custom_params)
def __str__(self) -> str:
"""Get the string representation of the URI."""
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
if self.version:
base_uri += f"/{self.version}"
if self.custom_params:
query_string = urlencode(self.custom_params, doseq=True)
base_uri += f"?{query_string}"
return base_uri
class StorageBackend(ABC):
"""Storage backend interface."""
storage_type: str = "__base__"
@abstractmethod
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the storage backend.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
@abstractmethod
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
@abstractmethod
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
@property
@abstractmethod
def save_chunk_size(self) -> int:
"""Get the save chunk size.
Returns:
int: The save chunk size
"""
class LocalFileStorage(StorageBackend):
"""Local file storage backend."""
storage_type: str = "local"
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
"""Initialize the local file storage backend."""
self.base_path = base_path
self._save_chunk_size = save_chunk_size
os.makedirs(self.base_path, exist_ok=True)
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the local storage backend."""
bucket_path = os.path.join(self.base_path, bucket)
os.makedirs(bucket_path, exist_ok=True)
file_path = os.path.join(bucket_path, file_id)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return file_path
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
return open(file_path, "rb") # noqa: SIM115
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
class FileStorageSystem:
"""File storage system."""
def __init__(
self,
storage_backends: Dict[str, StorageBackend],
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
check_hash: bool = True,
):
"""Initialize the file storage system."""
metadata_storage = metadata_storage or InMemoryStorage()
self.storage_backends = storage_backends
self.metadata_storage = metadata_storage
self.check_hash = check_hash
self._save_chunk_size = min(
backend.save_chunk_size for backend in storage_backends.values()
)
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
"""Calculate the MD5 hash of the file data."""
if not self.check_hash:
return "-1"
hasher = hashlib.md5()
file_data.seek(0)
while chunk := file_data.read(self._save_chunk_size):
hasher.update(chunk)
file_data.seek(0)
return hasher.hexdigest()
@trace("file_storage_system.save_file")
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage backend."""
file_id = str(uuid.uuid4())
backend = self.storage_backends.get(storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {storage_type}")
with root_tracer.start_span(
"file_storage_system.save_file.backend_save",
metadata={
"bucket": bucket,
"file_id": file_id,
"file_name": file_name,
"storage_type": storage_type,
},
):
storage_path = backend.save(bucket, file_id, file_data)
file_data.seek(0, 2) # Move to the end of the file
file_size = file_data.tell() # Get the file size
file_data.seek(0) # Reset file pointer
with root_tracer.start_span(
"file_storage_system.save_file.calculate_hash",
):
file_hash = self._calculate_file_hash(file_data)
uri = FileStorageURI(
storage_type, bucket, file_id, custom_params=custom_metadata
)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name=file_name,
file_size=file_size,
storage_type=storage_type,
storage_path=storage_path,
uri=str(uri),
custom_metadata=custom_metadata or {},
file_hash=file_hash,
)
self.metadata_storage.save(metadata)
return str(uri)
@trace("file_storage_system.get_file")
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage backend."""
parsed_uri = FileStorageURI.parse(uri)
metadata = self.metadata_storage.load(
FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
),
FileMetadata,
)
if not metadata:
raise FileNotFoundError(f"No metadata found for URI: {uri}")
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
with root_tracer.start_span(
"file_storage_system.get_file.backend_load",
metadata={
"bucket": metadata.bucket,
"file_id": metadata.file_id,
"file_name": metadata.file_name,
"storage_type": metadata.storage_type,
},
):
file_data = backend.load(metadata)
with root_tracer.start_span(
"file_storage_system.get_file.verify_hash",
):
calculated_hash = self._calculate_file_hash(file_data)
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
raise ValueError("File integrity check failed. Hash mismatch.")
return file_data, metadata
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
"""Get the file metadata.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Optional[FileMetadata]: The file metadata
"""
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
return self.metadata_storage.load(fid, FileMetadata)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage backend.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
parsed_uri = FileStorageURI.parse(uri)
fid = FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
)
metadata = self.metadata_storage.load(fid, FileMetadata)
if not metadata:
return False
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
if backend.delete(metadata):
try:
self.metadata_storage.delete(fid)
return True
except Exception:
# If the metadata deletion fails, log the error and return False
return False
return False
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket."""
filters = filters or {}
filters["bucket"] = bucket
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
class FileStorageClient(BaseComponent):
"""File storage client component."""
name = ComponentType.FILE_STORAGE_CLIENT.value
def __init__(
self,
system_app: Optional[SystemApp] = None,
storage_system: Optional[FileStorageSystem] = None,
):
"""Initialize the file storage client."""
super().__init__(system_app=system_app)
if not storage_system:
from pathlib import Path
base_path = Path.home() / ".cache" / "dbgpt" / "files"
storage_system = FileStorageSystem(
{
LocalFileStorage.storage_type: LocalFileStorage(
base_path=str(base_path)
)
}
)
self.system_app = system_app
self._storage_system = storage_system
def init_app(self, system_app: SystemApp):
"""Initialize the application."""
self.system_app = system_app
@property
def storage_system(self) -> FileStorageSystem:
"""Get the file storage system."""
if not self._storage_system:
raise ValueError("File storage system not initialized")
return self._storage_system
def upload_file(
self,
bucket: str,
file_path: str,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Upload a file to the storage system.
Args:
bucket (str): The bucket name
file_path (str): The file path
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
with open(file_path, "rb") as file:
return self.save_file(
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
)
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage system.
Args:
bucket (str): The bucket name
file_name (str): The file name
file_data (BinaryIO): The file data
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
return self.storage_system.save_file(
bucket, file_name, file_data, storage_type, custom_metadata
)
def download_file(self, uri: str, destination_path: str) -> None:
"""Download a file from the storage system.
Args:
uri (str): The file URI
destination_path (str): The destination
Raises:
FileNotFoundError: If the file is not found
"""
file_data, _ = self.storage_system.get_file(uri)
with open(destination_path, "wb") as f:
f.write(file_data.read())
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system.
Args:
uri (str): The file URI
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
return self.storage_system.get_file(uri)
def get_file_by_id(
self, bucket: str, file_id: str
) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.get_file(metadata.uri)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage system.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
return self.storage_system.delete_file(uri)
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
"""Delete the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
bool: True if the file was deleted, False otherwise
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.delete_file(metadata.uri)
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket.
Args:
bucket (str): The bucket name
filters (Dict[str, Any], optional): Filters. Defaults to None.
Returns:
List[FileMetadata]: The list of file metadata
"""
return self.storage_system.list_files(bucket, filters)
class SimpleDistributedStorage(StorageBackend):
"""Simple distributed storage backend."""
storage_type: str = "distributed"
def __init__(
self,
node_address: str,
local_storage_path: str,
save_chunk_size: int = 1024 * 1024,
transfer_chunk_size: int = 1024 * 1024,
transfer_timeout: int = 360,
api_prefix: str = "/api/v2/serve/file/files",
):
"""Initialize the simple distributed storage backend."""
self.node_address = node_address
self.local_storage_path = local_storage_path
os.makedirs(self.local_storage_path, exist_ok=True)
self._save_chunk_size = save_chunk_size
self._transfer_chunk_size = transfer_chunk_size
self._transfer_timeout = transfer_timeout
self._api_prefix = api_prefix
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
node_id = hashlib.md5(node_address.encode()).hexdigest()
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
def _parse_node_address(self, fm: FileMetadata) -> str:
storage_path = fm.storage_path
if not storage_path.startswith("distributed://"):
raise ValueError("Invalid storage path")
return storage_path.split("//")[1].split("/")[0]
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the distributed storage backend.
Just save the file locally.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
file_path = self._get_file_path(bucket, file_id, self.node_address)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return f"distributed://{self.node_address}/{bucket}/{file_id}"
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the distributed storage backend.
If the file is stored on the local node, load it from the local storage.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
# TODO: check if the file is cached in local storage
if node_address == self.node_address:
if os.path.exists(file_path):
return open(file_path, "rb") # noqa: SIM115
else:
raise FileNotFoundError(f"File {file_id} not found on the local node")
else:
response = requests.get(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
stream=True,
)
response.raise_for_status()
# TODO: cache the file in local storage
return StreamedBytesIO(
response.iter_content(chunk_size=self._transfer_chunk_size)
)
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the distributed storage backend.
If the file is stored on the local node, delete it from the local storage.
If the file is stored on a remote node, send a delete request to the remote
node.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
if node_address == self.node_address:
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
else:
try:
response = requests.delete(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
)
response.raise_for_status()
return True
except Exception:
return False
class StreamedBytesIO(io.BytesIO):
"""A BytesIO subclass that can be used with streaming responses.
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
"""
def __init__(self, request_iterator):
"""Initialize the StreamedBytesIO instance."""
super().__init__()
self._bytes = BytesIO()
self._iterator = request_iterator
def _load_all(self):
self._bytes.seek(0, io.SEEK_END)
for chunk in self._iterator:
self._bytes.write(chunk)
def _load_until(self, goal_position):
current_position = self._bytes.seek(0, io.SEEK_END)
while current_position < goal_position:
try:
current_position += self._bytes.write(next(self._iterator))
except StopIteration:
break
def tell(self) -> int:
"""Get the current position."""
return self._bytes.tell()
def read(self, size: Optional[int] = None) -> bytes:
"""Read the data from the stream.
Args:
size (Optional[int], optional): The number of bytes to read. Defaults to
None.
Returns:
bytes: The read data
"""
left_off_at = self._bytes.tell()
if size is None:
self._load_all()
else:
goal_position = left_off_at + size
self._load_until(goal_position)
self._bytes.seek(left_off_at)
return self._bytes.read(size)
def seek(self, position: int, whence: int = io.SEEK_SET):
"""Seek to a position in the stream.
Args:
position (int): The position
whence (int, optional): The reference point. Defaults to io.SEEK
Raises:
ValueError: If the reference point is invalid
"""
if whence == io.SEEK_END:
self._load_all()
else:
self._bytes.seek(position, whence)
def __enter__(self):
"""Enter the context manager."""
return self
def __exit__(self, ext_type, value, tb):
"""Exit the context manager."""
self._bytes.close()

View File

@@ -0,0 +1,506 @@
import hashlib
import io
import os
from unittest import mock
import pytest
from ..file import (
FileMetadata,
FileMetadataIdentifier,
FileStorageClient,
FileStorageSystem,
InMemoryStorage,
LocalFileStorage,
SimpleDistributedStorage,
)
@pytest.fixture
def temp_test_file_dir(tmpdir):
return str(tmpdir)
@pytest.fixture
def temp_storage_path(tmpdir):
return str(tmpdir)
@pytest.fixture
def local_storage_backend(temp_storage_path):
return LocalFileStorage(temp_storage_path)
@pytest.fixture
def distributed_storage_backend(temp_storage_path):
node_address = "127.0.0.1:8000"
return SimpleDistributedStorage(node_address, temp_storage_path)
@pytest.fixture
def file_storage_system(local_storage_backend):
backends = {"local": local_storage_backend}
metadata_storage = InMemoryStorage()
return FileStorageSystem(backends, metadata_storage)
@pytest.fixture
def file_storage_client(file_storage_system):
return FileStorageClient(storage_system=file_storage_system)
@pytest.fixture
def sample_file_path(temp_test_file_dir):
file_path = os.path.join(temp_test_file_dir, "sample.txt")
with open(file_path, "wb") as f:
f.write(b"Sample file content")
return file_path
@pytest.fixture
def sample_file_data():
return io.BytesIO(b"Sample file content for distributed storage")
def test_save_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
assert uri.startswith("dbgpt-fs://local/test-bucket/")
assert os.path.exists(sample_file_path)
def test_get_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
assert file_data.read() == b"Sample file content"
assert metadata.file_name == "sample.txt"
assert metadata.bucket == bucket
def test_delete_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
assert len(file_storage_client.list_files(bucket=bucket)) == 1
result = file_storage_client.delete_file(uri)
assert result is True
assert len(file_storage_client.list_files(bucket=bucket)) == 0
def test_list_files(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri1 = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
files = file_storage_client.list_files(bucket=bucket)
assert len(files) == 1
def test_save_file_unsupported_storage(file_storage_system, sample_file_path):
bucket = "test-bucket"
with pytest.raises(ValueError):
file_storage_system.save_file(
bucket=bucket,
file_name="unsupported.txt",
file_data=io.BytesIO(b"Unsupported storage"),
storage_type="unsupported",
)
def test_get_file_not_found(file_storage_system):
with pytest.raises(FileNotFoundError):
file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent")
def test_delete_file_not_found(file_storage_system):
result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent")
assert result is False
def test_metadata_management(file_storage_system):
bucket = "test-bucket"
file_id = "test_file"
metadata = file_storage_system.metadata_storage.save(
FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=100,
storage_type="local",
storage_path="/path/to/test.txt",
uri="dbgpt-fs://local/test-bucket/test_file",
custom_metadata={"key": "value"},
file_hash="hash",
)
)
loaded_metadata = file_storage_system.metadata_storage.load(
FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata
)
assert loaded_metadata.file_name == "test.txt"
assert loaded_metadata.custom_metadata["key"] == "value"
assert loaded_metadata.bucket == bucket
def test_concurrent_save_and_delete(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Simulate concurrent file save and delete operations
def save_file():
return file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
def delete_file(uri):
return file_storage_client.delete_file(uri)
uri = save_file()
# Simulate concurrent operations
save_file()
delete_file(uri)
assert len(file_storage_client.list_files(bucket=bucket)) == 1
def test_large_file_handling(file_storage_client, temp_storage_path):
bucket = "test-bucket"
large_file_path = os.path.join(temp_storage_path, "large_sample.bin")
with open(large_file_path, "wb") as f:
f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file
uri = file_storage_client.upload_file(
bucket=bucket,
file_path=large_file_path,
storage_type="local",
custom_metadata={"description": "Large file test"},
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
assert file_data.read() == open(large_file_path, "rb").read()
assert metadata.file_name == "large_sample.bin"
assert metadata.bucket == bucket
def test_file_hash_verification_success(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Upload file and
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
file_hash = metadata.file_hash
calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data)
assert (
file_hash == calculated_hash
), "File hash should match after saving and loading"
def test_file_hash_verification_failure(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Upload file and
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
# Modify the file content manually to simulate file tampering
storage_system = file_storage_client.storage_system
metadata = storage_system.metadata_storage.load(
FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata
)
with open(metadata.storage_path, "wb") as f:
f.write(b"Tampered content")
# Get file should raise an exception due to hash mismatch
with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."):
storage_system.get_file(uri)
def test_file_isolation_across_buckets(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the same file to two different buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Verify both URIs are different and point to different files
assert uri1 != uri2
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
assert file_data1.read() == b"Sample file content"
assert file_data2.read() == b"Sample file content"
assert metadata1.bucket == bucket1
assert metadata2.bucket == bucket2
def test_list_files_in_specific_bucket(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload a file to both buckets
file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# List files in bucket1 and bucket2
files_in_bucket1 = file_storage_client.list_files(bucket=bucket1)
files_in_bucket2 = file_storage_client.list_files(bucket=bucket2)
assert len(files_in_bucket1) == 1
assert len(files_in_bucket2) == 1
assert files_in_bucket1[0].bucket == bucket1
assert files_in_bucket2[0].bucket == bucket2
def test_delete_file_in_one_bucket_does_not_affect_other_bucket(
file_storage_client, sample_file_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the same file to two different buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Delete the file in bucket1
file_storage_client.delete_file(uri1)
# Check that the file in bucket1 is deleted
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
# Check that the file in bucket2 is still there
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
assert file_data2.read() == b"Sample file content"
def test_file_hash_verification_in_different_buckets(
file_storage_client, sample_file_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the file to both buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
# Verify that file hashes are the same for the same content
file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1)
file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2)
assert file_hash1 == metadata1.file_hash
assert file_hash2 == metadata2.file_hash
assert file_hash1 == file_hash2
def test_file_download_from_different_buckets(
file_storage_client, sample_file_path, temp_storage_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the file to both buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Download files to different locations
download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt")
download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt")
file_storage_client.download_file(uri1, download_path1)
file_storage_client.download_file(uri2, download_path2)
# Verify contents of downloaded files
assert open(download_path1, "rb").read() == b"Sample file content"
assert open(download_path2, "rb").read() == b"Sample file content"
def test_delete_all_files_in_bucket(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload files to both buckets
file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Delete all files in bucket1
for file in file_storage_client.list_files(bucket=bucket1):
file_storage_client.delete_file(file.uri)
# Verify bucket1 is empty
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
# Verify bucket2 still has files
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
def test_simple_distributed_storage_save_file(
distributed_storage_backend, sample_file_data, temp_storage_path
):
bucket = "test-bucket"
file_id = "test_file"
file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data)
expected_path = os.path.join(
temp_storage_path,
bucket,
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
)
assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}"
assert os.path.exists(expected_path)
def test_simple_distributed_storage_load_file_local(
distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
distributed_storage_backend.save(bucket, file_id, sample_file_data)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
file_data = distributed_storage_backend.load(metadata)
assert file_data.read() == b"Sample file content for distributed storage"
@mock.patch("requests.get")
def test_simple_distributed_storage_load_file_remote(
mock_get, distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
remote_node_address = "127.0.0.2:8000"
# Mock the response from remote node
mock_response = mock.Mock()
mock_response.iter_content = mock.Mock(
return_value=iter([b"Sample file content for distributed storage"])
)
mock_response.raise_for_status = mock.Mock(return_value=None)
mock_get.return_value = mock_response
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
file_data = distributed_storage_backend.load(metadata)
assert file_data.read() == b"Sample file content for distributed storage"
mock_get.assert_called_once_with(
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
stream=True,
timeout=360,
)
def test_simple_distributed_storage_delete_file_local(
distributed_storage_backend, sample_file_data, temp_storage_path
):
bucket = "test-bucket"
file_id = "test_file"
distributed_storage_backend.save(bucket, file_id, sample_file_data)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
result = distributed_storage_backend.delete(metadata)
file_path = os.path.join(
temp_storage_path,
bucket,
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
)
assert result is True
assert not os.path.exists(file_path)
@mock.patch("requests.delete")
def test_simple_distributed_storage_delete_file_remote(
mock_delete, distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
remote_node_address = "127.0.0.2:8000"
mock_response = mock.Mock()
mock_response.raise_for_status = mock.Mock(return_value=None)
mock_delete.return_value = mock_response
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
result = distributed_storage_backend.delete(metadata)
assert result is True
mock_delete.assert_called_once_with(
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
timeout=360,
)

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve file`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve file`

View File

@@ -0,0 +1,159 @@
import logging
from functools import cache
from typing import List, Optional
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result, blocking_func_to_async
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse, UploadFileResponse
router = APIRouter()
logger = logging.getLogger(__name__)
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/files/{bucket}",
response_model=Result[List[UploadFileResponse]],
dependencies=[Depends(check_api_key)],
)
async def upload_files(
bucket: str, files: List[UploadFile], service: Service = Depends(get_service)
) -> Result[List[UploadFileResponse]]:
"""Upload files by a list of UploadFile."""
logger.info(f"upload_files: bucket={bucket}, files={files}")
results = await blocking_func_to_async(
global_system_app, service.upload_files, bucket, "distributed", files
)
return Result.succ(results)
@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)])
async def download_file(
bucket: str, file_id: str, service: Service = Depends(get_service)
):
"""Download a file by file_id."""
logger.info(f"download_file: bucket={bucket}, file_id={file_id}")
file_data, file_metadata = await blocking_func_to_async(
global_system_app, service.download_file, bucket, file_id
)
file_name_encoded = quote(file_metadata.file_name)
def file_iterator(raw_iter):
with raw_iter:
while chunk := raw_iter.read(
service.config.file_server_download_chunk_size
):
yield chunk
response = StreamingResponse(
file_iterator(file_data), media_type="application/octet-stream"
)
response.headers[
"Content-Disposition"
] = f"attachment; filename={file_name_encoded}"
return response
@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)])
async def delete_file(
bucket: str, file_id: str, service: Service = Depends(get_service)
):
"""Delete a file by file_id."""
await blocking_func_to_async(
global_system_app, service.delete_file, bucket, file_id
)
return Result.succ(None)
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,43 @@
# Define your Pydantic schemas here
from typing import Any, Dict
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""File request model"""
# TODO define your own fields here
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
class ServerResponse(BaseModel):
"""File response model"""
# TODO define your own fields here
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
class UploadFileResponse(BaseModel):
"""Upload file response model"""
file_name: str = Field(..., title="The name of the uploaded file")
file_id: str = Field(..., title="The ID of the uploaded file")
bucket: str = Field(..., title="The bucket of the uploaded file")
uri: str = Field(..., title="The URI of the uploaded file")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)

View File

@@ -0,0 +1,68 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "file"
SERVE_APP_NAME = "dbgpt_serve_file"
SERVE_APP_NAME_HUMP = "dbgpt_serve_File"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_file"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
check_hash: Optional[bool] = field(
default=True, metadata={"help": "Check the hash of the file when downloading"}
)
file_server_host: Optional[str] = field(
default=None, metadata={"help": "The host of the file server"}
)
file_server_port: Optional[int] = field(
default=5670, metadata={"help": "The port of the file server"}
)
file_server_download_chunk_size: Optional[int] = field(
default=1024 * 1024,
metadata={"help": "The chunk size when downloading the file"},
)
file_server_save_chunk_size: Optional[int] = field(
default=1024 * 1024, metadata={"help": "The chunk size when saving the file"}
)
file_server_transfer_chunk_size: Optional[int] = field(
default=1024 * 1024,
metadata={"help": "The chunk size when transferring the file"},
)
file_server_transfer_timeout: Optional[int] = field(
default=360, metadata={"help": "The timeout when transferring the file"}
)
local_storage_path: Optional[str] = field(
default=None, metadata={"help": "The local storage path"}
)
def get_node_address(self) -> str:
"""Get the node address"""
file_server_host = self.file_server_host
if not file_server_host:
from dbgpt.util.net_utils import _get_ip_address
file_server_host = _get_ip_address()
file_server_port = self.file_server_port or 5670
return f"{file_server_host}:{file_server_port}"
def get_local_storage_path(self) -> str:
"""Get the local storage path"""
local_storage_path = self.local_storage_path
if not local_storage_path:
from pathlib import Path
base_path = Path.home() / ".cache" / "dbgpt" / "files"
local_storage_path = str(base_path)
return local_storage_path

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve file`

View File

@@ -0,0 +1,66 @@
import json
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier
from dbgpt.core.interface.storage import StorageItemAdapter
from .models import ServeEntity
class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]):
"""File metadata adapter.
Convert between storage format and database model.
"""
def to_storage_format(self, item: FileMetadata) -> ServeEntity:
"""Convert to storage format."""
custom_metadata = (
json.dumps(item.custom_metadata, ensure_ascii=False)
if item.custom_metadata
else None
)
return ServeEntity(
bucket=item.bucket,
file_id=item.file_id,
file_name=item.file_name,
file_size=item.file_size,
storage_type=item.storage_type,
storage_path=item.storage_path,
uri=item.uri,
custom_metadata=custom_metadata,
file_hash=item.file_hash,
)
def from_storage_format(self, model: ServeEntity) -> FileMetadata:
"""Convert from storage format."""
custom_metadata = (
json.loads(model.custom_metadata) if model.custom_metadata else None
)
return FileMetadata(
bucket=model.bucket,
file_id=model.file_id,
file_name=model.file_name,
file_size=model.file_size,
storage_type=model.storage_type,
storage_path=model.storage_path,
uri=model.uri,
custom_metadata=custom_metadata,
file_hash=model.file_hash,
)
def get_query_for_identifier(
self,
storage_format: Type[ServeEntity],
resource_id: FileMetadataIdentifier,
**kwargs,
):
"""Get query for identifier."""
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
return session.query(storage_format).filter(
storage_format.file_id == resource_id.file_id
)

View File

@@ -0,0 +1,87 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
bucket = Column(String(255), nullable=False, comment="Bucket name")
file_id = Column(String(255), nullable=False, comment="File id")
file_name = Column(String(256), nullable=False, comment="File name")
file_size = Column(Integer, nullable=True, comment="File size")
storage_type = Column(String(32), nullable=False, comment="Storage type")
storage_path = Column(String(512), nullable=False, comment="Storage path")
uri = Column(String(512), nullable=False, comment="File URI")
custom_metadata = Column(
Text, nullable=True, comment="Custom metadata, JSON format"
)
file_hash = Column(String(128), nullable=True, comment="File hash")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return (
f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', "
f"gmt_modified='{self.gmt_modified}')"
)
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for File"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeEntity(**request_dict)
# TODO implement your own logic here, transfer the request_dict to an entity
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
# TODO implement your own logic here, transfer the entity to a request
return ServeRequest()
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
return ServerResponse()

113
dbgpt/serve/file/serve.py Normal file
View File

@@ -0,0 +1,113 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.core.interface.file import FileStorageClient
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
self._db_manager: Optional[DatabaseManager] = None
self._file_storage_client: Optional[FileStorageClient] = None
self._serve_config: Optional[ServeConfig] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
from dbgpt.core.interface.file import (
FileStorageSystem,
SimpleDistributedStorage,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.file_adapter import FileMetadataAdapter
from .models.models import ServeEntity
self._serve_config = ServeConfig.from_app_config(
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._db_manager = self.create_or_get_db_manager()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
self._db_manager,
ServeEntity,
FileMetadataAdapter(),
serializer,
)
simple_distributed_storage = SimpleDistributedStorage(
node_address=self._serve_config.get_node_address(),
local_storage_path=self._serve_config.get_local_storage_path(),
save_chunk_size=self._serve_config.file_server_save_chunk_size,
transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size,
transfer_timeout=self._serve_config.file_server_transfer_timeout,
)
storage_backends = {
simple_distributed_storage.storage_type: simple_distributed_storage,
}
fs = FileStorageSystem(
storage_backends,
metadata_storage=storage,
check_hash=self._serve_config.check_hash,
)
self._file_storage_client = FileStorageClient(
system_app=self._system_app, storage_system=fs
)
@property
def file_storage_client(self) -> FileStorageClient:
"""Returns the file storage client."""
if not self._file_storage_client:
raise ValueError("File storage client is not initialized")
return self._file_storage_client

View File

View File

@@ -0,0 +1,106 @@
import logging
from typing import BinaryIO, List, Optional, Tuple
from fastapi import UploadFile
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.tracer import root_tracer, trace
from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for File"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
@property
def file_storage_client(self) -> FileStorageClient:
"""Returns the internal FileStorageClient.
Returns:
FileStorageClient: The internal FileStorageClient
"""
file_storage_client = FileStorageClient.get_instance(
self._system_app, default_component=None
)
if file_storage_client:
return file_storage_client
else:
from ..serve import Serve
file_storage_client = Serve.get_instance(
self._system_app
).file_storage_client
self._system_app.register_instance(file_storage_client)
return file_storage_client
@trace("upload_files")
def upload_files(
self, bucket: str, storage_type: str, files: List[UploadFile]
) -> List[UploadFileResponse]:
"""Upload files by a list of UploadFile."""
results = []
for file in files:
file_name = file.filename
logger.info(f"Uploading file {file_name} to bucket {bucket}")
uri = self.file_storage_client.save_file(
bucket, file_name, file_data=file.file, storage_type=storage_type
)
parsed_uri = FileStorageURI.parse(uri)
logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}")
results.append(
UploadFileResponse(
file_name=file_name,
file_id=parsed_uri.file_id,
bucket=bucket,
uri=uri,
)
)
return results
@trace("download_file")
def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]:
"""Download a file by file_id."""
return self.file_storage_client.get_file_by_id(bucket, file_id)
def delete_file(self, bucket: str, file_id: str) -> None:
"""Delete a file by file_id."""
self.file_storage_client.delete_file_by_id(bucket, file_id)

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,99 @@
import pytest
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig
from ..models.models import ServeDao, ServeEntity
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic