feat(model): Support claude proxy models (#2155)

This commit is contained in:
Fangyin Cheng
2024-11-26 19:47:28 +08:00
committed by GitHub
parent 9d8673a02f
commit 61509dc5ea
20 changed files with 508 additions and 157 deletions

View File

@@ -5,17 +5,11 @@ import logging
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt.core import (
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
from dbgpt.util.i18n_utils import _
@@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: OpenAILLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
@@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient):
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(

View File

@@ -1,7 +1,275 @@
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
import logging
import os
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import (
ProxyLLMClient,
ProxyTokenizer,
TiktokenProxyTokenizer,
)
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
if TYPE_CHECKING:
from anthropic import AsyncAnthropic, ProxiesTypes
logger = logging.getLogger(__name__)
def claude_generate_stream(
async def claude_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "claude LLM was not supported!"
) -> AsyncIterator[ModelOutput]:
client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
class ClaudeLLMClient(ProxyLLMClient):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "claude_proxyllm",
context_length: Optional[int] = 8192,
client: Optional["AsyncAnthropic"] = None,
claude_kwargs: Optional[Dict[str, Any]] = None,
proxy_tokenizer: Optional[ProxyTokenizer] = None,
):
try:
import anthropic
except ImportError as exc:
raise ValueError(
"Could not import python package: anthropic "
"Please install anthropic by command `pip install anthropic"
) from exc
if not model:
model = "claude-3-5-sonnet-20241022"
self._client = client
self._model = model
self._api_key = api_key
self._api_base = api_base or os.environ.get(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)
self._proxies = proxies
self._timeout = timeout
self._claude_kwargs = claude_kwargs or {}
self._model_alias = model_alias
self._proxy_tokenizer = proxy_tokenizer
super().__init__(
model_names=[model_alias],
context_length=context_length,
proxy_tokenizer=proxy_tokenizer,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ClaudeProxyLLMClient":
return cls(
api_key=model_params.proxy_api_key,
api_base=model_params.proxy_api_base,
# api_type=model_params.proxy_api_type,
# api_version=model_params.proxy_api_version,
model=model_params.proxyllm_backend,
proxies=model_params.http_proxy,
model_alias=model_params.model_name,
context_length=max(model_params.max_context_size, 8192),
)
@property
def client(self) -> "AsyncAnthropic":
from anthropic import AsyncAnthropic
if self._client is None:
self._client = AsyncAnthropic(
api_key=self._api_key,
base_url=self._api_base,
proxies=self._proxies,
timeout=self._timeout,
)
return self._client
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
if not self._proxy_tokenizer:
self._proxy_tokenizer = ClaudeProxyTokenizer(self.client)
return self._proxy_tokenizer
@property
def default_model(self) -> str:
"""Default model name"""
model = self._model
if not model:
model = "claude-3-5-sonnet-20241022"
return model
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"stream": stream}
model = request.model or self.default_model
payload["model"] = model
# Apply claude kwargs
for k, v in self._claude_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
response = await self.client.messages.create(
max_tokens=max_tokens,
messages=messages,
**payload,
)
usage = None
finish_reason = response.stop_reason
if response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
}
response_content = response.content
if not response_content:
raise ValueError("Response content is empty")
return ModelOutput(
text=response_content[0].text,
error_code=0,
finish_reason=finish_reason,
usage=usage,
)
except Exception as e:
return ModelOutput(
text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request, stream=True)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
if "stream" in payload:
del payload["stream"]
full_text = ""
async with self.client.messages.stream(
max_tokens=max_tokens,
messages=messages,
**payload,
) as stream:
async for text in stream.text_stream:
full_text += text
usage = {
"prompt_tokens": stream.current_message_snapshot.usage.input_tokens,
"completion_tokens": stream.current_message_snapshot.usage.output_tokens,
}
yield ModelOutput(text=full_text, error_code=0, usage=usage)
except Exception as e:
yield ModelOutput(
text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self.context_length
class ClaudeProxyTokenizer(ProxyTokenizer):
def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10):
self.client = client
self.concurrency_limit = concurrency_limit
self._tiktoken_tokenizer = TiktokenProxyTokenizer()
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
# Use tiktoken to count token in local environment
return self._tiktoken_tokenizer.count_token(model_name, prompts)
def support_async(self) -> bool:
return True
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given messages.
This is relying on the claude beta API, which is not available for some users.
"""
from dbgpt.util.chat_util import run_async_tasks
tasks = []
model_name = model_name or "claude-3-5-sonnet-20241022"
for prompt in prompts:
request = ModelRequest(
model=model_name, messages=[{"role": "user", "content": prompt}]
)
tasks.append(
self.client.beta.messages.count_tokens(
model=model_name,
messages=request.messages,
)
)
results = await run_async_tasks(tasks, self.concurrency_limit)
return results

View File

@@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -20,15 +19,7 @@ async def deepseek_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: DeepseekLLMClient = cast(DeepseekLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,17 +2,11 @@ import os
from concurrent.futures import Executor
from typing import Any, Dict, Iterator, List, Optional, Tuple
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMessage, ModelOutput, ModelRequest
from dbgpt.core.interface.message import parse_model_messages
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
GEMINI_DEFAULT_MODEL = "gemini-pro"
@@ -39,15 +33,7 @@ def gemini_generate_stream(
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
client: GeminiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -1,8 +1,8 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core import ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -19,15 +19,7 @@ async def moonshot_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@@ -14,14 +14,7 @@ def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
@@ -41,3 +42,30 @@ class ProxyModel:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)
def parse_model_request(
params: Dict[str, Any], default_model: str, stream: bool = True
) -> ModelRequest:
"""Parse model request from params.
Args:
params (Dict[str, Any]): request params
default_model (str): default model name
stream (bool, optional): whether stream. Defaults to True.
"""
context = ModelRequestContext(
stream=stream,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
top_p=params.get("top_p"),
)
return request

View File

@@ -1,12 +1,12 @@
import json
import os
from concurrent.futures import Executor
from typing import AsyncIterator, Optional
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
def getlength(text):
@@ -28,20 +28,8 @@ def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: SparkLLMClient = model.proxy_llm_client
context = ModelRequestContext(
stream=True,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.generate_stream(request):
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r
@@ -141,11 +129,11 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str:
return self._model
def generate_stream(
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
) -> Iterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E

View File

@@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@@ -14,15 +14,7 @@ def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: TongyiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -2,19 +2,12 @@ import json
import logging
import os
from concurrent.futures import Executor
from typing import Iterator, List, Optional
from typing import Iterator, Optional
import requests
from cachetools import TTLCache, cached
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
@@ -51,26 +44,16 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token")
def _to_wenxin_messages(messages: List[ModelMessage]):
def _to_wenxin_messages(request: ModelRequest):
"""Convert messages to wenxin compatible format
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
wenxin_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
wenxin_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
wenxin_messages.append({"role": "assistant", "content": message.content})
else:
pass
messages, system_messages = request.split_messages()
if len(system_messages) > 1:
raise ValueError("Wenxin only support one system message")
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
return messages, str_system_message
def wenxin_generate_stream(
@@ -167,7 +150,7 @@ class WenxinLLMClient(ProxyLLMClient):
"Failed to get access token. please set the correct api_key and secret key."
)
history, system_message = _to_wenxin_messages(request.get_messages())
history, system_message = _to_wenxin_messages(request)
payload = {
"messages": history,
"system": system_message,

View File

@@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -19,15 +18,7 @@ async def yi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: YiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,10 +2,10 @@ import os
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
@@ -21,15 +21,7 @@ def zhipu_generate_stream(
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
# history, systems = __convert_2_zhipu_messages(messages)
client: ZhipuLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r