feat(model): Support Qwen3 models (#2664)

This commit is contained in:
Fangyin Cheng
2025-04-29 09:55:28 +08:00
committed by GitHub
parent bcb43266cf
commit 2abd68d6c0
11 changed files with 257 additions and 11 deletions

View File

@@ -155,6 +155,9 @@ hf_kimi = [
"blobfile",
"transformers<4.51.3",
]
hf_qwen3 = [
"transformers>=4.51.0",
]
[build-system]
requires = ["hatchling"]
@@ -174,6 +177,10 @@ conflicts = [
{ extra = "hf_glm4" },
{ extra = "hf_kimi" },
],
[
{ extra = "hf_qwen3" },
{ extra = "hf_kimi" },
],
]
[tool.hatch.build.targets.wheel]
packages = ["src/dbgpt"]

View File

@@ -245,7 +245,10 @@ class MediaContent:
@classmethod
def to_chat_completion_message(
cls, role, content: Union[str, "MediaContent", List["MediaContent"]]
cls,
role,
content: Union[str, "MediaContent", List["MediaContent"]],
support_media_content: bool = True,
) -> ChatCompletionMessageParam:
"""Convert the media contents to chat completion message."""
if not content:
@@ -255,6 +258,14 @@ class MediaContent:
if isinstance(content, MediaContent):
content = [content]
new_content = [cls._parse_single_media_content(c) for c in content]
if not support_media_content:
text_content = [
c["text"] for c in new_content if c["type"] == "text" and "text" in c
]
if not text_content:
raise ValueError("No text content found in the media contents")
# Not support media content, just pass the string text as content
new_content = text_content[0]
return {
"role": role,
"content": new_content,

View File

@@ -407,6 +407,7 @@ class ModelMessage(BaseModel):
messages: List["ModelMessage"],
convert_to_compatible_format: bool = False,
support_system_role: bool = True,
support_media_content: bool = True,
) -> List[Dict[str, str]]:
"""Cover to common message format.
@@ -418,6 +419,7 @@ class ModelMessage(BaseModel):
messages (List["ModelMessage"]): The model messages
convert_to_compatible_format (bool): Whether to convert to compatible format
support_system_role (bool): Whether to support system role
support_media_content (bool): Whether to support media content
Returns:
List[Dict[str, str]]: The common messages
@@ -430,7 +432,11 @@ class ModelMessage(BaseModel):
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append(
MediaContent.to_chat_completion_message("user", message.content)
MediaContent.to_chat_completion_message(
"user",
message.content,
support_media_content=support_media_content,
)
)
elif message.role == ModelMessageRoleType.SYSTEM:
if not support_system_role:
@@ -440,6 +446,7 @@ class ModelMessage(BaseModel):
MediaContent.to_chat_completion_message(
"system",
message.content,
support_media_content=support_media_content,
)
)
elif message.role == ModelMessageRoleType.AI:
@@ -447,6 +454,7 @@ class ModelMessage(BaseModel):
MediaContent.to_chat_completion_message(
"assistant",
message.content,
support_media_content=support_media_content,
)
)
else:

View File

@@ -367,7 +367,10 @@ class LLMModelAdapter(ABC):
return roles
def transform_model_messages(
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
self,
messages: List[ModelMessage],
convert_to_compatible_format: bool = False,
support_media_content: bool = True,
) -> List[Dict[str, str]]:
"""Transform the model messages
@@ -392,6 +395,7 @@ class LLMModelAdapter(ABC):
messages (List[ModelMessage]): The model messages
convert_to_compatible_format (bool, optional): Whether to convert to
compatible format. Defaults to False.
support_media_content (bool, optional): Whether to support media content
Returns:
List[Dict[str, str]]: The transformed model messages
@@ -399,14 +403,20 @@ class LLMModelAdapter(ABC):
logger.info(f"support_system_message: {self.support_system_message}")
if not self.support_system_message and convert_to_compatible_format:
# We will not do any transform in the future
return self._transform_to_no_system_messages(messages)
return self._transform_to_no_system_messages(
messages, support_media_content=support_media_content
)
else:
return ModelMessage.to_common_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
messages,
convert_to_compatible_format=convert_to_compatible_format,
support_media_content=support_media_content,
)
def _transform_to_no_system_messages(
self, messages: List[ModelMessage]
self,
messages: List[ModelMessage],
support_media_content: bool = True,
) -> List[Dict[str, str]]:
"""Transform the model messages to no system messages
@@ -433,7 +443,9 @@ class LLMModelAdapter(ABC):
Returns:
List[Dict[str, str]]: The transformed model messages
"""
openai_messages = ModelMessage.to_common_messages(messages)
openai_messages = ModelMessage.to_common_messages(
messages, support_media_content=support_media_content
)
system_messages = []
return_messages = []
for message in openai_messages:

View File

@@ -600,6 +600,63 @@ class QwenMoeAdapter(NewHFChatModelAdapter):
)
class Qwen3Adapter(QwenAdapter):
support_4bit: bool = True
support_8bit: bool = True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path and (
"qwen3" in lower_model_name_or_path
and "base" not in lower_model_name_or_path
)
def check_transformer_version(self, current_version: str) -> None:
if not current_version >= "4.51.0":
raise ValueError(
"Qwen3 require transformers.__version__>=4.51.0, please upgrade your"
" transformers package."
)
def is_reasoning_model(
self,
deploy_model_params: LLMDeployModelParameters,
lower_model_name_or_path: Optional[str] = None,
) -> bool:
if (
deploy_model_params.reasoning_model is not None
and deploy_model_params.reasoning_model is False
):
return False
return True
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
from transformers import AutoTokenizer
if not tokenizer:
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer
is_reasoning_model = params.get("is_reasoning_model", True)
messages = self.transform_model_messages(
messages, convert_to_compatible_format, support_media_content=False
)
logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=is_reasoning_model,
)
return str_prompt
class QwenOmniAdapter(NewHFChatModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path and (
@@ -997,6 +1054,7 @@ register_model_adapter(Gemma2Adapter)
register_model_adapter(StarlingLMAdapter)
register_model_adapter(QwenAdapter)
register_model_adapter(QwenMoeAdapter)
register_model_adapter(Qwen3Adapter)
register_model_adapter(QwenOmniAdapter)
register_model_adapter(Llama3Adapter)
register_model_adapter(Llama31Adapter)

View File

@@ -369,6 +369,13 @@ class DefaultModelWorker(ModelWorker):
def _prepare_generate_stream(
self, params: Dict, span_operation_name: str, is_stream=True
):
if self.llm_adapter.is_reasoning_model(
self._model_params, self.model_name.lower()
):
params["is_reasoning_model"] = True
else:
params["is_reasoning_model"] = False
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
@@ -427,10 +434,6 @@ class DefaultModelWorker(ModelWorker):
span_params["messages"] = list(
map(lambda m: m.dict(), span_params["messages"])
)
if self.llm_adapter.is_reasoning_model(
self._model_params, self.model_name.lower()
):
params["is_reasoning_model"] = True
metadata = {
"is_async_func": self.support_async(),