mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
feat(model): Proxy multimodal supports (#2641)
This commit is contained in:
@@ -43,6 +43,7 @@ class AliyunOSSStorage(StorageBackend):
|
||||
bucket_prefix: str = "dbgpt-fs-",
|
||||
bucket_mapper: Optional[Callable[[str], str]] = None,
|
||||
auto_create_bucket: bool = True,
|
||||
default_public_url_expire: int = 3600,
|
||||
):
|
||||
"""Initialize the Aliyun OSS storage backend.
|
||||
|
||||
@@ -75,6 +76,7 @@ class AliyunOSSStorage(StorageBackend):
|
||||
self.bucket_prefix = bucket_prefix
|
||||
self.custom_bucket_mapper = bucket_mapper
|
||||
self.auto_create_bucket = auto_create_bucket
|
||||
self.default_public_url_expire = default_public_url_expire
|
||||
|
||||
# Initialize OSS authentication
|
||||
if use_environment_credentials:
|
||||
@@ -223,13 +225,24 @@ class AliyunOSSStorage(StorageBackend):
|
||||
f"Failed to get or create bucket for logical bucket {logical_bucket}"
|
||||
)
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
def save(
|
||||
self,
|
||||
bucket: str,
|
||||
file_id: str,
|
||||
file_data: BinaryIO,
|
||||
public_url: bool = False,
|
||||
public_url_expire: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Save the file data to Aliyun OSS.
|
||||
|
||||
Args:
|
||||
bucket (str): The logical bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
public_url (bool, optional): Whether to generate a public URL. Defaults to
|
||||
False.
|
||||
public_url_expire (Optional[int], optional): Expiration time for the public
|
||||
URL in seconds. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The storage path (OSS URI)
|
||||
@@ -270,7 +283,29 @@ class AliyunOSSStorage(StorageBackend):
|
||||
# Format: oss://{actual_bucket_name}/{object_name}
|
||||
# We store both the actual bucket name and the object path in the URI
|
||||
# But we'll also keep the logical bucket in the external URI format
|
||||
return f"oss://{bucket}/{file_id}?actual_bucket={actual_bucket_name}&object_name={object_name}" # noqa
|
||||
storage_path = f"oss://{bucket}/{file_id}?actual_bucket={actual_bucket_name}&object_name={object_name}" # noqa
|
||||
|
||||
# Generate a public URL if requested
|
||||
if public_url:
|
||||
# Use provided expiration time or default
|
||||
expire_seconds = public_url_expire or self.default_public_url_expire
|
||||
|
||||
# Generate a signed URL for public access
|
||||
try:
|
||||
url = oss_bucket.sign_url(
|
||||
"GET", object_name, expire_seconds, slash_safe=True
|
||||
)
|
||||
logger.info(
|
||||
f"Generated public URL for {object_name} with expiration "
|
||||
f"{expire_seconds} seconds"
|
||||
)
|
||||
return url
|
||||
except oss2.exceptions.OssError as e:
|
||||
logger.error(f"Failed to generate public URL for {object_name}: {e}")
|
||||
# Fall back to returning the storage path
|
||||
return storage_path
|
||||
|
||||
return storage_path
|
||||
|
||||
def _get_file_size(self, file_data: BinaryIO) -> int:
|
||||
"""Get file size without consuming the file object.
|
||||
@@ -482,3 +517,58 @@ class AliyunOSSStorage(StorageBackend):
|
||||
f" {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def get_public_url(
|
||||
self, fm: FileMetadata, expire: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""Generate a public URL for an existing file.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
expire (Optional[int], optional): Expiration time in seconds. Defaults to
|
||||
class default.
|
||||
|
||||
Returns:
|
||||
str: The generated public URL
|
||||
"""
|
||||
# Parse the storage path
|
||||
path_info = self._parse_storage_path(fm.storage_path)
|
||||
|
||||
# Get actual bucket and object name
|
||||
actual_bucket_name = path_info["actual_bucket"]
|
||||
object_name = path_info["object_name"]
|
||||
logical_bucket = path_info["logical_bucket"]
|
||||
|
||||
# If we couldn't determine the actual bucket from the URI, try with the logical
|
||||
# bucket
|
||||
if not actual_bucket_name and logical_bucket:
|
||||
actual_bucket_name = self._map_bucket_name(logical_bucket)
|
||||
|
||||
# Use the file_id as object name if object_name is still None
|
||||
if not object_name:
|
||||
object_name = fm.file_id
|
||||
# If using fixed bucket, prefix with logical bucket
|
||||
if self.fixed_bucket and logical_bucket:
|
||||
object_name = f"{logical_bucket}/{fm.file_id}"
|
||||
|
||||
# Get the bucket object
|
||||
try:
|
||||
oss_bucket = oss2.Bucket(
|
||||
self.auth, self.endpoint, actual_bucket_name, region=self.region
|
||||
)
|
||||
|
||||
# Use provided expiration time or default
|
||||
expire_seconds = expire or self.default_public_url_expire
|
||||
|
||||
# Generate signed URL
|
||||
url = oss_bucket.sign_url(
|
||||
"GET", object_name, expire_seconds, slash_safe=True
|
||||
)
|
||||
logger.info(
|
||||
f"Generated public URL for {object_name} with expiration "
|
||||
f"{expire_seconds} seconds"
|
||||
)
|
||||
return url
|
||||
except oss2.exceptions.OssError as e:
|
||||
logger.error(f"Failed to generate public URL for {fm.file_id}: {e}")
|
||||
raise
|
||||
|
@@ -37,6 +37,7 @@ class S3Storage(StorageBackend):
|
||||
auto_create_bucket: bool = True,
|
||||
signature_version: Optional[str] = None,
|
||||
s3_config: Optional[Dict[str, Union[str, int]]] = None,
|
||||
default_public_url_expire: int = 3600, # Default to 1 hour
|
||||
):
|
||||
"""Initialize the S3 compatible storage backend.
|
||||
|
||||
@@ -62,6 +63,8 @@ class S3Storage(StorageBackend):
|
||||
signature_version (str, optional): S3 signature version to use.
|
||||
s3_config (Optional[Dict[str, Union[str, int]]], optional): Additional
|
||||
S3 configuration options. Defaults to None.
|
||||
default_public_url_expire (int, optional): Default expiration time for
|
||||
public URL in seconds. Defaults to 3600 (1 hour).
|
||||
"""
|
||||
self.endpoint_url = endpoint_url
|
||||
self.region_name = region_name
|
||||
@@ -71,6 +74,7 @@ class S3Storage(StorageBackend):
|
||||
self.custom_bucket_mapper = bucket_mapper
|
||||
self.auto_create_bucket = auto_create_bucket
|
||||
self.signature_version = signature_version
|
||||
self.default_public_url_expire = default_public_url_expire
|
||||
|
||||
# Build S3 client configuration
|
||||
if not s3_config:
|
||||
@@ -251,16 +255,27 @@ class S3Storage(StorageBackend):
|
||||
logger.error(f"Failed to check bucket {bucket_name}: {e}")
|
||||
return False
|
||||
|
||||
def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str:
|
||||
def save(
|
||||
self,
|
||||
bucket: str,
|
||||
file_id: str,
|
||||
file_data: BinaryIO,
|
||||
public_url: bool = False,
|
||||
public_url_expire: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Save the file data to S3.
|
||||
|
||||
Args:
|
||||
bucket (str): The logical bucket name
|
||||
file_id (str): The file ID
|
||||
file_data (BinaryIO): The file data
|
||||
public_url (bool, optional): Whether to generate a public URL. Defaults to
|
||||
False.
|
||||
public_url_expire (Optional[int], optional): Expiration time for the public
|
||||
URL in seconds. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The storage path (S3 URI)
|
||||
str: The storage path (S3 URI) or the public URL if public_url is True
|
||||
"""
|
||||
# Get the actual S3 bucket
|
||||
actual_bucket_name = self._map_bucket_name(bucket)
|
||||
@@ -337,8 +352,32 @@ class S3Storage(StorageBackend):
|
||||
)
|
||||
raise
|
||||
|
||||
# Format: s3://{logical_bucket}/{file_id}?actual_bucket={actual_bucket_name}&object_key={object_key} # noqa
|
||||
return f"s3://{bucket}/{file_id}?actual_bucket={actual_bucket_name}&object_key={object_key}" # noqa
|
||||
# Standard storage path
|
||||
storage_path = f"s3://{bucket}/{file_id}?actual_bucket={actual_bucket_name}&object_key={object_key}" # noqa
|
||||
|
||||
# Generate a public URL if requested
|
||||
if public_url:
|
||||
# Use provided expiration time or default
|
||||
expire_seconds = public_url_expire or self.default_public_url_expire
|
||||
|
||||
try:
|
||||
# Generate a pre-signed URL for public access
|
||||
url = self.s3_client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": actual_bucket_name, "Key": object_key},
|
||||
ExpiresIn=expire_seconds,
|
||||
)
|
||||
logger.info(
|
||||
f"Generated public URL for {object_key} with expiration "
|
||||
f"{expire_seconds} seconds"
|
||||
)
|
||||
return url
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to generate public URL for {object_key}: {e}")
|
||||
# Fall back to returning the storage path
|
||||
return storage_path
|
||||
|
||||
return storage_path
|
||||
|
||||
def _get_file_size(self, file_data: BinaryIO) -> int:
|
||||
"""Get file size without consuming the file object.
|
||||
@@ -470,6 +509,58 @@ class S3Storage(StorageBackend):
|
||||
"object_key": object_key,
|
||||
}
|
||||
|
||||
def get_public_url(
|
||||
self, fm: FileMetadata, expire: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""Generate a public URL for an existing file.
|
||||
|
||||
Args:
|
||||
fm (FileMetadata): The file metadata
|
||||
expire (Optional[int], optional): Expiration time in seconds. Defaults to
|
||||
class default.
|
||||
|
||||
Returns:
|
||||
str: The generated public URL
|
||||
"""
|
||||
# Parse the storage path
|
||||
path_info = self._parse_storage_path(fm.storage_path)
|
||||
|
||||
# Get actual bucket and object key
|
||||
actual_bucket_name = path_info["actual_bucket"]
|
||||
object_key = path_info["object_key"]
|
||||
logical_bucket = path_info["logical_bucket"]
|
||||
|
||||
# If we couldn't determine the actual bucket from the URI, try with the logical
|
||||
# bucket
|
||||
if not actual_bucket_name and logical_bucket:
|
||||
actual_bucket_name = self._map_bucket_name(logical_bucket)
|
||||
|
||||
# Use the file_id as object key if object_key is still None
|
||||
if not object_key:
|
||||
object_key = fm.file_id
|
||||
# If using fixed bucket, prefix with logical bucket
|
||||
if self.fixed_bucket and logical_bucket:
|
||||
object_key = f"{logical_bucket}/{fm.file_id}"
|
||||
|
||||
# Use provided expiration time or default
|
||||
expire_seconds = expire or self.default_public_url_expire
|
||||
|
||||
try:
|
||||
# Generate a pre-signed URL for public access
|
||||
url = self.s3_client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": actual_bucket_name, "Key": object_key},
|
||||
ExpiresIn=expire_seconds,
|
||||
)
|
||||
logger.info(
|
||||
f"Generated public URL for {object_key} with expiration "
|
||||
f"{expire_seconds} seconds"
|
||||
)
|
||||
return url
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to generate public URL for {fm.file_id}: {e}")
|
||||
raise
|
||||
|
||||
def load(self, fm: FileMetadata) -> BinaryIO:
|
||||
"""Load the file data from S3.
|
||||
|
||||
|
Reference in New Issue
Block a user