mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-19 18:40:00 +00:00
feat(model): Support GLM4.1-vl model (#2806)
This commit is contained in:
parent
00af86d89e
commit
d27fdb7928
@ -149,6 +149,9 @@ hf_qwen_omni = [
|
||||
hf_glm4 = [
|
||||
"transformers>=4.51.3",
|
||||
]
|
||||
hf_glm4_1vl = [
|
||||
"transformers>=4.53.0",
|
||||
]
|
||||
hf_kimi = [
|
||||
"tiktoken",
|
||||
"blobfile",
|
||||
@ -179,6 +182,10 @@ conflicts = [
|
||||
{ extra = "hf_glm4" },
|
||||
{ extra = "hf_kimi" },
|
||||
],
|
||||
[
|
||||
{ extra = "hf_glm4_1vl" },
|
||||
{ extra = "hf_kimi" },
|
||||
],
|
||||
[
|
||||
{ extra = "hf_qwen3" },
|
||||
{ extra = "hf_kimi" },
|
||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Literal, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt.core.schema.types import (
|
||||
ChatCompletionContentPartParam,
|
||||
@ -249,6 +249,7 @@ class MediaContent:
|
||||
role,
|
||||
content: Union[str, "MediaContent", List["MediaContent"]],
|
||||
support_media_content: bool = True,
|
||||
type_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert the media contents to chat completion message."""
|
||||
if not content:
|
||||
@ -257,7 +258,10 @@ class MediaContent:
|
||||
return {"role": role, "content": content}
|
||||
if isinstance(content, MediaContent):
|
||||
content = [content]
|
||||
new_content = [cls._parse_single_media_content(c) for c in content]
|
||||
new_content = [
|
||||
cls._parse_single_media_content(c, type_mapping=type_mapping)
|
||||
for c in content
|
||||
]
|
||||
if not support_media_content:
|
||||
text_content = [
|
||||
c["text"] for c in new_content if c["type"] == "text" and "text" in c
|
||||
@ -275,40 +279,54 @@ class MediaContent:
|
||||
def _parse_single_media_content(
|
||||
cls,
|
||||
content: "MediaContent",
|
||||
type_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> ChatCompletionContentPartParam:
|
||||
"""Parse a single content."""
|
||||
if type_mapping is None:
|
||||
type_mapping = {}
|
||||
if content.type == MediaContentType.TEXT:
|
||||
real_type = type_mapping.get("text", "text")
|
||||
return {
|
||||
"text": str(content.object.data),
|
||||
"type": "text",
|
||||
real_type: str(content.object.data),
|
||||
"type": real_type,
|
||||
}
|
||||
elif content.type == MediaContentType.IMAGE:
|
||||
if content.object.format.startswith("url"):
|
||||
return {
|
||||
"image_url": {
|
||||
"url": content.object.data,
|
||||
},
|
||||
"type": "image_url",
|
||||
}
|
||||
# Compatibility for most image url formats
|
||||
real_type = type_mapping.get("image_url", "image_url")
|
||||
if real_type == "image_url":
|
||||
return {
|
||||
"image_url": {
|
||||
"url": content.object.data,
|
||||
},
|
||||
"type": "image_url",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
real_type: content.object.data,
|
||||
"type": real_type,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format: {content.object.format}")
|
||||
elif content.type == MediaContentType.AUDIO:
|
||||
if content.object.format.startswith("base64"):
|
||||
real_type = type_mapping.get("input_audio", "input_audio")
|
||||
return {
|
||||
"input_audio": {
|
||||
real_type: {
|
||||
"data": content.object.data,
|
||||
},
|
||||
"type": "input_audio",
|
||||
"type": real_type,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio format: {content.object.format}")
|
||||
elif content.type == MediaContentType.VIDEO:
|
||||
if content.object.format.startswith("url"):
|
||||
real_type = type_mapping.get("video_url", "video_url")
|
||||
return {
|
||||
"video_url": {
|
||||
real_type: {
|
||||
"url": content.object.data,
|
||||
},
|
||||
"type": "video_url",
|
||||
"type": real_type,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported video format: {content.object.format}")
|
||||
|
@ -408,6 +408,7 @@ class ModelMessage(BaseModel):
|
||||
convert_to_compatible_format: bool = False,
|
||||
support_system_role: bool = True,
|
||||
support_media_content: bool = True,
|
||||
type_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Cover to common message format.
|
||||
|
||||
@ -420,6 +421,8 @@ class ModelMessage(BaseModel):
|
||||
convert_to_compatible_format (bool): Whether to convert to compatible format
|
||||
support_system_role (bool): Whether to support system role
|
||||
support_media_content (bool): Whether to support media content
|
||||
type_mapping (Optional[Dict[str, str]]): A mapping of role type to common
|
||||
message, for compatibility with different models.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The common messages
|
||||
@ -436,6 +439,7 @@ class ModelMessage(BaseModel):
|
||||
"user",
|
||||
message.content,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
)
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
@ -447,6 +451,7 @@ class ModelMessage(BaseModel):
|
||||
"system",
|
||||
message.content,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
@ -455,6 +460,7 @@ class ModelMessage(BaseModel):
|
||||
"assistant",
|
||||
message.content,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -381,6 +381,7 @@ class LLMModelAdapter(ABC):
|
||||
messages: List[ModelMessage],
|
||||
convert_to_compatible_format: bool = False,
|
||||
support_media_content: bool = True,
|
||||
type_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Transform the model messages
|
||||
|
||||
@ -414,19 +415,23 @@ class LLMModelAdapter(ABC):
|
||||
if not self.support_system_message and convert_to_compatible_format:
|
||||
# We will not do any transform in the future
|
||||
return self._transform_to_no_system_messages(
|
||||
messages, support_media_content=support_media_content
|
||||
messages,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
else:
|
||||
return ModelMessage.to_common_messages(
|
||||
messages,
|
||||
convert_to_compatible_format=convert_to_compatible_format,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
|
||||
def _transform_to_no_system_messages(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
support_media_content: bool = True,
|
||||
type_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Transform the model messages to no system messages
|
||||
|
||||
@ -454,7 +459,9 @@ class LLMModelAdapter(ABC):
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
openai_messages = ModelMessage.to_common_messages(
|
||||
messages, support_media_content=support_media_content
|
||||
messages,
|
||||
support_media_content=support_media_content,
|
||||
type_mapping=type_mapping,
|
||||
)
|
||||
system_messages = []
|
||||
return_messages = []
|
||||
|
@ -1034,6 +1034,100 @@ class GLM40414Adapter(NewHFChatModelAdapter):
|
||||
return lower_model_name_or_path and "z1" in lower_model_name_or_path
|
||||
|
||||
|
||||
class GLM41VAdapter(GLM40414Adapter):
|
||||
"""
|
||||
https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking
|
||||
Please make sure your transformers version is >= 4.54.0, and you can install it with
|
||||
following command:
|
||||
uv pip install git+https://github.com/huggingface/transformers.git
|
||||
"""
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path and "glm-4.1v" in lower_model_name_or_path
|
||||
|
||||
def is_reasoning_model(
|
||||
self,
|
||||
deploy_model_params: LLMDeployModelParameters,
|
||||
lower_model_name_or_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
if (
|
||||
deploy_model_params.reasoning_model is not None
|
||||
and deploy_model_params.reasoning_model
|
||||
):
|
||||
return True
|
||||
return lower_model_name_or_path and "thinking" in lower_model_name_or_path
|
||||
|
||||
def check_transformer_version(self, current_version: str) -> None:
|
||||
if not current_version >= "4.54.0":
|
||||
raise ValueError(
|
||||
"GLM-4.1V require transformers.__version__>= 4.54.0, please upgrade "
|
||||
"your transformers package."
|
||||
"And you can install transformers with "
|
||||
"`uv pip install git+https://github.com/huggingface/transformers.git`"
|
||||
)
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
try:
|
||||
from transformers import (
|
||||
Glm4vForConditionalGeneration,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import glm-4.1v model, please upgrade your "
|
||||
"transformers package to 4.54.0 or later."
|
||||
) from exc
|
||||
|
||||
logger.info(
|
||||
f"Load model from {model_path}, from_pretrained_kwargs: "
|
||||
f"{from_pretrained_kwargs}"
|
||||
)
|
||||
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
trust_remote_code = from_pretrained_kwargs.get(
|
||||
"trust_remote_code", self.trust_remote_code
|
||||
)
|
||||
low_cpu_mem_usage = from_pretrained_kwargs.get("low_cpu_mem_usage", False)
|
||||
if "trust_remote_code" not in from_pretrained_kwargs:
|
||||
from_pretrained_kwargs["trust_remote_code"] = trust_remote_code
|
||||
if "low_cpu_mem_usage" not in from_pretrained_kwargs:
|
||||
from_pretrained_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||
model_path, **from_pretrained_kwargs
|
||||
)
|
||||
tokenizer = self.load_tokenizer(
|
||||
model_path,
|
||||
revision,
|
||||
use_fast=self.use_fast_tokenizer(),
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
convert_to_compatible_format: bool = False,
|
||||
) -> Optional[str]:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
tokenizer: AutoTokenizer = tokenizer
|
||||
type_mapping = {
|
||||
"image_url": "image",
|
||||
}
|
||||
messages = self.transform_model_messages(
|
||||
messages, convert_to_compatible_format, type_mapping=type_mapping
|
||||
)
|
||||
logger.debug(f"The messages after transform: \n{messages}")
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
|
||||
|
||||
class Codegeex4Adapter(GLM4Adapter):
|
||||
"""
|
||||
https://huggingface.co/THUDM/codegeex4-all-9b
|
||||
@ -1150,6 +1244,7 @@ register_model_adapter(SQLCoderAdapter)
|
||||
register_model_adapter(OpenChatAdapter)
|
||||
register_model_adapter(GLM4Adapter, supported_models=COMMON_HF_GLM_MODELS)
|
||||
register_model_adapter(GLM40414Adapter)
|
||||
register_model_adapter(GLM41VAdapter)
|
||||
register_model_adapter(Codegeex4Adapter)
|
||||
register_model_adapter(Qwen2Adapter, supported_models=COMMON_HF_QWEN25_MODELS)
|
||||
register_model_adapter(Qwen2VLAdapter)
|
||||
|
Loading…
Reference in New Issue
Block a user