feat(core): AWEL flow 2.0 backend code (#1879)

Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
Fangyin Cheng
2024-08-23 14:57:54 +08:00
committed by GitHub
parent 3a32344380
commit 9502251c08
67 changed files with 8289 additions and 190 deletions

View File

@@ -0,0 +1,834 @@
"""File storage interface."""
import dataclasses
import hashlib
import io
import logging
import os
import uuid
from abc import ABC, abstractmethod
from io import BytesIO
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import requests
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.tracer import root_tracer, trace
from .storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageError,
StorageInterface,
StorageItem,
)
logger = logging.getLogger(__name__)
_SCHEMA = "dbgpt-fs"
@dataclasses.dataclass
class FileMetadataIdentifier(ResourceIdentifier):
"""File metadata identifier."""
file_id: str
bucket: str
def to_dict(self) -> Dict:
"""Convert the identifier to a dictionary."""
return {"file_id": self.file_id, "bucket": self.bucket}
@property
def str_identifier(self) -> str:
"""Get the string identifier.
Returns:
str: The string identifier
"""
return f"{self.bucket}/{self.file_id}"
@dataclasses.dataclass
class FileMetadata(StorageItem):
"""File metadata for storage."""
file_id: str
bucket: str
file_name: str
file_size: int
storage_type: str
storage_path: str
uri: str
custom_metadata: Dict[str, Any]
file_hash: str
user_name: Optional[str] = None
sys_code: Optional[str] = None
_identifier: FileMetadataIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = FileMetadataIdentifier(
file_id=self.file_id, bucket=self.bucket
)
custom_metadata = self.custom_metadata or {}
if not self.user_name:
self.user_name = custom_metadata.get("user_name")
if not self.sys_code:
self.sys_code = custom_metadata.get("sys_code")
@property
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the metadata with another item."""
if not isinstance(other, FileMetadata):
raise StorageError("Cannot merge different types of items")
self._from_object(other)
def to_dict(self) -> Dict:
"""Convert the metadata to a dictionary."""
return {
"file_id": self.file_id,
"bucket": self.bucket,
"file_name": self.file_name,
"file_size": self.file_size,
"storage_type": self.storage_type,
"storage_path": self.storage_path,
"uri": self.uri,
"custom_metadata": self.custom_metadata,
"file_hash": self.file_hash,
}
def _from_object(self, obj: "FileMetadata") -> None:
self.file_id = obj.file_id
self.bucket = obj.bucket
self.file_name = obj.file_name
self.file_size = obj.file_size
self.storage_type = obj.storage_type
self.storage_path = obj.storage_path
self.uri = obj.uri
self.custom_metadata = obj.custom_metadata
self.file_hash = obj.file_hash
self._identifier = obj._identifier
class FileStorageURI:
"""File storage URI."""
def __init__(
self,
storage_type: str,
bucket: str,
file_id: str,
version: Optional[str] = None,
custom_params: Optional[Dict[str, Any]] = None,
):
"""Initialize the file storage URI."""
self.scheme = _SCHEMA
self.storage_type = storage_type
self.bucket = bucket
self.file_id = file_id
self.version = version
self.custom_params = custom_params or {}
@classmethod
def is_local_file(cls, uri: str) -> bool:
"""Check if the URI is local."""
parsed = urlparse(uri)
if not parsed.scheme or parsed.scheme == "file":
return True
return False
@classmethod
def parse(cls, uri: str) -> "FileStorageURI":
"""Parse the URI string."""
parsed = urlparse(uri)
if parsed.scheme != _SCHEMA:
raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'")
path_parts = parsed.path.strip("/").split("/")
if len(path_parts) < 2:
raise ValueError("Invalid URI path. Must contain bucket and file ID")
storage_type = parsed.netloc
bucket = path_parts[0]
file_id = path_parts[1]
version = path_parts[2] if len(path_parts) > 2 else None
custom_params = parse_qs(parsed.query)
return cls(storage_type, bucket, file_id, version, custom_params)
def __str__(self) -> str:
"""Get the string representation of the URI."""
base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}"
if self.version:
base_uri += f"/{self.version}"
if self.custom_params:
query_string = urlencode(self.custom_params, doseq=True)
base_uri += f"?{query_string}"
return base_uri
class StorageBackend(ABC):
"""Storage backend interface."""
storage_type: str = "__base__"
@abstractmethod
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the storage backend.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
@abstractmethod
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
@abstractmethod
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the storage backend.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
@property
@abstractmethod
def save_chunk_size(self) -> int:
"""Get the save chunk size.
Returns:
int: The save chunk size
"""
class LocalFileStorage(StorageBackend):
"""Local file storage backend."""
storage_type: str = "local"
def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024):
"""Initialize the local file storage backend."""
self.base_path = base_path
self._save_chunk_size = save_chunk_size
os.makedirs(self.base_path, exist_ok=True)
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the local storage backend."""
bucket_path = os.path.join(self.base_path, bucket)
os.makedirs(bucket_path, exist_ok=True)
file_path = os.path.join(bucket_path, file_id)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return file_path
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
return open(file_path, "rb") # noqa: SIM115
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the local storage backend."""
bucket_path = os.path.join(self.base_path, fm.bucket)
file_path = os.path.join(bucket_path, fm.file_id)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
class FileStorageSystem:
"""File storage system."""
def __init__(
self,
storage_backends: Dict[str, StorageBackend],
metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None,
check_hash: bool = True,
):
"""Initialize the file storage system."""
metadata_storage = metadata_storage or InMemoryStorage()
self.storage_backends = storage_backends
self.metadata_storage = metadata_storage
self.check_hash = check_hash
self._save_chunk_size = min(
backend.save_chunk_size for backend in storage_backends.values()
)
def _calculate_file_hash(self, file_data: BinaryIO) -> str:
"""Calculate the MD5 hash of the file data."""
if not self.check_hash:
return "-1"
hasher = hashlib.md5()
file_data.seek(0)
while chunk := file_data.read(self._save_chunk_size):
hasher.update(chunk)
file_data.seek(0)
return hasher.hexdigest()
@trace("file_storage_system.save_file")
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage backend."""
file_id = str(uuid.uuid4())
backend = self.storage_backends.get(storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {storage_type}")
with root_tracer.start_span(
"file_storage_system.save_file.backend_save",
metadata={
"bucket": bucket,
"file_id": file_id,
"file_name": file_name,
"storage_type": storage_type,
},
):
storage_path = backend.save(bucket, file_id, file_data)
file_data.seek(0, 2) # Move to the end of the file
file_size = file_data.tell() # Get the file size
file_data.seek(0) # Reset file pointer
# filter None value
custom_metadata = (
{k: v for k, v in custom_metadata.items() if v is not None}
if custom_metadata
else {}
)
with root_tracer.start_span(
"file_storage_system.save_file.calculate_hash",
):
file_hash = self._calculate_file_hash(file_data)
uri = FileStorageURI(
storage_type, bucket, file_id, custom_params=custom_metadata
)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name=file_name,
file_size=file_size,
storage_type=storage_type,
storage_path=storage_path,
uri=str(uri),
custom_metadata=custom_metadata,
file_hash=file_hash,
)
self.metadata_storage.save(metadata)
return str(uri)
@trace("file_storage_system.get_file")
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage backend."""
if FileStorageURI.is_local_file(uri):
local_file_name = uri.split("/")[-1]
if not os.path.exists(uri):
raise FileNotFoundError(f"File not found: {uri}")
dummy_metadata = FileMetadata(
file_id=local_file_name,
bucket="dummy_bucket",
file_name=local_file_name,
file_size=-1,
storage_type="local",
storage_path=uri,
uri=uri,
custom_metadata={},
file_hash="",
)
logger.info(f"Reading local file: {uri}")
return open(uri, "rb"), dummy_metadata # noqa: SIM115
parsed_uri = FileStorageURI.parse(uri)
metadata = self.metadata_storage.load(
FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
),
FileMetadata,
)
if not metadata:
raise FileNotFoundError(f"No metadata found for URI: {uri}")
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
with root_tracer.start_span(
"file_storage_system.get_file.backend_load",
metadata={
"bucket": metadata.bucket,
"file_id": metadata.file_id,
"file_name": metadata.file_name,
"storage_type": metadata.storage_type,
},
):
file_data = backend.load(metadata)
with root_tracer.start_span(
"file_storage_system.get_file.verify_hash",
):
calculated_hash = self._calculate_file_hash(file_data)
if calculated_hash != "-1" and calculated_hash != metadata.file_hash:
raise ValueError("File integrity check failed. Hash mismatch.")
return file_data, metadata
def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]:
"""Get the file metadata.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Optional[FileMetadata]: The file metadata
"""
fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket)
return self.metadata_storage.load(fid, FileMetadata)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage backend.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
parsed_uri = FileStorageURI.parse(uri)
fid = FileMetadataIdentifier(
file_id=parsed_uri.file_id, bucket=parsed_uri.bucket
)
metadata = self.metadata_storage.load(fid, FileMetadata)
if not metadata:
return False
backend = self.storage_backends.get(metadata.storage_type)
if not backend:
raise ValueError(f"Unsupported storage type: {metadata.storage_type}")
if backend.delete(metadata):
try:
self.metadata_storage.delete(fid)
return True
except Exception:
# If the metadata deletion fails, log the error and return False
return False
return False
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket."""
filters = filters or {}
filters["bucket"] = bucket
return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata)
class FileStorageClient(BaseComponent):
"""File storage client component."""
name = ComponentType.FILE_STORAGE_CLIENT.value
def __init__(
self,
system_app: Optional[SystemApp] = None,
storage_system: Optional[FileStorageSystem] = None,
):
"""Initialize the file storage client."""
super().__init__(system_app=system_app)
if not storage_system:
from pathlib import Path
base_path = Path.home() / ".cache" / "dbgpt" / "files"
storage_system = FileStorageSystem(
{
LocalFileStorage.storage_type: LocalFileStorage(
base_path=str(base_path)
)
}
)
self.system_app = system_app
self._storage_system = storage_system
def init_app(self, system_app: SystemApp):
"""Initialize the application."""
self.system_app = system_app
@property
def storage_system(self) -> FileStorageSystem:
"""Get the file storage system."""
if not self._storage_system:
raise ValueError("File storage system not initialized")
return self._storage_system
def upload_file(
self,
bucket: str,
file_path: str,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Upload a file to the storage system.
Args:
bucket (str): The bucket name
file_path (str): The file path
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
with open(file_path, "rb") as file:
return self.save_file(
bucket, os.path.basename(file_path), file, storage_type, custom_metadata
)
def save_file(
self,
bucket: str,
file_name: str,
file_data: BinaryIO,
storage_type: str,
custom_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save the file data to the storage system.
Args:
bucket (str): The bucket name
file_name (str): The file name
file_data (BinaryIO): The file data
storage_type (str): The storage type
custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to
None.
Returns:
str: The file URI
"""
return self.storage_system.save_file(
bucket, file_name, file_data, storage_type, custom_metadata
)
def download_file(self, uri: str, destination_path: str) -> None:
"""Download a file from the storage system.
Args:
uri (str): The file URI
destination_path (str): The destination
Raises:
FileNotFoundError: If the file is not found
"""
file_data, _ = self.storage_system.get_file(uri)
with open(destination_path, "wb") as f:
f.write(file_data.read())
def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system.
Args:
uri (str): The file URI
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
return self.storage_system.get_file(uri)
def get_file_by_id(
self, bucket: str, file_id: str
) -> Tuple[BinaryIO, FileMetadata]:
"""Get the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
Tuple[BinaryIO, FileMetadata]: The file data and metadata
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.get_file(metadata.uri)
def delete_file(self, uri: str) -> bool:
"""Delete the file data from the storage system.
Args:
uri (str): The file URI
Returns:
bool: True if the file was deleted, False otherwise
"""
return self.storage_system.delete_file(uri)
def delete_file_by_id(self, bucket: str, file_id: str) -> bool:
"""Delete the file data from the storage system by ID.
Args:
bucket (str): The bucket name
file_id (str): The file ID
Returns:
bool: True if the file was deleted, False otherwise
"""
metadata = self.storage_system.get_file_metadata(bucket, file_id)
if not metadata:
raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}")
return self.delete_file(metadata.uri)
def list_files(
self, bucket: str, filters: Optional[Dict[str, Any]] = None
) -> List[FileMetadata]:
"""List the files in the bucket.
Args:
bucket (str): The bucket name
filters (Dict[str, Any], optional): Filters. Defaults to None.
Returns:
List[FileMetadata]: The list of file metadata
"""
return self.storage_system.list_files(bucket, filters)
class SimpleDistributedStorage(StorageBackend):
"""Simple distributed storage backend."""
storage_type: str = "distributed"
def __init__(
self,
node_address: str,
local_storage_path: str,
save_chunk_size: int = 1024 * 1024,
transfer_chunk_size: int = 1024 * 1024,
transfer_timeout: int = 360,
api_prefix: str = "/api/v2/serve/file/files",
):
"""Initialize the simple distributed storage backend."""
self.node_address = node_address
self.local_storage_path = local_storage_path
os.makedirs(self.local_storage_path, exist_ok=True)
self._save_chunk_size = save_chunk_size
self._transfer_chunk_size = transfer_chunk_size
self._transfer_timeout = transfer_timeout
self._api_prefix = api_prefix
@property
def save_chunk_size(self) -> int:
"""Get the save chunk size."""
return self._save_chunk_size
def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str:
node_id = hashlib.md5(node_address.encode()).hexdigest()
return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}")
def _parse_node_address(self, fm: FileMetadata) -> str:
storage_path = fm.storage_path
if not storage_path.startswith("distributed://"):
raise ValueError("Invalid storage path")
return storage_path.split("//")[1].split("/")[0]
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
"""Save the file data to the distributed storage backend.
Just save the file locally.
Args:
bucket (str): The bucket name
file_id (str): The file ID
file_data (BinaryIO): The file data
Returns:
str: The storage path
"""
file_path = self._get_file_path(bucket, file_id, self.node_address)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
while True:
chunk = file_data.read(self.save_chunk_size)
if not chunk:
break
f.write(chunk)
return f"distributed://{self.node_address}/{bucket}/{file_id}"
def load(self, fm: FileMetadata) -> BinaryIO:
"""Load the file data from the distributed storage backend.
If the file is stored on the local node, load it from the local storage.
Args:
fm (FileMetadata): The file metadata
Returns:
BinaryIO: The file data
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
# TODO: check if the file is cached in local storage
if node_address == self.node_address:
if os.path.exists(file_path):
return open(file_path, "rb") # noqa: SIM115
else:
raise FileNotFoundError(f"File {file_id} not found on the local node")
else:
response = requests.get(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
stream=True,
)
response.raise_for_status()
# TODO: cache the file in local storage
return StreamedBytesIO(
response.iter_content(chunk_size=self._transfer_chunk_size)
)
def delete(self, fm: FileMetadata) -> bool:
"""Delete the file data from the distributed storage backend.
If the file is stored on the local node, delete it from the local storage.
If the file is stored on a remote node, send a delete request to the remote
node.
Args:
fm (FileMetadata): The file metadata
Returns:
bool: True if the file was deleted, False otherwise
"""
file_id = fm.file_id
bucket = fm.bucket
node_address = self._parse_node_address(fm)
file_path = self._get_file_path(bucket, file_id, node_address)
if node_address == self.node_address:
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
else:
try:
response = requests.delete(
f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}",
timeout=self._transfer_timeout,
)
response.raise_for_status()
return True
except Exception:
return False
class StreamedBytesIO(io.BytesIO):
"""A BytesIO subclass that can be used with streaming responses.
Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13
"""
def __init__(self, request_iterator):
"""Initialize the StreamedBytesIO instance."""
super().__init__()
self._bytes = BytesIO()
self._iterator = request_iterator
def _load_all(self):
self._bytes.seek(0, io.SEEK_END)
for chunk in self._iterator:
self._bytes.write(chunk)
def _load_until(self, goal_position):
current_position = self._bytes.seek(0, io.SEEK_END)
while current_position < goal_position:
try:
current_position += self._bytes.write(next(self._iterator))
except StopIteration:
break
def tell(self) -> int:
"""Get the current position."""
return self._bytes.tell()
def read(self, size: Optional[int] = None) -> bytes:
"""Read the data from the stream.
Args:
size (Optional[int], optional): The number of bytes to read. Defaults to
None.
Returns:
bytes: The read data
"""
left_off_at = self._bytes.tell()
if size is None:
self._load_all()
else:
goal_position = left_off_at + size
self._load_until(goal_position)
self._bytes.seek(left_off_at)
return self._bytes.read(size)
def seek(self, position: int, whence: int = io.SEEK_SET):
"""Seek to a position in the stream.
Args:
position (int): The position
whence (int, optional): The reference point. Defaults to io.SEEK
Raises:
ValueError: If the reference point is invalid
"""
if whence == io.SEEK_END:
self._load_all()
else:
self._bytes.seek(position, whence)
def __enter__(self):
"""Enter the context manager."""
return self
def __exit__(self, ext_type, value, tb):
"""Exit the context manager."""
self._bytes.close()

View File

@@ -24,6 +24,7 @@ from dbgpt.core.awel.flow import (
OperatorType,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.interface.llm import (
LLMClient,
@@ -69,6 +70,10 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]):
optional=True,
default=None,
description=_("The temperature of the model request."),
ui=ui.UISlider(
show_input=True,
attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1),
),
),
Parameter.build_from(
_("Max New Tokens"),

View File

@@ -1,4 +1,5 @@
"""The prompt operator."""
from abc import ABC
from typing import Any, Dict, List, Optional, Union
@@ -18,6 +19,7 @@ from dbgpt.core.awel.flow import (
ResourceCategory,
ViewMetadata,
register_resource,
ui,
)
from dbgpt.core.interface.message import BaseMessage
from dbgpt.core.interface.operators.llm_operator import BaseLLM
@@ -48,6 +50,7 @@ from dbgpt.util.i18n_utils import _
optional=True,
default="You are a helpful AI Assistant.",
description=_("The system message."),
ui=ui.DefaultUITextArea(),
),
Parameter.build_from(
label=_("Message placeholder"),
@@ -65,6 +68,7 @@ from dbgpt.util.i18n_utils import _
default="{user_input}",
placeholder="{user_input}",
description=_("The human message."),
ui=ui.DefaultUITextArea(),
),
],
)

View File

@@ -3,13 +3,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.i18n_utils import _
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ..awel.flow import Parameter, ResourceCategory, register_resource
@PublicAPI(stability="beta")
class ResourceIdentifier(Serializable, ABC):

View File

@@ -0,0 +1,506 @@
import hashlib
import io
import os
from unittest import mock
import pytest
from ..file import (
FileMetadata,
FileMetadataIdentifier,
FileStorageClient,
FileStorageSystem,
InMemoryStorage,
LocalFileStorage,
SimpleDistributedStorage,
)
@pytest.fixture
def temp_test_file_dir(tmpdir):
return str(tmpdir)
@pytest.fixture
def temp_storage_path(tmpdir):
return str(tmpdir)
@pytest.fixture
def local_storage_backend(temp_storage_path):
return LocalFileStorage(temp_storage_path)
@pytest.fixture
def distributed_storage_backend(temp_storage_path):
node_address = "127.0.0.1:8000"
return SimpleDistributedStorage(node_address, temp_storage_path)
@pytest.fixture
def file_storage_system(local_storage_backend):
backends = {"local": local_storage_backend}
metadata_storage = InMemoryStorage()
return FileStorageSystem(backends, metadata_storage)
@pytest.fixture
def file_storage_client(file_storage_system):
return FileStorageClient(storage_system=file_storage_system)
@pytest.fixture
def sample_file_path(temp_test_file_dir):
file_path = os.path.join(temp_test_file_dir, "sample.txt")
with open(file_path, "wb") as f:
f.write(b"Sample file content")
return file_path
@pytest.fixture
def sample_file_data():
return io.BytesIO(b"Sample file content for distributed storage")
def test_save_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
assert uri.startswith("dbgpt-fs://local/test-bucket/")
assert os.path.exists(sample_file_path)
def test_get_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
assert file_data.read() == b"Sample file content"
assert metadata.file_name == "sample.txt"
assert metadata.bucket == bucket
def test_delete_file(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
assert len(file_storage_client.list_files(bucket=bucket)) == 1
result = file_storage_client.delete_file(uri)
assert result is True
assert len(file_storage_client.list_files(bucket=bucket)) == 0
def test_list_files(file_storage_client, sample_file_path):
bucket = "test-bucket"
uri1 = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
files = file_storage_client.list_files(bucket=bucket)
assert len(files) == 1
def test_save_file_unsupported_storage(file_storage_system, sample_file_path):
bucket = "test-bucket"
with pytest.raises(ValueError):
file_storage_system.save_file(
bucket=bucket,
file_name="unsupported.txt",
file_data=io.BytesIO(b"Unsupported storage"),
storage_type="unsupported",
)
def test_get_file_not_found(file_storage_system):
with pytest.raises(FileNotFoundError):
file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent")
def test_delete_file_not_found(file_storage_system):
result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent")
assert result is False
def test_metadata_management(file_storage_system):
bucket = "test-bucket"
file_id = "test_file"
metadata = file_storage_system.metadata_storage.save(
FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=100,
storage_type="local",
storage_path="/path/to/test.txt",
uri="dbgpt-fs://local/test-bucket/test_file",
custom_metadata={"key": "value"},
file_hash="hash",
)
)
loaded_metadata = file_storage_system.metadata_storage.load(
FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata
)
assert loaded_metadata.file_name == "test.txt"
assert loaded_metadata.custom_metadata["key"] == "value"
assert loaded_metadata.bucket == bucket
def test_concurrent_save_and_delete(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Simulate concurrent file save and delete operations
def save_file():
return file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
def delete_file(uri):
return file_storage_client.delete_file(uri)
uri = save_file()
# Simulate concurrent operations
save_file()
delete_file(uri)
assert len(file_storage_client.list_files(bucket=bucket)) == 1
def test_large_file_handling(file_storage_client, temp_storage_path):
bucket = "test-bucket"
large_file_path = os.path.join(temp_storage_path, "large_sample.bin")
with open(large_file_path, "wb") as f:
f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file
uri = file_storage_client.upload_file(
bucket=bucket,
file_path=large_file_path,
storage_type="local",
custom_metadata={"description": "Large file test"},
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
assert file_data.read() == open(large_file_path, "rb").read()
assert metadata.file_name == "large_sample.bin"
assert metadata.bucket == bucket
def test_file_hash_verification_success(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Upload file and
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
file_data, metadata = file_storage_client.storage_system.get_file(uri)
file_hash = metadata.file_hash
calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data)
assert (
file_hash == calculated_hash
), "File hash should match after saving and loading"
def test_file_hash_verification_failure(file_storage_client, sample_file_path):
bucket = "test-bucket"
# Upload file and
uri = file_storage_client.upload_file(
bucket=bucket, file_path=sample_file_path, storage_type="local"
)
# Modify the file content manually to simulate file tampering
storage_system = file_storage_client.storage_system
metadata = storage_system.metadata_storage.load(
FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata
)
with open(metadata.storage_path, "wb") as f:
f.write(b"Tampered content")
# Get file should raise an exception due to hash mismatch
with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."):
storage_system.get_file(uri)
def test_file_isolation_across_buckets(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the same file to two different buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Verify both URIs are different and point to different files
assert uri1 != uri2
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
assert file_data1.read() == b"Sample file content"
assert file_data2.read() == b"Sample file content"
assert metadata1.bucket == bucket1
assert metadata2.bucket == bucket2
def test_list_files_in_specific_bucket(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload a file to both buckets
file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# List files in bucket1 and bucket2
files_in_bucket1 = file_storage_client.list_files(bucket=bucket1)
files_in_bucket2 = file_storage_client.list_files(bucket=bucket2)
assert len(files_in_bucket1) == 1
assert len(files_in_bucket2) == 1
assert files_in_bucket1[0].bucket == bucket1
assert files_in_bucket2[0].bucket == bucket2
def test_delete_file_in_one_bucket_does_not_affect_other_bucket(
file_storage_client, sample_file_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the same file to two different buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Delete the file in bucket1
file_storage_client.delete_file(uri1)
# Check that the file in bucket1 is deleted
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
# Check that the file in bucket2 is still there
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
assert file_data2.read() == b"Sample file content"
def test_file_hash_verification_in_different_buckets(
file_storage_client, sample_file_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the file to both buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1)
file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2)
# Verify that file hashes are the same for the same content
file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1)
file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2)
assert file_hash1 == metadata1.file_hash
assert file_hash2 == metadata2.file_hash
assert file_hash1 == file_hash2
def test_file_download_from_different_buckets(
file_storage_client, sample_file_path, temp_storage_path
):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload the file to both buckets
uri1 = file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
uri2 = file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Download files to different locations
download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt")
download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt")
file_storage_client.download_file(uri1, download_path1)
file_storage_client.download_file(uri2, download_path2)
# Verify contents of downloaded files
assert open(download_path1, "rb").read() == b"Sample file content"
assert open(download_path2, "rb").read() == b"Sample file content"
def test_delete_all_files_in_bucket(file_storage_client, sample_file_path):
bucket1 = "bucket1"
bucket2 = "bucket2"
# Upload files to both buckets
file_storage_client.upload_file(
bucket=bucket1, file_path=sample_file_path, storage_type="local"
)
file_storage_client.upload_file(
bucket=bucket2, file_path=sample_file_path, storage_type="local"
)
# Delete all files in bucket1
for file in file_storage_client.list_files(bucket=bucket1):
file_storage_client.delete_file(file.uri)
# Verify bucket1 is empty
assert len(file_storage_client.list_files(bucket=bucket1)) == 0
# Verify bucket2 still has files
assert len(file_storage_client.list_files(bucket=bucket2)) == 1
def test_simple_distributed_storage_save_file(
distributed_storage_backend, sample_file_data, temp_storage_path
):
bucket = "test-bucket"
file_id = "test_file"
file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data)
expected_path = os.path.join(
temp_storage_path,
bucket,
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
)
assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}"
assert os.path.exists(expected_path)
def test_simple_distributed_storage_load_file_local(
distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
distributed_storage_backend.save(bucket, file_id, sample_file_data)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
file_data = distributed_storage_backend.load(metadata)
assert file_data.read() == b"Sample file content for distributed storage"
@mock.patch("requests.get")
def test_simple_distributed_storage_load_file_remote(
mock_get, distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
remote_node_address = "127.0.0.2:8000"
# Mock the response from remote node
mock_response = mock.Mock()
mock_response.iter_content = mock.Mock(
return_value=iter([b"Sample file content for distributed storage"])
)
mock_response.raise_for_status = mock.Mock(return_value=None)
mock_get.return_value = mock_response
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
file_data = distributed_storage_backend.load(metadata)
assert file_data.read() == b"Sample file content for distributed storage"
mock_get.assert_called_once_with(
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
stream=True,
timeout=360,
)
def test_simple_distributed_storage_delete_file_local(
distributed_storage_backend, sample_file_data, temp_storage_path
):
bucket = "test-bucket"
file_id = "test_file"
distributed_storage_backend.save(bucket, file_id, sample_file_data)
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
result = distributed_storage_backend.delete(metadata)
file_path = os.path.join(
temp_storage_path,
bucket,
f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}",
)
assert result is True
assert not os.path.exists(file_path)
@mock.patch("requests.delete")
def test_simple_distributed_storage_delete_file_remote(
mock_delete, distributed_storage_backend, sample_file_data
):
bucket = "test-bucket"
file_id = "test_file"
remote_node_address = "127.0.0.2:8000"
mock_response = mock.Mock()
mock_response.raise_for_status = mock.Mock(return_value=None)
mock_delete.return_value = mock_response
metadata = FileMetadata(
file_id=file_id,
bucket=bucket,
file_name="test.txt",
file_size=len(sample_file_data.getvalue()),
storage_type="distributed",
storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}",
uri=f"distributed://{remote_node_address}/{bucket}/{file_id}",
custom_metadata={},
file_hash="hash",
)
result = distributed_storage_backend.delete(metadata)
assert result is True
mock_delete.assert_called_once_with(
f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}",
timeout=360,
)

View File

@@ -0,0 +1,327 @@
import base64
import os
from itertools import product
from cryptography.fernet import Fernet
from ..variables import (
FernetEncryption,
InMemoryStorage,
SimpleEncryption,
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
build_variable_string,
parse_variable,
)
def test_fernet_encryption():
key = Fernet.generate_key()
encryption = FernetEncryption(key)
new_encryption = FernetEncryption(key)
data = "test_data"
salt = "test_salt"
encrypted_data = encryption.encrypt(data, salt)
assert encrypted_data != data
decrypted_data = encryption.decrypt(encrypted_data, salt)
assert decrypted_data == data
assert decrypted_data == new_encryption.decrypt(encrypted_data, salt)
def test_simple_encryption():
key = base64.b64encode(os.urandom(32)).decode()
encryption = SimpleEncryption(key)
data = "test_data"
salt = "test_salt"
encrypted_data = encryption.encrypt(data, salt)
assert encrypted_data != data
decrypted_data = encryption.decrypt(encrypted_data, salt)
assert decrypted_data == data
def test_storage_variables_provider():
storage = InMemoryStorage()
encryption = SimpleEncryption()
provider = StorageVariablesProvider(storage, encryption)
full_key = "${key:name@global}"
value = "secret_value"
value_type = "str"
label = "test_label"
id = VariablesIdentifier.from_str_identifier(full_key)
provider.save(
StorageVariables.from_identifier(
id, value, value_type, label, category="secret"
)
)
loaded_variable_value = provider.get(full_key)
assert loaded_variable_value == value
def test_variables_identifier():
full_key = "${key:name@global:scope_key#sys_code%user_name}"
identifier = VariablesIdentifier.from_str_identifier(full_key)
assert identifier.key == "key"
assert identifier.name == "name"
assert identifier.scope == "global"
assert identifier.scope_key == "scope_key"
assert identifier.sys_code == "sys_code"
assert identifier.user_name == "user_name"
str_identifier = identifier.str_identifier
assert str_identifier == full_key
def test_storage_variables():
key = "test_key"
name = "test_name"
label = "test_label"
value = "test_value"
value_type = "str"
category = "common"
scope = "global"
storage_variable = StorageVariables(
key=key,
name=name,
label=label,
value=value,
value_type=value_type,
category=category,
scope=scope,
)
assert storage_variable.key == key
assert storage_variable.name == name
assert storage_variable.label == label
assert storage_variable.value == value
assert storage_variable.value_type == value_type
assert storage_variable.category == category
assert storage_variable.scope == scope
dict_representation = storage_variable.to_dict()
assert dict_representation["key"] == key
assert dict_representation["name"] == name
assert dict_representation["label"] == label
assert dict_representation["value"] == value
assert dict_representation["value_type"] == value_type
assert dict_representation["category"] == category
assert dict_representation["scope"] == scope
def generate_test_cases(enable_escape=False):
# Define possible values for each field, including special characters for escaping
_EMPTY_ = "___EMPTY___"
fields = {
"name": [
None,
"test_name",
"test:name" if enable_escape else _EMPTY_,
"test::name" if enable_escape else _EMPTY_,
"test#name" if enable_escape else _EMPTY_,
"test##name" if enable_escape else _EMPTY_,
"test::@@@#22name" if enable_escape else _EMPTY_,
],
"scope": [
None,
"test_scope",
"test@scope" if enable_escape else _EMPTY_,
"test@@scope" if enable_escape else _EMPTY_,
"test:scope" if enable_escape else _EMPTY_,
"test:#:scope" if enable_escape else _EMPTY_,
],
"scope_key": [
None,
"test_scope_key",
"test:scope_key" if enable_escape else _EMPTY_,
],
"sys_code": [
None,
"test_sys_code",
"test#sys_code" if enable_escape else _EMPTY_,
],
"user_name": [
None,
"test_user_name",
"test%user_name" if enable_escape else _EMPTY_,
],
}
# Remove empty values
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
# Generate all possible combinations
combinations = product(*fields.values())
test_cases = []
for combo in combinations:
name, scope, scope_key, sys_code, user_name = combo
var_str = build_variable_string(
{
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
},
enable_escape=enable_escape,
)
# Construct the expected output
expected = {
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
}
test_cases.append((var_str, expected, enable_escape))
return test_cases
def test_parse_variables():
# Run test cases without escape
test_cases = generate_test_cases(enable_escape=False)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
result = parse_variable(input_str, enable_escape=enable_escape)
assert result == expected_output, f"Test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_test_cases(enable_escape=True)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
print(f"input_str: {input_str}, expected_output: {expected_output}")
result = parse_variable(input_str, enable_escape=enable_escape)
assert result == expected_output, f"Test case {i} failed with escape"
def generate_build_test_cases(enable_escape=False):
# Define possible values for each field, including special characters for escaping
_EMPTY_ = "___EMPTY___"
fields = {
"name": [
None,
"test_name",
"test:name" if enable_escape else _EMPTY_,
"test::name" if enable_escape else _EMPTY_,
"test\name" if enable_escape else _EMPTY_,
"test\\name" if enable_escape else _EMPTY_,
"test\:\#\@\%name" if enable_escape else _EMPTY_,
"test\::\###\@@\%%name" if enable_escape else _EMPTY_,
"test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_,
"test\:#:name" if enable_escape else _EMPTY_,
],
"scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_],
"scope_key": [
None,
"test_scope_key",
"test:scope_key" if enable_escape else _EMPTY_,
],
"sys_code": [
None,
"test_sys_code",
"test#sys_code" if enable_escape else _EMPTY_,
],
"user_name": [
None,
"test_user_name",
"test%user_name" if enable_escape else _EMPTY_,
],
}
# Remove empty values
fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()}
# Generate all possible combinations
combinations = product(*fields.values())
test_cases = []
def escape_special_chars(s):
if not enable_escape or s is None:
return s
return (
s.replace(":", "\\:")
.replace("@", "\\@")
.replace("%", "\\%")
.replace("#", "\\#")
)
for combo in combinations:
name, scope, scope_key, sys_code, user_name = combo
# Construct the input dictionary
input_dict = {
"key": "test_key",
"name": name,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
}
input_dict_with_escape = {
k: escape_special_chars(v) for k, v in input_dict.items()
}
# Construct the expected variable string
expected_str = "${test_key"
if name:
expected_str += f":{input_dict_with_escape['name']}"
if scope or scope_key:
expected_str += "@"
if scope:
expected_str += input_dict_with_escape["scope"]
if scope_key:
expected_str += f":{input_dict_with_escape['scope_key']}"
if sys_code:
expected_str += f"#{input_dict_with_escape['sys_code']}"
if user_name:
expected_str += f"%{input_dict_with_escape['user_name']}"
expected_str += "}"
test_cases.append((input_dict, expected_str, enable_escape))
return test_cases
def test_build_variable_string():
# Run test cases without escape
test_cases = generate_build_test_cases(enable_escape=False)
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
result = build_variable_string(input_dict, enable_escape=enable_escape)
assert result == expected_str, f"Test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_build_test_cases(enable_escape=True)
for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1):
print(f"input_dict: {input_dict}, expected_str: {expected_str}")
result = build_variable_string(input_dict, enable_escape=enable_escape)
assert result == expected_str, f"Test case {i} failed with escape"
def test_variable_string_round_trip():
# Run test cases without escape
test_cases = generate_test_cases(enable_escape=False)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
assert (
built_result == input_str
), f"Round trip test case {i} failed without escape"
# Run test cases with escape
test_cases = generate_test_cases(enable_escape=True)
for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1):
parsed_result = parse_variable(input_str, enable_escape=enable_escape)
built_result = build_variable_string(parsed_result, enable_escape=enable_escape)
assert built_result == input_str, f"Round trip test case {i} failed with escape"

View File

@@ -0,0 +1,979 @@
"""Variables Module."""
import base64
import dataclasses
import hashlib
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.executor_utils import (
DefaultExecutorFactory,
blocking_func_to_async,
blocking_func_to_async_no_executor,
)
from .storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE"
BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows"
BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes"
BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces"
class Encryption(ABC):
"""Encryption interface."""
name: str = "__abstract__"
@abstractmethod
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data."""
@abstractmethod
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data."""
def _generate_key_from_password(
password: bytes, salt: Optional[Union[str, bytes]] = None
):
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
if salt is None:
salt = os.urandom(16)
elif isinstance(salt, str):
salt = salt.encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password))
return key, salt
class FernetEncryption(Encryption):
"""Fernet encryption.
A symmetric encryption algorithm that uses the same key for both encryption and
decryption which is powered by the cryptography library.
"""
name = "fernet"
def __init__(self, key: Optional[bytes] = None):
"""Initialize the fernet encryption."""
if key is not None and isinstance(key, str):
key = key.encode()
try:
from cryptography.fernet import Fernet
except ImportError:
raise ImportError(
"cryptography is required for encryption, please install by running "
"`pip install cryptography`"
)
if key is None:
key = Fernet.generate_key()
self.key = key
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data with the salt.
Args:
data (str): The data to encrypt.
salt (str): The salt to use, which is used to derive the key.
Returns:
str: The encrypted data.
"""
from cryptography.fernet import Fernet
key, salt = _generate_key_from_password(self.key, salt)
fernet = Fernet(key)
encrypted_secret = fernet.encrypt(data.encode()).decode()
return encrypted_secret
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data with the salt.
Args:
encrypted_data (str): The encrypted data.
salt (str): The salt to use, which is used to derive the key.
Returns:
str: The decrypted data.
"""
from cryptography.fernet import Fernet
key, salt = _generate_key_from_password(self.key, salt)
fernet = Fernet(key)
return fernet.decrypt(encrypted_data.encode()).decode()
class SimpleEncryption(Encryption):
"""Simple implementation of encryption.
A simple encryption algorithm that uses a key to XOR the data.
"""
name = "simple"
def __init__(self, key: Optional[str] = None):
"""Initialize the simple encryption."""
if key is None:
key = base64.b64encode(os.urandom(32)).decode()
self.key = key
def _derive_key(self, salt: str) -> bytes:
return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000)
def encrypt(self, data: str, salt: str) -> str:
"""Encrypt the data with the salt."""
key = self._derive_key(salt)
encrypted = bytes(
x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1))
)
return base64.b64encode(encrypted).decode()
def decrypt(self, encrypted_data: str, salt: str) -> str:
"""Decrypt the data with the salt."""
key = self._derive_key(salt)
data = base64.b64decode(encrypted_data)
decrypted = bytes(
x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1))
)
return decrypted.decode()
@dataclasses.dataclass
class VariablesIdentifier(ResourceIdentifier):
"""The variables identifier."""
identifier_split: str = dataclasses.field(default="@", init=False)
key: str
name: str
scope: str = "global"
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
def __post_init__(self):
"""Post init method."""
if not self.key or not self.name or not self.scope:
raise ValueError("Key, name, and scope are required.")
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return build_variable_string(
{
"key": self.key,
"name": self.name,
"scope": self.scope,
"scope_key": self.scope_key,
"sys_code": self.sys_code,
"user_name": self.user_name,
}
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"key": self.key,
"name": self.name,
"scope": self.scope,
"scope_key": self.scope_key,
"sys_code": self.sys_code,
"user_name": self.user_name,
}
@classmethod
def from_str_identifier(
cls,
str_identifier: str,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> "VariablesIdentifier":
"""Create a VariablesIdentifier from a string identifier.
Args:
str_identifier (str): The string identifier.
default_identifier_map (Optional[Dict[str, str]]): The default identifier
map, which contains the default values for the identifier. Defaults to
None.
Returns:
VariablesIdentifier: The VariablesIdentifier.
"""
variable_dict = parse_variable(str_identifier)
if not variable_dict:
raise ValueError("Invalid string identifier.")
if not variable_dict.get("key"):
raise ValueError("Invalid string identifier, must have key")
if not variable_dict.get("name"):
raise ValueError("Invalid string identifier, must have name")
def _get_value(key, default_value: Optional[str] = None) -> Optional[str]:
if variable_dict.get(key) is not None:
return variable_dict.get(key)
if default_identifier_map is not None and default_identifier_map.get(key):
return default_identifier_map.get(key)
return default_value
return cls(
key=variable_dict["key"],
name=variable_dict["name"],
scope=variable_dict["scope"],
scope_key=_get_value("scope_key"),
sys_code=_get_value("sys_code"),
user_name=_get_value("user_name"),
)
@dataclasses.dataclass
class StorageVariables(StorageItem):
"""The storage variables."""
key: str
name: str
label: str
value: Any
category: Literal["common", "secret"] = "common"
scope: str = "global"
value_type: Optional[str] = None
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
encryption_method: Optional[str] = None
salt: Optional[str] = None
enabled: int = 1
description: Optional[str] = None
_identifier: VariablesIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = VariablesIdentifier(
key=self.key,
name=self.name,
scope=self.scope,
scope_key=self.scope_key,
sys_code=self.sys_code,
user_name=self.user_name,
)
if not self.value_type:
self.value_type = type(self.value).__name__
@property
def identifier(self) -> ResourceIdentifier:
"""Return the identifier."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge with another storage variables."""
if not isinstance(other, StorageVariables):
raise ValueError(f"Cannot merge with {type(other)}")
self.from_object(other)
def to_dict(self) -> Dict:
"""Convert the storage variables to a dict.
Returns:
Dict: The dict of the storage variables.
"""
return {
**self._identifier.to_dict(),
"label": self.label,
"value": self.value,
"value_type": self.value_type,
"category": self.category,
"encryption_method": self.encryption_method,
"salt": self.salt,
"enabled": self.enabled,
"description": self.description,
}
def from_object(self, other: "StorageVariables") -> None:
"""Copy the values from another storage variables object."""
self.label = other.label
self.value = other.value
self.value_type = other.value_type
self.category = other.category
self.scope = other.scope
self.scope_key = other.scope_key
self.sys_code = other.sys_code
self.user_name = other.user_name
self.encryption_method = other.encryption_method
self.salt = other.salt
self.enabled = other.enabled
self.description = other.description
@classmethod
def from_identifier(
cls,
identifier: VariablesIdentifier,
value: Any,
value_type: str,
label: str = "",
category: Literal["common", "secret"] = "common",
encryption_method: Optional[str] = None,
salt: Optional[str] = None,
) -> "StorageVariables":
"""Copy the values from an identifier."""
return cls(
key=identifier.key,
name=identifier.name,
label=label,
value=value,
value_type=value_type,
category=category,
scope=identifier.scope,
scope_key=identifier.scope_key,
sys_code=identifier.sys_code,
user_name=identifier.user_name,
encryption_method=encryption_method,
salt=salt,
)
class VariablesProvider(BaseComponent, ABC):
"""The variables provider interface."""
name = ComponentType.VARIABLES_PROVIDER.value
@abstractmethod
def get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
@abstractmethod
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
@abstractmethod
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get variables by key."""
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get variables by key async."""
raise NotImplementedError("Current variables provider does not support async.")
def support_async(self) -> bool:
"""Whether the variables provider support async."""
return False
def _convert_to_value_type(self, var: StorageVariables):
"""Convert the variable to the value type."""
if var.value is None:
return None
if var.value_type == "str":
return str(var.value)
elif var.value_type == "int":
return int(var.value)
elif var.value_type == "float":
return float(var.value)
elif var.value_type == "bool":
if var.value.lower() in ["true", "1"]:
return True
elif var.value.lower() in ["false", "0"]:
return False
else:
return bool(var.value)
else:
return var.value
class VariablesPlaceHolder:
"""The variables place holder."""
def __init__(
self,
param_name: str,
full_key: str,
default_value: Any = _EMPTY_DEFAULT_VALUE,
):
"""Initialize the variables place holder."""
self.param_name = param_name
self.full_key = full_key
self.default_value = default_value
def parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables."""
try:
return variables_provider.get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e
def __repr__(self):
"""Return the representation of the variables place holder."""
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
class StorageVariablesProvider(VariablesProvider):
"""The storage variables provider."""
def __init__(
self,
storage: Optional[StorageInterface] = None,
encryption: Optional[Encryption] = None,
system_app: Optional[SystemApp] = None,
key: Optional[str] = None,
):
"""Initialize the storage variables provider."""
if storage is None:
storage = InMemoryStorage()
self.system_app = system_app
self.encryption = encryption or SimpleEncryption(key)
self.storage = storage
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the storage variables provider."""
self.system_app = system_app
def get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables)
if variable is None:
if default_value == _EMPTY_DEFAULT_VALUE:
raise ValueError(f"Variable {full_key} not found")
return default_value
variable.value = self.deserialize_value(variable.value)
if (
variable.value is not None
and variable.category == "secret"
and variable.encryption_method
and variable.salt
):
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return self._convert_to_value_type(variable)
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
if variables_item.category == "secret":
salt = base64.b64encode(os.urandom(16)).decode()
variables_item.value = self.encryption.encrypt(
str(variables_item.value), salt
)
variables_item.encryption_method = self.encryption.name
variables_item.salt = salt
# Replace value to a json serializable object
variables_item.value = self.serialize_value(variables_item.value)
self.storage.save_or_update(variables_item)
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Query variables from storage."""
# Try to get builtin variables
is_builtin, builtin_variables = self._get_builtins_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
if is_builtin:
return builtin_variables
variables = self.storage.query(
QuerySpec(
conditions={
"key": key,
"scope": scope,
"scope_key": scope_key,
"sys_code": sys_code,
"user_name": user_name,
"enabled": 1,
}
),
StorageVariables,
)
for variable in variables:
variable.value = self.deserialize_value(variable.value)
return variables
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Query variables from storage async."""
# Try to get builtin variables
is_builtin, builtin_variables = await self._async_get_builtins_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
if is_builtin:
return builtin_variables
executor_factory: Optional[
DefaultExecutorFactory
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
if executor_factory:
return await blocking_func_to_async(
executor_factory.create(),
self.get_variables,
key,
scope,
scope_key,
sys_code,
user_name,
)
else:
return await blocking_func_to_async_no_executor(
self.get_variables, key, scope, scope_key, sys_code, user_name
)
def _get_builtins_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> Tuple[bool, List[StorageVariables]]:
"""Get builtin variables."""
if self.system_app is None:
return False, []
provider: BuiltinVariablesProvider = self.system_app.get_component(
key,
component_type=BuiltinVariablesProvider,
default_component=None,
)
if not provider:
return False, []
return True, provider.get_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
async def _async_get_builtins_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> Tuple[bool, List[StorageVariables]]:
"""Get builtin variables."""
if self.system_app is None:
return False, []
provider: BuiltinVariablesProvider = self.system_app.get_component(
key,
component_type=BuiltinVariablesProvider,
default_component=None,
)
if not provider:
return False, []
if not provider.support_async():
return False, []
return True, await provider.async_get_variables(
key,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
@classmethod
def serialize_value(cls, value: Any) -> str:
"""Serialize the value."""
value_dict = {"value": value}
return json.dumps(value_dict, ensure_ascii=False)
@classmethod
def deserialize_value(cls, value: str) -> Any:
"""Deserialize the value."""
value_dict = json.loads(value)
return value_dict["value"]
class BuiltinVariablesProvider(VariablesProvider, ABC):
"""The builtin variables provider.
You can implement this class to provide builtin variables. Such LLMs, agents,
datasource, knowledge base, etc.
"""
name = "dbgpt_variables_builtin"
def __init__(self, system_app: Optional[SystemApp] = None):
"""Initialize the builtin variables provider."""
self.system_app = system_app
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the builtin variables provider."""
self.system_app = system_app
def get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support get.")
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
raise NotImplementedError("BuiltinVariablesProvider does not support save.")
def parse_variable(
variable_str: str,
enable_escape: bool = True,
) -> Dict[str, Any]:
"""Parse the variable string.
Examples:
.. code-block:: python
cases = [
{
"full_key": "${test_key:test_name@test_scope:test_scope_key}",
"expected": {
"key": "test_key",
"name": "test_name",
"scope": "test_scope",
"scope_key": "test_scope_key",
"sys_code": None,
"user_name": None,
},
},
{
"full_key": "${test_key#test_sys_code}",
"expected": {
"key": "test_key",
"name": None,
"scope": None,
"scope_key": None,
"sys_code": "test_sys_code",
"user_name": None,
},
},
{
"full_key": "${test_key@:test_scope_key}",
"expected": {
"key": "test_key",
"name": None,
"scope": None,
"scope_key": "test_scope_key",
"sys_code": None,
"user_name": None,
},
},
]
for case in cases:
assert parse_variable(case["full_key"]) == case["expected"]
Args:
variable_str (str): The variable string.
enable_escape (bool): Whether to handle escaped characters.
Returns:
Dict[str, Any]: The parsed variable.
"""
if not variable_str.startswith("${") or not variable_str.endswith("}"):
raise ValueError(
"Invalid variable format, must start with '${' and end with '}'"
)
# Remove the surrounding ${ and }
content = variable_str[2:-1]
# Define placeholders for escaped characters
placeholders = {
r"\@": "__ESCAPED_AT__",
r"\#": "__ESCAPED_HASH__",
r"\%": "__ESCAPED_PERCENT__",
r"\:": "__ESCAPED_COLON__",
}
if enable_escape:
# Replace escaped characters with placeholders
for original, placeholder in placeholders.items():
content = content.replace(original, placeholder)
# Initialize the result dictionary
result: Dict[str, Optional[str]] = {
"key": None,
"name": None,
"scope": None,
"scope_key": None,
"sys_code": None,
"user_name": None,
}
# Split the content by special characters
parts = content.split("@")
# Parse key and name
key_name = parts[0].split("#")[0].split("%")[0]
if ":" in key_name:
result["key"], result["name"] = key_name.split(":", 1)
else:
result["key"] = key_name
# Parse scope and scope_key
if len(parts) > 1:
scope_part = parts[1].split("#")[0].split("%")[0]
if ":" in scope_part:
result["scope"], result["scope_key"] = scope_part.split(":", 1)
else:
result["scope"] = scope_part
# Parse sys_code
if "#" in content:
result["sys_code"] = content.split("#", 1)[1].split("%")[0]
# Parse user_name
if "%" in content:
result["user_name"] = content.split("%", 1)[1]
if enable_escape:
# Replace placeholders back with escaped characters
reverse_placeholders = {v: k[1:] for k, v in placeholders.items()}
for key, value in result.items():
if value:
for placeholder, original in reverse_placeholders.items():
result[key] = result[key].replace( # type: ignore
placeholder, original
)
# Replace empty strings with None
for key, value in result.items():
if value == "":
result[key] = None
return result
def _is_variable_format(value: str) -> bool:
if not value.startswith("${") or not value.endswith("}"):
return False
return True
def is_variable_string(variable_str: str) -> bool:
"""Check if the given string is a variable string.
A valid variable string should start with "${" and end with "}", and contain key
and name
Args:
variable_str (str): The string to check.
Returns:
bool: True if the string is a variable string, False otherwise.
"""
if not variable_str or not isinstance(variable_str, str):
return False
if not _is_variable_format(variable_str):
return False
try:
variable_dict = parse_variable(variable_str)
if not variable_dict.get("key"):
return False
if not variable_dict.get("name"):
return False
return True
except Exception:
return False
def is_variable_list_string(variable_str: str) -> bool:
"""Check if the given string is a variable string.
A valid variable list string should start with "${" and end with "}", and contain
key and not contain name
A valid variable list string means that the variable is a list of variables with the
same key.
Args:
variable_str (str): The string to check.
Returns:
bool: True if the string is a variable string, False otherwise.
"""
if not _is_variable_format(variable_str):
return False
try:
variable_dict = parse_variable(variable_str)
if not variable_dict.get("key"):
return False
if variable_dict.get("name"):
return False
return True
except Exception:
return False
def build_variable_string(
variable_dict: Dict[str, Any],
scope_sig: str = "@",
sys_code_sig: str = "#",
user_sig: str = "%",
kv_sig: str = ":",
enable_escape: bool = True,
) -> str:
"""Build a variable string from the given dictionary.
Args:
variable_dict (Dict[str, Any]): The dictionary containing the variable details.
scope_sig (str): The scope signature.
sys_code_sig (str): The sys code signature.
user_sig (str): The user signature.
kv_sig (str): The key-value split signature.
enable_escape (bool): Whether to escape special characters
Returns:
str: The formatted variable string.
Examples:
>>> build_variable_string(
... {
... "key": "test_key",
... "name": "test_name",
... "scope": "test_scope",
... "scope_key": "test_scope_key",
... "sys_code": "test_sys_code",
... "user_name": "test_user",
... }
... )
'${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}'
>>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"})
'${test_key@:test_scope_key}'
>>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"})
'${test_key#test_sys_code}'
>>> build_variable_string({"key": "test_key"})
'${test_key}'
"""
special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig}
# Replace None with ""
new_variable_dict = {key: value or "" for key, value in variable_dict.items()}
# Check if the variable_dict contains any special characters
for key, value in new_variable_dict.items():
if value != "" and any(char in value for char in special_chars):
if enable_escape:
# Escape special characters
new_variable_dict[key] = (
value.replace("@", "\\@")
.replace("#", "\\#")
.replace("%", "\\%")
.replace(":", "\\:")
)
else:
raise ValueError(
f"{key} contains special characters, error value: {value}, special "
f"characters: {special_chars}"
)
key = new_variable_dict.get("key", "")
name = new_variable_dict.get("name", "")
scope = new_variable_dict.get("scope", "")
scope_key = new_variable_dict.get("scope_key", "")
sys_code = new_variable_dict.get("sys_code", "")
user_name = new_variable_dict.get("user_name", "")
# Construct the base of the variable string
variable_str = f"${{{key}"
# Add name if present
if name:
variable_str += f":{name}"
# Add scope and scope_key if present
if scope or scope_key:
variable_str += f"@{scope}"
if scope_key:
variable_str += f":{scope_key}"
# Add sys_code if present
if sys_code:
variable_str += f"#{sys_code}"
# Add user_name if present
if user_name:
variable_str += f"%{user_name}"
# Close the variable string
variable_str += "}"
return variable_str