mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -13,6 +13,7 @@ BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k"
|
||||
def baichuan_generate_stream(
|
||||
model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096
|
||||
):
|
||||
# TODO: Support new Baichuan ProxyLLMClient
|
||||
url = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
|
||||
model_params = model.get_params()
|
||||
|
@@ -9,6 +9,7 @@ from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
def bard_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
# TODO: Support new bard ProxyLLMClient
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
|
@@ -1,259 +1,231 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata as metadata
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from concurrent.futures import Executor
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMetadata,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx._types import ProxiesTypes
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _initialize_openai(params: ProxyModelParameters):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai` "
|
||||
) from exc
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
api_base = params.proxy_api_base or os.getenv(
|
||||
"OPENAI_API_TYPE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.proxy_api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not api_base and params.proxy_server_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
api_base = params.proxy_server_url.split("/chat/completions")[0]
|
||||
if api_type:
|
||||
openai.api_type = api_type
|
||||
if api_base:
|
||||
openai.api_base = api_base
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
if api_version:
|
||||
openai.api_version = api_version
|
||||
if params.http_proxy:
|
||||
openai.proxy = params.http_proxy
|
||||
|
||||
openai_params = {
|
||||
"api_type": api_type,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"proxy": params.http_proxy,
|
||||
}
|
||||
|
||||
return openai_params
|
||||
|
||||
|
||||
def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
)
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
base_url = params.proxy_api_base or os.getenv(
|
||||
"OPENAI_API_BASE",
|
||||
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||
)
|
||||
api_key = params.proxy_api_key or os.getenv(
|
||||
"OPENAI_API_KEY",
|
||||
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||
)
|
||||
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||
|
||||
if not base_url and params.proxy_server_url:
|
||||
# Adapt previous proxy_server_url configuration
|
||||
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
||||
|
||||
proxies = params.http_proxy
|
||||
openai_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
return openai_params, api_type, api_version, proxies
|
||||
|
||||
|
||||
def __convert_2_gpt_messages(messages: List[ModelMessage]):
|
||||
gpt_messages = []
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
# TODO: We can't change message order in low level
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI or message.role == "assistant":
|
||||
last_ai_message = message.content
|
||||
gpt_messages.append({"role": "user", "content": last_usr_message})
|
||||
gpt_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) < 2:
|
||||
gpt_messages.insert(0, {"role": "system", "content": system_messages[0]})
|
||||
gpt_messages.append({"role": "user", "content": last_usr_message})
|
||||
else:
|
||||
gpt_messages.append({"role": "user", "content": system_messages[1]})
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
gpt_messages.append({"role": "user", "content": last_message.content})
|
||||
|
||||
return gpt_messages
|
||||
|
||||
|
||||
def _build_request(model: ProxyModel, params):
|
||||
model_params = model.get_params()
|
||||
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
# history = __convert_2_gpt_messages(messages)
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
history = ModelMessage.to_openai_messages(
|
||||
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||
)
|
||||
payloads = {
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
else:
|
||||
openai_params = _initialize_openai(model_params)
|
||||
if openai_params["api_type"] == "azure":
|
||||
# engine = "deployment_name".
|
||||
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
|
||||
payloads["engine"] = proxyllm_backend
|
||||
else:
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
|
||||
logger.info(f"Send request to real model {proxyllm_backend}")
|
||||
return history, payloads
|
||||
|
||||
|
||||
def chatgpt_generate_stream(
|
||||
async def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
client: OpenAILLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
class OpenAILLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = "chatgpt_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
openai_client: Optional["ClientType"] = None,
|
||||
openai_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
) from exc
|
||||
self._openai_version = metadata.version("openai")
|
||||
self._openai_less_then_v1 = not self._openai_version >= "1.0.0"
|
||||
self._init_params = OpenAIParameters(
|
||||
api_type=api_type,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
proxies=proxies,
|
||||
full_url=kwargs.get("full_url"),
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AzureOpenAI
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.Client(proxies=proxies),
|
||||
self._model = model
|
||||
self._proxies = proxies
|
||||
self._timeout = timeout
|
||||
self._model_alias = model_alias
|
||||
self._context_length = context_length
|
||||
self._api_type = api_type
|
||||
self._client = openai_client
|
||||
self._openai_kwargs = openai_kwargs or {}
|
||||
super().__init__(model_names=[model_alias], context_length=context_length)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "OpenAILLMClient":
|
||||
return cls(
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_base=model_params.proxy_api_base,
|
||||
api_type=model_params.proxy_api_type,
|
||||
api_version=model_params.proxy_api_version,
|
||||
model=model_params.proxyllm_backend,
|
||||
proxies=model_params.http_proxy,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=max(model_params.max_context_size, 8192),
|
||||
full_url=model_params.proxy_server_url,
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> ClientType:
|
||||
if self._openai_less_then_v1:
|
||||
raise ValueError(
|
||||
"Current model (Load by OpenAILLMClient) require openai.__version__>=1.0.0"
|
||||
)
|
||||
else:
|
||||
from openai import OpenAI
|
||||
if self._client is None:
|
||||
from dbgpt.model.utils.chatgpt_utils import _build_openai_client
|
||||
|
||||
client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
|
||||
res = client.chat.completions.create(messages=history, **payloads)
|
||||
self._api_type, self._client = _build_openai_client(
|
||||
init_params=self._init_params
|
||||
)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
model = self._model
|
||||
if not model:
|
||||
model = "gpt-35-turbo" if self._api_type == "azure" else "gpt-3.5-turbo"
|
||||
return model
|
||||
|
||||
def _build_request(
|
||||
self, request: ModelRequest, stream: Optional[bool] = False
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"stream": stream}
|
||||
model = request.model or self.default_model
|
||||
if self._openai_less_then_v1 and self._api_type == "azure":
|
||||
payload["engine"] = model
|
||||
else:
|
||||
payload["model"] = model
|
||||
# Apply openai kwargs
|
||||
for k, v in self._openai_kwargs.items():
|
||||
payload[k] = v
|
||||
if request.temperature:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_new_tokens:
|
||||
payload["max_tokens"] = request.max_new_tokens
|
||||
return payload
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages = request.to_common_messages()
|
||||
payload = self._build_request(request)
|
||||
logger.info(
|
||||
f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}"
|
||||
)
|
||||
try:
|
||||
if self._openai_less_then_v1:
|
||||
return await self.generate_less_then_v1(messages, payload)
|
||||
else:
|
||||
return await self.generate_v1(messages, payload)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages = request.to_common_messages()
|
||||
payload = self._build_request(request, stream=True)
|
||||
logger.info(
|
||||
f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}"
|
||||
)
|
||||
if self._openai_less_then_v1:
|
||||
async for r in self.generate_stream_less_then_v1(messages, payload):
|
||||
yield r
|
||||
else:
|
||||
async for r in self.generate_stream_v1(messages, payload):
|
||||
yield r
|
||||
|
||||
async def generate_v1(
|
||||
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
|
||||
) -> ModelOutput:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = chat_completion.choices[0].message.content
|
||||
usage = chat_completion.usage.dict()
|
||||
return ModelOutput(text=text, error_code=0, usage=usage)
|
||||
|
||||
async def generate_less_then_v1(
|
||||
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
|
||||
) -> ModelOutput:
|
||||
import openai
|
||||
|
||||
chat_completion = await openai.ChatCompletion.acreate(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = chat_completion.choices[0].message.content
|
||||
usage = chat_completion.usage.to_dict()
|
||||
return ModelOutput(text=text, error_code=0, usage=usage)
|
||||
|
||||
async def generate_stream_v1(
|
||||
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = ""
|
||||
for r in res:
|
||||
# logger.info(str(r))
|
||||
# Azure Openai reponse may have empty choices body in the first chunk
|
||||
# to avoid index out of range error
|
||||
async for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield text
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
|
||||
else:
|
||||
async def generate_stream_less_then_v1(
|
||||
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
|
||||
text = ""
|
||||
for r in res:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
|
||||
|
||||
async def async_chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.AsyncClient(proxies=proxies),
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
|
||||
)
|
||||
|
||||
res = await client.chat.completions.create(messages=history, **payloads)
|
||||
text = ""
|
||||
for r in res:
|
||||
if not r.get("choices"):
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield text
|
||||
else:
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||
|
||||
res = await openai.ChatCompletion.acreate(messages=messages, **payload)
|
||||
text = ""
|
||||
async for r in res:
|
||||
if not r.get("choices"):
|
||||
@@ -261,4 +233,21 @@ async def async_chatgpt_generate_stream(
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
model_metadata = ModelMetadata(
|
||||
model=self._model_alias,
|
||||
context_length=await self.get_context_length(),
|
||||
)
|
||||
return [model_metadata]
|
||||
|
||||
async def get_context_length(self) -> int:
|
||||
"""Get the context length of the model.
|
||||
|
||||
Returns:
|
||||
int: The context length.
|
||||
# TODO: This is a temporary solution. We should have a better way to get the context length.
|
||||
eg. get real context length from the openai api.
|
||||
"""
|
||||
return self._context_length
|
||||
|
@@ -1,72 +1,54 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, parse_model_messages
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMessage,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.interface.message import parse_model_messages
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
GEMINI_DEFAULT_MODEL = "gemini-pro"
|
||||
|
||||
safety_settings = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def gemini_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
# TODO proxy model use unified config?
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 1,
|
||||
"top_k": 1,
|
||||
"max_output_tokens": 2048,
|
||||
}
|
||||
|
||||
safety_settings = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
]
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
if model_params.proxy_api_base:
|
||||
from google.api_core import client_options
|
||||
|
||||
client_opts = client_options.ClientOptions(
|
||||
api_endpoint=model_params.proxy_api_base
|
||||
)
|
||||
genai.configure(
|
||||
api_key=proxy_api_key, transport="rest", client_options=client_opts
|
||||
)
|
||||
else:
|
||||
genai.configure(api_key=proxy_api_key)
|
||||
model = genai.GenerativeModel(
|
||||
model_name=proxyllm_backend,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
client: GeminiLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
|
||||
chat = model.start_chat(history=gemini_hist)
|
||||
response = chat.send_message(user_prompt, stream=True)
|
||||
text = ""
|
||||
for chunk in response:
|
||||
text += chunk.text
|
||||
print(text)
|
||||
yield text
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
def _transform_to_gemini_messages(
|
||||
@@ -97,12 +79,104 @@ def _transform_to_gemini_messages(
|
||||
{"role": "model", "parts": {"text": "Hi there!"}},
|
||||
]
|
||||
"""
|
||||
# TODO raise error if messages has system message
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
if system_messages:
|
||||
user_prompt = "".join(system_messages) + "\n" + user_prompt
|
||||
raise ValueError("Gemini does not support system role")
|
||||
gemini_hist = []
|
||||
if history_messages:
|
||||
for user_message, model_message in history_messages:
|
||||
gemini_hist.append({"role": "user", "parts": {"text": user_message}})
|
||||
gemini_hist.append({"role": "model", "parts": {"text": model_message}})
|
||||
return user_prompt, gemini_hist
|
||||
|
||||
|
||||
class GeminiLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_alias: Optional[str] = "gemini_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: generativeai "
|
||||
"Please install dashscope by command `pip install google-generativeai"
|
||||
) from exc
|
||||
if not model:
|
||||
model = GEMINI_DEFAULT_MODEL
|
||||
self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY")
|
||||
self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE")
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
if not self._api_key:
|
||||
raise RuntimeError("api_key can't be empty")
|
||||
|
||||
if self._api_base:
|
||||
from google.api_core import client_options
|
||||
|
||||
client_opts = client_options.ClientOptions(api_endpoint=self._api_base)
|
||||
genai.configure(
|
||||
api_key=self._api_key, transport="rest", client_options=client_opts
|
||||
)
|
||||
else:
|
||||
genai.configure(api_key=self._api_key)
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "GeminiLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_base=model_params.proxy_api_base,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
generation_config = {
|
||||
"temperature": request.temperature,
|
||||
"top_p": 1,
|
||||
"top_k": 1,
|
||||
"max_output_tokens": request.max_new_tokens,
|
||||
}
|
||||
model = genai.GenerativeModel(
|
||||
model_name=self._model,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
user_prompt, gemini_hist = _transform_to_gemini_messages(request.messages)
|
||||
chat = model.start_chat(history=gemini_hist)
|
||||
response = chat.send_message(user_prompt, stream=True)
|
||||
text = ""
|
||||
for chunk in response:
|
||||
text += chunk.text
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,9 +14,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyModel:
|
||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model_params: ProxyModelParameters,
|
||||
proxy_llm_client: Optional[ProxyLLMClient] = None,
|
||||
) -> None:
|
||||
self._model_params = model_params
|
||||
self._tokenizer = ProxyTokenizerWrapper()
|
||||
self.proxy_llm_client = proxy_llm_client
|
||||
|
||||
def get_params(self) -> ProxyModelParameters:
|
||||
return self._model_params
|
||||
|
@@ -2,15 +2,19 @@ import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import List
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from websockets.sync.client import connect
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
SPARK_DEFAULT_API_VERSION = "v3"
|
||||
@@ -34,63 +38,21 @@ def checklen(text):
|
||||
def spark_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
model_params = model.get_params()
|
||||
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
proxy_app_id = model_params.proxy_api_app_id
|
||||
|
||||
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||
url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
else:
|
||||
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
domain = "generalv2"
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
last_user_input = None
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
print(f"index: {index}")
|
||||
if messages[index].role == ModelMessageRoleType.HUMAN:
|
||||
last_user_input = {"role": "user", "content": messages[index].content}
|
||||
del messages[index]
|
||||
break
|
||||
|
||||
# TODO: Support convert_to_compatible_format config
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history = []
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
# There is no role for system in spark LLM
|
||||
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
question = checklen(history + [last_user_input])
|
||||
|
||||
print('last_user_input.get("content")', last_user_input.get("content"))
|
||||
data = {
|
||||
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": context_len,
|
||||
"auditing": "default",
|
||||
"temperature": params.get("temperature"),
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": question}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
||||
request_url = spark_api.gen_url()
|
||||
return get_response(request_url, data)
|
||||
client: SparkLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(
|
||||
stream=True,
|
||||
user_name=params.get("user_name"),
|
||||
request_id=params.get("request_id"),
|
||||
)
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
def get_response(request_url, data):
|
||||
@@ -107,8 +69,8 @@ def get_response(request_url, data):
|
||||
result += text[0]["content"]
|
||||
if choices.get("status") == 2:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
yield result
|
||||
|
||||
|
||||
@@ -155,3 +117,103 @@ class SparkAPI:
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
class SparkLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_secret: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_domain: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
model_alias: Optional[str] = "spark_proxyllm",
|
||||
context_length: Optional[int] = 4096,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
if not model_version:
|
||||
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||
if not api_base:
|
||||
if model_version == SPARK_DEFAULT_API_VERSION:
|
||||
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
else:
|
||||
api_base = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
domain = "generalv2"
|
||||
if not api_domain:
|
||||
api_domain = domain
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
self._model_version = model_version
|
||||
self._api_base = api_base
|
||||
self._domain = api_domain
|
||||
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
|
||||
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
|
||||
|
||||
if not self._app_id:
|
||||
raise ValueError("app_id can't be empty")
|
||||
if not self._api_key:
|
||||
raise ValueError("api_key can't be empty")
|
||||
if not self._api_secret:
|
||||
raise ValueError("api_secret can't be empty")
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "SparkLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
app_id=model_params.proxy_api_app_id,
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_secret=model_params.proxy_api_secret,
|
||||
api_base=model_params.proxy_api_base,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages = request.to_common_messages(support_system_role=False)
|
||||
request_id = request.context.request_id or "1"
|
||||
data = {
|
||||
"header": {"app_id": self._app_id, "uid": request_id},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": self._domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": request.max_new_tokens,
|
||||
"auditing": "default",
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": messages}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(
|
||||
self._app_id, self._api_key, self._api_secret, self._api_base
|
||||
)
|
||||
request_url = spark_api.gen_url()
|
||||
try:
|
||||
for text in get_response(request_url, data):
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
@@ -1,79 +1,109 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def __convert_2_tongyi_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
tongyi_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
tongyi_messages.append({"role": "user", "content": last_usr_message})
|
||||
tongyi_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) < 2:
|
||||
tongyi_messages.insert(0, {"role": "system", "content": system_messages[0]})
|
||||
tongyi_messages.append({"role": "user", "content": last_usr_message})
|
||||
else:
|
||||
tongyi_messages.append({"role": "user", "content": system_messages[1]})
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
tongyi_messages.append({"role": "user", "content": last_message.content})
|
||||
|
||||
return tongyi_messages
|
||||
|
||||
|
||||
def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
dashscope.api_key = proxy_api_key
|
||||
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
if convert_to_compatible_format:
|
||||
history = __convert_2_tongyi_messages(messages)
|
||||
else:
|
||||
history = ModelMessage.to_openai_messages(messages)
|
||||
gen = Generation()
|
||||
res = gen.call(
|
||||
proxyllm_backend,
|
||||
messages=history,
|
||||
top_p=params.get("top_p", 0.8),
|
||||
stream=True,
|
||||
result_format="message",
|
||||
client: TongyiLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
for r in res:
|
||||
if r:
|
||||
if r["status_code"] == 200:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield content
|
||||
else:
|
||||
content = r["code"] + ":" + r["message"]
|
||||
yield content
|
||||
|
||||
class TongyiLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_region: Optional[str] = None,
|
||||
model_alias: Optional[str] = "tongyi_proxyllm",
|
||||
context_length: Optional[int] = 4096,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
try:
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: dashscope "
|
||||
"Please install dashscope by command `pip install dashscope"
|
||||
) from exc
|
||||
if not model:
|
||||
model = Generation.Models.qwen_turbo
|
||||
if api_key:
|
||||
dashscope.api_key = api_key
|
||||
if api_region:
|
||||
dashscope.api_region = api_region
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "TongyiLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
api_key=model_params.proxy_api_key,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
from dashscope import Generation
|
||||
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
|
||||
messages = request.to_common_messages()
|
||||
|
||||
model = request.model or self._model
|
||||
try:
|
||||
gen = Generation()
|
||||
res = gen.call(
|
||||
model,
|
||||
messages=messages,
|
||||
top_p=0.8,
|
||||
stream=True,
|
||||
result_format="message",
|
||||
)
|
||||
for r in res:
|
||||
if r:
|
||||
if r["status_code"] == 200:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield ModelOutput(text=content, error_code=0)
|
||||
else:
|
||||
content = r["code"] + ":" + r["message"]
|
||||
yield ModelOutput(text=content, error_code=-1)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
@@ -1,12 +1,36 @@
|
||||
import json
|
||||
from typing import List
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from cachetools import TTLCache, cached
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
MODEL_VERSION_MAPPING = {
|
||||
"ERNIE-Bot-4.0": "completions_pro",
|
||||
"ERNIE-Bot-8K": "ernie_bot_8k",
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
}
|
||||
|
||||
_DEFAULT_MODEL = "ERNIE-Bot"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800))
|
||||
def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
@@ -49,94 +73,128 @@ def _to_wenxin_messages(messages: List[ModelMessage]):
|
||||
return wenxin_messages, str_system_message
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||
return wenxin_messages, str_system_message
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
MODEL_VERSION = {
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
}
|
||||
client: WenxinLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
model_params = model.get_params()
|
||||
model_name = model_params.proxyllm_backend
|
||||
model_version = MODEL_VERSION.get(model_name)
|
||||
if not model_version:
|
||||
yield f"Unsupport model version {model_name}"
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
|
||||
|
||||
if not access_token:
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
if convert_to_compatible_format:
|
||||
history, system_message = __convert_2_wenxin_messages(messages)
|
||||
else:
|
||||
history, system_message = _to_wenxin_messages(messages)
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system_message,
|
||||
"temperature": params.get("temperature"),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
text = ""
|
||||
res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True)
|
||||
print(f"Send request to {proxy_server_url} with real model {model_name}")
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
class WenxinLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_secret: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
model_alias: Optional[str] = "wenxin_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
if not model:
|
||||
model = _DEFAULT_MODEL
|
||||
if not api_key:
|
||||
api_key = os.getenv("WEN_XIN_API_KEY")
|
||||
if not api_secret:
|
||||
api_secret = os.getenv("WEN_XIN_API_SECRET")
|
||||
if not model_version:
|
||||
if model:
|
||||
model_version = MODEL_VERSION_MAPPING.get(model)
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["result"] is not None:
|
||||
content = obj["result"]
|
||||
text += content
|
||||
yield text
|
||||
model_version = os.getenv("WEN_XIN_MODEL_VERSION")
|
||||
if not api_key:
|
||||
raise ValueError("api_key can't be empty")
|
||||
if not api_secret:
|
||||
raise ValueError("api_secret can't be empty")
|
||||
if not model_version:
|
||||
raise ValueError("model_version can't be empty")
|
||||
self._model = model
|
||||
self._api_key = api_key
|
||||
self._api_secret = api_secret
|
||||
self._model_version = model_version
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "WenxinLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_secret=model_params.proxy_api_secret,
|
||||
model_version=model_params.proxy_api_version,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
|
||||
try:
|
||||
access_token = _build_access_token(self._api_key, self._api_secret)
|
||||
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{self._model_version}?access_token={access_token}"
|
||||
|
||||
if not access_token:
|
||||
raise RuntimeError(
|
||||
"Failed to get access token. please set the correct api_key and secret key."
|
||||
)
|
||||
|
||||
history, system_message = _to_wenxin_messages(request.get_messages())
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system_message,
|
||||
"temperature": request.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
text = ""
|
||||
res = requests.post(
|
||||
proxy_server_url, headers=headers, json=payload, stream=True
|
||||
)
|
||||
logger.info(
|
||||
f"Send request to {proxy_server_url} with real model {self._model}, model version {self._model_version}"
|
||||
)
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield ModelOutput(text=error_message, error_code=1)
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["result"] is not None:
|
||||
content = obj["result"]
|
||||
text += content
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
@@ -1,46 +1,14 @@
|
||||
from typing import List
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
|
||||
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
last_usr_message = message.content
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
last_ai_message = message.content
|
||||
wenxin_messages.append({"role": "user", "content": last_usr_message})
|
||||
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
|
||||
|
||||
# build last user messge
|
||||
|
||||
if len(system_messages) > 0:
|
||||
if len(system_messages) > 1:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
if last_message.role == ModelMessageRoleType.HUMAN:
|
||||
end_message = system_messages[-1] + "\n" + last_message.content
|
||||
else:
|
||||
end_message = system_messages[-1]
|
||||
else:
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@@ -48,27 +16,93 @@ def zhipu_generate_stream(
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
# TODO proxy model use unified config?
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||
|
||||
import zhipuai
|
||||
|
||||
zhipuai.api_key = proxy_api_key
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
# TODO: Support convert_to_compatible_format config, zhipu not support system message
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history, systems = __convert_2_zhipu_messages(messages)
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
prompt=history,
|
||||
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
# history, systems = __convert_2_zhipu_messages(messages)
|
||||
client: ZhipuLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
for r in res.events():
|
||||
if r.event == "add":
|
||||
yield r.data
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
class ZhipuLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_alias: Optional[str] = "zhipu_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
try:
|
||||
import zhipuai
|
||||
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: zhipuai "
|
||||
"Please install dashscope by command `pip install zhipuai"
|
||||
) from exc
|
||||
if not model:
|
||||
model = CHATGLM_DEFAULT_MODEL
|
||||
if api_key:
|
||||
zhipuai.api_key = api_key
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "ZhipuLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
api_key=model_params.proxy_api_key,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
import zhipuai
|
||||
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
|
||||
messages = request.to_common_messages(support_system_role=False)
|
||||
|
||||
model = request.model or self._model
|
||||
try:
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=model,
|
||||
prompt=messages,
|
||||
temperature=request.temperature,
|
||||
# top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
)
|
||||
for r in res.events():
|
||||
if r.event == "add":
|
||||
yield ModelOutput(text=r.data, error_code=0)
|
||||
elif r.event == "error":
|
||||
yield ModelOutput(text=r.data, error_code=1)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
Reference in New Issue
Block a user