mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-19 10:37:13 +00:00
feat(model): Support Qwen3 embeddings (#2772)
This commit is contained in:
parent
cc81e7af09
commit
bb947b6af7
@ -30,8 +30,12 @@ provider = "hf"
|
||||
# reasoning_model = false
|
||||
|
||||
[[models.embeddings]]
|
||||
name = "BAAI/bge-large-zh-v1.5"
|
||||
name = "Qwen/Qwen3-Embedding-0.6B"
|
||||
provider = "hf"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
|
||||
[[models.rerankers]]
|
||||
name = "Qwen/Qwen3-Reranker-0.6B"
|
||||
provider = "qwen"
|
@ -34,7 +34,6 @@ api_url = "https://cloud.infini-ai.com/maas/v1/embeddings"
|
||||
api_key = "${env:INFINIAI_API_KEY}"
|
||||
|
||||
[[models.rerankers]]
|
||||
type = "reranker"
|
||||
name = "bge-reranker-v2-m3"
|
||||
provider = "proxy/infiniai"
|
||||
api_key = "${env:INFINIAI_API_KEY}"
|
@ -38,7 +38,6 @@ provider = "proxy/siliconflow"
|
||||
api_key = "${env:SILICONFLOW_API_KEY}"
|
||||
|
||||
[[models.rerankers]]
|
||||
type = "reranker"
|
||||
name = "BAAI/bge-reranker-v2-m3"
|
||||
provider = "proxy/siliconflow"
|
||||
api_key = "${env:SILICONFLOW_API_KEY}"
|
||||
|
@ -34,7 +34,6 @@ api_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
api_key = "${env:SILICONFLOW_API_KEY}"
|
||||
|
||||
[[models.rerankers]]
|
||||
type = "reranker"
|
||||
name = "BAAI/bge-reranker-v2-m3"
|
||||
provider = "proxy/siliconflow"
|
||||
api_key = "${env:SILICONFLOW_API_KEY}"
|
||||
|
@ -18,6 +18,48 @@ def _register_reranker_common_hf_models(models: List[EmbeddingModelMetadata]) ->
|
||||
RERANKER_COMMON_HF_MODELS.extend(models)
|
||||
|
||||
|
||||
EMBED_COMMON_HF_QWEN_MODELS = [
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
],
|
||||
dimension=1024,
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Embedding-0.6B is a multilingual embedding model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 0.6B parameters "
|
||||
"and a context length of 32k tokens and the dimension is 1024."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Embedding-0.6B",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"Qwen/Qwen3-Embedding-4B",
|
||||
],
|
||||
dimension=2560,
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Embedding-4B is a multilingual embedding model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 4B parameters "
|
||||
"and a context length of 32k tokens and the dimension is 2560."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Embedding-4B",
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=[
|
||||
"Qwen/Qwen3-Embedding-8B",
|
||||
],
|
||||
dimension=4096,
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Embedding-8B is a multilingual embedding model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 8B parameters "
|
||||
"and a context length of 32k tokens and the dimension is 4096."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Embedding-8B",
|
||||
),
|
||||
]
|
||||
|
||||
EMBED_COMMON_HF_BGE_MODELS = [
|
||||
EmbeddingModelMetadata(
|
||||
model=["BAAI/bge-m3"],
|
||||
@ -58,6 +100,43 @@ EMBED_COMMON_HF_JINA_MODELS = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
RERANKER_COMMON_HF_QWEN_MODELS = [
|
||||
EmbeddingModelMetadata(
|
||||
model=["Qwen/Qwen3-Reranker-0.6B"],
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Reranker-0.6B is a multilingual reranker model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 0.6B parameters "
|
||||
"and a context length of 32k tokens."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Reranker-0.6B",
|
||||
is_reranker=True,
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=["Qwen/Qwen3-Reranker-4B"],
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Reranker-4B is a multilingual reranker model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 4B parameters "
|
||||
"and a context length of 32k tokens."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Reranker-4B",
|
||||
is_reranker=True,
|
||||
),
|
||||
EmbeddingModelMetadata(
|
||||
model=["Qwen/Qwen3-Reranker-8B"],
|
||||
context_length=32 * 1024, # 32k context length
|
||||
description=_(
|
||||
"Qwen3-Reranker-8B is a multilingual reranker model trained by "
|
||||
"Qwen team, supporting more than 100 languages. It has 8B parameters "
|
||||
"and a context length of 32k tokens."
|
||||
),
|
||||
link="https://huggingface.co/Qwen/Qwen3-Reranker-8B",
|
||||
is_reranker=True,
|
||||
),
|
||||
]
|
||||
|
||||
RERANKER_COMMON_HF_BGE_MODELS = [
|
||||
EmbeddingModelMetadata(
|
||||
model=["BAAI/bge-reranker-v2-m3"],
|
||||
@ -91,6 +170,7 @@ RERANKER_COMMON_HF_JINA_MODELS = [
|
||||
|
||||
_register_embed_common_hf_models(EMBED_COMMON_HF_BGE_MODELS)
|
||||
_register_embed_common_hf_models(EMBED_COMMON_HF_JINA_MODELS)
|
||||
_register_embed_common_hf_models(EMBED_COMMON_HF_QWEN_MODELS)
|
||||
|
||||
# Register reranker models
|
||||
_register_reranker_common_hf_models(RERANKER_COMMON_HF_BGE_MODELS)
|
||||
|
@ -14,6 +14,7 @@ from dbgpt.model.adapter.base import register_embedding_adapter
|
||||
from dbgpt.model.adapter.embed_metadata import (
|
||||
EMBED_COMMON_HF_BGE_MODELS,
|
||||
EMBED_COMMON_HF_JINA_MODELS,
|
||||
EMBED_COMMON_HF_QWEN_MODELS,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
|
||||
@ -958,7 +959,8 @@ register_embedding_adapter(
|
||||
languages=["zh"],
|
||||
),
|
||||
]
|
||||
+ EMBED_COMMON_HF_JINA_MODELS,
|
||||
+ EMBED_COMMON_HF_JINA_MODELS
|
||||
+ EMBED_COMMON_HF_QWEN_MODELS,
|
||||
)
|
||||
register_embedding_adapter(
|
||||
HuggingFaceInstructEmbeddings,
|
||||
|
@ -9,10 +9,14 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import RerankEmbeddings
|
||||
from dbgpt.core.interface.parameter import RerankerDeployModelParameters
|
||||
from dbgpt.model.adapter.base import register_embedding_adapter
|
||||
from dbgpt.model.adapter.embed_metadata import RERANKER_COMMON_HF_MODELS
|
||||
from dbgpt.model.adapter.embed_metadata import (
|
||||
RERANKER_COMMON_HF_MODELS,
|
||||
RERANKER_COMMON_HF_QWEN_MODELS,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
|
||||
|
||||
@ -122,7 +126,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
) -> "RerankEmbeddings":
|
||||
"""Create a rerank model from parameters."""
|
||||
return cls(
|
||||
model_name=parameters.real_provider_model_name,
|
||||
model_name=parameters.real_model_path,
|
||||
max_length=parameters.max_length,
|
||||
model_kwargs=parameters.real_model_kwargs,
|
||||
)
|
||||
@ -161,6 +165,201 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
return rank_scores # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenRerankEmbeddingsParameters(RerankerDeployModelParameters):
|
||||
"""Qwen Rerank Embeddings Parameters."""
|
||||
|
||||
provider: str = "qwen"
|
||||
path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"order": -800,
|
||||
"help": _(
|
||||
"The path of the model, if you want to deploy a local model. Defaults "
|
||||
"to 'Qwen/Qwen3-Reranker-0.6B'."
|
||||
),
|
||||
},
|
||||
)
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"order": -700,
|
||||
"help": _(
|
||||
"Device to run model. If None, the device is automatically determined"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def real_provider_model_name(self) -> str:
|
||||
"""Get the real provider model name."""
|
||||
return self.path or self.name
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
"""Get the real model path.
|
||||
|
||||
If deploy model is not local, return None.
|
||||
"""
|
||||
return self._resolve_root_path(self.path)
|
||||
|
||||
@property
|
||||
def real_device(self) -> Optional[str]:
|
||||
"""Get the real device."""
|
||||
return self.device or super().real_device
|
||||
|
||||
|
||||
class QwenRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""Qwen Rerank Embeddings.
|
||||
|
||||
.. code-block:: python
|
||||
from dbgpt.rag.embedding.rerank import QwenRerankEmbeddings
|
||||
|
||||
reranker = QwenRerankEmbeddings(model_name="Qwen/Qwen3-Reranker-0.6B")
|
||||
query = "Apple"
|
||||
documents = [
|
||||
"apple",
|
||||
"banana",
|
||||
"fruit",
|
||||
"vegetable",
|
||||
"苹果",
|
||||
"Mac OS",
|
||||
"乔布斯",
|
||||
"iPhone",
|
||||
]
|
||||
scores = reranker.predict(query, candidates)
|
||||
print(scores)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
model_name: str = "Qwen/Qwen3-Reranker-0.6B"
|
||||
max_length: int = 8192
|
||||
model: Any #: :meta private:
|
||||
tokenizer: Any #: :meta private:
|
||||
token_false_id: int #: :meta private:
|
||||
token_true_id: int #: :meta private:
|
||||
prefix_tokens: List[int] #: :meta private:
|
||||
suffix_tokens: List[int] #: :meta private:
|
||||
task: str = (
|
||||
"Given a web search query, retrieve relevant passages that answer the "
|
||||
"query"
|
||||
) #: :meta private:
|
||||
device: Optional[str] = None #: :meta private:
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install transformers`",
|
||||
)
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install torch`",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"), padding_side="left"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"),
|
||||
).eval()
|
||||
device = kwargs.get("device", get_device())
|
||||
if device:
|
||||
model = model.to(device=device)
|
||||
|
||||
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
prefix = (
|
||||
"<|im_start|>system\nJudge whether the Document meets the requirements"
|
||||
" based on the Query and the Instruct provided. Note that the answer can "
|
||||
'only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
|
||||
)
|
||||
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
||||
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
|
||||
|
||||
kwargs["model"] = model
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
kwargs["token_false_id"] = token_false_id
|
||||
kwargs["token_true_id"] = token_true_id
|
||||
kwargs["prefix_tokens"] = prefix_tokens
|
||||
kwargs["suffix_tokens"] = suffix_tokens
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[QwenRerankEmbeddingsParameters]:
|
||||
"""Get the parameter class."""
|
||||
return QwenRerankEmbeddingsParameters
|
||||
|
||||
@classmethod
|
||||
def from_parameters(
|
||||
cls, parameters: QwenRerankEmbeddingsParameters
|
||||
) -> "RerankEmbeddings":
|
||||
"""Create a rerank model from parameters."""
|
||||
return cls(
|
||||
model_name=parameters.real_model_path,
|
||||
device=parameters.real_device,
|
||||
)
|
||||
|
||||
def format_instruction(self, instruction, query, doc):
|
||||
if instruction is None:
|
||||
instruction = (
|
||||
"Given a web search query, retrieve relevant passages that "
|
||||
"answer the query"
|
||||
)
|
||||
output = (
|
||||
"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
||||
instruction=instruction, query=query, doc=doc
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
def process_inputs(self, pairs):
|
||||
inputs = self.tokenizer(
|
||||
pairs,
|
||||
padding=False,
|
||||
truncation="longest_first",
|
||||
return_attention_mask=False,
|
||||
max_length=self.max_length
|
||||
- len(self.prefix_tokens)
|
||||
- len(self.suffix_tokens),
|
||||
)
|
||||
for i, ele in enumerate(inputs["input_ids"]):
|
||||
inputs["input_ids"][i] = self.prefix_tokens + ele + self.suffix_tokens
|
||||
inputs = self.tokenizer.pad(
|
||||
inputs, padding=True, return_tensors="pt", max_length=self.max_length
|
||||
)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
|
||||
def compute_logits(self, inputs, **kwargs):
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, self.token_true_id]
|
||||
false_vector = batch_scores[:, self.token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
queries = [query] * len(candidates)
|
||||
pairs = [
|
||||
self.format_instruction(self.task, query, doc)
|
||||
for query, doc in zip(queries, candidates)
|
||||
]
|
||||
# Tokenize the input texts
|
||||
inputs = self.process_inputs(pairs)
|
||||
scores = self.compute_logits(inputs)
|
||||
return scores
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAPIRerankerDeployModelParameters(RerankerDeployModelParameters):
|
||||
"""OpenAPI Reranker Deploy Model Parameters."""
|
||||
@ -258,7 +457,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
return cls(
|
||||
api_url=parameters.api_url,
|
||||
api_key=parameters.api_key,
|
||||
model_name=parameters.real_provider_model_name,
|
||||
model_name=parameters.real_model_path,
|
||||
timeout=parameters.timeout,
|
||||
)
|
||||
|
||||
@ -579,3 +778,6 @@ register_embedding_adapter(
|
||||
register_embedding_adapter(
|
||||
InfiniAIRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
|
||||
)
|
||||
register_embedding_adapter(
|
||||
QwenRerankEmbeddings, supported_models=RERANKER_COMMON_HF_QWEN_MODELS
|
||||
)
|
||||
|
@ -1,22 +1,22 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
|
||||
|
||||
docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
|
||||
|
||||
mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 184
|
||||
Server version: 5.1.0 3.1.5-5d8438a
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py
|
||||
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu
|
||||
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > "
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 184
|
||||
Server version: 5.1.0 3.1.5-5d8438a
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
|
||||
> create database test;
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
|
||||
> create database test;
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
Loading…
Reference in New Issue
Block a user