refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -13,6 +13,7 @@ BAICHUAN_DEFAULT_MODEL = "Baichuan2-Turbo-192k"
def baichuan_generate_stream(
model: ProxyModel, tokenizer=None, params=None, device=None, context_len=4096
):
# TODO: Support new Baichuan ProxyLLMClient
url = "https://api.baichuan-ai.com/v1/chat/completions"
model_params = model.get_params()

View File

@@ -9,6 +9,7 @@ from dbgpt.model.proxy.llms.proxy_model import ProxyModel
def bard_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
# TODO: Support new bard ProxyLLMClient
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")

View File

@@ -1,259 +1,231 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import importlib.metadata as metadata
import logging
import os
from typing import List
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
import httpx
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import (
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
logger = logging.getLogger(__name__)
def _initialize_openai(params: ProxyModelParameters):
try:
import openai
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai` "
) from exc
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
api_base = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
if not api_base and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
api_base = params.proxy_server_url.split("/chat/completions")[0]
if api_type:
openai.api_type = api_type
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
if api_version:
openai.api_version = api_version
if params.http_proxy:
openai.proxy = params.http_proxy
openai_params = {
"api_type": api_type,
"api_base": api_base,
"api_version": api_version,
"proxy": params.http_proxy,
}
return openai_params
def _initialize_openai_v1(params: ProxyModelParameters):
try:
from openai import OpenAI
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
)
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
base_url = params.proxy_api_base or os.getenv(
"OPENAI_API_BASE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
if not base_url and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
base_url = params.proxy_server_url.split("/chat/completions")[0]
proxies = params.http_proxy
openai_params = {
"api_key": api_key,
"base_url": base_url,
}
return openai_params, api_type, api_version, proxies
def __convert_2_gpt_messages(messages: List[ModelMessage]):
gpt_messages = []
last_usr_message = ""
system_messages = []
# TODO: We can't change message order in low level
for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI or message.role == "assistant":
last_ai_message = message.content
gpt_messages.append({"role": "user", "content": last_usr_message})
gpt_messages.append({"role": "assistant", "content": last_ai_message})
if len(system_messages) > 0:
if len(system_messages) < 2:
gpt_messages.insert(0, {"role": "system", "content": system_messages[0]})
gpt_messages.append({"role": "user", "content": last_usr_message})
else:
gpt_messages.append({"role": "user", "content": system_messages[1]})
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
gpt_messages.append({"role": "user", "content": last_message.content})
return gpt_messages
def _build_request(model: ProxyModel, params):
model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}")
messages: List[ModelMessage] = params["messages"]
# history = __convert_2_gpt_messages(messages)
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = ModelMessage.to_openai_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
payloads = {
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True,
}
proxyllm_backend = model_params.proxyllm_backend
if metadata.version("openai") >= "1.0.0":
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
else:
openai_params = _initialize_openai(model_params)
if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend
logger.info(f"Send request to real model {proxyllm_backend}")
return history, payloads
def chatgpt_generate_stream(
async def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
client: OpenAILLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
async for r in client.generate_stream(request):
yield r
class OpenAILLMClient(ProxyLLMClient):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "chatgpt_proxyllm",
context_length: Optional[int] = 8192,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
try:
import openai
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai"
) from exc
self._openai_version = metadata.version("openai")
self._openai_less_then_v1 = not self._openai_version >= "1.0.0"
self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
api_key=api_key,
api_version=api_version,
proxies=proxies,
full_url=kwargs.get("full_url"),
)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AzureOpenAI
client = AzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params["base_url"],
http_client=httpx.Client(proxies=proxies),
self._model = model
self._proxies = proxies
self._timeout = timeout
self._model_alias = model_alias
self._context_length = context_length
self._api_type = api_type
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
super().__init__(model_names=[model_alias], context_length=context_length)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "OpenAILLMClient":
return cls(
api_key=model_params.proxy_api_key,
api_base=model_params.proxy_api_base,
api_type=model_params.proxy_api_type,
api_version=model_params.proxy_api_version,
model=model_params.proxyllm_backend,
proxies=model_params.http_proxy,
model_alias=model_params.model_name,
context_length=max(model_params.max_context_size, 8192),
full_url=model_params.proxy_server_url,
)
@property
def client(self) -> ClientType:
if self._openai_less_then_v1:
raise ValueError(
"Current model (Load by OpenAILLMClient) require openai.__version__>=1.0.0"
)
else:
from openai import OpenAI
if self._client is None:
from dbgpt.model.utils.chatgpt_utils import _build_openai_client
client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
res = client.chat.completions.create(messages=history, **payloads)
self._api_type, self._client = _build_openai_client(
init_params=self._init_params
)
return self._client
@property
def default_model(self) -> str:
model = self._model
if not model:
model = "gpt-35-turbo" if self._api_type == "azure" else "gpt-3.5-turbo"
return model
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"stream": stream}
model = request.model or self.default_model
if self._openai_less_then_v1 and self._api_type == "azure":
payload["engine"] = model
else:
payload["model"] = model
# Apply openai kwargs
for k, v in self._openai_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
return payload
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
payload = self._build_request(request)
logger.info(
f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}"
)
try:
if self._openai_less_then_v1:
return await self.generate_less_then_v1(messages, payload)
else:
return await self.generate_v1(messages, payload)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
payload = self._build_request(request, stream=True)
logger.info(
f"Send request to openai({self._openai_version}), payload: {payload}\n\n messages:\n{messages}"
)
if self._openai_less_then_v1:
async for r in self.generate_stream_less_then_v1(messages, payload):
yield r
else:
async for r in self.generate_stream_v1(messages, payload):
yield r
async def generate_v1(
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
) -> ModelOutput:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
async def generate_less_then_v1(
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
) -> ModelOutput:
import openai
chat_completion = await openai.ChatCompletion.acreate(
messages=messages, **payload
)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.to_dict()
return ModelOutput(text=text, error_code=0, usage=usage)
async def generate_stream_v1(
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
) -> AsyncIterator[ModelOutput]:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
for r in res:
# logger.info(str(r))
# Azure Openai reponse may have empty choices body in the first chunk
# to avoid index out of range error
async for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield text
yield ModelOutput(text=text, error_code=0)
else:
async def generate_stream_less_then_v1(
self, messages: List[Dict[str, Any]], payload: Dict[str, Any]
) -> AsyncIterator[ModelOutput]:
import openai
history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads)
text = ""
for r in res:
if len(r.choices) == 0:
continue
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text
async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
if metadata.version("openai") >= "1.0.0":
model_params = model.get_params()
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
history, payloads = _build_request(model, params)
if api_type == "azure":
from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params["base_url"],
http_client=httpx.AsyncClient(proxies=proxies),
)
else:
from openai import AsyncOpenAI
client = AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
)
res = await client.chat.completions.create(messages=history, **payloads)
text = ""
for r in res:
if not r.get("choices"):
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield text
else:
import openai
history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
res = await openai.ChatCompletion.acreate(messages=messages, **payload)
text = ""
async for r in res:
if not r.get("choices"):
@@ -261,4 +233,21 @@ async def async_chatgpt_generate_stream(
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text
yield ModelOutput(text=text, error_code=0)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self._context_length

View File

@@ -1,72 +1,54 @@
from typing import Any, Dict, List, Tuple
import os
from concurrent.futures import Executor
from typing import Any, Dict, Iterator, List, Optional, Tuple
from dbgpt.core.interface.message import ModelMessage, parse_model_messages
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.interface.message import parse_model_messages
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
GEMINI_DEFAULT_MODEL = "gemini-pro"
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
def gemini_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend
generation_config = {
"temperature": 0.7,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
import google.generativeai as genai
if model_params.proxy_api_base:
from google.api_core import client_options
client_opts = client_options.ClientOptions(
api_endpoint=model_params.proxy_api_base
)
genai.configure(
api_key=proxy_api_key, transport="rest", client_options=client_opts
)
else:
genai.configure(api_key=proxy_api_key)
model = genai.GenerativeModel(
model_name=proxyllm_backend,
generation_config=generation_config,
safety_settings=safety_settings,
client: GeminiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
messages: List[ModelMessage] = params["messages"]
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
chat = model.start_chat(history=gemini_hist)
response = chat.send_message(user_prompt, stream=True)
text = ""
for chunk in response:
text += chunk.text
print(text)
yield text
for r in client.sync_generate_stream(request):
yield r
def _transform_to_gemini_messages(
@@ -97,12 +79,104 @@ def _transform_to_gemini_messages(
{"role": "model", "parts": {"text": "Hi there!"}},
]
"""
# TODO raise error if messages has system message
user_prompt, system_messages, history_messages = parse_model_messages(messages)
if system_messages:
user_prompt = "".join(system_messages) + "\n" + user_prompt
raise ValueError("Gemini does not support system role")
gemini_hist = []
if history_messages:
for user_message, model_message in history_messages:
gemini_hist.append({"role": "user", "parts": {"text": user_message}})
gemini_hist.append({"role": "model", "parts": {"text": model_message}})
return user_prompt, gemini_hist
class GeminiLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_alias: Optional[str] = "gemini_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
try:
import google.generativeai as genai
except ImportError as exc:
raise ValueError(
"Could not import python package: generativeai "
"Please install dashscope by command `pip install google-generativeai"
) from exc
if not model:
model = GEMINI_DEFAULT_MODEL
self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY")
self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE")
self._model = model
self.default_model = self._model
if not self._api_key:
raise RuntimeError("api_key can't be empty")
if self._api_base:
from google.api_core import client_options
client_opts = client_options.ClientOptions(api_endpoint=self._api_base)
genai.configure(
api_key=self._api_key, transport="rest", client_options=client_opts
)
else:
genai.configure(api_key=self._api_key)
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "GeminiLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
try:
import google.generativeai as genai
generation_config = {
"temperature": request.temperature,
"top_p": 1,
"top_k": 1,
"max_output_tokens": request.max_new_tokens,
}
model = genai.GenerativeModel(
model_name=self._model,
generation_config=generation_config,
safety_settings=safety_settings,
)
user_prompt, gemini_hist = _transform_to_gemini_messages(request.messages)
chat = model.start_chat(history=gemini_hist)
response = chat.send_message(user_prompt, stream=True)
text = ""
for chunk in response:
text += chunk.text
yield ModelOutput(text=text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)

