feat: add model provider InfiniAI (#2653)

Co-authored-by: yaozhuyu <yaozhuyu@infini-ai.com>
This commit is contained in:
paxionfruit 2025-04-27 16:22:31 +08:00 committed by GitHub
parent 1b77ed6319
commit 445076b433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 306 additions and 0 deletions

View File

@ -0,0 +1,40 @@
[system]
# Load language from environment variable(It is set by the hook)
language = "${env:DBGPT_LANG:-zh}"
api_keys = []
encrypt_key = "your_secret_key"
# Server Configurations
[service.web]
host = "0.0.0.0"
port = 5670
[service.web.database]
type = "sqlite"
path = "pilot/meta_data/dbgpt.db"
[service.model.worker]
host = "127.0.0.1"
[rag.storage]
[rag.storage.vector]
type = "chroma"
persist_path = "pilot/data"
# Model Configurations
[models]
[[models.llms]]
name = "deepseek-v3"
provider = "proxy/infiniai"
api_key = "${env:INFINIAI_API_KEY}"
[[models.embeddings]]
name = "bge-m3"
provider = "proxy/openai"
api_url = "https://cloud.infini-ai.com/maas/v1/embeddings"
api_key = "${env:INFINIAI_API_KEY}"
[[models.rerankers]]
type = "reranker"
name = "bge-reranker-v2-m3"
provider = "proxy/infiniai"
api_key = "${env:INFINIAI_API_KEY}"

View File

@ -8,6 +8,7 @@ if TYPE_CHECKING:
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.gitee import GiteeLLMClient from dbgpt.model.proxy.llms.gitee import GiteeLLMClient
from dbgpt.model.proxy.llms.infiniai import InfiniAILLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
@ -33,6 +34,7 @@ def __lazy_import(name):
"OllamaLLMClient": "dbgpt.model.proxy.llms.ollama", "OllamaLLMClient": "dbgpt.model.proxy.llms.ollama",
"DeepseekLLMClient": "dbgpt.model.proxy.llms.deepseek", "DeepseekLLMClient": "dbgpt.model.proxy.llms.deepseek",
"GiteeLLMClient": "dbgpt.model.proxy.llms.gitee", "GiteeLLMClient": "dbgpt.model.proxy.llms.gitee",
"InfiniAILLMClient": "dbgpt.model.proxy.llms.infiniai",
} }
if name in module_path: if name in module_path:
@ -60,4 +62,5 @@ __all__ = [
"OllamaLLMClient", "OllamaLLMClient",
"DeepseekLLMClient", "DeepseekLLMClient",
"GiteeLLMClient", "GiteeLLMClient",
"InfiniAILLMClient",
] ]

View File

