diff --git a/configs/dbgpt-local-qwen3.example.toml b/configs/dbgpt-local-qwen3.example.toml
index 31f644e70..7c205ccf2 100644
--- a/configs/dbgpt-local-qwen3.example.toml
+++ b/configs/dbgpt-local-qwen3.example.toml
@@ -30,8 +30,12 @@ provider = "hf"
# reasoning_model = false
[[models.embeddings]]
-name = "BAAI/bge-large-zh-v1.5"
+name = "Qwen/Qwen3-Embedding-0.6B"
provider = "hf"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# path = "the-model-path-in-the-local-file-system"
+
+[[models.rerankers]]
+name = "Qwen/Qwen3-Reranker-0.6B"
+provider = "qwen"
\ No newline at end of file
diff --git a/configs/dbgpt-proxy-infiniai.toml b/configs/dbgpt-proxy-infiniai.toml
index 8c0109e59..ddfa1a9d2 100644
--- a/configs/dbgpt-proxy-infiniai.toml
+++ b/configs/dbgpt-proxy-infiniai.toml
@@ -34,7 +34,6 @@ api_url = "https://cloud.infini-ai.com/maas/v1/embeddings"
api_key = "${env:INFINIAI_API_KEY}"
[[models.rerankers]]
-type = "reranker"
name = "bge-reranker-v2-m3"
provider = "proxy/infiniai"
api_key = "${env:INFINIAI_API_KEY}"
\ No newline at end of file
diff --git a/configs/dbgpt-proxy-siliconflow-mysql.toml b/configs/dbgpt-proxy-siliconflow-mysql.toml
index 60cba4210..1e166e6f8 100644
--- a/configs/dbgpt-proxy-siliconflow-mysql.toml
+++ b/configs/dbgpt-proxy-siliconflow-mysql.toml
@@ -38,7 +38,6 @@ provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"
[[models.rerankers]]
-type = "reranker"
name = "BAAI/bge-reranker-v2-m3"
provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"
diff --git a/configs/dbgpt-proxy-siliconflow.toml b/configs/dbgpt-proxy-siliconflow.toml
index ac1731903..9d8d63c6e 100644
--- a/configs/dbgpt-proxy-siliconflow.toml
+++ b/configs/dbgpt-proxy-siliconflow.toml
@@ -34,7 +34,6 @@ api_url = "https://api.siliconflow.cn/v1/embeddings"
api_key = "${env:SILICONFLOW_API_KEY}"
[[models.rerankers]]
-type = "reranker"
name = "BAAI/bge-reranker-v2-m3"
provider = "proxy/siliconflow"
api_key = "${env:SILICONFLOW_API_KEY}"
diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py b/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py
index 59755a043..64906b8cc 100644
--- a/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py
+++ b/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py
@@ -18,6 +18,48 @@ def _register_reranker_common_hf_models(models: List[EmbeddingModelMetadata]) ->
RERANKER_COMMON_HF_MODELS.extend(models)
+EMBED_COMMON_HF_QWEN_MODELS = [
+ EmbeddingModelMetadata(
+ model=[
+ "Qwen/Qwen3-Embedding-0.6B",
+ ],
+ dimension=1024,
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Embedding-0.6B is a multilingual embedding model trained by "
+ "Qwen team, supporting more than 100 languages. It has 0.6B parameters "
+ "and a context length of 32k tokens and the dimension is 1024."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Embedding-0.6B",
+ ),
+ EmbeddingModelMetadata(
+ model=[
+ "Qwen/Qwen3-Embedding-4B",
+ ],
+ dimension=2560,
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Embedding-4B is a multilingual embedding model trained by "
+ "Qwen team, supporting more than 100 languages. It has 4B parameters "
+ "and a context length of 32k tokens and the dimension is 2560."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Embedding-4B",
+ ),
+ EmbeddingModelMetadata(
+ model=[
+ "Qwen/Qwen3-Embedding-8B",
+ ],
+ dimension=4096,
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Embedding-8B is a multilingual embedding model trained by "
+ "Qwen team, supporting more than 100 languages. It has 8B parameters "
+ "and a context length of 32k tokens and the dimension is 4096."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Embedding-8B",
+ ),
+]
+
EMBED_COMMON_HF_BGE_MODELS = [
EmbeddingModelMetadata(
model=["BAAI/bge-m3"],
@@ -58,6 +100,43 @@ EMBED_COMMON_HF_JINA_MODELS = [
),
]
+
+RERANKER_COMMON_HF_QWEN_MODELS = [
+ EmbeddingModelMetadata(
+ model=["Qwen/Qwen3-Reranker-0.6B"],
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Reranker-0.6B is a multilingual reranker model trained by "
+ "Qwen team, supporting more than 100 languages. It has 0.6B parameters "
+ "and a context length of 32k tokens."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Reranker-0.6B",
+ is_reranker=True,
+ ),
+ EmbeddingModelMetadata(
+ model=["Qwen/Qwen3-Reranker-4B"],
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Reranker-4B is a multilingual reranker model trained by "
+ "Qwen team, supporting more than 100 languages. It has 4B parameters "
+ "and a context length of 32k tokens."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Reranker-4B",
+ is_reranker=True,
+ ),
+ EmbeddingModelMetadata(
+ model=["Qwen/Qwen3-Reranker-8B"],
+ context_length=32 * 1024, # 32k context length
+ description=_(
+ "Qwen3-Reranker-8B is a multilingual reranker model trained by "
+ "Qwen team, supporting more than 100 languages. It has 8B parameters "
+ "and a context length of 32k tokens."
+ ),
+ link="https://huggingface.co/Qwen/Qwen3-Reranker-8B",
+ is_reranker=True,
+ ),
+]
+
RERANKER_COMMON_HF_BGE_MODELS = [
EmbeddingModelMetadata(
model=["BAAI/bge-reranker-v2-m3"],
@@ -91,6 +170,7 @@ RERANKER_COMMON_HF_JINA_MODELS = [
_register_embed_common_hf_models(EMBED_COMMON_HF_BGE_MODELS)
_register_embed_common_hf_models(EMBED_COMMON_HF_JINA_MODELS)
+_register_embed_common_hf_models(EMBED_COMMON_HF_QWEN_MODELS)
# Register reranker models
_register_reranker_common_hf_models(RERANKER_COMMON_HF_BGE_MODELS)
diff --git a/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py b/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py
index 4ab13d242..b4235aa08 100644
--- a/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py
+++ b/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py
@@ -14,6 +14,7 @@ from dbgpt.model.adapter.base import register_embedding_adapter
from dbgpt.model.adapter.embed_metadata import (
EMBED_COMMON_HF_BGE_MODELS,
EMBED_COMMON_HF_JINA_MODELS,
+ EMBED_COMMON_HF_QWEN_MODELS,
)
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
@@ -958,7 +959,8 @@ register_embedding_adapter(
languages=["zh"],
),
]
- + EMBED_COMMON_HF_JINA_MODELS,
+ + EMBED_COMMON_HF_JINA_MODELS
+ + EMBED_COMMON_HF_QWEN_MODELS,
)
register_embedding_adapter(
HuggingFaceInstructEmbeddings,
diff --git a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py
index caf986b7f..e0b755b34 100644
--- a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py
+++ b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py
@@ -9,10 +9,14 @@ import numpy as np
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
+from dbgpt.configs.model_config import get_device
from dbgpt.core import RerankEmbeddings
from dbgpt.core.interface.parameter import RerankerDeployModelParameters
from dbgpt.model.adapter.base import register_embedding_adapter
-from dbgpt.model.adapter.embed_metadata import RERANKER_COMMON_HF_MODELS
+from dbgpt.model.adapter.embed_metadata import (
+ RERANKER_COMMON_HF_MODELS,
+ RERANKER_COMMON_HF_QWEN_MODELS,
+)
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
@@ -122,7 +126,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
) -> "RerankEmbeddings":
"""Create a rerank model from parameters."""
return cls(
- model_name=parameters.real_provider_model_name,
+ model_name=parameters.real_model_path,
max_length=parameters.max_length,
model_kwargs=parameters.real_model_kwargs,
)
@@ -161,6 +165,201 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
return rank_scores # type: ignore
+@dataclass
+class QwenRerankEmbeddingsParameters(RerankerDeployModelParameters):
+ """Qwen Rerank Embeddings Parameters."""
+
+ provider: str = "qwen"
+ path: Optional[str] = field(
+ default=None,
+ metadata={
+ "order": -800,
+ "help": _(
+ "The path of the model, if you want to deploy a local model. Defaults "
+ "to 'Qwen/Qwen3-Reranker-0.6B'."
+ ),
+ },
+ )
+ device: Optional[str] = field(
+ default=None,
+ metadata={
+ "order": -700,
+ "help": _(
+ "Device to run model. If None, the device is automatically determined"
+ ),
+ },
+ )
+
+ @property
+ def real_provider_model_name(self) -> str:
+ """Get the real provider model name."""
+ return self.path or self.name
+
+ @property
+ def real_model_path(self) -> Optional[str]:
+ """Get the real model path.
+
+ If deploy model is not local, return None.
+ """
+ return self._resolve_root_path(self.path)
+
+ @property
+ def real_device(self) -> Optional[str]:
+ """Get the real device."""
+ return self.device or super().real_device
+
+
+class QwenRerankEmbeddings(BaseModel, RerankEmbeddings):
+ """Qwen Rerank Embeddings.
+
+ .. code-block:: python
+ from dbgpt.rag.embedding.rerank import QwenRerankEmbeddings
+
+ reranker = QwenRerankEmbeddings(model_name="Qwen/Qwen3-Reranker-0.6B")
+ query = "Apple"
+ documents = [
+ "apple",
+ "banana",
+ "fruit",
+ "vegetable",
+ "苹果",
+ "Mac OS",
+ "乔布斯",
+ "iPhone",
+ ]
+ scores = reranker.predict(query, candidates)
+ print(scores)
+ """
+
+ model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
+ model_name: str = "Qwen/Qwen3-Reranker-0.6B"
+ max_length: int = 8192
+ model: Any #: :meta private:
+ tokenizer: Any #: :meta private:
+ token_false_id: int #: :meta private:
+ token_true_id: int #: :meta private:
+ prefix_tokens: List[int] #: :meta private:
+ suffix_tokens: List[int] #: :meta private:
+ task: str = (
+ "Given a web search query, retrieve relevant passages that answer the "
+ "query"
+ ) #: :meta private:
+ device: Optional[str] = None #: :meta private:
+
+ def __init__(self, **kwargs: Any):
+ try:
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ except ImportError:
+ raise ImportError(
+ "please `pip install transformers`",
+ )
+
+ try:
+ import torch # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "please `pip install torch`",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"), padding_side="left"
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"),
+ ).eval()
+ device = kwargs.get("device", get_device())
+ if device:
+ model = model.to(device=device)
+
+ token_false_id = tokenizer.convert_tokens_to_ids("no")
+ token_true_id = tokenizer.convert_tokens_to_ids("yes")
+ prefix = (
+ "<|im_start|>system\nJudge whether the Document meets the requirements"
+ " based on the Query and the Instruct provided. Note that the answer can "
+ 'only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
+ )
+ suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
+ suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
+
+ kwargs["model"] = model
+ kwargs["tokenizer"] = tokenizer
+ kwargs["token_false_id"] = token_false_id
+ kwargs["token_true_id"] = token_true_id
+ kwargs["prefix_tokens"] = prefix_tokens
+ kwargs["suffix_tokens"] = suffix_tokens
+ super().__init__(**kwargs)
+
+ @classmethod
+ def param_class(cls) -> Type[QwenRerankEmbeddingsParameters]:
+ """Get the parameter class."""
+ return QwenRerankEmbeddingsParameters
+
+ @classmethod
+ def from_parameters(
+ cls, parameters: QwenRerankEmbeddingsParameters
+ ) -> "RerankEmbeddings":
+ """Create a rerank model from parameters."""
+ return cls(
+ model_name=parameters.real_model_path,
+ device=parameters.real_device,
+ )
+
+ def format_instruction(self, instruction, query, doc):
+ if instruction is None:
+ instruction = (
+ "Given a web search query, retrieve relevant passages that "
+ "answer the query"
+ )
+ output = (
+ ": {instruction}\n: {query}\n: {doc}".format(
+ instruction=instruction, query=query, doc=doc
+ )
+ )
+ return output
+
+ def process_inputs(self, pairs):
+ inputs = self.tokenizer(
+ pairs,
+ padding=False,
+ truncation="longest_first",
+ return_attention_mask=False,
+ max_length=self.max_length
+ - len(self.prefix_tokens)
+ - len(self.suffix_tokens),
+ )
+ for i, ele in enumerate(inputs["input_ids"]):
+ inputs["input_ids"][i] = self.prefix_tokens + ele + self.suffix_tokens
+ inputs = self.tokenizer.pad(
+ inputs, padding=True, return_tensors="pt", max_length=self.max_length
+ )
+ for key in inputs:
+ inputs[key] = inputs[key].to(self.model.device)
+ return inputs
+
+ def compute_logits(self, inputs, **kwargs):
+ import torch
+
+ with torch.no_grad():
+ batch_scores = self.model(**inputs).logits[:, -1, :]
+ true_vector = batch_scores[:, self.token_true_id]
+ false_vector = batch_scores[:, self.token_false_id]
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
+ scores = batch_scores[:, 1].exp().tolist()
+ return scores
+
+ def predict(self, query: str, candidates: List[str]) -> List[float]:
+ queries = [query] * len(candidates)
+ pairs = [
+ self.format_instruction(self.task, query, doc)
+ for query, doc in zip(queries, candidates)
+ ]
+ # Tokenize the input texts
+ inputs = self.process_inputs(pairs)
+ scores = self.compute_logits(inputs)
+ return scores
+
+
@dataclass
class OpenAPIRerankerDeployModelParameters(RerankerDeployModelParameters):
"""OpenAPI Reranker Deploy Model Parameters."""
@@ -258,7 +457,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
return cls(
api_url=parameters.api_url,
api_key=parameters.api_key,
- model_name=parameters.real_provider_model_name,
+ model_name=parameters.real_model_path,
timeout=parameters.timeout,
)
@@ -579,3 +778,6 @@ register_embedding_adapter(
register_embedding_adapter(
InfiniAIRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)
+register_embedding_adapter(
+ QwenRerankEmbeddings, supported_models=RERANKER_COMMON_HF_QWEN_MODELS
+)
diff --git a/tests/intetration_tests/datasource/test_conn_starrocks.py b/tests/intetration_tests/datasource/test_conn_starrocks.py
index afc1b5dae..b358ccccc 100644
--- a/tests/intetration_tests/datasource/test_conn_starrocks.py
+++ b/tests/intetration_tests/datasource/test_conn_starrocks.py
@@ -1,22 +1,22 @@
"""
- Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
-
- docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
-
- mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
- Welcome to the MySQL monitor. Commands end with ; or \g.
- Your MySQL connection id is 184
- Server version: 5.1.0 3.1.5-5d8438a
+Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
- Copyright (c) 2000, 2023, Oracle and/or its affiliates.
+docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
- Oracle is a registered trademark of Oracle Corporation and/or its
- affiliates. Other names may be trademarks of their respective
- owners.
+mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
+Welcome to the MySQL monitor. Commands end with ; or \g.
+Your MySQL connection id is 184
+Server version: 5.1.0 3.1.5-5d8438a
- Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
+Copyright (c) 2000, 2023, Oracle and/or its affiliates.
- > create database test;
+Oracle is a registered trademark of Oracle Corporation and/or its
+affiliates. Other names may be trademarks of their respective
+owners.
+
+Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
+
+> create database test;
"""
import pytest