feat(model): Support Qwen3 embeddings (#2772)

This commit is contained in:
Fangyin Cheng 2025-06-14 23:34:48 +08:00 committed by GitHub
parent cc81e7af09
commit bb947b6af7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 307 additions and 22 deletions

View File

@ -30,8 +30,12 @@ provider = "hf"
# reasoning_model = false
[[models.embeddings]]
name = "BAAI/bge-large-zh-v1.5"
name = "Qwen/Qwen3-Embedding-0.6B"
provider = "hf"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# path = "the-model-path-in-the-local-file-system"
[[models.rerankers]]
name = "Qwen/Qwen3-Reranker-0.6B"
provider = "qwen"

View File

@ -34,7 +34,6 @@ api_url = "https://cloud.infini-ai.com/maas/v1/embeddings"
api_key = "${env:INFINIAI_API_KEY}"
[[models.rerankers]]
type = "reranker"
name = "bge-reranker-v2-m3"
provider = "proxy/infiniai"
api_key = "${env:INFINIAI_API_KEY}"

View File

@ -38,7 +38,6 @@ provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"
[[models.rerankers]]
type = "reranker"
name = "BAAI/bge-reranker-v2-m3"
provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"

View File

@ -34,7 +34,6 @@ api_url = "https://api.siliconflow.cn/v1/embeddings"
api_key = "${env:SILICONFLOW_API_KEY}"
[[models.rerankers]]
type = "reranker"
name = "BAAI/bge-reranker-v2-m3"
provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"

View File

@ -18,6 +18,48 @@ def _register_reranker_common_hf_models(models: List[EmbeddingModelMetadata]) ->
RERANKER_COMMON_HF_MODELS.extend(models)
EMBED_COMMON_HF_QWEN_MODELS = [
EmbeddingModelMetadata(
model=[
"Qwen/Qwen3-Embedding-0.6B",
],
dimension=1024,
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Embedding-0.6B is a multilingual embedding model trained by "
"Qwen team, supporting more than 100 languages. It has 0.6B parameters "
"and a context length of 32k tokens and the dimension is 1024."
),
link="https://huggingface.co/Qwen/Qwen3-Embedding-0.6B",
),
EmbeddingModelMetadata(
model=[
"Qwen/Qwen3-Embedding-4B",
],
dimension=2560,
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Embedding-4B is a multilingual embedding model trained by "
"Qwen team, supporting more than 100 languages. It has 4B parameters "
"and a context length of 32k tokens and the dimension is 2560."
),
link="https://huggingface.co/Qwen/Qwen3-Embedding-4B",
),
EmbeddingModelMetadata(
model=[
"Qwen/Qwen3-Embedding-8B",
],
dimension=4096,
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Embedding-8B is a multilingual embedding model trained by "
"Qwen team, supporting more than 100 languages. It has 8B parameters "
"and a context length of 32k tokens and the dimension is 4096."
),
link="https://huggingface.co/Qwen/Qwen3-Embedding-8B",
),
]
EMBED_COMMON_HF_BGE_MODELS = [
EmbeddingModelMetadata(
model=["BAAI/bge-m3"],
@ -58,6 +100,43 @@ EMBED_COMMON_HF_JINA_MODELS = [
),
]
RERANKER_COMMON_HF_QWEN_MODELS = [
EmbeddingModelMetadata(
model=["Qwen/Qwen3-Reranker-0.6B"],
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Reranker-0.6B is a multilingual reranker model trained by "
"Qwen team, supporting more than 100 languages. It has 0.6B parameters "
"and a context length of 32k tokens."
),
link="https://huggingface.co/Qwen/Qwen3-Reranker-0.6B",
is_reranker=True,
),
EmbeddingModelMetadata(
model=["Qwen/Qwen3-Reranker-4B"],
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Reranker-4B is a multilingual reranker model trained by "
"Qwen team, supporting more than 100 languages. It has 4B parameters "
"and a context length of 32k tokens."
),
link="https://huggingface.co/Qwen/Qwen3-Reranker-4B",
is_reranker=True,
),
EmbeddingModelMetadata(
model=["Qwen/Qwen3-Reranker-8B"],
context_length=32 * 1024, # 32k context length
description=_(
"Qwen3-Reranker-8B is a multilingual reranker model trained by "
"Qwen team, supporting more than 100 languages. It has 8B parameters "
"and a context length of 32k tokens."
),
link="https://huggingface.co/Qwen/Qwen3-Reranker-8B",
is_reranker=True,
),
]
RERANKER_COMMON_HF_BGE_MODELS = [
EmbeddingModelMetadata(
model=["BAAI/bge-reranker-v2-m3"],
@ -91,6 +170,7 @@ RERANKER_COMMON_HF_JINA_MODELS = [
_register_embed_common_hf_models(EMBED_COMMON_HF_BGE_MODELS)
_register_embed_common_hf_models(EMBED_COMMON_HF_JINA_MODELS)
_register_embed_common_hf_models(EMBED_COMMON_HF_QWEN_MODELS)
# Register reranker models
_register_reranker_common_hf_models(RERANKER_COMMON_HF_BGE_MODELS)

View File

@ -14,6 +14,7 @@ from dbgpt.model.adapter.base import register_embedding_adapter
from dbgpt.model.adapter.embed_metadata import (
EMBED_COMMON_HF_BGE_MODELS,
EMBED_COMMON_HF_JINA_MODELS,
EMBED_COMMON_HF_QWEN_MODELS,
)
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
@ -958,7 +959,8 @@ register_embedding_adapter(
languages=["zh"],
),
]
+ EMBED_COMMON_HF_JINA_MODELS,
+ EMBED_COMMON_HF_JINA_MODELS
+ EMBED_COMMON_HF_QWEN_MODELS,
)
register_embedding_adapter(
HuggingFaceInstructEmbeddings,

View File

@ -9,10 +9,14 @@ import numpy as np
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.configs.model_config import get_device
from dbgpt.core import RerankEmbeddings
from dbgpt.core.interface.parameter import RerankerDeployModelParameters
from dbgpt.model.adapter.base import register_embedding_adapter
from dbgpt.model.adapter.embed_metadata import RERANKER_COMMON_HF_MODELS
from dbgpt.model.adapter.embed_metadata import (
RERANKER_COMMON_HF_MODELS,
RERANKER_COMMON_HF_QWEN_MODELS,
)
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
@ -122,7 +126,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
) -> "RerankEmbeddings":
"""Create a rerank model from parameters."""
return cls(
model_name=parameters.real_provider_model_name,
model_name=parameters.real_model_path,
max_length=parameters.max_length,
model_kwargs=parameters.real_model_kwargs,
)
@ -161,6 +165,201 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
return rank_scores # type: ignore
@dataclass
class QwenRerankEmbeddingsParameters(RerankerDeployModelParameters):
"""Qwen Rerank Embeddings Parameters."""
provider: str = "qwen"
path: Optional[str] = field(
default=None,
metadata={
"order": -800,
"help": _(
"The path of the model, if you want to deploy a local model. Defaults "
"to 'Qwen/Qwen3-Reranker-0.6B'."
),
},
)
device: Optional[str] = field(
default=None,
metadata={
"order": -700,
"help": _(
"Device to run model. If None, the device is automatically determined"
),
},
)
@property
def real_provider_model_name(self) -> str:
"""Get the real provider model name."""
return self.path or self.name
@property
def real_model_path(self) -> Optional[str]:
"""Get the real model path.
If deploy model is not local, return None.
"""
return self._resolve_root_path(self.path)
@property
def real_device(self) -> Optional[str]:
"""Get the real device."""
return self.device or super().real_device
class QwenRerankEmbeddings(BaseModel, RerankEmbeddings):
"""Qwen Rerank Embeddings.
.. code-block:: python
from dbgpt.rag.embedding.rerank import QwenRerankEmbeddings
reranker = QwenRerankEmbeddings(model_name="Qwen/Qwen3-Reranker-0.6B")
query = "Apple"
documents = [
"apple",
"banana",
"fruit",
"vegetable",
"苹果",
"Mac OS",
"乔布斯",
"iPhone",
]
scores = reranker.predict(query, candidates)
print(scores)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
model_name: str = "Qwen/Qwen3-Reranker-0.6B"
max_length: int = 8192
model: Any #: :meta private:
tokenizer: Any #: :meta private:
token_false_id: int #: :meta private:
token_true_id: int #: :meta private:
prefix_tokens: List[int] #: :meta private:
suffix_tokens: List[int] #: :meta private:
task: str = (
"Given a web search query, retrieve relevant passages that answer the "
"query"
) #: :meta private:
device: Optional[str] = None #: :meta private:
def __init__(self, **kwargs: Any):
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"please `pip install transformers`",
)
try:
import torch # noqa: F401
except ImportError:
raise ImportError(
"please `pip install torch`",
)
tokenizer = AutoTokenizer.from_pretrained(
kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"), padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"),
).eval()
device = kwargs.get("device", get_device())
if device:
model = model.to(device=device)
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
prefix = (
"<|im_start|>system\nJudge whether the Document meets the requirements"
" based on the Query and the Instruct provided. Note that the answer can "
'only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
kwargs["model"] = model
kwargs["tokenizer"] = tokenizer
kwargs["token_false_id"] = token_false_id
kwargs["token_true_id"] = token_true_id
kwargs["prefix_tokens"] = prefix_tokens
kwargs["suffix_tokens"] = suffix_tokens
super().__init__(**kwargs)
@classmethod
def param_class(cls) -> Type[QwenRerankEmbeddingsParameters]:
"""Get the parameter class."""
return QwenRerankEmbeddingsParameters
@classmethod
def from_parameters(
cls, parameters: QwenRerankEmbeddingsParameters
) -> "RerankEmbeddings":
"""Create a rerank model from parameters."""
return cls(
model_name=parameters.real_model_path,
device=parameters.real_device,
)
def format_instruction(self, instruction, query, doc):
if instruction is None:
instruction = (
"Given a web search query, retrieve relevant passages that "
"answer the query"
)
output = (
"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
instruction=instruction, query=query, doc=doc
)
)
return output
def process_inputs(self, pairs):
inputs = self.tokenizer(
pairs,
padding=False,
truncation="longest_first",
return_attention_mask=False,
max_length=self.max_length
- len(self.prefix_tokens)
- len(self.suffix_tokens),
)
for i, ele in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = self.prefix_tokens + ele + self.suffix_tokens
inputs = self.tokenizer.pad(
inputs, padding=True, return_tensors="pt", max_length=self.max_length
)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
def compute_logits(self, inputs, **kwargs):
import torch
with torch.no_grad():
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores
def predict(self, query: str, candidates: List[str]) -> List[float]:
queries = [query] * len(candidates)
pairs = [
self.format_instruction(self.task, query, doc)
for query, doc in zip(queries, candidates)
]
# Tokenize the input texts
inputs = self.process_inputs(pairs)
scores = self.compute_logits(inputs)
return scores
@dataclass
class OpenAPIRerankerDeployModelParameters(RerankerDeployModelParameters):
"""OpenAPI Reranker Deploy Model Parameters."""
@ -258,7 +457,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
return cls(
api_url=parameters.api_url,
api_key=parameters.api_key,
model_name=parameters.real_provider_model_name,
model_name=parameters.real_model_path,
timeout=parameters.timeout,
)
@ -579,3 +778,6 @@ register_embedding_adapter(
register_embedding_adapter(
InfiniAIRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)
register_embedding_adapter(
QwenRerankEmbeddings, supported_models=RERANKER_COMMON_HF_QWEN_MODELS
)

View File

@ -1,22 +1,22 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 184
Server version: 5.1.0 3.1.5-5d8438a
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 184
Server version: 5.1.0 3.1.5-5d8438a
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
> create database test;
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
"""
import pytest