From b6d54ed8ab3d69a7ce4f70894e30acdc8c9c5a1d Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 00:19:53 +0800 Subject: [PATCH] feat(core): Add file server for DB-GPT --- .env.template | 5 + .mypy.ini | 3 + dbgpt/_private/config.py | 7 + dbgpt/app/component_configs.py | 2 +- .../initialization/db_model_initialization.py | 2 + .../initialization/serve_initialization.py | 36 +- dbgpt/component.py | 1 + dbgpt/configs/model_config.py | 1 + dbgpt/core/interface/file.py | 791 ++++++++++++++++++ dbgpt/core/interface/tests/test_file.py | 506 +++++++++++ dbgpt/serve/file/__init__.py | 2 + dbgpt/serve/file/api/__init__.py | 2 + dbgpt/serve/file/api/endpoints.py | 159 ++++ dbgpt/serve/file/api/schemas.py | 43 + dbgpt/serve/file/config.py | 68 ++ dbgpt/serve/file/dependencies.py | 1 + dbgpt/serve/file/models/__init__.py | 2 + dbgpt/serve/file/models/file_adapter.py | 66 ++ dbgpt/serve/file/models/models.py | 87 ++ dbgpt/serve/file/serve.py | 113 +++ dbgpt/serve/file/service/__init__.py | 0 dbgpt/serve/file/service/service.py | 106 +++ dbgpt/serve/file/tests/__init__.py | 0 dbgpt/serve/file/tests/test_endpoints.py | 124 +++ dbgpt/serve/file/tests/test_models.py | 99 +++ dbgpt/serve/file/tests/test_service.py | 78 ++ 26 files changed, 2301 insertions(+), 3 deletions(-) create mode 100644 dbgpt/core/interface/file.py create mode 100644 dbgpt/core/interface/tests/test_file.py create mode 100644 dbgpt/serve/file/__init__.py create mode 100644 dbgpt/serve/file/api/__init__.py create mode 100644 dbgpt/serve/file/api/endpoints.py create mode 100644 dbgpt/serve/file/api/schemas.py create mode 100644 dbgpt/serve/file/config.py create mode 100644 dbgpt/serve/file/dependencies.py create mode 100644 dbgpt/serve/file/models/__init__.py create mode 100644 dbgpt/serve/file/models/file_adapter.py create mode 100644 dbgpt/serve/file/models/models.py create mode 100644 dbgpt/serve/file/serve.py create mode 100644 dbgpt/serve/file/service/__init__.py create mode 100644 dbgpt/serve/file/service/service.py create mode 100644 dbgpt/serve/file/tests/__init__.py create mode 100644 dbgpt/serve/file/tests/test_endpoints.py create mode 100644 dbgpt/serve/file/tests/test_models.py create mode 100644 dbgpt/serve/file/tests/test_service.py diff --git a/.env.template b/.env.template index 44aa2d710..2a281e698 100644 --- a/.env.template +++ b/.env.template @@ -277,6 +277,11 @@ DBGPT_LOG_LEVEL=INFO # ENCRYPT KEY - The key used to encrypt and decrypt the data # ENCRYPT_KEY=your_secret_key +#*******************************************************************# +#** File Server **# +#*******************************************************************# +## The local storage path of the file server, the default is pilot/data/file_server +# FILE_SERVER_LOCAL_STORAGE_PATH = #*******************************************************************# #** Application Config **# diff --git a/.mypy.ini b/.mypy.ini index e2c2bc3ab..52ae00c35 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -115,3 +115,6 @@ ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True + +[mypy-pypdf.*] +ignore_missing_imports = True diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 2dbfac0f0..18e972a4c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -316,6 +316,13 @@ class Config(metaclass=Singleton): # experimental financial report model configuration self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None) + # file server configuration + # The host of the current file server, if None, get the host automatically + self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST") + self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv( + "FILE_SERVER_LOCAL_STORAGE_PATH" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 3ef08d4bc..29c9e59be 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -59,7 +59,7 @@ def initialize_components( _initialize_agent(system_app) _initialize_openapi(system_app) # Register serve apps - register_serve_apps(system_app, CFG) + register_serve_apps(system_app, CFG, param.port) def _initialize_model_cache(system_app: SystemApp): diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index b8808c400..969340c44 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -8,6 +8,7 @@ from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity @@ -19,6 +20,7 @@ from dbgpt.storage.chat_history.chat_history_db import ( _MODELS = [ PluginHubEntity, + FileServeEntity, MyPluginEntity, PromptManageEntity, KnowledgeSpaceEntity, diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index f0b9c9e42..7838644e0 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -2,7 +2,7 @@ from dbgpt._private.config import Config from dbgpt.component import SystemApp -def register_serve_apps(system_app: SystemApp, cfg: Config): +def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int): """Register serve apps""" system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: @@ -47,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(FlowServe) + # ################################ AWEL Flow Serve Register End ######################################## + # ################################ Rag Serve Register Begin ###################################### from dbgpt.serve.rag.serve import ( @@ -57,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(RagServe) + # ################################ Rag Serve Register End ######################################## + # ################################ Datasource Serve Register Begin ###################################### from dbgpt.serve.datasource.serve import ( @@ -66,4 +70,32 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(DatasourceServe) - # ################################ AWEL Flow Serve Register End ######################################## + + # ################################ Datasource Serve Register End ######################################## + + # ################################ File Serve Register Begin ###################################### + + from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH + from dbgpt.serve.file.serve import ( + SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX, + ) + from dbgpt.serve.file.serve import Serve as FileServe + + local_storage_path = ( + cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH + ) + # Set config + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path + ) + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port + ) + if cfg.FILE_SERVER_HOST: + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST + ) + # Register serve app + system_app.register(FileServe) + + # ################################ File Serve Register End ######################################## diff --git a/dbgpt/component.py b/dbgpt/component.py index cb88a61ec..da3c5e753 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -90,6 +90,7 @@ class ComponentType(str, Enum): AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" VARIABLES_PROVIDER = "dbgpt_variables_provider" + FILE_STORAGE_CLIENT = "dbgpt_file_storage_client" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4d02a2730..e4abac3e7 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -14,6 +14,7 @@ DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") +FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server") _DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel") # Global language setting LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales") diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py new file mode 100644 index 000000000..5bd6cf842 --- /dev/null +++ b/dbgpt/core/interface/file.py @@ -0,0 +1,791 @@ +"""File storage interface.""" + +import dataclasses +import hashlib +import io +import os +import uuid +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, BinaryIO, Dict, List, Optional, Tuple +from urllib.parse import parse_qs, urlencode, urlparse + +import requests + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.tracer import root_tracer, trace + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageError, + StorageInterface, + StorageItem, +) + +_SCHEMA = "dbgpt-fs" + + +@dataclasses.dataclass +class FileMetadataIdentifier(ResourceIdentifier): + """File metadata identifier.""" + + file_id: str + bucket: str + + def to_dict(self) -> Dict: + """Convert the identifier to a dictionary.""" + return {"file_id": self.file_id, "bucket": self.bucket} + + @property + def str_identifier(self) -> str: + """Get the string identifier. + + Returns: + str: The string identifier + """ + return f"{self.bucket}/{self.file_id}" + + +@dataclasses.dataclass +class FileMetadata(StorageItem): + """File metadata for storage.""" + + file_id: str + bucket: str + file_name: str + file_size: int + storage_type: str + storage_path: str + uri: str + custom_metadata: Dict[str, Any] + file_hash: str + _identifier: FileMetadataIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = FileMetadataIdentifier( + file_id=self.file_id, bucket=self.bucket + ) + + @property + def identifier(self) -> ResourceIdentifier: + """Get the resource identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge the metadata with another item.""" + if not isinstance(other, FileMetadata): + raise StorageError("Cannot merge different types of items") + self._from_object(other) + + def to_dict(self) -> Dict: + """Convert the metadata to a dictionary.""" + return { + "file_id": self.file_id, + "bucket": self.bucket, + "file_name": self.file_name, + "file_size": self.file_size, + "storage_type": self.storage_type, + "storage_path": self.storage_path, + "uri": self.uri, + "custom_metadata": self.custom_metadata, + "file_hash": self.file_hash, + } + + def _from_object(self, obj: "FileMetadata") -> None: + self.file_id = obj.file_id + self.bucket = obj.bucket + self.file_name = obj.file_name + self.file_size = obj.file_size + self.storage_type = obj.storage_type + self.storage_path = obj.storage_path + self.uri = obj.uri + self.custom_metadata = obj.custom_metadata + self.file_hash = obj.file_hash + self._identifier = obj._identifier + + +class FileStorageURI: + """File storage URI.""" + + def __init__( + self, + storage_type: str, + bucket: str, + file_id: str, + version: Optional[str] = None, + custom_params: Optional[Dict[str, Any]] = None, + ): + """Initialize the file storage URI.""" + self.scheme = _SCHEMA + self.storage_type = storage_type + self.bucket = bucket + self.file_id = file_id + self.version = version + self.custom_params = custom_params or {} + + @classmethod + def parse(cls, uri: str) -> "FileStorageURI": + """Parse the URI string.""" + parsed = urlparse(uri) + if parsed.scheme != _SCHEMA: + raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'") + path_parts = parsed.path.strip("/").split("/") + if len(path_parts) < 2: + raise ValueError("Invalid URI path. Must contain bucket and file ID") + storage_type = parsed.netloc + bucket = path_parts[0] + file_id = path_parts[1] + version = path_parts[2] if len(path_parts) > 2 else None + custom_params = parse_qs(parsed.query) + return cls(storage_type, bucket, file_id, version, custom_params) + + def __str__(self) -> str: + """Get the string representation of the URI.""" + base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}" + if self.version: + base_uri += f"/{self.version}" + if self.custom_params: + query_string = urlencode(self.custom_params, doseq=True) + base_uri += f"?{query_string}" + return base_uri + + +class StorageBackend(ABC): + """Storage backend interface.""" + + storage_type: str = "__base__" + + @abstractmethod + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the storage backend. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + + @abstractmethod + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + + @abstractmethod + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + + @property + @abstractmethod + def save_chunk_size(self) -> int: + """Get the save chunk size. + + Returns: + int: The save chunk size + """ + + +class LocalFileStorage(StorageBackend): + """Local file storage backend.""" + + storage_type: str = "local" + + def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024): + """Initialize the local file storage backend.""" + self.base_path = base_path + self._save_chunk_size = save_chunk_size + os.makedirs(self.base_path, exist_ok=True) + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the local storage backend.""" + bucket_path = os.path.join(self.base_path, bucket) + os.makedirs(bucket_path, exist_ok=True) + file_path = os.path.join(bucket_path, file_id) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + return file_path + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + return open(file_path, "rb") # noqa: SIM115 + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + + +class FileStorageSystem: + """File storage system.""" + + def __init__( + self, + storage_backends: Dict[str, StorageBackend], + metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None, + check_hash: bool = True, + ): + """Initialize the file storage system.""" + metadata_storage = metadata_storage or InMemoryStorage() + self.storage_backends = storage_backends + self.metadata_storage = metadata_storage + self.check_hash = check_hash + self._save_chunk_size = min( + backend.save_chunk_size for backend in storage_backends.values() + ) + + def _calculate_file_hash(self, file_data: BinaryIO) -> str: + """Calculate the MD5 hash of the file data.""" + if not self.check_hash: + return "-1" + hasher = hashlib.md5() + file_data.seek(0) + while chunk := file_data.read(self._save_chunk_size): + hasher.update(chunk) + file_data.seek(0) + return hasher.hexdigest() + + @trace("file_storage_system.save_file") + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage backend.""" + file_id = str(uuid.uuid4()) + backend = self.storage_backends.get(storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {storage_type}") + + with root_tracer.start_span( + "file_storage_system.save_file.backend_save", + metadata={ + "bucket": bucket, + "file_id": file_id, + "file_name": file_name, + "storage_type": storage_type, + }, + ): + storage_path = backend.save(bucket, file_id, file_data) + file_data.seek(0, 2) # Move to the end of the file + file_size = file_data.tell() # Get the file size + file_data.seek(0) # Reset file pointer + + with root_tracer.start_span( + "file_storage_system.save_file.calculate_hash", + ): + file_hash = self._calculate_file_hash(file_data) + uri = FileStorageURI( + storage_type, bucket, file_id, custom_params=custom_metadata + ) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name=file_name, + file_size=file_size, + storage_type=storage_type, + storage_path=storage_path, + uri=str(uri), + custom_metadata=custom_metadata or {}, + file_hash=file_hash, + ) + + self.metadata_storage.save(metadata) + return str(uri) + + @trace("file_storage_system.get_file") + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage backend.""" + parsed_uri = FileStorageURI.parse(uri) + metadata = self.metadata_storage.load( + FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ), + FileMetadata, + ) + if not metadata: + raise FileNotFoundError(f"No metadata found for URI: {uri}") + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + with root_tracer.start_span( + "file_storage_system.get_file.backend_load", + metadata={ + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_name": metadata.file_name, + "storage_type": metadata.storage_type, + }, + ): + file_data = backend.load(metadata) + + with root_tracer.start_span( + "file_storage_system.get_file.verify_hash", + ): + calculated_hash = self._calculate_file_hash(file_data) + if calculated_hash != "-1" and calculated_hash != metadata.file_hash: + raise ValueError("File integrity check failed. Hash mismatch.") + + return file_data, metadata + + def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]: + """Get the file metadata. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Optional[FileMetadata]: The file metadata + """ + fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket) + return self.metadata_storage.load(fid, FileMetadata) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage backend. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + parsed_uri = FileStorageURI.parse(uri) + fid = FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ) + metadata = self.metadata_storage.load(fid, FileMetadata) + if not metadata: + return False + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + if backend.delete(metadata): + try: + self.metadata_storage.delete(fid) + return True + except Exception: + # If the metadata deletion fails, log the error and return False + return False + return False + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket.""" + filters = filters or {} + filters["bucket"] = bucket + return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata) + + +class FileStorageClient(BaseComponent): + """File storage client component.""" + + name = ComponentType.FILE_STORAGE_CLIENT.value + + def __init__( + self, + system_app: Optional[SystemApp] = None, + storage_system: Optional[FileStorageSystem] = None, + ): + """Initialize the file storage client.""" + super().__init__(system_app=system_app) + if not storage_system: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + storage_system = FileStorageSystem( + { + LocalFileStorage.storage_type: LocalFileStorage( + base_path=str(base_path) + ) + } + ) + + self.system_app = system_app + self._storage_system = storage_system + + def init_app(self, system_app: SystemApp): + """Initialize the application.""" + self.system_app = system_app + + @property + def storage_system(self) -> FileStorageSystem: + """Get the file storage system.""" + if not self._storage_system: + raise ValueError("File storage system not initialized") + return self._storage_system + + def upload_file( + self, + bucket: str, + file_path: str, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Upload a file to the storage system. + + Args: + bucket (str): The bucket name + file_path (str): The file path + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + with open(file_path, "rb") as file: + return self.save_file( + bucket, os.path.basename(file_path), file, storage_type, custom_metadata + ) + + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage system. + + Args: + bucket (str): The bucket name + file_name (str): The file name + file_data (BinaryIO): The file data + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + return self.storage_system.save_file( + bucket, file_name, file_data, storage_type, custom_metadata + ) + + def download_file(self, uri: str, destination_path: str) -> None: + """Download a file from the storage system. + + Args: + uri (str): The file URI + destination_path (str): The destination + + Raises: + FileNotFoundError: If the file is not found + """ + file_data, _ = self.storage_system.get_file(uri) + with open(destination_path, "wb") as f: + f.write(file_data.read()) + + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + return self.storage_system.get_file(uri) + + def get_file_by_id( + self, bucket: str, file_id: str + ) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.get_file(metadata.uri) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + return self.storage_system.delete_file(uri) + + def delete_file_by_id(self, bucket: str, file_id: str) -> bool: + """Delete the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + bool: True if the file was deleted, False otherwise + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.delete_file(metadata.uri) + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket. + + Args: + bucket (str): The bucket name + filters (Dict[str, Any], optional): Filters. Defaults to None. + + Returns: + List[FileMetadata]: The list of file metadata + """ + return self.storage_system.list_files(bucket, filters) + + +class SimpleDistributedStorage(StorageBackend): + """Simple distributed storage backend.""" + + storage_type: str = "distributed" + + def __init__( + self, + node_address: str, + local_storage_path: str, + save_chunk_size: int = 1024 * 1024, + transfer_chunk_size: int = 1024 * 1024, + transfer_timeout: int = 360, + api_prefix: str = "/api/v2/serve/file/files", + ): + """Initialize the simple distributed storage backend.""" + self.node_address = node_address + self.local_storage_path = local_storage_path + os.makedirs(self.local_storage_path, exist_ok=True) + self._save_chunk_size = save_chunk_size + self._transfer_chunk_size = transfer_chunk_size + self._transfer_timeout = transfer_timeout + self._api_prefix = api_prefix + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str: + node_id = hashlib.md5(node_address.encode()).hexdigest() + return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}") + + def _parse_node_address(self, fm: FileMetadata) -> str: + storage_path = fm.storage_path + if not storage_path.startswith("distributed://"): + raise ValueError("Invalid storage path") + return storage_path.split("//")[1].split("/")[0] + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the distributed storage backend. + + Just save the file locally. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + file_path = self._get_file_path(bucket, file_id, self.node_address) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + + return f"distributed://{self.node_address}/{bucket}/{file_id}" + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the distributed storage backend. + + If the file is stored on the local node, load it from the local storage. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + + # TODO: check if the file is cached in local storage + if node_address == self.node_address: + if os.path.exists(file_path): + return open(file_path, "rb") # noqa: SIM115 + else: + raise FileNotFoundError(f"File {file_id} not found on the local node") + else: + response = requests.get( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + stream=True, + ) + response.raise_for_status() + # TODO: cache the file in local storage + return StreamedBytesIO( + response.iter_content(chunk_size=self._transfer_chunk_size) + ) + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the distributed storage backend. + + If the file is stored on the local node, delete it from the local storage. + If the file is stored on a remote node, send a delete request to the remote + node. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + if node_address == self.node_address: + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + else: + try: + response = requests.delete( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + ) + response.raise_for_status() + return True + except Exception: + return False + + +class StreamedBytesIO(io.BytesIO): + """A BytesIO subclass that can be used with streaming responses. + + Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13 + """ + + def __init__(self, request_iterator): + """Initialize the StreamedBytesIO instance.""" + super().__init__() + self._bytes = BytesIO() + self._iterator = request_iterator + + def _load_all(self): + self._bytes.seek(0, io.SEEK_END) + for chunk in self._iterator: + self._bytes.write(chunk) + + def _load_until(self, goal_position): + current_position = self._bytes.seek(0, io.SEEK_END) + while current_position < goal_position: + try: + current_position += self._bytes.write(next(self._iterator)) + except StopIteration: + break + + def tell(self) -> int: + """Get the current position.""" + return self._bytes.tell() + + def read(self, size: Optional[int] = None) -> bytes: + """Read the data from the stream. + + Args: + size (Optional[int], optional): The number of bytes to read. Defaults to + None. + + Returns: + bytes: The read data + """ + left_off_at = self._bytes.tell() + if size is None: + self._load_all() + else: + goal_position = left_off_at + size + self._load_until(goal_position) + + self._bytes.seek(left_off_at) + return self._bytes.read(size) + + def seek(self, position: int, whence: int = io.SEEK_SET): + """Seek to a position in the stream. + + Args: + position (int): The position + whence (int, optional): The reference point. Defaults to io.SEEK + + Raises: + ValueError: If the reference point is invalid + """ + if whence == io.SEEK_END: + self._load_all() + else: + self._bytes.seek(position, whence) + + def __enter__(self): + """Enter the context manager.""" + return self + + def __exit__(self, ext_type, value, tb): + """Exit the context manager.""" + self._bytes.close() diff --git a/dbgpt/core/interface/tests/test_file.py b/dbgpt/core/interface/tests/test_file.py new file mode 100644 index 000000000..f6e462944 --- /dev/null +++ b/dbgpt/core/interface/tests/test_file.py @@ -0,0 +1,506 @@ +import hashlib +import io +import os +from unittest import mock + +import pytest + +from ..file import ( + FileMetadata, + FileMetadataIdentifier, + FileStorageClient, + FileStorageSystem, + InMemoryStorage, + LocalFileStorage, + SimpleDistributedStorage, +) + + +@pytest.fixture +def temp_test_file_dir(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def temp_storage_path(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def local_storage_backend(temp_storage_path): + return LocalFileStorage(temp_storage_path) + + +@pytest.fixture +def distributed_storage_backend(temp_storage_path): + node_address = "127.0.0.1:8000" + return SimpleDistributedStorage(node_address, temp_storage_path) + + +@pytest.fixture +def file_storage_system(local_storage_backend): + backends = {"local": local_storage_backend} + metadata_storage = InMemoryStorage() + return FileStorageSystem(backends, metadata_storage) + + +@pytest.fixture +def file_storage_client(file_storage_system): + return FileStorageClient(storage_system=file_storage_system) + + +@pytest.fixture +def sample_file_path(temp_test_file_dir): + file_path = os.path.join(temp_test_file_dir, "sample.txt") + with open(file_path, "wb") as f: + f.write(b"Sample file content") + return file_path + + +@pytest.fixture +def sample_file_data(): + return io.BytesIO(b"Sample file content for distributed storage") + + +def test_save_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert uri.startswith("dbgpt-fs://local/test-bucket/") + assert os.path.exists(sample_file_path) + + +def test_get_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == b"Sample file content" + assert metadata.file_name == "sample.txt" + assert metadata.bucket == bucket + + +def test_delete_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + result = file_storage_client.delete_file(uri) + assert result is True + assert len(file_storage_client.list_files(bucket=bucket)) == 0 + + +def test_list_files(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri1 = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + files = file_storage_client.list_files(bucket=bucket) + assert len(files) == 1 + + +def test_save_file_unsupported_storage(file_storage_system, sample_file_path): + bucket = "test-bucket" + with pytest.raises(ValueError): + file_storage_system.save_file( + bucket=bucket, + file_name="unsupported.txt", + file_data=io.BytesIO(b"Unsupported storage"), + storage_type="unsupported", + ) + + +def test_get_file_not_found(file_storage_system): + with pytest.raises(FileNotFoundError): + file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent") + + +def test_delete_file_not_found(file_storage_system): + result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent") + assert result is False + + +def test_metadata_management(file_storage_system): + bucket = "test-bucket" + file_id = "test_file" + metadata = file_storage_system.metadata_storage.save( + FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=100, + storage_type="local", + storage_path="/path/to/test.txt", + uri="dbgpt-fs://local/test-bucket/test_file", + custom_metadata={"key": "value"}, + file_hash="hash", + ) + ) + + loaded_metadata = file_storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata + ) + assert loaded_metadata.file_name == "test.txt" + assert loaded_metadata.custom_metadata["key"] == "value" + assert loaded_metadata.bucket == bucket + + +def test_concurrent_save_and_delete(file_storage_client, sample_file_path): + bucket = "test-bucket" + + # Simulate concurrent file save and delete operations + def save_file(): + return file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + def delete_file(uri): + return file_storage_client.delete_file(uri) + + uri = save_file() + + # Simulate concurrent operations + save_file() + delete_file(uri) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + + +def test_large_file_handling(file_storage_client, temp_storage_path): + bucket = "test-bucket" + large_file_path = os.path.join(temp_storage_path, "large_sample.bin") + with open(large_file_path, "wb") as f: + f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file + + uri = file_storage_client.upload_file( + bucket=bucket, + file_path=large_file_path, + storage_type="local", + custom_metadata={"description": "Large file test"}, + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == open(large_file_path, "rb").read() + assert metadata.file_name == "large_sample.bin" + assert metadata.bucket == bucket + + +def test_file_hash_verification_success(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + file_data, metadata = file_storage_client.storage_system.get_file(uri) + file_hash = metadata.file_hash + calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data) + + assert ( + file_hash == calculated_hash + ), "File hash should match after saving and loading" + + +def test_file_hash_verification_failure(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + # Modify the file content manually to simulate file tampering + storage_system = file_storage_client.storage_system + metadata = storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata + ) + with open(metadata.storage_path, "wb") as f: + f.write(b"Tampered content") + + # Get file should raise an exception due to hash mismatch + with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."): + storage_system.get_file(uri) + + +def test_file_isolation_across_buckets(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Verify both URIs are different and point to different files + assert uri1 != uri2 + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + assert file_data1.read() == b"Sample file content" + assert file_data2.read() == b"Sample file content" + assert metadata1.bucket == bucket1 + assert metadata2.bucket == bucket2 + + +def test_list_files_in_specific_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload a file to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # List files in bucket1 and bucket2 + files_in_bucket1 = file_storage_client.list_files(bucket=bucket1) + files_in_bucket2 = file_storage_client.list_files(bucket=bucket2) + + assert len(files_in_bucket1) == 1 + assert len(files_in_bucket2) == 1 + assert files_in_bucket1[0].bucket == bucket1 + assert files_in_bucket2[0].bucket == bucket2 + + +def test_delete_file_in_one_bucket_does_not_affect_other_bucket( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete the file in bucket1 + file_storage_client.delete_file(uri1) + + # Check that the file in bucket1 is deleted + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Check that the file in bucket2 is still there + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + assert file_data2.read() == b"Sample file content" + + +def test_file_hash_verification_in_different_buckets( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + # Verify that file hashes are the same for the same content + file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1) + file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2) + + assert file_hash1 == metadata1.file_hash + assert file_hash2 == metadata2.file_hash + assert file_hash1 == file_hash2 + + +def test_file_download_from_different_buckets( + file_storage_client, sample_file_path, temp_storage_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Download files to different locations + download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt") + download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt") + + file_storage_client.download_file(uri1, download_path1) + file_storage_client.download_file(uri2, download_path2) + + # Verify contents of downloaded files + assert open(download_path1, "rb").read() == b"Sample file content" + assert open(download_path2, "rb").read() == b"Sample file content" + + +def test_delete_all_files_in_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload files to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete all files in bucket1 + for file in file_storage_client.list_files(bucket=bucket1): + file_storage_client.delete_file(file.uri) + + # Verify bucket1 is empty + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Verify bucket2 still has files + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + + +def test_simple_distributed_storage_save_file( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data) + + expected_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}" + assert os.path.exists(expected_path) + + +def test_simple_distributed_storage_load_file_local( + distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + + +@mock.patch("requests.get") +def test_simple_distributed_storage_load_file_remote( + mock_get, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + # Mock the response from remote node + mock_response = mock.Mock() + mock_response.iter_content = mock.Mock( + return_value=iter([b"Sample file content for distributed storage"]) + ) + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_get.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + mock_get.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + stream=True, + timeout=360, + ) + + +def test_simple_distributed_storage_delete_file_local( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + file_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert result is True + assert not os.path.exists(file_path) + + +@mock.patch("requests.delete") +def test_simple_distributed_storage_delete_file_remote( + mock_delete, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + mock_response = mock.Mock() + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_delete.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + assert result is True + mock_delete.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + timeout=360, + ) diff --git a/dbgpt/serve/file/__init__.py b/dbgpt/serve/file/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/__init__.py b/dbgpt/serve/file/api/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py new file mode 100644 index 000000000..edf1d2d98 --- /dev/null +++ b/dbgpt/serve/file/api/endpoints.py @@ -0,0 +1,159 @@ +import logging +from functools import cache +from typing import List, Optional +from urllib.parse import quote + +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import StreamingResponse + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result, blocking_func_to_async +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse, UploadFileResponse + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/files/{bucket}", + response_model=Result[List[UploadFileResponse]], + dependencies=[Depends(check_api_key)], +) +async def upload_files( + bucket: str, files: List[UploadFile], service: Service = Depends(get_service) +) -> Result[List[UploadFileResponse]]: + """Upload files by a list of UploadFile.""" + logger.info(f"upload_files: bucket={bucket}, files={files}") + results = await blocking_func_to_async( + global_system_app, service.upload_files, bucket, "distributed", files + ) + return Result.succ(results) + + +@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def download_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Download a file by file_id.""" + logger.info(f"download_file: bucket={bucket}, file_id={file_id}") + file_data, file_metadata = await blocking_func_to_async( + global_system_app, service.download_file, bucket, file_id + ) + file_name_encoded = quote(file_metadata.file_name) + + def file_iterator(raw_iter): + with raw_iter: + while chunk := raw_iter.read( + service.config.file_server_download_chunk_size + ): + yield chunk + + response = StreamingResponse( + file_iterator(file_data), media_type="application/octet-stream" + ) + response.headers[ + "Content-Disposition" + ] = f"attachment; filename={file_name_encoded}" + return response + + +@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def delete_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Delete a file by file_id.""" + await blocking_func_to_async( + global_system_app, service.delete_file, bucket, file_id + ) + return Result.succ(None) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py new file mode 100644 index 000000000..911f71db3 --- /dev/null +++ b/dbgpt/serve/file/api/schemas.py @@ -0,0 +1,43 @@ +# Define your Pydantic schemas here +from typing import Any, Dict + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """File request model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class ServerResponse(BaseModel): + """File response model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class UploadFileResponse(BaseModel): + """Upload file response model""" + + file_name: str = Field(..., title="The name of the uploaded file") + file_id: str = Field(..., title="The ID of the uploaded file") + bucket: str = Field(..., title="The bucket of the uploaded file") + uri: str = Field(..., title="The URI of the uploaded file") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/file/config.py b/dbgpt/serve/file/config.py new file mode 100644 index 000000000..1ab1afede --- /dev/null +++ b/dbgpt/serve/file/config.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "file" +SERVE_APP_NAME = "dbgpt_serve_file" +SERVE_APP_NAME_HUMP = "dbgpt_serve_File" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_file" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) + check_hash: Optional[bool] = field( + default=True, metadata={"help": "Check the hash of the file when downloading"} + ) + file_server_host: Optional[str] = field( + default=None, metadata={"help": "The host of the file server"} + ) + file_server_port: Optional[int] = field( + default=5670, metadata={"help": "The port of the file server"} + ) + file_server_download_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when downloading the file"}, + ) + file_server_save_chunk_size: Optional[int] = field( + default=1024 * 1024, metadata={"help": "The chunk size when saving the file"} + ) + file_server_transfer_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when transferring the file"}, + ) + file_server_transfer_timeout: Optional[int] = field( + default=360, metadata={"help": "The timeout when transferring the file"} + ) + local_storage_path: Optional[str] = field( + default=None, metadata={"help": "The local storage path"} + ) + + def get_node_address(self) -> str: + """Get the node address""" + file_server_host = self.file_server_host + if not file_server_host: + from dbgpt.util.net_utils import _get_ip_address + + file_server_host = _get_ip_address() + file_server_port = self.file_server_port or 5670 + return f"{file_server_host}:{file_server_port}" + + def get_local_storage_path(self) -> str: + """Get the local storage path""" + local_storage_path = self.local_storage_path + if not local_storage_path: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + local_storage_path = str(base_path) + return local_storage_path diff --git a/dbgpt/serve/file/dependencies.py b/dbgpt/serve/file/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/file/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/file/models/__init__.py b/dbgpt/serve/file/models/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/models/file_adapter.py b/dbgpt/serve/file/models/file_adapter.py new file mode 100644 index 000000000..a8ab36465 --- /dev/null +++ b/dbgpt/serve/file/models/file_adapter.py @@ -0,0 +1,66 @@ +import json +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier +from dbgpt.core.interface.storage import StorageItemAdapter + +from .models import ServeEntity + + +class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]): + """File metadata adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: FileMetadata) -> ServeEntity: + """Convert to storage format.""" + custom_metadata = ( + json.dumps(item.custom_metadata, ensure_ascii=False) + if item.custom_metadata + else None + ) + return ServeEntity( + bucket=item.bucket, + file_id=item.file_id, + file_name=item.file_name, + file_size=item.file_size, + storage_type=item.storage_type, + storage_path=item.storage_path, + uri=item.uri, + custom_metadata=custom_metadata, + file_hash=item.file_hash, + ) + + def from_storage_format(self, model: ServeEntity) -> FileMetadata: + """Convert from storage format.""" + custom_metadata = ( + json.loads(model.custom_metadata) if model.custom_metadata else None + ) + return FileMetadata( + bucket=model.bucket, + file_id=model.file_id, + file_name=model.file_name, + file_size=model.file_size, + storage_type=model.storage_type, + storage_path=model.storage_path, + uri=model.uri, + custom_metadata=custom_metadata, + file_hash=model.file_hash, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ServeEntity], + resource_id: FileMetadataIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return session.query(storage_format).filter( + storage_format.file_id == resource_id.file_id + ) diff --git a/dbgpt/serve/file/models/models.py b/dbgpt/serve/file/models/models.py new file mode 100644 index 000000000..62dd1ef80 --- /dev/null +++ b/dbgpt/serve/file/models/models.py @@ -0,0 +1,87 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" + +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="Auto increment id") + + bucket = Column(String(255), nullable=False, comment="Bucket name") + file_id = Column(String(255), nullable=False, comment="File id") + file_name = Column(String(256), nullable=False, comment="File name") + file_size = Column(Integer, nullable=True, comment="File size") + storage_type = Column(String(32), nullable=False, comment="Storage type") + storage_path = Column(String(512), nullable=False, comment="Storage path") + uri = Column(String(512), nullable=False, comment="File URI") + custom_metadata = Column( + Text, nullable=True, comment="Custom metadata, JSON format" + ) + file_hash = Column(String(128), nullable=True, comment="File hash") + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return ( + f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', " + f"gmt_modified='{self.gmt_modified}')" + ) + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for File""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/file/serve.py b/dbgpt/serve/file/serve.py new file mode 100644 index 000000000..559509573 --- /dev/null +++ b/dbgpt/serve/file/serve.py @@ -0,0 +1,113 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.core.interface.file import FileStorageClient +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + self._db_manager: Optional[DatabaseManager] = None + self._file_storage_client: Optional[FileStorageClient] = None + self._serve_config: Optional[ServeConfig] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + from dbgpt.core.interface.file import ( + FileStorageSystem, + SimpleDistributedStorage, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.file_adapter import FileMetadataAdapter + from .models.models import ServeEntity + + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + ServeEntity, + FileMetadataAdapter(), + serializer, + ) + simple_distributed_storage = SimpleDistributedStorage( + node_address=self._serve_config.get_node_address(), + local_storage_path=self._serve_config.get_local_storage_path(), + save_chunk_size=self._serve_config.file_server_save_chunk_size, + transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size, + transfer_timeout=self._serve_config.file_server_transfer_timeout, + ) + storage_backends = { + simple_distributed_storage.storage_type: simple_distributed_storage, + } + fs = FileStorageSystem( + storage_backends, + metadata_storage=storage, + check_hash=self._serve_config.check_hash, + ) + self._file_storage_client = FileStorageClient( + system_app=self._system_app, storage_system=fs + ) + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the file storage client.""" + if not self._file_storage_client: + raise ValueError("File storage client is not initialized") + return self._file_storage_client diff --git a/dbgpt/serve/file/service/__init__.py b/dbgpt/serve/file/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py new file mode 100644 index 000000000..d4d0118f3 --- /dev/null +++ b/dbgpt/serve/file/service/service.py @@ -0,0 +1,106 @@ +import logging +from typing import BinaryIO, List, Optional, Tuple + +from fastapi import UploadFile + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.tracer import root_tracer, trace + +from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for File""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the internal FileStorageClient. + + Returns: + FileStorageClient: The internal FileStorageClient + """ + file_storage_client = FileStorageClient.get_instance( + self._system_app, default_component=None + ) + if file_storage_client: + return file_storage_client + else: + from ..serve import Serve + + file_storage_client = Serve.get_instance( + self._system_app + ).file_storage_client + self._system_app.register_instance(file_storage_client) + return file_storage_client + + @trace("upload_files") + def upload_files( + self, bucket: str, storage_type: str, files: List[UploadFile] + ) -> List[UploadFileResponse]: + """Upload files by a list of UploadFile.""" + results = [] + for file in files: + file_name = file.filename + logger.info(f"Uploading file {file_name} to bucket {bucket}") + uri = self.file_storage_client.save_file( + bucket, file_name, file_data=file.file, storage_type=storage_type + ) + parsed_uri = FileStorageURI.parse(uri) + logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}") + results.append( + UploadFileResponse( + file_name=file_name, + file_id=parsed_uri.file_id, + bucket=bucket, + uri=uri, + ) + ) + return results + + @trace("download_file") + def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]: + """Download a file by file_id.""" + return self.file_storage_client.get_file_by_id(bucket, file_id) + + def delete_file(self, bucket: str, file_id: str) -> None: + """Delete a file by file_id.""" + self.file_storage_client.delete_file_by_id(bucket, file_id) diff --git a/dbgpt/serve/file/tests/__init__.py b/dbgpt/serve/file/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/tests/test_endpoints.py b/dbgpt/serve/file/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/file/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_models.py b/dbgpt/serve/file/tests/test_models.py new file mode 100644 index 000000000..8b66e9f97 --- /dev/null +++ b/dbgpt/serve/file/tests/test_models.py @@ -0,0 +1,99 @@ +import pytest + +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_service.py b/dbgpt/serve/file/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/file/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic