feat(model): Support Gitee models (#2257)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt 2024-12-31 19:24:24 +08:00 committed by GitHub
parent 6b4ccc8dfc
commit 576da34e92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 0 deletions

View File

@ -151,6 +151,15 @@ class Config(metaclass=Singleton):
os.environ["siliconflow_proxyllm_api_base"] = os.getenv( os.environ["siliconflow_proxyllm_api_base"] = os.getenv(
"SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1" "SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1"
) )
self.gitee_proxy_api_key = os.getenv("GITEE_API_KEY")
if self.gitee_proxy_api_key:
os.environ["gitee_proxyllm_proxy_api_key"] = self.gitee_proxy_api_key
os.environ["gitee_proxyllm_proxyllm_backend"] = os.getenv(
"GITEE_MODEL_VERSION", "Qwen2.5-72B-Instruct"
)
os.environ["gitee_proxyllm_api_base"] = os.getenv(
"GITEE_API_BASE", "https://ai.gitee.com/v1"
)
self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

View File

@ -81,6 +81,7 @@ LLM_MODEL_CONFIG = {
"deepseek_proxyllm": "deepseek_proxyllm", "deepseek_proxyllm": "deepseek_proxyllm",
# https://docs.siliconflow.cn/quickstart # https://docs.siliconflow.cn/quickstart
"siliconflow_proxyllm": "siliconflow_proxyllm", "siliconflow_proxyllm": "siliconflow_proxyllm",
"gitee_proxyllm": "gitee_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
@ -307,6 +308,7 @@ EMBEDDING_MODEL_CONFIG = {
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
# https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True # https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True
"bge-m3": os.path.join(MODEL_PATH, "bge-m3"), "bge-m3": os.path.join(MODEL_PATH, "bge-m3"),
"bge-large-zh-v1.5": os.path.join(MODEL_PATH, "bge-large-zh-v1.5"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"), "gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"), "gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),

View File

@ -364,6 +364,31 @@ class SiliconFlowProxyLLMModelAdapter(ProxyLLMModelAdapter):
return siliconflow_generate_stream return siliconflow_generate_stream
class GiteeProxyLLMModelAdapter(ProxyLLMModelAdapter):
"""Gitee proxy LLM model adapter.
See Also: `Gitee Documentation <https://ai.gitee.com/docs/getting-started/intro>`_
"""
def support_async(self) -> bool:
return True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "gitee_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.gitee import GiteeLLMClient
return GiteeLLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.gitee import gitee_generate_stream
return gitee_generate_stream
register_model_adapter(OpenAIProxyLLMModelAdapter) register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(ClaudeProxyLLMModelAdapter) register_model_adapter(ClaudeProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter) register_model_adapter(TongyiProxyLLMModelAdapter)
@ -378,3 +403,4 @@ register_model_adapter(YiProxyLLMModelAdapter)
register_model_adapter(MoonshotProxyLLMModelAdapter) register_model_adapter(MoonshotProxyLLMModelAdapter)
register_model_adapter(DeepseekProxyLLMModelAdapter) register_model_adapter(DeepseekProxyLLMModelAdapter)
register_model_adapter(SiliconFlowProxyLLMModelAdapter) register_model_adapter(SiliconFlowProxyLLMModelAdapter)
register_model_adapter(GiteeProxyLLMModelAdapter)

View File

@ -7,6 +7,7 @@ if TYPE_CHECKING:
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.gitee import GiteeLLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
@ -31,6 +32,7 @@ def __lazy_import(name):
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot", "MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
"OllamaLLMClient": "dbgpt.model.proxy.llms.ollama", "OllamaLLMClient": "dbgpt.model.proxy.llms.ollama",
"DeepseekLLMClient": "dbgpt.model.proxy.llms.deepseek", "DeepseekLLMClient": "dbgpt.model.proxy.llms.deepseek",
"GiteeLLMClient": "dbgpt.model.proxy.llms.gitee",
} }
if name in module_path: if name in module_path:
@ -57,4 +59,5 @@ __all__ = [
"MoonshotLLMClient", "MoonshotLLMClient",
"OllamaLLMClient", "OllamaLLMClient",
"DeepseekLLMClient", "DeepseekLLMClient",
"GiteeLLMClient",
] ]

View File

@ -0,0 +1,83 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
_GITEE_DEFAULT_MODEL = "Qwen2.5-72B-Instruct"
async def gitee_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: GiteeLLMClient = model.proxy_llm_client
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
class GiteeLLMClient(OpenAILLMClient):
"""Gitee LLM Client.
Gitee's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "gitee_proxyllm",
context_length: Optional[int] = None,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
):
api_base = api_base or os.getenv("GITEE_API_BASE") or "https://ai.gitee.com/v1"
api_key = api_key or os.getenv("GITEE_API_KEY")
model = model or _GITEE_DEFAULT_MODEL
if not context_length:
if "200k" in model:
context_length = 200 * 1024
else:
context_length = 4096
if not api_key:
raise ValueError(
"Gitee API key is required, please set 'GITEE_API_KEY' in environment "
"or pass it as an argument."
)
super().__init__(
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
model=model,
proxies=proxies,
timeout=timeout,
model_alias=model_alias,
context_length=context_length,
openai_client=openai_client,
openai_kwargs=openai_kwargs,
**kwargs
)
@property
def default_model(self) -> str:
model = self._model
if not model:
model = _GITEE_DEFAULT_MODEL
return model