View File

@@ -4,6 +4,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional, Union
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
@@ -13,9 +14,14 @@ logger = logging.getLogger(__name__)
class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
def __init__(
self,
model_params: ProxyModelParameters,
proxy_llm_client: Optional[ProxyLLMClient] = None,
) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()
self.proxy_llm_client = proxy_llm_client
def get_params(self) -> ProxyModelParameters:
return self._model_params

View File

@@ -2,15 +2,19 @@ import base64
import hashlib
import hmac
import json
import os
from concurrent.futures import Executor
from datetime import datetime
from time import mktime
from typing import List
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
from websockets.sync.client import connect
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v3"
@@ -34,63 +38,21 @@ def checklen(text):
def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
model_params = model.get_params()
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
proxy_app_id = model_params.proxy_api_app_id
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
messages: List[ModelMessage] = params["messages"]
last_user_input = None
for index in range(len(messages) - 1, -1, -1):
print(f"index: {index}")
if messages[index].role == ModelMessageRoleType.HUMAN:
last_user_input = {"role": "user", "content": messages[index].content}
del messages[index]
break
# TODO: Support convert_to_compatible_format config
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = []
# Add history conversation
for message in messages:
# There is no role for system in spark LLM
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
question = checklen(history + [last_user_input])
print('last_user_input.get("content")', last_user_input.get("content"))
data = {
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": context_len,
"auditing": "default",
"temperature": params.get("temperature"),
}
},
"payload": {"message": {"text": question}},
}
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
return get_response(request_url, data)
client: SparkLLMClient = model.proxy_llm_client
context = ModelRequestContext(
stream=True,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
def get_response(request_url, data):
@@ -107,8 +69,8 @@ def get_response(request_url, data):
result += text[0]["content"]
if choices.get("status") == 2:
break
except Exception:
break
except Exception as e:
raise e
yield result
@@ -155,3 +117,103 @@ class SparkAPI:
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
class SparkLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
api_base = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
if not api_domain:
api_domain = domain
self._model = model
self.default_model = self._model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "SparkLLMClient":
return cls(
model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
}
spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
try:
for text in get_response(request_url, data):
yield ModelOutput(text=text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)

View File

@@ -1,79 +1,109 @@
import logging
from typing import List
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def __convert_2_tongyi_messages(messages: List[ModelMessage]):
chat_round = 0
tongyi_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
tongyi_messages.append({"role": "user", "content": last_usr_message})
tongyi_messages.append({"role": "assistant", "content": last_ai_message})
if len(system_messages) > 0:
if len(system_messages) < 2:
tongyi_messages.insert(0, {"role": "system", "content": system_messages[0]})
tongyi_messages.append({"role": "user", "content": last_usr_message})
else:
tongyi_messages.append({"role": "user", "content": system_messages[1]})
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
tongyi_messages.append({"role": "user", "content": last_message.content})
return tongyi_messages
def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import dashscope
from dashscope import Generation
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
dashscope.api_key = proxy_api_key
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
messages: List[ModelMessage] = params["messages"]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
if convert_to_compatible_format:
history = __convert_2_tongyi_messages(messages)
else:
history = ModelMessage.to_openai_messages(messages)
gen = Generation()
res = gen.call(
proxyllm_backend,
messages=history,
top_p=params.get("top_p", 0.8),
stream=True,
result_format="message",
client: TongyiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
for r in res:
if r:
if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield content
else:
content = r["code"] + ":" + r["message"]
yield content
class TongyiLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_region: Optional[str] = None,
model_alias: Optional[str] = "tongyi_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
try:
import dashscope
from dashscope import Generation
except ImportError as exc:
raise ValueError(
"Could not import python package: dashscope "
"Please install dashscope by command `pip install dashscope"
) from exc
if not model:
model = Generation.Models.qwen_turbo
if api_key:
dashscope.api_key = api_key
if api_region:
dashscope.api_region = api_region
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "TongyiLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
from dashscope import Generation
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
try:
gen = Generation()
res = gen.call(
model,
messages=messages,
top_p=0.8,
stream=True,
result_format="message",
)
for r in res:
if r:
if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield ModelOutput(text=content, error_code=0)
else:
content = r["code"] + ":" + r["message"]
yield ModelOutput(text=content, error_code=-1)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)

View File

@@ -1,12 +1,36 @@
import json
from typing import List
import logging
import os
from concurrent.futures import Executor
from typing import Iterator, List, Optional
import requests
from cachetools import TTLCache, cached
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
MODEL_VERSION_MAPPING = {
"ERNIE-Bot-4.0": "completions_pro",
"ERNIE-Bot-8K": "ernie_bot_8k",
"ERNIE-Bot": "completions",
"ERNIE-Bot-turbo": "eb-instant",
}
_DEFAULT_MODEL = "ERNIE-Bot"
logger = logging.getLogger(__name__)
@cached(TTLCache(1, 1800))
def _build_access_token(api_key: str, secret_key: str) -> str:
@@ -49,94 +73,128 @@ def _to_wenxin_messages(messages: List[ModelMessage]):
return wenxin_messages, str_system_message
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
wenxin_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
wenxin_messages.append({"role": "user", "content": last_usr_message})
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
# build last user messge
if len(system_messages) > 0:
if len(system_messages) > 1:
end_message = system_messages[-1]
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
end_message = system_messages[-1] + "\n" + last_message.content
else:
end_message = system_messages[-1]
else:
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
def wenxin_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
MODEL_VERSION = {
"ERNIE-Bot": "completions",
"ERNIE-Bot-turbo": "eb-instant",
}
client: WenxinLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
model_params = model.get_params()
model_name = model_params.proxyllm_backend
model_version = MODEL_VERSION.get(model_name)
if not model_version:
yield f"Unsupport model version {model_name}"
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
headers = {"Content-Type": "application/json", "Accept": "application/json"}
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
if not access_token:
yield "Failed to get access token. please set the correct api_key and secret key."
messages: List[ModelMessage] = params["messages"]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
if convert_to_compatible_format:
history, system_message = __convert_2_wenxin_messages(messages)
else:
history, system_message = _to_wenxin_messages(messages)
payload = {
"messages": history,
"system": system_message,
"temperature": params.get("temperature"),
"stream": True,
}
text = ""
res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True)
print(f"Send request to {proxy_server_url} with real model {model_name}")
for line in res.iter_lines():
if line:
if not line.startswith(b"data: "):
error_message = line.decode("utf-8")
yield error_message
class WenxinLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "wenxin_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
if not model:
model = _DEFAULT_MODEL
if not api_key:
api_key = os.getenv("WEN_XIN_API_KEY")
if not api_secret:
api_secret = os.getenv("WEN_XIN_API_SECRET")
if not model_version:
if model:
model_version = MODEL_VERSION_MAPPING.get(model)
else:
json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)
if obj["result"] is not None:
content = obj["result"]
text += content
yield text
model_version = os.getenv("WEN_XIN_MODEL_VERSION")
if not api_key:
raise ValueError("api_key can't be empty")
if not api_secret:
raise ValueError("api_secret can't be empty")
if not model_version:
raise ValueError("model_version can't be empty")
self._model = model
self._api_key = api_key
self._api_secret = api_secret
self._model_version = model_version
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "WenxinLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
model_version=model_params.proxy_api_version,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
try:
access_token = _build_access_token(self._api_key, self._api_secret)
headers = {"Content-Type": "application/json", "Accept": "application/json"}
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{self._model_version}?access_token={access_token}"
if not access_token:
raise RuntimeError(
"Failed to get access token. please set the correct api_key and secret key."
)
history, system_message = _to_wenxin_messages(request.get_messages())
payload = {
"messages": history,
"system": system_message,
"temperature": request.temperature,
"stream": True,
}
text = ""
res = requests.post(
proxy_server_url, headers=headers, json=payload, stream=True
)
logger.info(
f"Send request to {proxy_server_url} with real model {self._model}, model version {self._model_version}"
)
for line in res.iter_lines():
if line:
if not line.startswith(b"data: "):
error_message = line.decode("utf-8")
yield ModelOutput(text=error_message, error_code=1)
else:
json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)
if obj["result"] is not None:
content = obj["result"]
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)

