feat(model): Support moonshot proxy LLM (#1404)

This commit is contained in:
Fangyin Cheng 2024-04-10 23:41:50 +08:00 committed by GitHub
parent 1c6a897137
commit 7d6dfd9ea8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 186 additions and 0 deletions

View File

@ -214,6 +214,10 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
#YI_API_BASE=https://api.lingyiwanwu.com/v1
#YI_API_KEY={your-yi-api-key}
## Moonshot Proxyllm, https://platform.moonshot.cn/docs/
# MOONSHOT_MODEL_VERSION=moonshot-v1-8k
# MOONSHOT_API_BASE=https://api.moonshot.cn/v1
# MOONSHOT_API_KEY={your-moonshot-api-key}
#*******************************************************************#

View File

@ -94,3 +94,6 @@ ignore_missing_imports = True
[mypy-unstructured.*]
ignore_missing_imports = True
[mypy-rich.*]
ignore_missing_imports = True

View File

@ -118,6 +118,16 @@ class Config(metaclass=Singleton):
os.environ["yi_proxyllm_proxy_api_base"] = os.getenv(
"YI_API_BASE", "https://api.lingyiwanwu.com/v1"
)
# Moonshot proxy
self.moonshot_proxy_api_key = os.getenv("MOONSHOT_API_KEY")
if self.moonshot_proxy_api_key:
os.environ["moonshot_proxyllm_proxy_api_key"] = self.moonshot_proxy_api_key
os.environ["moonshot_proxyllm_proxyllm_backend"] = os.getenv(
"MOONSHOT_MODEL_VERSION", "moonshot-v1-8k"
)
os.environ["moonshot_proxyllm_api_base"] = os.getenv(
"MOONSHOT_API_BASE", "https://api.moonshot.cn/v1"
)
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

View File

@ -67,6 +67,8 @@ LLM_MODEL_CONFIG = {
"spark_proxyllm": "spark_proxyllm",
# https://platform.lingyiwanwu.com/docs/
"yi_proxyllm": "yi_proxyllm",
# https://platform.moonshot.cn/docs/
"moonshot_proxyllm": "moonshot_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),

View File

@ -857,3 +857,31 @@ class LLMClient(ABC):
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata
def __call__(self, *args, **kwargs) -> ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
Please do not use this method in the production environment, it is only used
for debugging.
"""
from dbgpt.util import get_or_create_event_loop
messages = kwargs.get("messages")
model = kwargs.get("model")
if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]
if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")
if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))

View File

@ -252,6 +252,31 @@ class YiProxyLLMModelAdapter(ProxyLLMModelAdapter):
return yi_generate_stream
class MoonshotProxyLLMModelAdapter(ProxyLLMModelAdapter):
"""Moonshot proxy LLM model adapter.
See Also: `Moonshot Documentation <https://platform.moonshot.cn/docs/>`_
"""
def support_async(self) -> bool:
return True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path in ["moonshot_proxyllm"]
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
return MoonshotLLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.moonshot import moonshot_generate_stream
return moonshot_generate_stream
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
@ -261,3 +286,4 @@ register_model_adapter(SparkProxyLLMModelAdapter)
register_model_adapter(BardProxyLLMModelAdapter)
register_model_adapter(BaichuanProxyLLMModelAdapter)
register_model_adapter(YiProxyLLMModelAdapter)
register_model_adapter(MoonshotProxyLLMModelAdapter)

View File

@ -10,6 +10,7 @@ def __lazy_import(name):
"WenxinLLMClient": "dbgpt.model.proxy.llms.wenxin",
"ZhipuLLMClient": "dbgpt.model.proxy.llms.zhipu",
"YiLLMClient": "dbgpt.model.proxy.llms.yi",
"MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot",
}
if name in module_path:
@ -31,4 +32,5 @@ __all__ = [
"WenxinLLMClient",
"SparkLLMClient",
"YiLLMClient",
"MoonshotLLMClient",
]

View File

@ -99,6 +99,8 @@ class OpenAILLMClient(ProxyLLMClient):
) from exc
self._openai_version = metadata.version("openai")
self._openai_less_then_v1 = not self._openai_version >= "1.0.0"
self.check_sdk_version(self._openai_version)
self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
@ -141,6 +143,14 @@ class OpenAILLMClient(ProxyLLMClient):
full_url=model_params.proxy_server_url,
)
def check_sdk_version(self, version: str) -> None:
"""Check the sdk version of the client.
Raises:
ValueError: If check failed.
"""
pass
@property
def client(self) -> ClientType:
if self._openai_less_then_v1:

View File

@ -0,0 +1,101 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from .chatgpt import OpenAILLMClient
if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
_MOONSHOT_DEFAULT_MODEL = "moonshot-v1-8k"
async def moonshot_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
async for r in client.generate_stream(request):
yield r
class MoonshotLLMClient(OpenAILLMClient):
"""Moonshot LLM Client.
Moonshot's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = _MOONSHOT_DEFAULT_MODEL,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "moonshot_proxyllm",
context_length: Optional[int] = None,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
api_base = (
api_base or os.getenv("MOONSHOT_API_BASE") or "https://api.moonshot.cn/v1"
)
api_key = api_key or os.getenv("MOONSHOT_API_KEY")
model = model or _MOONSHOT_DEFAULT_MODEL
if not context_length:
if "128k" in model:
context_length = 1024 * 128
elif "32k" in model:
context_length = 1024 * 32
else:
# 8k
context_length = 1024 * 8
if not api_key:
raise ValueError(
"Moonshot API key is required, please set 'MOONSHOT_API_KEY' in "
"environment variable or pass it to the client."
)
super().__init__(
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
model=model,
proxies=proxies,
timeout=timeout,
model_alias=model_alias,
context_length=context_length,
openai_client=openai_client,
openai_kwargs=openai_kwargs,
**kwargs,
)
def check_sdk_version(self, version: str) -> None:
if not version >= "1.0":
raise ValueError(
"Moonshot API requires openai>=1.0, please upgrade it by "
"`pip install --upgrade 'openai>=1.0'`"
)
@property
def default_model(self) -> str:
model = self._model
if not model:
model = _MOONSHOT_DEFAULT_MODEL
return model