mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
834
dbgpt/core/interface/file.py
Normal file
834
dbgpt/core/interface/file.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""File storage interface."""
|
||||
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
from .storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SCHEMA = "dbgpt-fs"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileMetadataIdentifier(ResourceIdentifier):
|
||||
"""File metadata identifier."""
|
||||
|
||||
file_id: str
|
||||
bucket: str
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dictionary."""
|
||||
return {"file_id": self.file_id, "bucket": self.bucket}
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Get the string identifier.
|
||||
|
||||
Returns:
|
||||
str: The string identifier
|
||||
"""
|
||||
return f"{self.bucket}/{self.file_id}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileMetadata(StorageItem):
|
||||
"""File metadata for storage."""
|
||||
|
||||
file_id: str
|
||||
bucket: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
storage_type: str
|
||||
storage_path: str
|
||||
uri: str
|
||||
custom_metadata: Dict[str, Any]
|
||||
file_hash: str
|
||||
user_name: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = FileMetadataIdentifier(
|
||||
file_id=self.file_id, bucket=self.bucket
|
||||
)
|
||||
custom_metadata = self.custom_metadata or {}
|
||||
if not self.user_name:
|
||||
self.user_name = custom_metadata.get("user_name")
|
||||
if not self.sys_code:
|
||||
self.sys_code = custom_metadata.get("sys_code")
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Get the resource identifier."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the metadata with another item."""
|
||||
if not isinstance(other, FileMetadata):
|
||||
raise StorageError("Cannot merge different types of items")
|
||||
self._from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the metadata to a dictionary."""
|
||||
return {
|
||||
"file_id": self.file_id,
|
||||
"bucket": self.bucket,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"storage_type": self.storage_type,
|
||||
"storage_path": self.storage_path,
|
||||
"uri": self.uri,
|
||||
"custom_metadata": self.custom_metadata,
|
||||
"file_hash": self.file_hash,
|
||||
}
|
||||
|
||||
def _from_object(self, obj: "FileMetadata") -> None:
|
||||
self.file_id = obj.file_id
|
||||
self.bucket = obj.bucket
|
||||
self.file_name = obj.file_name
|
||||
self.file_size = obj.file_size
|
||||
self.storage_type = obj.storage_type
|
||||
self.storage_path = obj.storage_path
|
||||
self.uri = obj.uri
|
||||
self.custom_metadata = obj.custom_metadata
|
||||
self.file_hash = obj.file_hash
|
||||
self._identifier = obj._identifier
|
||||
|
||||
|
||||
class FileStorageURI:
|
||||
"""File storage URI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_type: str,
|
||||
bucket: str,
|
||||
file_id: str,
|
||||
version: Optional[str] = None,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize the file storage URI."""
|
||||
self.scheme = _SCHEMA
|
||||
self.storage_type = storage_type
|
||||
self.bucket = bucket
|
||||
self.file_id = file_id
|
||||
self.version = version
|
||||
self.custom_params = custom_params or {}
|
||||
|
||||
@classmethod
|
||||
def is_local_file(cls, uri: str) -> bool:
|
||||
"""Check if the URI is local."""
|
||||
parsed = urlparse(uri)
|
||||
if not parsed.scheme or parsed.scheme == "file":
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, uri: str) -> "FileStorageURI":
|
||||
"""Parse the URI string."""
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != _SCHEMA:
|
||||
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
|
||||
path_parts = parsed.path.strip("/").split("/")
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError("Invalid URI path. Must contain bucket and file ID")
|
||||
storage_type = parsed.netloc
|
||||
bucket = path_parts[0]
|
||||
file_id = path_parts[1]
|
||||
version = path_parts[2] if len(path_parts) > 2 else None
|
||||
custom_params = parse_qs(parsed.query)
|
||||
return cls(storage_type, bucket, file_id, version, custom_params)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get the string representation of the URI."""
|
||||
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
|
||||
if self.version:
|
||||
base_uri += f"/{self.version}"
|
||||
if self.custom_params:
|
||||
query_string = urlencode(self.custom_params, doseq=True)
|
||||
base_uri += f"?{query_string}"
|
||||
return base_uri
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Storage backend interface."""
|
||||
|
||||
storage_type: str = "__base__"
|
||||
|
||||
@abstractmethod
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the storage backend.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
|
||||
Returns:
|
||||
str: The storage path
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
BinaryIO: The file data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size.
|
||||
|
||||
Returns:
|
||||
int: The save chunk size
|
||||
"""
|
||||
|
||||
|
||||
class LocalFileStorage(StorageBackend):
|
||||
"""Local file storage backend."""
|
||||
|
||||
storage_type: str = "local"
|
||||
|
||||
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
|
||||
"""Initialize the local file storage backend."""
|
||||
self.base_path = base_path
|
||||
self._save_chunk_size = save_chunk_size
|
||||
os.makedirs(self.base_path, exist_ok=True)
|
||||
|
||||
@property
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size."""
|
||||
return self._save_chunk_size
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, bucket)
|
||||
os.makedirs(bucket_path, exist_ok=True)
|
||||
file_path = os.path.join(bucket_path, file_id)
|
||||
with open(file_path, "wb") as f:
|
||||
while True:
|
||||
chunk = file_data.read(self.save_chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
return file_path
|
||||
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, fm.bucket)
|
||||
file_path = os.path.join(bucket_path, fm.file_id)
|
||||
return open(file_path, "rb") # noqa: SIM115
|
||||
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the local storage backend."""
|
||||
bucket_path = os.path.join(self.base_path, fm.bucket)
|
||||
file_path = os.path.join(bucket_path, fm.file_id)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FileStorageSystem:
|
||||
"""File storage system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_backends: Dict[str, StorageBackend],
|
||||
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
|
||||
check_hash: bool = True,
|
||||
):
|
||||
"""Initialize the file storage system."""
|
||||
metadata_storage = metadata_storage or InMemoryStorage()
|
||||
self.storage_backends = storage_backends
|
||||
self.metadata_storage = metadata_storage
|
||||
self.check_hash = check_hash
|
||||
self._save_chunk_size = min(
|
||||
backend.save_chunk_size for backend in storage_backends.values()
|
||||
)
|
||||
|
||||
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
|
||||
"""Calculate the MD5 hash of the file data."""
|
||||
if not self.check_hash:
|
||||
return "-1"
|
||||
hasher = hashlib.md5()
|
||||
file_data.seek(0)
|
||||
while chunk := file_data.read(self._save_chunk_size):
|
||||
hasher.update(chunk)
|
||||
file_data.seek(0)
|
||||
return hasher.hexdigest()
|
||||
|
||||
@trace("file_storage_system.save_file")
|
||||
def save_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_name: str,
|
||||
file_data: BinaryIO,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save the file data to the storage backend."""
|
||||
file_id = str(uuid.uuid4())
|
||||
backend = self.storage_backends.get(storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {storage_type}")
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.save_file.backend_save",
|
||||
metadata={
|
||||
"bucket": bucket,
|
||||
"file_id": file_id,
|
||||
"file_name": file_name,
|
||||
"storage_type": storage_type,
|
||||
},
|
||||
):
|
||||
storage_path = backend.save(bucket, file_id, file_data)
|
||||
file_data.seek(0, 2) # Move to the end of the file
|
||||
file_size = file_data.tell() # Get the file size
|
||||
file_data.seek(0) # Reset file pointer
|
||||
|
||||
# filter None value
|
||||
custom_metadata = (
|
||||
{k: v for k, v in custom_metadata.items() if v is not None}
|
||||
if custom_metadata
|
||||
else {}
|
||||
)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.save_file.calculate_hash",
|
||||
):
|
||||
file_hash = self._calculate_file_hash(file_data)
|
||||
uri = FileStorageURI(
|
||||
storage_type, bucket, file_id, custom_params=custom_metadata
|
||||
)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
storage_type=storage_type,
|
||||
storage_path=storage_path,
|
||||
uri=str(uri),
|
||||
custom_metadata=custom_metadata,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
|
||||
self.metadata_storage.save(metadata)
|
||||
return str(uri)
|
||||
|
||||
@trace("file_storage_system.get_file")
|
||||
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage backend."""
|
||||
if FileStorageURI.is_local_file(uri):
|
||||
local_file_name = uri.split("/")[-1]
|
||||
if not os.path.exists(uri):
|
||||
raise FileNotFoundError(f"File not found: {uri}")
|
||||
|
||||
dummy_metadata = FileMetadata(
|
||||
file_id=local_file_name,
|
||||
bucket="dummy_bucket",
|
||||
file_name=local_file_name,
|
||||
file_size=-1,
|
||||
storage_type="local",
|
||||
storage_path=uri,
|
||||
uri=uri,
|
||||
custom_metadata={},
|
||||
file_hash="",
|
||||
)
|
||||
logger.info(f"Reading local file: {uri}")
|
||||
return open(uri, "rb"), dummy_metadata # noqa: SIM115
|
||||
|
||||
parsed_uri = FileStorageURI.parse(uri)
|
||||
metadata = self.metadata_storage.load(
|
||||
FileMetadataIdentifier(
|
||||
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
||||
),
|
||||
FileMetadata,
|
||||
)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"No metadata found for URI: {uri}")
|
||||
|
||||
backend = self.storage_backends.get(metadata.storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.get_file.backend_load",
|
||||
metadata={
|
||||
"bucket": metadata.bucket,
|
||||
"file_id": metadata.file_id,
|
||||
"file_name": metadata.file_name,
|
||||
"storage_type": metadata.storage_type,
|
||||
},
|
||||
):
|
||||
file_data = backend.load(metadata)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"file_storage_system.get_file.verify_hash",
|
||||
):
|
||||
calculated_hash = self._calculate_file_hash(file_data)
|
||||
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
|
||||
raise ValueError("File integrity check failed. Hash mismatch.")
|
||||
|
||||
return file_data, metadata
|
||||
|
||||
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
|
||||
"""Get the file metadata.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
Optional[FileMetadata]: The file metadata
|
||||
"""
|
||||
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
|
||||
return self.metadata_storage.load(fid, FileMetadata)
|
||||
|
||||
def delete_file(self, uri: str) -> bool:
|
||||
"""Delete the file data from the storage backend.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
parsed_uri = FileStorageURI.parse(uri)
|
||||
fid = FileMetadataIdentifier(
|
||||
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
|
||||
)
|
||||
metadata = self.metadata_storage.load(fid, FileMetadata)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
backend = self.storage_backends.get(metadata.storage_type)
|
||||
if not backend:
|
||||
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
|
||||
|
||||
if backend.delete(metadata):
|
||||
try:
|
||||
self.metadata_storage.delete(fid)
|
||||
return True
|
||||
except Exception:
|
||||
# If the metadata deletion fails, log the error and return False
|
||||
return False
|
||||
return False
|
||||
|
||||
def list_files(
|
||||
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[FileMetadata]:
|
||||
"""List the files in the bucket."""
|
||||
filters = filters or {}
|
||||
filters["bucket"] = bucket
|
||||
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
|
||||
|
||||
|
||||
class FileStorageClient(BaseComponent):
|
||||
"""File storage client component."""
|
||||
|
||||
name = ComponentType.FILE_STORAGE_CLIENT.value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
storage_system: Optional[FileStorageSystem] = None,
|
||||
):
|
||||
"""Initialize the file storage client."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not storage_system:
|
||||
from pathlib import Path
|
||||
|
||||
base_path = Path.home() / ".cache" / "dbgpt" / "files"
|
||||
storage_system = FileStorageSystem(
|
||||
{
|
||||
LocalFileStorage.storage_type: LocalFileStorage(
|
||||
base_path=str(base_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.system_app = system_app
|
||||
self._storage_system = storage_system
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the application."""
|
||||
self.system_app = system_app
|
||||
|
||||
@property
|
||||
def storage_system(self) -> FileStorageSystem:
|
||||
"""Get the file storage system."""
|
||||
if not self._storage_system:
|
||||
raise ValueError("File storage system not initialized")
|
||||
return self._storage_system
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_path: str,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Upload a file to the storage system.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_path (str): The file path
|
||||
storage_type (str): The storage type
|
||||
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
str: The file URI
|
||||
"""
|
||||
with open(file_path, "rb") as file:
|
||||
return self.save_file(
|
||||
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
|
||||
)
|
||||
|
||||
def save_file(
|
||||
self,
|
||||
bucket: str,
|
||||
file_name: str,
|
||||
file_data: BinaryIO,
|
||||
storage_type: str,
|
||||
custom_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save the file data to the storage system.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_name (str): The file name
|
||||
file_data (BinaryIO): The file data
|
||||
storage_type (str): The storage type
|
||||
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
str: The file URI
|
||||
"""
|
||||
return self.storage_system.save_file(
|
||||
bucket, file_name, file_data, storage_type, custom_metadata
|
||||
)
|
||||
|
||||
def download_file(self, uri: str, destination_path: str) -> None:
|
||||
"""Download a file from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
destination_path (str): The destination
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file is not found
|
||||
"""
|
||||
file_data, _ = self.storage_system.get_file(uri)
|
||||
with open(destination_path, "wb") as f:
|
||||
f.write(file_data.read())
|
||||
|
||||
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
||||
"""
|
||||
return self.storage_system.get_file(uri)
|
||||
|
||||
def get_file_by_id(
|
||||
self, bucket: str, file_id: str
|
||||
) -> Tuple[BinaryIO, FileMetadata]:
|
||||
"""Get the file data from the storage system by ID.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
Tuple[BinaryIO, FileMetadata]: The file data and metadata
|
||||
"""
|
||||
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
||||
return self.get_file(metadata.uri)
|
||||
|
||||
def delete_file(self, uri: str) -> bool:
|
||||
"""Delete the file data from the storage system.
|
||||
|
||||
Args:
|
||||
uri (str): The file URI
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
return self.storage_system.delete_file(uri)
|
||||
|
||||
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
|
||||
"""Delete the file data from the storage system by ID.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
metadata = self.storage_system.get_file_metadata(bucket, file_id)
|
||||
if not metadata:
|
||||
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
|
||||
return self.delete_file(metadata.uri)
|
||||
|
||||
def list_files(
|
||||
self, bucket: str, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[FileMetadata]:
|
||||
"""List the files in the bucket.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
filters (Dict[str, Any], optional): Filters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[FileMetadata]: The list of file metadata
|
||||
"""
|
||||
return self.storage_system.list_files(bucket, filters)
|
||||
|
||||
|
||||
class SimpleDistributedStorage(StorageBackend):
|
||||
"""Simple distributed storage backend."""
|
||||
|
||||
storage_type: str = "distributed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_address: str,
|
||||
local_storage_path: str,
|
||||
save_chunk_size: int = 1024 * 1024,
|
||||
transfer_chunk_size: int = 1024 * 1024,
|
||||
transfer_timeout: int = 360,
|
||||
api_prefix: str = "/api/v2/serve/file/files",
|
||||
):
|
||||
"""Initialize the simple distributed storage backend."""
|
||||
self.node_address = node_address
|
||||
self.local_storage_path = local_storage_path
|
||||
os.makedirs(self.local_storage_path, exist_ok=True)
|
||||
self._save_chunk_size = save_chunk_size
|
||||
self._transfer_chunk_size = transfer_chunk_size
|
||||
self._transfer_timeout = transfer_timeout
|
||||
self._api_prefix = api_prefix
|
||||
|
||||
@property
|
||||
def save_chunk_size(self) -> int:
|
||||
"""Get the save chunk size."""
|
||||
return self._save_chunk_size
|
||||
|
||||
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
|
||||
node_id = hashlib.md5(node_address.encode()).hexdigest()
|
||||
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
|
||||
|
||||
def _parse_node_address(self, fm: FileMetadata) -> str:
|
||||
storage_path = fm.storage_path
|
||||
if not storage_path.startswith("distributed://"):
|
||||
raise ValueError("Invalid storage path")
|
||||
return storage_path.split("//")[1].split("/")[0]
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
"""Save the file data to the distributed storage backend.
|
||||
|
||||
Just save the file locally.
|
||||
|
||||
Args:
|
||||
bucket (str): The bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
|
||||
Returns:
|
||||
str: The storage path
|
||||
"""
|
||||
file_path = self._get_file_path(bucket, file_id, self.node_address)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "wb") as f:
|
||||
while True:
|
||||
chunk = file_data.read(self.save_chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
return f"distributed://{self.node_address}/{bucket}/{file_id}"
|
||||
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from the distributed storage backend.
|
||||
|
||||
If the file is stored on the local node, load it from the local storage.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
BinaryIO: The file data
|
||||
"""
|
||||
file_id = fm.file_id
|
||||
bucket = fm.bucket
|
||||
node_address = self._parse_node_address(fm)
|
||||
file_path = self._get_file_path(bucket, file_id, node_address)
|
||||
|
||||
# TODO: check if the file is cached in local storage
|
||||
if node_address == self.node_address:
|
||||
if os.path.exists(file_path):
|
||||
return open(file_path, "rb") # noqa: SIM115
|
||||
else:
|
||||
raise FileNotFoundError(f"File {file_id} not found on the local node")
|
||||
else:
|
||||
response = requests.get(
|
||||
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
||||
timeout=self._transfer_timeout,
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# TODO: cache the file in local storage
|
||||
return StreamedBytesIO(
|
||||
response.iter_content(chunk_size=self._transfer_chunk_size)
|
||||
)
|
||||
|
||||
def delete(self, fm: FileMetadata) -> bool:
|
||||
"""Delete the file data from the distributed storage backend.
|
||||
|
||||
If the file is stored on the local node, delete it from the local storage.
|
||||
If the file is stored on a remote node, send a delete request to the remote
|
||||
node.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
|
||||
Returns:
|
||||
bool: True if the file was deleted, False otherwise
|
||||
"""
|
||||
file_id = fm.file_id
|
||||
bucket = fm.bucket
|
||||
node_address = self._parse_node_address(fm)
|
||||
file_path = self._get_file_path(bucket, file_id, node_address)
|
||||
if node_address == self.node_address:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
response = requests.delete(
|
||||
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
|
||||
timeout=self._transfer_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class StreamedBytesIO(io.BytesIO):
|
||||
"""A BytesIO subclass that can be used with streaming responses.
|
||||
|
||||
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
|
||||
"""
|
||||
|
||||
def __init__(self, request_iterator):
|
||||
"""Initialize the StreamedBytesIO instance."""
|
||||
super().__init__()
|
||||
self._bytes = BytesIO()
|
||||
self._iterator = request_iterator
|
||||
|
||||
def _load_all(self):
|
||||
self._bytes.seek(0, io.SEEK_END)
|
||||
for chunk in self._iterator:
|
||||
self._bytes.write(chunk)
|
||||
|
||||
def _load_until(self, goal_position):
|
||||
current_position = self._bytes.seek(0, io.SEEK_END)
|
||||
while current_position < goal_position:
|
||||
try:
|
||||
current_position += self._bytes.write(next(self._iterator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Get the current position."""
|
||||
return self._bytes.tell()
|
||||
|
||||
def read(self, size: Optional[int] = None) -> bytes:
|
||||
"""Read the data from the stream.
|
||||
|
||||
Args:
|
||||
size (Optional[int], optional): The number of bytes to read. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
bytes: The read data
|
||||
"""
|
||||
left_off_at = self._bytes.tell()
|
||||
if size is None:
|
||||
self._load_all()
|
||||
else:
|
||||
goal_position = left_off_at + size
|
||||
self._load_until(goal_position)
|
||||
|
||||
self._bytes.seek(left_off_at)
|
||||
return self._bytes.read(size)
|
||||
|
||||
def seek(self, position: int, whence: int = io.SEEK_SET):
|
||||
"""Seek to a position in the stream.
|
||||
|
||||
Args:
|
||||
position (int): The position
|
||||
whence (int, optional): The reference point. Defaults to io.SEEK
|
||||
|
||||
Raises:
|
||||
ValueError: If the reference point is invalid
|
||||
"""
|
||||
if whence == io.SEEK_END:
|
||||
self._load_all()
|
||||
else:
|
||||
self._bytes.seek(position, whence)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, ext_type, value, tb):
|
||||
"""Exit the context manager."""
|
||||
self._bytes.close()
|
@@ -24,6 +24,7 @@ from dbgpt.core.awel.flow import (
|
||||
OperatorType,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
LLMClient,
|
||||
@@ -69,6 +70,10 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]):
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The temperature of the model request."),
|
||||
ui=ui.UISlider(
|
||||
show_input=True,
|
||||
attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1),
|
||||
),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Max New Tokens"),
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""The prompt operator."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -18,6 +19,7 @@ from dbgpt.core.awel.flow import (
|
||||
ResourceCategory,
|
||||
ViewMetadata,
|
||||
register_resource,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.interface.message import BaseMessage
|
||||
from dbgpt.core.interface.operators.llm_operator import BaseLLM
|
||||
@@ -48,6 +50,7 @@ from dbgpt.util.i18n_utils import _
|
||||
optional=True,
|
||||
default="You are a helpful AI Assistant.",
|
||||
description=_("The system message."),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
),
|
||||
Parameter.build_from(
|
||||
label=_("Message placeholder"),
|
||||
@@ -65,6 +68,7 @@ from dbgpt.util.i18n_utils import _
|
||||
default="{user_input}",
|
||||
placeholder="{user_input}",
|
||||
description=_("The human message."),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@@ -3,13 +3,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from ..awel.flow import Parameter, ResourceCategory, register_resource
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class ResourceIdentifier(Serializable, ABC):
|
||||
|
506
dbgpt/core/interface/tests/test_file.py
Normal file
506
dbgpt/core/interface/tests/test_file.py
Normal file
@@ -0,0 +1,506 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from ..file import (
|
||||
FileMetadata,
|
||||
FileMetadataIdentifier,
|
||||
FileStorageClient,
|
||||
FileStorageSystem,
|
||||
InMemoryStorage,
|
||||
LocalFileStorage,
|
||||
SimpleDistributedStorage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_test_file_dir(tmpdir):
|
||||
return str(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_path(tmpdir):
|
||||
return str(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_storage_backend(temp_storage_path):
|
||||
return LocalFileStorage(temp_storage_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_storage_backend(temp_storage_path):
|
||||
node_address = "127.0.0.1:8000"
|
||||
return SimpleDistributedStorage(node_address, temp_storage_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_system(local_storage_backend):
|
||||
backends = {"local": local_storage_backend}
|
||||
metadata_storage = InMemoryStorage()
|
||||
return FileStorageSystem(backends, metadata_storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_storage_client(file_storage_system):
|
||||
return FileStorageClient(storage_system=file_storage_system)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_path(temp_test_file_dir):
|
||||
file_path = os.path.join(temp_test_file_dir, "sample.txt")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(b"Sample file content")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_data():
|
||||
return io.BytesIO(b"Sample file content for distributed storage")
|
||||
|
||||
|
||||
def test_save_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
assert uri.startswith("dbgpt-fs://local/test-bucket/")
|
||||
assert os.path.exists(sample_file_path)
|
||||
|
||||
|
||||
def test_get_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
assert file_data.read() == b"Sample file content"
|
||||
assert metadata.file_name == "sample.txt"
|
||||
assert metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_delete_file(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 1
|
||||
result = file_storage_client.delete_file(uri)
|
||||
assert result is True
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 0
|
||||
|
||||
|
||||
def test_list_files(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
files = file_storage_client.list_files(bucket=bucket)
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
def test_save_file_unsupported_storage(file_storage_system, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
with pytest.raises(ValueError):
|
||||
file_storage_system.save_file(
|
||||
bucket=bucket,
|
||||
file_name="unsupported.txt",
|
||||
file_data=io.BytesIO(b"Unsupported storage"),
|
||||
storage_type="unsupported",
|
||||
)
|
||||
|
||||
|
||||
def test_get_file_not_found(file_storage_system):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent")
|
||||
|
||||
|
||||
def test_delete_file_not_found(file_storage_system):
|
||||
result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_metadata_management(file_storage_system):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
metadata = file_storage_system.metadata_storage.save(
|
||||
FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=100,
|
||||
storage_type="local",
|
||||
storage_path="/path/to/test.txt",
|
||||
uri="dbgpt-fs://local/test-bucket/test_file",
|
||||
custom_metadata={"key": "value"},
|
||||
file_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
loaded_metadata = file_storage_system.metadata_storage.load(
|
||||
FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata
|
||||
)
|
||||
assert loaded_metadata.file_name == "test.txt"
|
||||
assert loaded_metadata.custom_metadata["key"] == "value"
|
||||
assert loaded_metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_concurrent_save_and_delete(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
|
||||
# Simulate concurrent file save and delete operations
|
||||
def save_file():
|
||||
return file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
def delete_file(uri):
|
||||
return file_storage_client.delete_file(uri)
|
||||
|
||||
uri = save_file()
|
||||
|
||||
# Simulate concurrent operations
|
||||
save_file()
|
||||
delete_file(uri)
|
||||
assert len(file_storage_client.list_files(bucket=bucket)) == 1
|
||||
|
||||
|
||||
def test_large_file_handling(file_storage_client, temp_storage_path):
|
||||
bucket = "test-bucket"
|
||||
large_file_path = os.path.join(temp_storage_path, "large_sample.bin")
|
||||
with open(large_file_path, "wb") as f:
|
||||
f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file
|
||||
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket,
|
||||
file_path=large_file_path,
|
||||
storage_type="local",
|
||||
custom_metadata={"description": "Large file test"},
|
||||
)
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
assert file_data.read() == open(large_file_path, "rb").read()
|
||||
assert metadata.file_name == "large_sample.bin"
|
||||
assert metadata.bucket == bucket
|
||||
|
||||
|
||||
def test_file_hash_verification_success(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
# Upload file and
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
file_data, metadata = file_storage_client.storage_system.get_file(uri)
|
||||
file_hash = metadata.file_hash
|
||||
calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data)
|
||||
|
||||
assert (
|
||||
file_hash == calculated_hash
|
||||
), "File hash should match after saving and loading"
|
||||
|
||||
|
||||
def test_file_hash_verification_failure(file_storage_client, sample_file_path):
|
||||
bucket = "test-bucket"
|
||||
# Upload file and
|
||||
uri = file_storage_client.upload_file(
|
||||
bucket=bucket, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Modify the file content manually to simulate file tampering
|
||||
storage_system = file_storage_client.storage_system
|
||||
metadata = storage_system.metadata_storage.load(
|
||||
FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata
|
||||
)
|
||||
with open(metadata.storage_path, "wb") as f:
|
||||
f.write(b"Tampered content")
|
||||
|
||||
# Get file should raise an exception due to hash mismatch
|
||||
with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."):
|
||||
storage_system.get_file(uri)
|
||||
|
||||
|
||||
def test_file_isolation_across_buckets(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the same file to two different buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Verify both URIs are different and point to different files
|
||||
assert uri1 != uri2
|
||||
|
||||
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
|
||||
assert file_data1.read() == b"Sample file content"
|
||||
assert file_data2.read() == b"Sample file content"
|
||||
assert metadata1.bucket == bucket1
|
||||
assert metadata2.bucket == bucket2
|
||||
|
||||
|
||||
def test_list_files_in_specific_bucket(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload a file to both buckets
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# List files in bucket1 and bucket2
|
||||
files_in_bucket1 = file_storage_client.list_files(bucket=bucket1)
|
||||
files_in_bucket2 = file_storage_client.list_files(bucket=bucket2)
|
||||
|
||||
assert len(files_in_bucket1) == 1
|
||||
assert len(files_in_bucket2) == 1
|
||||
assert files_in_bucket1[0].bucket == bucket1
|
||||
assert files_in_bucket2[0].bucket == bucket2
|
||||
|
||||
|
||||
def test_delete_file_in_one_bucket_does_not_affect_other_bucket(
|
||||
file_storage_client, sample_file_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the same file to two different buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Delete the file in bucket1
|
||||
file_storage_client.delete_file(uri1)
|
||||
|
||||
# Check that the file in bucket1 is deleted
|
||||
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
|
||||
|
||||
# Check that the file in bucket2 is still there
|
||||
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
assert file_data2.read() == b"Sample file content"
|
||||
|
||||
|
||||
def test_file_hash_verification_in_different_buckets(
|
||||
file_storage_client, sample_file_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the file to both buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
|
||||
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
|
||||
|
||||
# Verify that file hashes are the same for the same content
|
||||
file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1)
|
||||
file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2)
|
||||
|
||||
assert file_hash1 == metadata1.file_hash
|
||||
assert file_hash2 == metadata2.file_hash
|
||||
assert file_hash1 == file_hash2
|
||||
|
||||
|
||||
def test_file_download_from_different_buckets(
|
||||
file_storage_client, sample_file_path, temp_storage_path
|
||||
):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload the file to both buckets
|
||||
uri1 = file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
uri2 = file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Download files to different locations
|
||||
download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt")
|
||||
download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt")
|
||||
|
||||
file_storage_client.download_file(uri1, download_path1)
|
||||
file_storage_client.download_file(uri2, download_path2)
|
||||
|
||||
# Verify contents of downloaded files
|
||||
assert open(download_path1, "rb").read() == b"Sample file content"
|
||||
assert open(download_path2, "rb").read() == b"Sample file content"
|
||||
|
||||
|
||||
def test_delete_all_files_in_bucket(file_storage_client, sample_file_path):
|
||||
bucket1 = "bucket1"
|
||||
bucket2 = "bucket2"
|
||||
|
||||
# Upload files to both buckets
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket1, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
file_storage_client.upload_file(
|
||||
bucket=bucket2, file_path=sample_file_path, storage_type="local"
|
||||
)
|
||||
|
||||
# Delete all files in bucket1
|
||||
for file in file_storage_client.list_files(bucket=bucket1):
|
||||
file_storage_client.delete_file(file.uri)
|
||||
|
||||
# Verify bucket1 is empty
|
||||
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
|
||||
|
||||
# Verify bucket2 still has files
|
||||
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
|
||||
|
||||
|
||||
def test_simple_distributed_storage_save_file(
|
||||
distributed_storage_backend, sample_file_data, temp_storage_path
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
expected_path = os.path.join(
|
||||
temp_storage_path,
|
||||
bucket,
|
||||
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
|
||||
)
|
||||
assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}"
|
||||
assert os.path.exists(expected_path)
|
||||
|
||||
|
||||
def test_simple_distributed_storage_load_file_local(
|
||||
distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
file_data = distributed_storage_backend.load(metadata)
|
||||
assert file_data.read() == b"Sample file content for distributed storage"
|
||||
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_simple_distributed_storage_load_file_remote(
|
||||
mock_get, distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
remote_node_address = "127.0.0.2:8000"
|
||||
|
||||
# Mock the response from remote node
|
||||
mock_response = mock.Mock()
|
||||
mock_response.iter_content = mock.Mock(
|
||||
return_value=iter([b"Sample file content for distributed storage"])
|
||||
)
|
||||
mock_response.raise_for_status = mock.Mock(return_value=None)
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
file_data = distributed_storage_backend.load(metadata)
|
||||
assert file_data.read() == b"Sample file content for distributed storage"
|
||||
mock_get.assert_called_once_with(
|
||||
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
|
||||
stream=True,
|
||||
timeout=360,
|
||||
)
|
||||
|
||||
|
||||
def test_simple_distributed_storage_delete_file_local(
|
||||
distributed_storage_backend, sample_file_data, temp_storage_path
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
distributed_storage_backend.save(bucket, file_id, sample_file_data)
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
result = distributed_storage_backend.delete(metadata)
|
||||
file_path = os.path.join(
|
||||
temp_storage_path,
|
||||
bucket,
|
||||
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
|
||||
)
|
||||
assert result is True
|
||||
assert not os.path.exists(file_path)
|
||||
|
||||
|
||||
@mock.patch("requests.delete")
|
||||
def test_simple_distributed_storage_delete_file_remote(
|
||||
mock_delete, distributed_storage_backend, sample_file_data
|
||||
):
|
||||
bucket = "test-bucket"
|
||||
file_id = "test_file"
|
||||
remote_node_address = "127.0.0.2:8000"
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.raise_for_status = mock.Mock(return_value=None)
|
||||
mock_delete.return_value = mock_response
|
||||
|
||||
metadata = FileMetadata(
|
||||
file_id=file_id,
|
||||
bucket=bucket,
|
||||
file_name="test.txt",
|
||||
file_size=len(sample_file_data.getvalue()),
|
||||
storage_type="distributed",
|
||||
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
|
||||
custom_metadata={},
|
||||
file_hash="hash",
|
||||
)
|
||||
|
||||
result = distributed_storage_backend.delete(metadata)
|
||||
assert result is True
|
||||
mock_delete.assert_called_once_with(
|
||||
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
|
||||
timeout=360,
|
||||
)
|
327
dbgpt/core/interface/tests/test_variables.py
Normal file
327
dbgpt/core/interface/tests/test_variables.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import base64
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from ..variables import (
|
||||
FernetEncryption,
|
||||
InMemoryStorage,
|
||||
SimpleEncryption,
|
||||
StorageVariables,
|
||||
StorageVariablesProvider,
|
||||
VariablesIdentifier,
|
||||
build_variable_string,
|
||||
parse_variable,
|
||||
)
|
||||
|
||||
|
||||
def test_fernet_encryption():
|
||||
key = Fernet.generate_key()
|
||||
encryption = FernetEncryption(key)
|
||||
new_encryption = FernetEncryption(key)
|
||||
data = "test_data"
|
||||
salt = "test_salt"
|
||||
|
||||
encrypted_data = encryption.encrypt(data, salt)
|
||||
assert encrypted_data != data
|
||||
|
||||
decrypted_data = encryption.decrypt(encrypted_data, salt)
|
||||
assert decrypted_data == data
|
||||
assert decrypted_data == new_encryption.decrypt(encrypted_data, salt)
|
||||
|
||||
|
||||
def test_simple_encryption():
|
||||
key = base64.b64encode(os.urandom(32)).decode()
|
||||
encryption = SimpleEncryption(key)
|
||||
data = "test_data"
|
||||
salt = "test_salt"
|
||||
|
||||
encrypted_data = encryption.encrypt(data, salt)
|
||||
assert encrypted_data != data
|
||||
|
||||
decrypted_data = encryption.decrypt(encrypted_data, salt)
|
||||
assert decrypted_data == data
|
||||
|
||||
|
||||
def test_storage_variables_provider():
|
||||
storage = InMemoryStorage()
|
||||
encryption = SimpleEncryption()
|
||||
provider = StorageVariablesProvider(storage, encryption)
|
||||
|
||||
full_key = "${key:name@global}"
|
||||
value = "secret_value"
|
||||
value_type = "str"
|
||||
label = "test_label"
|
||||
|
||||
id = VariablesIdentifier.from_str_identifier(full_key)
|
||||
provider.save(
|
||||
StorageVariables.from_identifier(
|
||||
id, value, value_type, label, category="secret"
|
||||
)
|
||||
)
|
||||
|
||||
loaded_variable_value = provider.get(full_key)
|
||||
assert loaded_variable_value == value
|
||||
|
||||
|
||||
def test_variables_identifier():
|
||||
full_key = "${key:name@global:scope_key#sys_code%user_name}"
|
||||
identifier = VariablesIdentifier.from_str_identifier(full_key)
|
||||
|
||||
assert identifier.key == "key"
|
||||
assert identifier.name == "name"
|
||||
assert identifier.scope == "global"
|
||||
assert identifier.scope_key == "scope_key"
|
||||
assert identifier.sys_code == "sys_code"
|
||||
assert identifier.user_name == "user_name"
|
||||
|
||||
str_identifier = identifier.str_identifier
|
||||
assert str_identifier == full_key
|
||||
|
||||
|
||||
def test_storage_variables():
|
||||
key = "test_key"
|
||||
name = "test_name"
|
||||
label = "test_label"
|
||||
value = "test_value"
|
||||
value_type = "str"
|
||||
category = "common"
|
||||
scope = "global"
|
||||
|
||||
storage_variable = StorageVariables(
|
||||
key=key,
|
||||
name=name,
|
||||
label=label,
|
||||
value=value,
|
||||
value_type=value_type,
|
||||
category=category,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
assert storage_variable.key == key
|
||||
assert storage_variable.name == name
|
||||
assert storage_variable.label == label
|
||||
assert storage_variable.value == value
|
||||
assert storage_variable.value_type == value_type
|
||||
assert storage_variable.category == category
|
||||
assert storage_variable.scope == scope
|
||||
|
||||
dict_representation = storage_variable.to_dict()
|
||||
assert dict_representation["key"] == key
|
||||
assert dict_representation["name"] == name
|
||||
assert dict_representation["label"] == label
|
||||
assert dict_representation["value"] == value
|
||||
assert dict_representation["value_type"] == value_type
|
||||
assert dict_representation["category"] == category
|
||||
assert dict_representation["scope"] == scope
|
||||
|
||||
|
||||
def generate_test_cases(enable_escape=False):
|
||||
# Define possible values for each field, including special characters for escaping
|
||||
_EMPTY_ = "___EMPTY___"
|
||||
fields = {
|
||||
"name": [
|
||||
None,
|
||||
"test_name",
|
||||
"test:name" if enable_escape else _EMPTY_,
|
||||
"test::name" if enable_escape else _EMPTY_,
|
||||
"test#name" if enable_escape else _EMPTY_,
|
||||
"test##name" if enable_escape else _EMPTY_,
|
||||
"test::@@@#22name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope": [
|
||||
None,
|
||||
"test_scope",
|
||||
"test@scope" if enable_escape else _EMPTY_,
|
||||
"test@@scope" if enable_escape else _EMPTY_,
|
||||
"test:scope" if enable_escape else _EMPTY_,
|
||||
"test:#:scope" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope_key": [
|
||||
None,
|
||||
"test_scope_key",
|
||||
"test:scope_key" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"sys_code": [
|
||||
None,
|
||||
"test_sys_code",
|
||||
"test#sys_code" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"user_name": [
|
||||
None,
|
||||
"test_user_name",
|
||||
"test%user_name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
}
|
||||
# Remove empty values
|
||||
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
|
||||
|
||||
# Generate all possible combinations
|
||||
combinations = product(*fields.values())
|
||||
|
||||
test_cases = []
|
||||
for combo in combinations:
|
||||
name, scope, scope_key, sys_code, user_name = combo
|
||||
|
||||
var_str = build_variable_string(
|
||||
{
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
},
|
||||
enable_escape=enable_escape,
|
||||
)
|
||||
|
||||
# Construct the expected output
|
||||
expected = {
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
}
|
||||
|
||||
test_cases.append((var_str, expected, enable_escape))
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def test_parse_variables():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_test_cases(enable_escape=False)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
assert result == expected_output, f"Test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_test_cases(enable_escape=True)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
print(f"input_str: {input_str}, expected_output: {expected_output}")
|
||||
result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
assert result == expected_output, f"Test case {i} failed with escape"
|
||||
|
||||
|
||||
def generate_build_test_cases(enable_escape=False):
|
||||
# Define possible values for each field, including special characters for escaping
|
||||
_EMPTY_ = "___EMPTY___"
|
||||
fields = {
|
||||
"name": [
|
||||
None,
|
||||
"test_name",
|
||||
"test:name" if enable_escape else _EMPTY_,
|
||||
"test::name" if enable_escape else _EMPTY_,
|
||||
"test\name" if enable_escape else _EMPTY_,
|
||||
"test\\name" if enable_escape else _EMPTY_,
|
||||
"test\:\#\@\%name" if enable_escape else _EMPTY_,
|
||||
"test\::\###\@@\%%name" if enable_escape else _EMPTY_,
|
||||
"test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_,
|
||||
"test\:#:name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_],
|
||||
"scope_key": [
|
||||
None,
|
||||
"test_scope_key",
|
||||
"test:scope_key" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"sys_code": [
|
||||
None,
|
||||
"test_sys_code",
|
||||
"test#sys_code" if enable_escape else _EMPTY_,
|
||||
],
|
||||
"user_name": [
|
||||
None,
|
||||
"test_user_name",
|
||||
"test%user_name" if enable_escape else _EMPTY_,
|
||||
],
|
||||
}
|
||||
# Remove empty values
|
||||
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
|
||||
|
||||
# Generate all possible combinations
|
||||
combinations = product(*fields.values())
|
||||
|
||||
test_cases = []
|
||||
|
||||
def escape_special_chars(s):
|
||||
if not enable_escape or s is None:
|
||||
return s
|
||||
return (
|
||||
s.replace(":", "\\:")
|
||||
.replace("@", "\\@")
|
||||
.replace("%", "\\%")
|
||||
.replace("#", "\\#")
|
||||
)
|
||||
|
||||
for combo in combinations:
|
||||
name, scope, scope_key, sys_code, user_name = combo
|
||||
|
||||
# Construct the input dictionary
|
||||
input_dict = {
|
||||
"key": "test_key",
|
||||
"name": name,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
}
|
||||
input_dict_with_escape = {
|
||||
k: escape_special_chars(v) for k, v in input_dict.items()
|
||||
}
|
||||
|
||||
# Construct the expected variable string
|
||||
expected_str = "${test_key"
|
||||
if name:
|
||||
expected_str += f":{input_dict_with_escape['name']}"
|
||||
if scope or scope_key:
|
||||
expected_str += "@"
|
||||
if scope:
|
||||
expected_str += input_dict_with_escape["scope"]
|
||||
if scope_key:
|
||||
expected_str += f":{input_dict_with_escape['scope_key']}"
|
||||
if sys_code:
|
||||
expected_str += f"#{input_dict_with_escape['sys_code']}"
|
||||
if user_name:
|
||||
expected_str += f"%{input_dict_with_escape['user_name']}"
|
||||
expected_str += "}"
|
||||
|
||||
test_cases.append((input_dict, expected_str, enable_escape))
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def test_build_variable_string():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_build_test_cases(enable_escape=False)
|
||||
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
|
||||
result = build_variable_string(input_dict, enable_escape=enable_escape)
|
||||
assert result == expected_str, f"Test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_build_test_cases(enable_escape=True)
|
||||
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
|
||||
print(f"input_dict: {input_dict}, expected_str: {expected_str}")
|
||||
result = build_variable_string(input_dict, enable_escape=enable_escape)
|
||||
assert result == expected_str, f"Test case {i} failed with escape"
|
||||
|
||||
|
||||
def test_variable_string_round_trip():
|
||||
# Run test cases without escape
|
||||
test_cases = generate_test_cases(enable_escape=False)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
|
||||
assert (
|
||||
built_result == input_str
|
||||
), f"Round trip test case {i} failed without escape"
|
||||
|
||||
# Run test cases with escape
|
||||
test_cases = generate_test_cases(enable_escape=True)
|
||||
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
|
||||
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
|
||||
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
|
||||
assert built_result == input_str, f"Round trip test case {i} failed with escape"
|
979
dbgpt/core/interface/variables.py
Normal file
979
dbgpt/core/interface/variables.py
Normal file
@@ -0,0 +1,979 @@
|
||||
"""Variables Module."""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.executor_utils import (
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
blocking_func_to_async_no_executor,
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE"
|
||||
|
||||
BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows"
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes"
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
|
||||
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
|
||||
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
|
||||
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
|
||||
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
|
||||
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces"
|
||||
|
||||
|
||||
class Encryption(ABC):
|
||||
"""Encryption interface."""
|
||||
|
||||
name: str = "__abstract__"
|
||||
|
||||
@abstractmethod
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data."""
|
||||
|
||||
@abstractmethod
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data."""
|
||||
|
||||
|
||||
def _generate_key_from_password(
|
||||
password: bytes, salt: Optional[Union[str, bytes]] = None
|
||||
):
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
if salt is None:
|
||||
salt = os.urandom(16)
|
||||
elif isinstance(salt, str):
|
||||
salt = salt.encode()
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password))
|
||||
return key, salt
|
||||
|
||||
|
||||
class FernetEncryption(Encryption):
|
||||
"""Fernet encryption.
|
||||
|
||||
A symmetric encryption algorithm that uses the same key for both encryption and
|
||||
decryption which is powered by the cryptography library.
|
||||
"""
|
||||
|
||||
name = "fernet"
|
||||
|
||||
def __init__(self, key: Optional[bytes] = None):
|
||||
"""Initialize the fernet encryption."""
|
||||
if key is not None and isinstance(key, str):
|
||||
key = key.encode()
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"cryptography is required for encryption, please install by running "
|
||||
"`pip install cryptography`"
|
||||
)
|
||||
if key is None:
|
||||
key = Fernet.generate_key()
|
||||
self.key = key
|
||||
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data with the salt.
|
||||
|
||||
Args:
|
||||
data (str): The data to encrypt.
|
||||
salt (str): The salt to use, which is used to derive the key.
|
||||
|
||||
Returns:
|
||||
str: The encrypted data.
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key, salt = _generate_key_from_password(self.key, salt)
|
||||
fernet = Fernet(key)
|
||||
encrypted_secret = fernet.encrypt(data.encode()).decode()
|
||||
return encrypted_secret
|
||||
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data with the salt.
|
||||
|
||||
Args:
|
||||
encrypted_data (str): The encrypted data.
|
||||
salt (str): The salt to use, which is used to derive the key.
|
||||
|
||||
Returns:
|
||||
str: The decrypted data.
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key, salt = _generate_key_from_password(self.key, salt)
|
||||
fernet = Fernet(key)
|
||||
return fernet.decrypt(encrypted_data.encode()).decode()
|
||||
|
||||
|
||||
class SimpleEncryption(Encryption):
|
||||
"""Simple implementation of encryption.
|
||||
|
||||
A simple encryption algorithm that uses a key to XOR the data.
|
||||
"""
|
||||
|
||||
name = "simple"
|
||||
|
||||
def __init__(self, key: Optional[str] = None):
|
||||
"""Initialize the simple encryption."""
|
||||
if key is None:
|
||||
key = base64.b64encode(os.urandom(32)).decode()
|
||||
self.key = key
|
||||
|
||||
def _derive_key(self, salt: str) -> bytes:
|
||||
return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000)
|
||||
|
||||
def encrypt(self, data: str, salt: str) -> str:
|
||||
"""Encrypt the data with the salt."""
|
||||
key = self._derive_key(salt)
|
||||
encrypted = bytes(
|
||||
x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1))
|
||||
)
|
||||
return base64.b64encode(encrypted).decode()
|
||||
|
||||
def decrypt(self, encrypted_data: str, salt: str) -> str:
|
||||
"""Decrypt the data with the salt."""
|
||||
key = self._derive_key(salt)
|
||||
data = base64.b64decode(encrypted_data)
|
||||
decrypted = bytes(
|
||||
x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1))
|
||||
)
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VariablesIdentifier(ResourceIdentifier):
|
||||
"""The variables identifier."""
|
||||
|
||||
identifier_split: str = dataclasses.field(default="@", init=False)
|
||||
|
||||
key: str
|
||||
name: str
|
||||
scope: str = "global"
|
||||
scope_key: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if not self.key or not self.name or not self.scope:
|
||||
raise ValueError("Key, name, and scope are required.")
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return build_variable_string(
|
||||
{
|
||||
"key": self.key,
|
||||
"name": self.name,
|
||||
"scope": self.scope,
|
||||
"scope_key": self.scope_key,
|
||||
"sys_code": self.sys_code,
|
||||
"user_name": self.user_name,
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the identifier.
|
||||
"""
|
||||
return {
|
||||
"key": self.key,
|
||||
"name": self.name,
|
||||
"scope": self.scope,
|
||||
"scope_key": self.scope_key,
|
||||
"sys_code": self.sys_code,
|
||||
"user_name": self.user_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_str_identifier(
|
||||
cls,
|
||||
str_identifier: str,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> "VariablesIdentifier":
|
||||
"""Create a VariablesIdentifier from a string identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The string identifier.
|
||||
default_identifier_map (Optional[Dict[str, str]]): The default identifier
|
||||
map, which contains the default values for the identifier. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
VariablesIdentifier: The VariablesIdentifier.
|
||||
"""
|
||||
variable_dict = parse_variable(str_identifier)
|
||||
if not variable_dict:
|
||||
raise ValueError("Invalid string identifier.")
|
||||
if not variable_dict.get("key"):
|
||||
raise ValueError("Invalid string identifier, must have key")
|
||||
if not variable_dict.get("name"):
|
||||
raise ValueError("Invalid string identifier, must have name")
|
||||
|
||||
def _get_value(key, default_value: Optional[str] = None) -> Optional[str]:
|
||||
if variable_dict.get(key) is not None:
|
||||
return variable_dict.get(key)
|
||||
if default_identifier_map is not None and default_identifier_map.get(key):
|
||||
return default_identifier_map.get(key)
|
||||
return default_value
|
||||
|
||||
return cls(
|
||||
key=variable_dict["key"],
|
||||
name=variable_dict["name"],
|
||||
scope=variable_dict["scope"],
|
||||
scope_key=_get_value("scope_key"),
|
||||
sys_code=_get_value("sys_code"),
|
||||
user_name=_get_value("user_name"),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StorageVariables(StorageItem):
|
||||
"""The storage variables."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
label: str
|
||||
value: Any
|
||||
category: Literal["common", "secret"] = "common"
|
||||
scope: str = "global"
|
||||
value_type: Optional[str] = None
|
||||
scope_key: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
encryption_method: Optional[str] = None
|
||||
salt: Optional[str] = None
|
||||
enabled: int = 1
|
||||
description: Optional[str] = None
|
||||
|
||||
_identifier: VariablesIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = VariablesIdentifier(
|
||||
key=self.key,
|
||||
name=self.name,
|
||||
scope=self.scope,
|
||||
scope_key=self.scope_key,
|
||||
sys_code=self.sys_code,
|
||||
user_name=self.user_name,
|
||||
)
|
||||
if not self.value_type:
|
||||
self.value_type = type(self.value).__name__
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge with another storage variables."""
|
||||
if not isinstance(other, StorageVariables):
|
||||
raise ValueError(f"Cannot merge with {type(other)}")
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the storage variables to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the storage variables.
|
||||
"""
|
||||
return {
|
||||
**self._identifier.to_dict(),
|
||||
"label": self.label,
|
||||
"value": self.value,
|
||||
"value_type": self.value_type,
|
||||
"category": self.category,
|
||||
"encryption_method": self.encryption_method,
|
||||
"salt": self.salt,
|
||||
"enabled": self.enabled,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
def from_object(self, other: "StorageVariables") -> None:
|
||||
"""Copy the values from another storage variables object."""
|
||||
self.label = other.label
|
||||
self.value = other.value
|
||||
self.value_type = other.value_type
|
||||
self.category = other.category
|
||||
self.scope = other.scope
|
||||
self.scope_key = other.scope_key
|
||||
self.sys_code = other.sys_code
|
||||
self.user_name = other.user_name
|
||||
self.encryption_method = other.encryption_method
|
||||
self.salt = other.salt
|
||||
self.enabled = other.enabled
|
||||
self.description = other.description
|
||||
|
||||
@classmethod
|
||||
def from_identifier(
|
||||
cls,
|
||||
identifier: VariablesIdentifier,
|
||||
value: Any,
|
||||
value_type: str,
|
||||
label: str = "",
|
||||
category: Literal["common", "secret"] = "common",
|
||||
encryption_method: Optional[str] = None,
|
||||
salt: Optional[str] = None,
|
||||
) -> "StorageVariables":
|
||||
"""Copy the values from an identifier."""
|
||||
return cls(
|
||||
key=identifier.key,
|
||||
name=identifier.name,
|
||||
label=label,
|
||||
value=value,
|
||||
value_type=value_type,
|
||||
category=category,
|
||||
scope=identifier.scope,
|
||||
scope_key=identifier.scope_key,
|
||||
sys_code=identifier.sys_code,
|
||||
user_name=identifier.user_name,
|
||||
encryption_method=encryption_method,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
|
||||
class VariablesProvider(BaseComponent, ABC):
|
||||
"""The variables provider interface."""
|
||||
|
||||
name = ComponentType.VARIABLES_PROVIDER.value
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
|
||||
@abstractmethod
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get variables by key."""
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get variables by key async."""
|
||||
raise NotImplementedError("Current variables provider does not support async.")
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the variables provider support async."""
|
||||
return False
|
||||
|
||||
def _convert_to_value_type(self, var: StorageVariables):
|
||||
"""Convert the variable to the value type."""
|
||||
if var.value is None:
|
||||
return None
|
||||
if var.value_type == "str":
|
||||
return str(var.value)
|
||||
elif var.value_type == "int":
|
||||
return int(var.value)
|
||||
elif var.value_type == "float":
|
||||
return float(var.value)
|
||||
elif var.value_type == "bool":
|
||||
if var.value.lower() in ["true", "1"]:
|
||||
return True
|
||||
elif var.value.lower() in ["false", "0"]:
|
||||
return False
|
||||
else:
|
||||
return bool(var.value)
|
||||
else:
|
||||
return var.value
|
||||
|
||||
|
||||
class VariablesPlaceHolder:
|
||||
"""The variables place holder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_name: str,
|
||||
full_key: str,
|
||||
default_value: Any = _EMPTY_DEFAULT_VALUE,
|
||||
):
|
||||
"""Initialize the variables place holder."""
|
||||
self.param_name = param_name
|
||||
self.full_key = full_key
|
||||
self.default_value = default_value
|
||||
|
||||
def parse(
|
||||
self,
|
||||
variables_provider: VariablesProvider,
|
||||
ignore_not_found_error: bool = False,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Parse the variables."""
|
||||
try:
|
||||
return variables_provider.get(
|
||||
self.full_key,
|
||||
self.default_value,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
except ValueError as e:
|
||||
if ignore_not_found_error:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
||||
|
||||
|
||||
class StorageVariablesProvider(VariablesProvider):
|
||||
"""The storage variables provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface] = None,
|
||||
encryption: Optional[Encryption] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
key: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the storage variables provider."""
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self.system_app = system_app
|
||||
self.encryption = encryption or SimpleEncryption(key)
|
||||
|
||||
self.storage = storage
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the storage variables provider."""
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
|
||||
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
|
||||
if variable is None:
|
||||
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||
raise ValueError(f"Variable {full_key} not found")
|
||||
return default_value
|
||||
variable.value = self.deserialize_value(variable.value)
|
||||
if (
|
||||
variable.value is not None
|
||||
and variable.category == "secret"
|
||||
and variable.encryption_method
|
||||
and variable.salt
|
||||
):
|
||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
if variables_item.category == "secret":
|
||||
salt = base64.b64encode(os.urandom(16)).decode()
|
||||
variables_item.value = self.encryption.encrypt(
|
||||
str(variables_item.value), salt
|
||||
)
|
||||
variables_item.encryption_method = self.encryption.name
|
||||
variables_item.salt = salt
|
||||
# Replace value to a json serializable object
|
||||
variables_item.value = self.serialize_value(variables_item.value)
|
||||
|
||||
self.storage.save_or_update(variables_item)
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Query variables from storage."""
|
||||
# Try to get builtin variables
|
||||
is_builtin, builtin_variables = self._get_builtins_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
variables = self.storage.query(
|
||||
QuerySpec(
|
||||
conditions={
|
||||
"key": key,
|
||||
"scope": scope,
|
||||
"scope_key": scope_key,
|
||||
"sys_code": sys_code,
|
||||
"user_name": user_name,
|
||||
"enabled": 1,
|
||||
}
|
||||
),
|
||||
StorageVariables,
|
||||
)
|
||||
for variable in variables:
|
||||
variable.value = self.deserialize_value(variable.value)
|
||||
return variables
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Query variables from storage async."""
|
||||
# Try to get builtin variables
|
||||
is_builtin, builtin_variables = await self._async_get_builtins_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
executor_factory: Optional[
|
||||
DefaultExecutorFactory
|
||||
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
|
||||
if executor_factory:
|
||||
return await blocking_func_to_async(
|
||||
executor_factory.create(),
|
||||
self.get_variables,
|
||||
key,
|
||||
scope,
|
||||
scope_key,
|
||||
sys_code,
|
||||
user_name,
|
||||
)
|
||||
else:
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self.get_variables, key, scope, scope_key, sys_code, user_name
|
||||
)
|
||||
|
||||
def _get_builtins_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Tuple[bool, List[StorageVariables]]:
|
||||
"""Get builtin variables."""
|
||||
if self.system_app is None:
|
||||
return False, []
|
||||
provider: BuiltinVariablesProvider = self.system_app.get_component(
|
||||
key,
|
||||
component_type=BuiltinVariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
if not provider:
|
||||
return False, []
|
||||
return True, provider.get_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
async def _async_get_builtins_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Tuple[bool, List[StorageVariables]]:
|
||||
"""Get builtin variables."""
|
||||
if self.system_app is None:
|
||||
return False, []
|
||||
provider: BuiltinVariablesProvider = self.system_app.get_component(
|
||||
key,
|
||||
component_type=BuiltinVariablesProvider,
|
||||
default_component=None,
|
||||
)
|
||||
if not provider:
|
||||
return False, []
|
||||
if not provider.support_async():
|
||||
return False, []
|
||||
return True, await provider.async_get_variables(
|
||||
key,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def serialize_value(cls, value: Any) -> str:
|
||||
"""Serialize the value."""
|
||||
value_dict = {"value": value}
|
||||
return json.dumps(value_dict, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def deserialize_value(cls, value: str) -> Any:
|
||||
"""Deserialize the value."""
|
||||
value_dict = json.loads(value)
|
||||
return value_dict["value"]
|
||||
|
||||
|
||||
class BuiltinVariablesProvider(VariablesProvider, ABC):
|
||||
"""The builtin variables provider.
|
||||
|
||||
You can implement this class to provide builtin variables. Such LLMs, agents,
|
||||
datasource, knowledge base, etc.
|
||||
"""
|
||||
|
||||
name = "dbgpt_variables_builtin"
|
||||
|
||||
def __init__(self, system_app: Optional[SystemApp] = None):
|
||||
"""Initialize the builtin variables provider."""
|
||||
self.system_app = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the builtin variables provider."""
|
||||
self.system_app = system_app
|
||||
|
||||
def get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support get.")
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
raise NotImplementedError("BuiltinVariablesProvider does not support save.")
|
||||
|
||||
|
||||
def parse_variable(
|
||||
variable_str: str,
|
||||
enable_escape: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse the variable string.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
cases = [
|
||||
{
|
||||
"full_key": "${test_key:test_name@test_scope:test_scope_key}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": "test_name",
|
||||
"scope": "test_scope",
|
||||
"scope_key": "test_scope_key",
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"full_key": "${test_key#test_sys_code}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": None,
|
||||
"sys_code": "test_sys_code",
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"full_key": "${test_key@:test_scope_key}",
|
||||
"expected": {
|
||||
"key": "test_key",
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": "test_scope_key",
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
for case in cases:
|
||||
assert parse_variable(case["full_key"]) == case["expected"]
|
||||
Args:
|
||||
variable_str (str): The variable string.
|
||||
enable_escape (bool): Whether to handle escaped characters.
|
||||
Returns:
|
||||
Dict[str, Any]: The parsed variable.
|
||||
"""
|
||||
if not variable_str.startswith("${") or not variable_str.endswith("}"):
|
||||
raise ValueError(
|
||||
"Invalid variable format, must start with '${' and end with '}'"
|
||||
)
|
||||
|
||||
# Remove the surrounding ${ and }
|
||||
content = variable_str[2:-1]
|
||||
|
||||
# Define placeholders for escaped characters
|
||||
placeholders = {
|
||||
r"\@": "__ESCAPED_AT__",
|
||||
r"\#": "__ESCAPED_HASH__",
|
||||
r"\%": "__ESCAPED_PERCENT__",
|
||||
r"\:": "__ESCAPED_COLON__",
|
||||
}
|
||||
|
||||
if enable_escape:
|
||||
# Replace escaped characters with placeholders
|
||||
for original, placeholder in placeholders.items():
|
||||
content = content.replace(original, placeholder)
|
||||
|
||||
# Initialize the result dictionary
|
||||
result: Dict[str, Optional[str]] = {
|
||||
"key": None,
|
||||
"name": None,
|
||||
"scope": None,
|
||||
"scope_key": None,
|
||||
"sys_code": None,
|
||||
"user_name": None,
|
||||
}
|
||||
|
||||
# Split the content by special characters
|
||||
parts = content.split("@")
|
||||
|
||||
# Parse key and name
|
||||
key_name = parts[0].split("#")[0].split("%")[0]
|
||||
if ":" in key_name:
|
||||
result["key"], result["name"] = key_name.split(":", 1)
|
||||
else:
|
||||
result["key"] = key_name
|
||||
|
||||
# Parse scope and scope_key
|
||||
if len(parts) > 1:
|
||||
scope_part = parts[1].split("#")[0].split("%")[0]
|
||||
if ":" in scope_part:
|
||||
result["scope"], result["scope_key"] = scope_part.split(":", 1)
|
||||
else:
|
||||
result["scope"] = scope_part
|
||||
|
||||
# Parse sys_code
|
||||
if "#" in content:
|
||||
result["sys_code"] = content.split("#", 1)[1].split("%")[0]
|
||||
|
||||
# Parse user_name
|
||||
if "%" in content:
|
||||
result["user_name"] = content.split("%", 1)[1]
|
||||
|
||||
if enable_escape:
|
||||
# Replace placeholders back with escaped characters
|
||||
reverse_placeholders = {v: k[1:] for k, v in placeholders.items()}
|
||||
for key, value in result.items():
|
||||
if value:
|
||||
for placeholder, original in reverse_placeholders.items():
|
||||
result[key] = result[key].replace( # type: ignore
|
||||
placeholder, original
|
||||
)
|
||||
|
||||
# Replace empty strings with None
|
||||
for key, value in result.items():
|
||||
if value == "":
|
||||
result[key] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _is_variable_format(value: str) -> bool:
|
||||
if not value.startswith("${") or not value.endswith("}"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_variable_string(variable_str: str) -> bool:
|
||||
"""Check if the given string is a variable string.
|
||||
|
||||
A valid variable string should start with "${" and end with "}", and contain key
|
||||
and name
|
||||
|
||||
Args:
|
||||
variable_str (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is a variable string, False otherwise.
|
||||
"""
|
||||
if not variable_str or not isinstance(variable_str, str):
|
||||
return False
|
||||
if not _is_variable_format(variable_str):
|
||||
return False
|
||||
try:
|
||||
variable_dict = parse_variable(variable_str)
|
||||
if not variable_dict.get("key"):
|
||||
return False
|
||||
if not variable_dict.get("name"):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_list_string(variable_str: str) -> bool:
|
||||
"""Check if the given string is a variable string.
|
||||
|
||||
A valid variable list string should start with "${" and end with "}", and contain
|
||||
key and not contain name
|
||||
|
||||
A valid variable list string means that the variable is a list of variables with the
|
||||
same key.
|
||||
|
||||
Args:
|
||||
variable_str (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is a variable string, False otherwise.
|
||||
"""
|
||||
if not _is_variable_format(variable_str):
|
||||
return False
|
||||
try:
|
||||
variable_dict = parse_variable(variable_str)
|
||||
if not variable_dict.get("key"):
|
||||
return False
|
||||
if variable_dict.get("name"):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def build_variable_string(
|
||||
variable_dict: Dict[str, Any],
|
||||
scope_sig: str = "@",
|
||||
sys_code_sig: str = "#",
|
||||
user_sig: str = "%",
|
||||
kv_sig: str = ":",
|
||||
enable_escape: bool = True,
|
||||
) -> str:
|
||||
"""Build a variable string from the given dictionary.
|
||||
|
||||
Args:
|
||||
variable_dict (Dict[str, Any]): The dictionary containing the variable details.
|
||||
scope_sig (str): The scope signature.
|
||||
sys_code_sig (str): The sys code signature.
|
||||
user_sig (str): The user signature.
|
||||
kv_sig (str): The key-value split signature.
|
||||
enable_escape (bool): Whether to escape special characters
|
||||
|
||||
Returns:
|
||||
str: The formatted variable string.
|
||||
|
||||
Examples:
|
||||
>>> build_variable_string(
|
||||
... {
|
||||
... "key": "test_key",
|
||||
... "name": "test_name",
|
||||
... "scope": "test_scope",
|
||||
... "scope_key": "test_scope_key",
|
||||
... "sys_code": "test_sys_code",
|
||||
... "user_name": "test_user",
|
||||
... }
|
||||
... )
|
||||
'${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"})
|
||||
'${test_key@:test_scope_key}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"})
|
||||
'${test_key#test_sys_code}'
|
||||
|
||||
>>> build_variable_string({"key": "test_key"})
|
||||
'${test_key}'
|
||||
"""
|
||||
special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig}
|
||||
# Replace None with ""
|
||||
new_variable_dict = {key: value or "" for key, value in variable_dict.items()}
|
||||
|
||||
# Check if the variable_dict contains any special characters
|
||||
for key, value in new_variable_dict.items():
|
||||
if value != "" and any(char in value for char in special_chars):
|
||||
if enable_escape:
|
||||
# Escape special characters
|
||||
new_variable_dict[key] = (
|
||||
value.replace("@", "\\@")
|
||||
.replace("#", "\\#")
|
||||
.replace("%", "\\%")
|
||||
.replace(":", "\\:")
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{key} contains special characters, error value: {value}, special "
|
||||
f"characters: {special_chars}"
|
||||
)
|
||||
|
||||
key = new_variable_dict.get("key", "")
|
||||
name = new_variable_dict.get("name", "")
|
||||
scope = new_variable_dict.get("scope", "")
|
||||
scope_key = new_variable_dict.get("scope_key", "")
|
||||
sys_code = new_variable_dict.get("sys_code", "")
|
||||
user_name = new_variable_dict.get("user_name", "")
|
||||
|
||||
# Construct the base of the variable string
|
||||
variable_str = f"${{{key}"
|
||||
|
||||
# Add name if present
|
||||
if name:
|
||||
variable_str += f":{name}"
|
||||
|
||||
# Add scope and scope_key if present
|
||||
if scope or scope_key:
|
||||
variable_str += f"@{scope}"
|
||||
if scope_key:
|
||||
variable_str += f":{scope_key}"
|
||||
|
||||
# Add sys_code if present
|
||||
if sys_code:
|
||||
variable_str += f"#{sys_code}"
|
||||
|
||||
# Add user_name if present
|
||||
if user_name:
|
||||
variable_str += f"%{user_name}"
|
||||
|
||||
# Close the variable string
|
||||
variable_str += "}"
|
||||
|
||||
return variable_str
|
Reference in New Issue
Block a user