View File

@@ -1,46 +1,14 @@
from typing import List
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
wenxin_messages.append({"role": "user", "content": last_usr_message})
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
# build last user messge
if len(system_messages) > 0:
if len(system_messages) > 1:
end_message = system_messages[-1]
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
end_message = system_messages[-1] + "\n" + last_message.content
else:
end_message = system_messages[-1]
else:
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -48,27 +16,93 @@ def zhipu_generate_stream(
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
import zhipuai
zhipuai.api_key = proxy_api_key
messages: List[ModelMessage] = params["messages"]
# TODO: Support convert_to_compatible_format config, zhipu not support system message
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
# history, systems = __convert_2_zhipu_messages(messages)
client: ZhipuLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in res.events():
if r.event == "add":
yield r.data
for r in client.sync_generate_stream(request):
yield r
class ZhipuLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
model_alias: Optional[str] = "zhipu_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
try:
import zhipuai
except ImportError as exc:
raise ValueError(
"Could not import python package: zhipuai "
"Please install dashscope by command `pip install zhipuai"
) from exc
if not model:
model = CHATGLM_DEFAULT_MODEL
if api_key:
zhipuai.api_key = api_key
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ZhipuLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
import zhipuai
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
model = request.model or self._model
try:
res = zhipuai.model_api.sse_invoke(
model=model,
prompt=messages,
temperature=request.temperature,
# top_p=params.get("top_p"),
incremental=False,
)
for r in res.events():
if r.event == "add":
yield ModelOutput(text=r.data, error_code=0)
elif r.event == "error":
yield ModelOutput(text=r.data, error_code=1)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)