mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 08:47:32 +00:00
feat(model): Support glm4.5 models (#2867)
This commit is contained in:
parent
150e84ed18
commit
154b6927fa
@ -35,8 +35,6 @@ initialize_tracer(
|
||||
async def main():
|
||||
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
|
||||
|
||||
agent_memory = AgentMemory()
|
||||
|
||||
llm_client = SiliconFlowLLMClient(
|
||||
model_alias=os.getenv(
|
||||
"SILICONFLOW_MODEL_VERSION", "Qwen/Qwen2.5-Coder-32B-Instruct"
|
||||
|
@ -121,7 +121,7 @@ llama_cpp_server = [
|
||||
"llama-cpp-server-py>=0.1.4",
|
||||
]
|
||||
proxy_ollama = ["ollama"]
|
||||
proxy_zhipuai = ["zhipuai>=2.1.5"]
|
||||
proxy_zhipuai = ["openai>=1.59.6"]
|
||||
proxy_tongyi = [
|
||||
# tongyi supported by openai package
|
||||
"openai",
|
||||
|
@ -2,9 +2,9 @@ import logging
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
ResourceCategory,
|
||||
@ -14,25 +14,30 @@ from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.proxy.base import (
|
||||
AsyncGenerateStreamFunction,
|
||||
GenerateStreamFunction,
|
||||
ProxyLLMClient,
|
||||
register_proxy_model_adapter,
|
||||
)
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .chatgpt import OpenAICompatibleDeployModelParameters
|
||||
from .chatgpt import OpenAICompatibleDeployModelParameters, OpenAILLMClient
|
||||
|
||||
_DEFAULT_MODEL = "glm-4-plus"
|
||||
if TYPE_CHECKING:
|
||||
from httpx._types import ProxiesTypes
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "glm-4.5"
|
||||
|
||||
|
||||
@auto_register_resource(
|
||||
label=_("Zhipu Proxy LLM"),
|
||||
category=ResourceCategory.LLM_CLIENT,
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
description=_("Zhipu proxy LLM configuration."),
|
||||
documentation_url="https://open.bigmodel.cn/dev/api/normal-model/glm-4#overview",
|
||||
documentation_url="https://docs.bigmodel.cn/cn/guide/start/model-overview",
|
||||
show_in_ui=False,
|
||||
)
|
||||
@dataclass
|
||||
@ -57,52 +62,32 @@ class ZhipuDeployModelParameters(OpenAICompatibleDeployModelParameters):
|
||||
)
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
async def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
# TODO: Support convert_to_compatible_format config, zhipu not support system
|
||||
# message
|
||||
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
# history, systems = __convert_2_zhipu_messages(messages)
|
||||
"""Zhipu ai, see: https://docs.bigmodel.cn/cn/guide/start/model-overview"""
|
||||
client: ZhipuLLMClient = model.proxy_llm_client
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
class ZhipuLLMClient(ProxyLLMClient):
|
||||
class ZhipuLLMClient(OpenAILLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = _DEFAULT_MODEL,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
model_alias: Optional[str] = _DEFAULT_MODEL,
|
||||
timeout: Optional[int] = 240,
|
||||
context_length: Optional[int] = 8192,
|
||||
executor: Optional[Executor] = None,
|
||||
openai_client: Optional["ClientType"] = None,
|
||||
openai_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
except ImportError as exc:
|
||||
if (
|
||||
"No module named" in str(exc)
|
||||
or "cannot find module" in str(exc).lower()
|
||||
):
|
||||
raise ValueError(
|
||||
"The python package 'zhipuai' is not installed. "
|
||||
"Please install it by running `pip install zhipuai`."
|
||||
) from exc
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not import python package: zhipuai "
|
||||
"This may be due to a version that is too low. "
|
||||
"Please upgrade the zhipuai package by running "
|
||||
"`pip install --upgrade zhipuai`."
|
||||
) from exc
|
||||
if not model:
|
||||
model = _DEFAULT_MODEL
|
||||
if not api_key:
|
||||
@ -111,13 +96,19 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
|
||||
api_key = self._resolve_env_vars(api_key)
|
||||
api_base = self._resolve_env_vars(api_base)
|
||||
self._model = model
|
||||
self.client = ZhipuAI(api_key=api_key, base_url=api_base)
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_type=api_type,
|
||||
api_version=api_version,
|
||||
model=model,
|
||||
proxies=proxies,
|
||||
timeout=timeout,
|
||||
model_alias=model_alias,
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
openai_client=openai_client,
|
||||
openai_kwargs=openai_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -147,55 +138,20 @@ class ZhipuLLMClient(ProxyLLMClient):
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
|
||||
messages = request.to_common_messages(support_system_role=False)
|
||||
|
||||
model = request.model or self._model
|
||||
try:
|
||||
logger.debug(
|
||||
f"Send request to zhipu ai, model: {model}, request: {request}"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_new_tokens,
|
||||
top_p=request.top_p,
|
||||
stream=True,
|
||||
)
|
||||
partial_text = ""
|
||||
for chunk in response:
|
||||
if not chunk.choices or not chunk.choices[0].delta:
|
||||
continue
|
||||
delta_content = chunk.choices[0].delta.content
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
partial_text += delta_content
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
print(delta_content, end="")
|
||||
yield ModelOutput(
|
||||
text=partial_text, error_code=0, finish_reason=finish_reason
|
||||
)
|
||||
if not partial_text:
|
||||
yield ModelOutput(text="**LLMServer Generate Empty.**", error_code=1)
|
||||
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
return _DEFAULT_MODEL
|
||||
|
||||
|
||||
register_proxy_model_adapter(
|
||||
ZhipuLLMClient,
|
||||
supported_models=[
|
||||
ModelMetadata(
|
||||
model=["glm-4.5", "glm-4.5-air", "glm-4.5-x", "glm-4.5-airx"],
|
||||
context_length=128 * 1024,
|
||||
max_output_length=96 * 1024,
|
||||
description="GLM-4.5 by Zhipu AI",
|
||||
link="https://docs.bigmodel.cn/cn/guide/start/model-overview",
|
||||
function_calling=True,
|
||||
),
|
||||
ModelMetadata(
|
||||
model=["glm-4-plus", "glm-4-air", "glm-4-air-0111"],
|
||||
context_length=128 * 1024,
|
||||
|
Loading…
Reference in New Issue
Block a user