diff --git a/configs/dbgpt-local-qwen3.example.toml b/configs/dbgpt-local-qwen3.example.toml index 31f644e70..7c205ccf2 100644 --- a/configs/dbgpt-local-qwen3.example.toml +++ b/configs/dbgpt-local-qwen3.example.toml @@ -30,8 +30,12 @@ provider = "hf" # reasoning_model = false [[models.embeddings]] -name = "BAAI/bge-large-zh-v1.5" +name = "Qwen/Qwen3-Embedding-0.6B" provider = "hf" # If not provided, the model will be downloaded from the Hugging Face model hub # uncomment the following line to specify the model path in the local file system # path = "the-model-path-in-the-local-file-system" + +[[models.rerankers]] +name = "Qwen/Qwen3-Reranker-0.6B" +provider = "qwen" \ No newline at end of file diff --git a/configs/dbgpt-proxy-infiniai.toml b/configs/dbgpt-proxy-infiniai.toml index 8c0109e59..ddfa1a9d2 100644 --- a/configs/dbgpt-proxy-infiniai.toml +++ b/configs/dbgpt-proxy-infiniai.toml @@ -34,7 +34,6 @@ api_url = "https://cloud.infini-ai.com/maas/v1/embeddings" api_key = "${env:INFINIAI_API_KEY}" [[models.rerankers]] -type = "reranker" name = "bge-reranker-v2-m3" provider = "proxy/infiniai" api_key = "${env:INFINIAI_API_KEY}" \ No newline at end of file diff --git a/configs/dbgpt-proxy-siliconflow-mysql.toml b/configs/dbgpt-proxy-siliconflow-mysql.toml index 60cba4210..1e166e6f8 100644 --- a/configs/dbgpt-proxy-siliconflow-mysql.toml +++ b/configs/dbgpt-proxy-siliconflow-mysql.toml @@ -38,7 +38,6 @@ provider = "proxy/siliconflow" api_key = "${env:SILICONFLOW_API_KEY}" [[models.rerankers]] -type = "reranker" name = "BAAI/bge-reranker-v2-m3" provider = "proxy/siliconflow" api_key = "${env:SILICONFLOW_API_KEY}" diff --git a/configs/dbgpt-proxy-siliconflow.toml b/configs/dbgpt-proxy-siliconflow.toml index ac1731903..9d8d63c6e 100644 --- a/configs/dbgpt-proxy-siliconflow.toml +++ b/configs/dbgpt-proxy-siliconflow.toml @@ -34,7 +34,6 @@ api_url = "https://api.siliconflow.cn/v1/embeddings" api_key = "${env:SILICONFLOW_API_KEY}" [[models.rerankers]] -type = "reranker" name = "BAAI/bge-reranker-v2-m3" provider = "proxy/siliconflow" api_key = "${env:SILICONFLOW_API_KEY}" diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py b/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py index 59755a043..64906b8cc 100644 --- a/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py +++ b/packages/dbgpt-core/src/dbgpt/model/adapter/embed_metadata.py @@ -18,6 +18,48 @@ def _register_reranker_common_hf_models(models: List[EmbeddingModelMetadata]) -> RERANKER_COMMON_HF_MODELS.extend(models) +EMBED_COMMON_HF_QWEN_MODELS = [ + EmbeddingModelMetadata( + model=[ + "Qwen/Qwen3-Embedding-0.6B", + ], + dimension=1024, + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Embedding-0.6B is a multilingual embedding model trained by " + "Qwen team, supporting more than 100 languages. It has 0.6B parameters " + "and a context length of 32k tokens and the dimension is 1024." + ), + link="https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", + ), + EmbeddingModelMetadata( + model=[ + "Qwen/Qwen3-Embedding-4B", + ], + dimension=2560, + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Embedding-4B is a multilingual embedding model trained by " + "Qwen team, supporting more than 100 languages. It has 4B parameters " + "and a context length of 32k tokens and the dimension is 2560." + ), + link="https://huggingface.co/Qwen/Qwen3-Embedding-4B", + ), + EmbeddingModelMetadata( + model=[ + "Qwen/Qwen3-Embedding-8B", + ], + dimension=4096, + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Embedding-8B is a multilingual embedding model trained by " + "Qwen team, supporting more than 100 languages. It has 8B parameters " + "and a context length of 32k tokens and the dimension is 4096." + ), + link="https://huggingface.co/Qwen/Qwen3-Embedding-8B", + ), +] + EMBED_COMMON_HF_BGE_MODELS = [ EmbeddingModelMetadata( model=["BAAI/bge-m3"], @@ -58,6 +100,43 @@ EMBED_COMMON_HF_JINA_MODELS = [ ), ] + +RERANKER_COMMON_HF_QWEN_MODELS = [ + EmbeddingModelMetadata( + model=["Qwen/Qwen3-Reranker-0.6B"], + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Reranker-0.6B is a multilingual reranker model trained by " + "Qwen team, supporting more than 100 languages. It has 0.6B parameters " + "and a context length of 32k tokens." + ), + link="https://huggingface.co/Qwen/Qwen3-Reranker-0.6B", + is_reranker=True, + ), + EmbeddingModelMetadata( + model=["Qwen/Qwen3-Reranker-4B"], + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Reranker-4B is a multilingual reranker model trained by " + "Qwen team, supporting more than 100 languages. It has 4B parameters " + "and a context length of 32k tokens." + ), + link="https://huggingface.co/Qwen/Qwen3-Reranker-4B", + is_reranker=True, + ), + EmbeddingModelMetadata( + model=["Qwen/Qwen3-Reranker-8B"], + context_length=32 * 1024, # 32k context length + description=_( + "Qwen3-Reranker-8B is a multilingual reranker model trained by " + "Qwen team, supporting more than 100 languages. It has 8B parameters " + "and a context length of 32k tokens." + ), + link="https://huggingface.co/Qwen/Qwen3-Reranker-8B", + is_reranker=True, + ), +] + RERANKER_COMMON_HF_BGE_MODELS = [ EmbeddingModelMetadata( model=["BAAI/bge-reranker-v2-m3"], @@ -91,6 +170,7 @@ RERANKER_COMMON_HF_JINA_MODELS = [ _register_embed_common_hf_models(EMBED_COMMON_HF_BGE_MODELS) _register_embed_common_hf_models(EMBED_COMMON_HF_JINA_MODELS) +_register_embed_common_hf_models(EMBED_COMMON_HF_QWEN_MODELS) # Register reranker models _register_reranker_common_hf_models(RERANKER_COMMON_HF_BGE_MODELS) diff --git a/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py b/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py index 4ab13d242..b4235aa08 100644 --- a/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py +++ b/packages/dbgpt-core/src/dbgpt/rag/embedding/embeddings.py @@ -14,6 +14,7 @@ from dbgpt.model.adapter.base import register_embedding_adapter from dbgpt.model.adapter.embed_metadata import ( EMBED_COMMON_HF_BGE_MODELS, EMBED_COMMON_HF_JINA_MODELS, + EMBED_COMMON_HF_QWEN_MODELS, ) from dbgpt.util.i18n_utils import _ from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer @@ -958,7 +959,8 @@ register_embedding_adapter( languages=["zh"], ), ] - + EMBED_COMMON_HF_JINA_MODELS, + + EMBED_COMMON_HF_JINA_MODELS + + EMBED_COMMON_HF_QWEN_MODELS, ) register_embedding_adapter( HuggingFaceInstructEmbeddings, diff --git a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py index caf986b7f..e0b755b34 100644 --- a/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py +++ b/packages/dbgpt-core/src/dbgpt/rag/embedding/rerank.py @@ -9,10 +9,14 @@ import numpy as np import requests from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field +from dbgpt.configs.model_config import get_device from dbgpt.core import RerankEmbeddings from dbgpt.core.interface.parameter import RerankerDeployModelParameters from dbgpt.model.adapter.base import register_embedding_adapter -from dbgpt.model.adapter.embed_metadata import RERANKER_COMMON_HF_MODELS +from dbgpt.model.adapter.embed_metadata import ( + RERANKER_COMMON_HF_MODELS, + RERANKER_COMMON_HF_QWEN_MODELS, +) from dbgpt.util.i18n_utils import _ from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer @@ -122,7 +126,7 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings): ) -> "RerankEmbeddings": """Create a rerank model from parameters.""" return cls( - model_name=parameters.real_provider_model_name, + model_name=parameters.real_model_path, max_length=parameters.max_length, model_kwargs=parameters.real_model_kwargs, ) @@ -161,6 +165,201 @@ class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings): return rank_scores # type: ignore +@dataclass +class QwenRerankEmbeddingsParameters(RerankerDeployModelParameters): + """Qwen Rerank Embeddings Parameters.""" + + provider: str = "qwen" + path: Optional[str] = field( + default=None, + metadata={ + "order": -800, + "help": _( + "The path of the model, if you want to deploy a local model. Defaults " + "to 'Qwen/Qwen3-Reranker-0.6B'." + ), + }, + ) + device: Optional[str] = field( + default=None, + metadata={ + "order": -700, + "help": _( + "Device to run model. If None, the device is automatically determined" + ), + }, + ) + + @property + def real_provider_model_name(self) -> str: + """Get the real provider model name.""" + return self.path or self.name + + @property + def real_model_path(self) -> Optional[str]: + """Get the real model path. + + If deploy model is not local, return None. + """ + return self._resolve_root_path(self.path) + + @property + def real_device(self) -> Optional[str]: + """Get the real device.""" + return self.device or super().real_device + + +class QwenRerankEmbeddings(BaseModel, RerankEmbeddings): + """Qwen Rerank Embeddings. + + .. code-block:: python + from dbgpt.rag.embedding.rerank import QwenRerankEmbeddings + + reranker = QwenRerankEmbeddings(model_name="Qwen/Qwen3-Reranker-0.6B") + query = "Apple" + documents = [ + "apple", + "banana", + "fruit", + "vegetable", + "苹果", + "Mac OS", + "乔布斯", + "iPhone", + ] + scores = reranker.predict(query, candidates) + print(scores) + """ + + model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=()) + model_name: str = "Qwen/Qwen3-Reranker-0.6B" + max_length: int = 8192 + model: Any #: :meta private: + tokenizer: Any #: :meta private: + token_false_id: int #: :meta private: + token_true_id: int #: :meta private: + prefix_tokens: List[int] #: :meta private: + suffix_tokens: List[int] #: :meta private: + task: str = ( + "Given a web search query, retrieve relevant passages that answer the " + "query" + ) #: :meta private: + device: Optional[str] = None #: :meta private: + + def __init__(self, **kwargs: Any): + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "please `pip install transformers`", + ) + + try: + import torch # noqa: F401 + except ImportError: + raise ImportError( + "please `pip install torch`", + ) + tokenizer = AutoTokenizer.from_pretrained( + kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"), padding_side="left" + ) + model = AutoModelForCausalLM.from_pretrained( + kwargs.get("model_name", "Qwen/Qwen3-Reranker-0.6B"), + ).eval() + device = kwargs.get("device", get_device()) + if device: + model = model.to(device=device) + + token_false_id = tokenizer.convert_tokens_to_ids("no") + token_true_id = tokenizer.convert_tokens_to_ids("yes") + prefix = ( + "<|im_start|>system\nJudge whether the Document meets the requirements" + " based on the Query and the Instruct provided. Note that the answer can " + 'only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' + ) + suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) + suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + + kwargs["model"] = model + kwargs["tokenizer"] = tokenizer + kwargs["token_false_id"] = token_false_id + kwargs["token_true_id"] = token_true_id + kwargs["prefix_tokens"] = prefix_tokens + kwargs["suffix_tokens"] = suffix_tokens + super().__init__(**kwargs) + + @classmethod + def param_class(cls) -> Type[QwenRerankEmbeddingsParameters]: + """Get the parameter class.""" + return QwenRerankEmbeddingsParameters + + @classmethod + def from_parameters( + cls, parameters: QwenRerankEmbeddingsParameters + ) -> "RerankEmbeddings": + """Create a rerank model from parameters.""" + return cls( + model_name=parameters.real_model_path, + device=parameters.real_device, + ) + + def format_instruction(self, instruction, query, doc): + if instruction is None: + instruction = ( + "Given a web search query, retrieve relevant passages that " + "answer the query" + ) + output = ( + ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + ) + return output + + def process_inputs(self, pairs): + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=self.max_length + - len(self.prefix_tokens) + - len(self.suffix_tokens), + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = self.prefix_tokens + ele + self.suffix_tokens + inputs = self.tokenizer.pad( + inputs, padding=True, return_tensors="pt", max_length=self.max_length + ) + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + def compute_logits(self, inputs, **kwargs): + import torch + + with torch.no_grad(): + batch_scores = self.model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def predict(self, query: str, candidates: List[str]) -> List[float]: + queries = [query] * len(candidates) + pairs = [ + self.format_instruction(self.task, query, doc) + for query, doc in zip(queries, candidates) + ] + # Tokenize the input texts + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) + return scores + + @dataclass class OpenAPIRerankerDeployModelParameters(RerankerDeployModelParameters): """OpenAPI Reranker Deploy Model Parameters.""" @@ -258,7 +457,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings): return cls( api_url=parameters.api_url, api_key=parameters.api_key, - model_name=parameters.real_provider_model_name, + model_name=parameters.real_model_path, timeout=parameters.timeout, ) @@ -579,3 +778,6 @@ register_embedding_adapter( register_embedding_adapter( InfiniAIRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS ) +register_embedding_adapter( + QwenRerankEmbeddings, supported_models=RERANKER_COMMON_HF_QWEN_MODELS +) diff --git a/tests/intetration_tests/datasource/test_conn_starrocks.py b/tests/intetration_tests/datasource/test_conn_starrocks.py index afc1b5dae..b358ccccc 100644 --- a/tests/intetration_tests/datasource/test_conn_starrocks.py +++ b/tests/intetration_tests/datasource/test_conn_starrocks.py @@ -1,22 +1,22 @@ """ - Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py - - docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu - - mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > " - Welcome to the MySQL monitor. Commands end with ; or \g. - Your MySQL connection id is 184 - Server version: 5.1.0 3.1.5-5d8438a +Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_starrocks.py - Copyright (c) 2000, 2023, Oracle and/or its affiliates. +docker run -p 9030:9030 -p 8030:8030 -p 8040:8040 -itd --name quickstart starrocks/allin1-ubuntu - Oracle is a registered trademark of Oracle Corporation and/or its - affiliates. Other names may be trademarks of their respective - owners. +mysql -P 9030 -h 127.0.0.1 -u root --prompt="StarRocks > " +Welcome to the MySQL monitor. Commands end with ; or \g. +Your MySQL connection id is 184 +Server version: 5.1.0 3.1.5-5d8438a - Type 'help;' or '\h' for help. Type '\c' to clear the current input statement. +Copyright (c) 2000, 2023, Oracle and/or its affiliates. - > create database test; +Oracle is a registered trademark of Oracle Corporation and/or its +affiliates. Other names may be trademarks of their respective +owners. + +Type 'help;' or '\h' for help. Type '\c' to clear the current input statement. + +> create database test; """ import pytest