mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-20 19:09:24 +00:00
feat(model): Support GLM4.1-vl model (#2806)
This commit is contained in:
parent
00af86d89e
commit
d27fdb7928
@ -149,6 +149,9 @@ hf_qwen_omni = [
|
|||||||
hf_glm4 = [
|
hf_glm4 = [
|
||||||
"transformers>=4.51.3",
|
"transformers>=4.51.3",
|
||||||
]
|
]
|
||||||
|
hf_glm4_1vl = [
|
||||||
|
"transformers>=4.53.0",
|
||||||
|
]
|
||||||
hf_kimi = [
|
hf_kimi = [
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
@ -179,6 +182,10 @@ conflicts = [
|
|||||||
{ extra = "hf_glm4" },
|
{ extra = "hf_glm4" },
|
||||||
{ extra = "hf_kimi" },
|
{ extra = "hf_kimi" },
|
||||||
],
|
],
|
||||||
|
[
|
||||||
|
{ extra = "hf_glm4_1vl" },
|
||||||
|
{ extra = "hf_kimi" },
|
||||||
|
],
|
||||||
[
|
[
|
||||||
{ extra = "hf_qwen3" },
|
{ extra = "hf_qwen3" },
|
||||||
{ extra = "hf_kimi" },
|
{ extra = "hf_kimi" },
|
||||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Literal, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core.schema.types import (
|
from dbgpt.core.schema.types import (
|
||||||
ChatCompletionContentPartParam,
|
ChatCompletionContentPartParam,
|
||||||
@ -249,6 +249,7 @@ class MediaContent:
|
|||||||
role,
|
role,
|
||||||
content: Union[str, "MediaContent", List["MediaContent"]],
|
content: Union[str, "MediaContent", List["MediaContent"]],
|
||||||
support_media_content: bool = True,
|
support_media_content: bool = True,
|
||||||
|
type_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
"""Convert the media contents to chat completion message."""
|
"""Convert the media contents to chat completion message."""
|
||||||
if not content:
|
if not content:
|
||||||
@ -257,7 +258,10 @@ class MediaContent:
|
|||||||
return {"role": role, "content": content}
|
return {"role": role, "content": content}
|
||||||
if isinstance(content, MediaContent):
|
if isinstance(content, MediaContent):
|
||||||
content = [content]
|
content = [content]
|
||||||
new_content = [cls._parse_single_media_content(c) for c in content]
|
new_content = [
|
||||||
|
cls._parse_single_media_content(c, type_mapping=type_mapping)
|
||||||
|
for c in content
|
||||||
|
]
|
||||||
if not support_media_content:
|
if not support_media_content:
|
||||||
text_content = [
|
text_content = [
|
||||||
c["text"] for c in new_content if c["type"] == "text" and "text" in c
|
c["text"] for c in new_content if c["type"] == "text" and "text" in c
|
||||||
@ -275,40 +279,54 @@ class MediaContent:
|
|||||||
def _parse_single_media_content(
|
def _parse_single_media_content(
|
||||||
cls,
|
cls,
|
||||||
content: "MediaContent",
|
content: "MediaContent",
|
||||||
|
type_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> ChatCompletionContentPartParam:
|
) -> ChatCompletionContentPartParam:
|
||||||
"""Parse a single content."""
|
"""Parse a single content."""
|
||||||
|
if type_mapping is None:
|
||||||
|
type_mapping = {}
|
||||||
if content.type == MediaContentType.TEXT:
|
if content.type == MediaContentType.TEXT:
|
||||||
|
real_type = type_mapping.get("text", "text")
|
||||||
return {
|
return {
|
||||||
"text": str(content.object.data),
|
real_type: str(content.object.data),
|
||||||
"type": "text",
|
"type": real_type,
|
||||||
}
|
}
|
||||||
elif content.type == MediaContentType.IMAGE:
|
elif content.type == MediaContentType.IMAGE:
|
||||||
if content.object.format.startswith("url"):
|
if content.object.format.startswith("url"):
|
||||||
return {
|
# Compatibility for most image url formats
|
||||||
"image_url": {
|
real_type = type_mapping.get("image_url", "image_url")
|
||||||
"url": content.object.data,
|
if real_type == "image_url":
|
||||||
},
|
return {
|
||||||
"type": "image_url",
|
"image_url": {
|
||||||
}
|
"url": content.object.data,
|
||||||
|
},
|
||||||
|
"type": "image_url",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
real_type: content.object.data,
|
||||||
|
"type": real_type,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported image format: {content.object.format}")
|
raise ValueError(f"Unsupported image format: {content.object.format}")
|
||||||
elif content.type == MediaContentType.AUDIO:
|
elif content.type == MediaContentType.AUDIO:
|
||||||
if content.object.format.startswith("base64"):
|
if content.object.format.startswith("base64"):
|
||||||
|
real_type = type_mapping.get("input_audio", "input_audio")
|
||||||
return {
|
return {
|
||||||
"input_audio": {
|
real_type: {
|
||||||
"data": content.object.data,
|
"data": content.object.data,
|
||||||
},
|
},
|
||||||
"type": "input_audio",
|
"type": real_type,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported audio format: {content.object.format}")
|
raise ValueError(f"Unsupported audio format: {content.object.format}")
|
||||||
elif content.type == MediaContentType.VIDEO:
|
elif content.type == MediaContentType.VIDEO:
|
||||||
if content.object.format.startswith("url"):
|
if content.object.format.startswith("url"):
|
||||||
|
real_type = type_mapping.get("video_url", "video_url")
|
||||||
return {
|
return {
|
||||||
"video_url": {
|
real_type: {
|
||||||
"url": content.object.data,
|
"url": content.object.data,
|
||||||
},
|
},
|
||||||
"type": "video_url",
|
"type": real_type,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video format: {content.object.format}")
|
raise ValueError(f"Unsupported video format: {content.object.format}")
|
||||||
|
@ -408,6 +408,7 @@ class ModelMessage(BaseModel):
|
|||||||
convert_to_compatible_format: bool = False,
|
convert_to_compatible_format: bool = False,
|
||||||
support_system_role: bool = True,
|
support_system_role: bool = True,
|
||||||
support_media_content: bool = True,
|
support_media_content: bool = True,
|
||||||
|
type_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Cover to common message format.
|
"""Cover to common message format.
|
||||||
|
|
||||||
@ -420,6 +421,8 @@ class ModelMessage(BaseModel):
|
|||||||
convert_to_compatible_format (bool): Whether to convert to compatible format
|
convert_to_compatible_format (bool): Whether to convert to compatible format
|
||||||
support_system_role (bool): Whether to support system role
|
support_system_role (bool): Whether to support system role
|
||||||
support_media_content (bool): Whether to support media content
|
support_media_content (bool): Whether to support media content
|
||||||
|
type_mapping (Optional[Dict[str, str]]): A mapping of role type to common
|
||||||
|
message, for compatibility with different models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, str]]: The common messages
|
List[Dict[str, str]]: The common messages
|
||||||
@ -436,6 +439,7 @@ class ModelMessage(BaseModel):
|
|||||||
"user",
|
"user",
|
||||||
message.content,
|
message.content,
|
||||||
support_media_content=support_media_content,
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
@ -447,6 +451,7 @@ class ModelMessage(BaseModel):
|
|||||||
"system",
|
"system",
|
||||||
message.content,
|
message.content,
|
||||||
support_media_content=support_media_content,
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif message.role == ModelMessageRoleType.AI:
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
@ -455,6 +460,7 @@ class ModelMessage(BaseModel):
|
|||||||
"assistant",
|
"assistant",
|
||||||
message.content,
|
message.content,
|
||||||
support_media_content=support_media_content,
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -381,6 +381,7 @@ class LLMModelAdapter(ABC):
|
|||||||
messages: List[ModelMessage],
|
messages: List[ModelMessage],
|
||||||
convert_to_compatible_format: bool = False,
|
convert_to_compatible_format: bool = False,
|
||||||
support_media_content: bool = True,
|
support_media_content: bool = True,
|
||||||
|
type_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Transform the model messages
|
"""Transform the model messages
|
||||||
|
|
||||||
@ -414,19 +415,23 @@ class LLMModelAdapter(ABC):
|
|||||||
if not self.support_system_message and convert_to_compatible_format:
|
if not self.support_system_message and convert_to_compatible_format:
|
||||||
# We will not do any transform in the future
|
# We will not do any transform in the future
|
||||||
return self._transform_to_no_system_messages(
|
return self._transform_to_no_system_messages(
|
||||||
messages, support_media_content=support_media_content
|
messages,
|
||||||
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ModelMessage.to_common_messages(
|
return ModelMessage.to_common_messages(
|
||||||
messages,
|
messages,
|
||||||
convert_to_compatible_format=convert_to_compatible_format,
|
convert_to_compatible_format=convert_to_compatible_format,
|
||||||
support_media_content=support_media_content,
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _transform_to_no_system_messages(
|
def _transform_to_no_system_messages(
|
||||||
self,
|
self,
|
||||||
messages: List[ModelMessage],
|
messages: List[ModelMessage],
|
||||||
support_media_content: bool = True,
|
support_media_content: bool = True,
|
||||||
|
type_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Transform the model messages to no system messages
|
"""Transform the model messages to no system messages
|
||||||
|
|
||||||
@ -454,7 +459,9 @@ class LLMModelAdapter(ABC):
|
|||||||
List[Dict[str, str]]: The transformed model messages
|
List[Dict[str, str]]: The transformed model messages
|
||||||
"""
|
"""
|
||||||
openai_messages = ModelMessage.to_common_messages(
|
openai_messages = ModelMessage.to_common_messages(
|
||||||
messages, support_media_content=support_media_content
|
messages,
|
||||||
|
support_media_content=support_media_content,
|
||||||
|
type_mapping=type_mapping,
|
||||||
)
|
)
|
||||||
system_messages = []
|
system_messages = []
|
||||||
return_messages = []
|
return_messages = []
|
||||||
|
@ -1034,6 +1034,100 @@ class GLM40414Adapter(NewHFChatModelAdapter):
|
|||||||
return lower_model_name_or_path and "z1" in lower_model_name_or_path
|
return lower_model_name_or_path and "z1" in lower_model_name_or_path
|
||||||
|
|
||||||
|
|
||||||
|
class GLM41VAdapter(GLM40414Adapter):
|
||||||
|
"""
|
||||||
|
https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking
|
||||||
|
Please make sure your transformers version is >= 4.54.0, and you can install it with
|
||||||
|
following command:
|
||||||
|
uv pip install git+https://github.com/huggingface/transformers.git
|
||||||
|
"""
|
||||||
|
|
||||||
|
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||||
|
return lower_model_name_or_path and "glm-4.1v" in lower_model_name_or_path
|
||||||
|
|
||||||
|
def is_reasoning_model(
|
||||||
|
self,
|
||||||
|
deploy_model_params: LLMDeployModelParameters,
|
||||||
|
lower_model_name_or_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
if (
|
||||||
|
deploy_model_params.reasoning_model is not None
|
||||||
|
and deploy_model_params.reasoning_model
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return lower_model_name_or_path and "thinking" in lower_model_name_or_path
|
||||||
|
|
||||||
|
def check_transformer_version(self, current_version: str) -> None:
|
||||||
|
if not current_version >= "4.54.0":
|
||||||
|
raise ValueError(
|
||||||
|
"GLM-4.1V require transformers.__version__>= 4.54.0, please upgrade "
|
||||||
|
"your transformers package."
|
||||||
|
"And you can install transformers with "
|
||||||
|
"`uv pip install git+https://github.com/huggingface/transformers.git`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
try:
|
||||||
|
from transformers import (
|
||||||
|
Glm4vForConditionalGeneration,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import glm-4.1v model, please upgrade your "
|
||||||
|
"transformers package to 4.54.0 or later."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Load model from {model_path}, from_pretrained_kwargs: "
|
||||||
|
f"{from_pretrained_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
revision = from_pretrained_kwargs.get("revision", "main")
|
||||||
|
trust_remote_code = from_pretrained_kwargs.get(
|
||||||
|
"trust_remote_code", self.trust_remote_code
|
||||||
|
)
|
||||||
|
low_cpu_mem_usage = from_pretrained_kwargs.get("low_cpu_mem_usage", False)
|
||||||
|
if "trust_remote_code" not in from_pretrained_kwargs:
|
||||||
|
from_pretrained_kwargs["trust_remote_code"] = trust_remote_code
|
||||||
|
if "low_cpu_mem_usage" not in from_pretrained_kwargs:
|
||||||
|
from_pretrained_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||||
|
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||||
|
model_path, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
tokenizer = self.load_tokenizer(
|
||||||
|
model_path,
|
||||||
|
revision,
|
||||||
|
use_fast=self.use_fast_tokenizer(),
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
convert_to_compatible_format: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if not tokenizer:
|
||||||
|
raise ValueError("tokenizer is is None")
|
||||||
|
tokenizer: AutoTokenizer = tokenizer
|
||||||
|
type_mapping = {
|
||||||
|
"image_url": "image",
|
||||||
|
}
|
||||||
|
messages = self.transform_model_messages(
|
||||||
|
messages, convert_to_compatible_format, type_mapping=type_mapping
|
||||||
|
)
|
||||||
|
logger.debug(f"The messages after transform: \n{messages}")
|
||||||
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return str_prompt
|
||||||
|
|
||||||
|
|
||||||
class Codegeex4Adapter(GLM4Adapter):
|
class Codegeex4Adapter(GLM4Adapter):
|
||||||
"""
|
"""
|
||||||
https://huggingface.co/THUDM/codegeex4-all-9b
|
https://huggingface.co/THUDM/codegeex4-all-9b
|
||||||
@ -1150,6 +1244,7 @@ register_model_adapter(SQLCoderAdapter)
|
|||||||
register_model_adapter(OpenChatAdapter)
|
register_model_adapter(OpenChatAdapter)
|
||||||
register_model_adapter(GLM4Adapter, supported_models=COMMON_HF_GLM_MODELS)
|
register_model_adapter(GLM4Adapter, supported_models=COMMON_HF_GLM_MODELS)
|
||||||
register_model_adapter(GLM40414Adapter)
|
register_model_adapter(GLM40414Adapter)
|
||||||
|
register_model_adapter(GLM41VAdapter)
|
||||||
register_model_adapter(Codegeex4Adapter)
|
register_model_adapter(Codegeex4Adapter)
|
||||||
register_model_adapter(Qwen2Adapter, supported_models=COMMON_HF_QWEN25_MODELS)
|
register_model_adapter(Qwen2Adapter, supported_models=COMMON_HF_QWEN25_MODELS)
|
||||||
register_model_adapter(Qwen2VLAdapter)
|
register_model_adapter(Qwen2VLAdapter)
|
||||||
|
Loading…
Reference in New Issue
Block a user