@ -0,0 +1,187 @@
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
from dbgpt.core import ModelMetadata
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
ResourceCategory,
auto_register_resource,
)
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.util.i18n_utils import _
from ..base import (
AsyncGenerateStreamFunction,
GenerateStreamFunction,
register_proxy_model_adapter,
)
from .chatgpt import OpenAICompatibleDeployModelParameters, OpenAILLMClient
if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
_INFINIAI_DEFAULT_MODEL = "deepseek-v3"
@auto_register_resource(
label=_("InfiniAI Proxy LLM"),
category=ResourceCategory.LLM_CLIENT,
tags={"order": TAGS_ORDER_HIGH},
description=_("InfiniAI proxy LLM configuration."),
documentation_url="https://docs.infini-ai.com/gen-studio/api/tutorial.html", # noqa
show_in_ui=False,
)
@dataclass
class InfiniAIDeployModelParameters(OpenAICompatibleDeployModelParameters):
"""Deploy model parameters for InfiniAI."""
provider: str = "proxy/infiniai"
api_base: Optional[str] = field(
default="${env:INFINIAI_API_BASE:-https://cloud.infini-ai.com/maas/v1}",
metadata={
"help": _("The base url of the InfiniAI API."),
},
)
api_key: Optional[str] = field(
default="${env:INFINIAI_API_KEY}",
metadata={
"help": _("The API key of the InfiniAI API."),
"tags": "privacy",
},
)
async def infiniai_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: InfiniAILLMClient = model.proxy_llm_client
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
class InfiniAILLMClient(OpenAILLMClient):
"""InfiniAI LLM Client.
InfiniAI's API is compatible with OpenAI's API, so we inherit from
OpenAILLMClient.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = _INFINIAI_DEFAULT_MODEL,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = _INFINIAI_DEFAULT_MODEL,
context_length: Optional[int] = None,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
api_base = (
api_base
or os.getenv("INFINIAI_API_BASE")
or "https://cloud.infini-ai.com/maas/v1"
)
api_key = api_key or os.getenv("INFINIAI_API_KEY")
model = model or _INFINIAI_DEFAULT_MODEL
if not context_length:
if "200k" in model:
context_length = 200 * 1024
else:
context_length = 4096
if not api_key:
raise ValueError(
"InfiniAI API key is required, please set 'INFINIAI_API_KEY' in "
"environment or pass it as an argument."
)
super().__init__(
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
model=model,
proxies=proxies,
timeout=timeout,
model_alias=model_alias,
context_length=context_length,
openai_client=openai_client,
openai_kwargs=openai_kwargs,
**kwargs,
)
@property
def default_model(self) -> str:
model = self._model
if not model:
model = _INFINIAI_DEFAULT_MODEL
return model
@classmethod
def param_class(cls) -> Type[InfiniAIDeployModelParameters]:
return InfiniAIDeployModelParameters
@classmethod
def generate_stream_function(
cls,
) -> Optional[Union[GenerateStreamFunction, AsyncGenerateStreamFunction]]:
return infiniai_generate_stream
register_proxy_model_adapter(
InfiniAILLMClient,
supported_models=[
ModelMetadata(
model=["deepseek-v3"],
context_length=64 * 1024,
max_output_length=8 * 1024,
description="DeepSeek-V3 by DeepSeek",
link="https://cloud.infini-ai.com/genstudio/model",
function_calling=True,
),
ModelMetadata(
model=["deepseek-r1"],
context_length=64 * 1024,
max_output_length=8 * 1024,
description="DeepSeek-V3 by DeepSeek",
link="https://cloud.infini-ai.com/genstudio/model",
function_calling=False,
),
ModelMetadata(
model=["qwq-32b"],
context_length=64 * 1024,
max_output_length=8 * 1024,
description="qwq By Qwen",
link="https://cloud.infini-ai.com/genstudio/model",
function_calling=True,
),
ModelMetadata(
model=[
"qwen2.5-72b-instruct",
"qwen2.5-32b-instruct",
"qwen2.5-14b-instruct",
"qwen2.5-7b-instruct",
"qwen2.5-coder-32b-instruct",
],
context_length=32 * 1024,
max_output_length=4 * 1024,
description="Qwen 2.5 By Qwen",
link="https://cloud.infini-ai.com/genstudio/model",
function_calling=True,
),
# More models see: https://cloud.infiniai.cn/models
],
)

View File

@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
) )
from .rerank import ( # noqa: F401 from .rerank import ( # noqa: F401
CrossEncoderRerankEmbeddings, CrossEncoderRerankEmbeddings,
InfiniAIRerankEmbeddings,
OpenAPIRerankEmbeddings, OpenAPIRerankEmbeddings,
SiliconFlowRerankEmbeddings, SiliconFlowRerankEmbeddings,
) )
@ -31,5 +32,6 @@ __ALL__ = [
"OpenAPIEmbeddings", "OpenAPIEmbeddings",
"OpenAPIRerankEmbeddings", "OpenAPIRerankEmbeddings",
"SiliconFlowRerankEmbeddings", "SiliconFlowRerankEmbeddings",
"InfiniAIRerankEmbeddings",
"WrappedEmbeddingFactory", "WrappedEmbeddingFactory",
] ]

View File

@ -493,6 +493,77 @@ class TeiRerankEmbeddings(OpenAPIRerankEmbeddings):
return self._parse_results(response_data) return self._parse_results(response_data)
@dataclass
class InfiniAIRerankEmbeddingsParameters(OpenAPIRerankerDeployModelParameters):
"""InfiniAI Rerank Embeddings Parameters."""
provider: str = "proxy/infiniai"
api_url: str = field(
default="https://cloud.infini-ai.com/maas/v1/rerank",
metadata={
"help": _("The URL of the rerank API."),
},
)
api_key: Optional[str] = field(
default="${env:INFINIAI_API_KEY}",
metadata={
"help": _("The API key for the rerank API."),
},
)
class InfiniAIRerankEmbeddings(OpenAPIRerankEmbeddings):
"""InfiniAI Rerank Model.
See `InfiniAI API
<https://docs.infini-ai.com/gen-studio/api/tutorial-rerank.html>`_ for more details.
"""
def __init__(self, **kwargs: Any):
"""Initialize the InfiniAIRerankEmbeddings."""
# If the API key is not provided, try to get it from the environment
if "api_key" not in kwargs:
kwargs["api_key"] = os.getenv("InfiniAI_API_KEY")
if "api_url" not in kwargs:
env_api_url = os.getenv("InfiniAI_API_BASE")
if env_api_url:
env_api_url = env_api_url.rstrip("/")
kwargs["api_url"] = env_api_url + "/rerank"
else:
kwargs["api_url"] = "https://cloud.infini-ai.com/maas/v1/rerank"
if "model_name" not in kwargs:
kwargs["model_name"] = "bge-reranker-v2-m3"
super().__init__(**kwargs)
@classmethod
def param_class(cls) -> Type[InfiniAIRerankEmbeddingsParameters]:
"""Get the parameter class."""
return InfiniAIRerankEmbeddingsParameters
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
results = response.get("results")
if not results:
raise RuntimeError("Cannot find results in the response")
if not isinstance(results, list):
raise RuntimeError("Results should be a list")
# Sort by index, 0 in the first element
results = sorted(results, key=lambda x: x.get("index", 0))
scores = [float(result.get("relevance_score")) for result in results]
return scores
register_embedding_adapter( register_embedding_adapter(
CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS CrossEncoderRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
) )
@ -505,3 +576,6 @@ register_embedding_adapter(
register_embedding_adapter( register_embedding_adapter(
TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS TeiRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
) )
register_embedding_adapter(
InfiniAIRerankEmbeddings, supported_models=RERANKER_COMMON_HF_MODELS
)