feat(model): Support claude proxy models (#2155)

This commit is contained in:
Fangyin Cheng
2024-11-26 19:47:28 +08:00
committed by GitHub
parent 9d8673a02f
commit 61509dc5ea
20 changed files with 508 additions and 157 deletions

View File

@@ -97,6 +97,26 @@ class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter):
return chatgpt_generate_stream
class ClaudeProxyLLMModelAdapter(ProxyLLMModelAdapter):
def support_async(self) -> bool:
return True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "claude_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
return ClaudeLLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.claude import claude_generate_stream
return claude_generate_stream
class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "tongyi_proxyllm"
@@ -320,6 +340,7 @@ class DeepseekProxyLLMModelAdapter(ProxyLLMModelAdapter):
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(ClaudeProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(OllamaLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, Dict, Iterator, List, Optional
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelOutput
@@ -113,7 +113,9 @@ class WorkerManager(ABC):
"""Shutdown model instance"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod

View File

@@ -9,7 +9,7 @@ import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Iterator
from typing import AsyncIterator, Awaitable, Callable, Iterator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
@@ -327,7 +327,7 @@ class LocalWorkerManager(WorkerManager):
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
with root_tracer.start_span(
"WorkerManager.generate_stream", params.get("span_id")
@@ -693,7 +693,9 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output

View File

@@ -1,9 +1,25 @@
"""Proxy models."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.spark import SparkLLMClient
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
from dbgpt.model.proxy.llms.yi import YiLLMClient
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient
def __lazy_import(name):
module_path = {
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
@@ -28,6 +44,7 @@ def __getattr__(name):
__all__ = [
"OpenAILLMClient",
"ClaudeLLMClient",
"GeminiLLMClient",
"TongyiLLMClient",
"ZhipuLLMClient",

View File

@@ -34,6 +34,25 @@ class ProxyTokenizer(ABC):
List[int]: token count, -1 if failed
"""
def support_async(self) -> bool:
"""Check if the tokenizer supports asynchronous counting token.
Returns:
bool: True if supports, False otherwise
"""
return False
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given prompts asynchronously.
Args:
model_name (str): model name
prompts (List[str]): prompts to count token
Returns:
List[int]: token count, -1 if failed
"""
raise NotImplementedError()
class TiktokenProxyTokenizer(ProxyTokenizer):
def __init__(self):
@@ -92,7 +111,7 @@ class ProxyLLMClient(LLMClient):
self.model_names = model_names
self.context_length = context_length
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
self._proxy_tokenizer = proxy_tokenizer
def __getstate__(self):
"""Customize the serialization of the object"""
@@ -105,6 +124,17 @@ class ProxyLLMClient(LLMClient):
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
"""Get proxy tokenizer
Returns:
ProxyTokenizer: proxy tokenizer
"""
if not self._proxy_tokenizer:
self._proxy_tokenizer = TiktokenProxyTokenizer()
return self._proxy_tokenizer
@classmethod
@abstractmethod
def new_client(
@@ -257,6 +287,9 @@ class ProxyLLMClient(LLMClient):
Returns:
int: token count, -1 if failed
"""
if self.proxy_tokenizer.support_async():
cnts = await self.proxy_tokenizer.count_token_async(model, [prompt])
return cnts[0]
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)

View File

@@ -5,17 +5,11 @@ import logging
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt.core import (
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
from dbgpt.util.i18n_utils import _
@@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: OpenAILLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
@@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient):
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(

View File

@@ -1,7 +1,275 @@
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
import logging
import os
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import (
ProxyLLMClient,
ProxyTokenizer,
TiktokenProxyTokenizer,
)
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
if TYPE_CHECKING:
from anthropic import AsyncAnthropic, ProxiesTypes
logger = logging.getLogger(__name__)
def claude_generate_stream(
async def claude_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "claude LLM was not supported!"
) -> AsyncIterator[ModelOutput]:
client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
class ClaudeLLMClient(ProxyLLMClient):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "claude_proxyllm",
context_length: Optional[int] = 8192,
client: Optional["AsyncAnthropic"] = None,
claude_kwargs: Optional[Dict[str, Any]] = None,
proxy_tokenizer: Optional[ProxyTokenizer] = None,
):
try:
import anthropic
except ImportError as exc:
raise ValueError(
"Could not import python package: anthropic "
"Please install anthropic by command `pip install anthropic"
) from exc
if not model:
model = "claude-3-5-sonnet-20241022"
self._client = client
self._model = model
self._api_key = api_key
self._api_base = api_base or os.environ.get(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)
self._proxies = proxies
self._timeout = timeout
self._claude_kwargs = claude_kwargs or {}
self._model_alias = model_alias
self._proxy_tokenizer = proxy_tokenizer
super().__init__(
model_names=[model_alias],
context_length=context_length,
proxy_tokenizer=proxy_tokenizer,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ClaudeProxyLLMClient":
return cls(
api_key=model_params.proxy_api_key,
api_base=model_params.proxy_api_base,
# api_type=model_params.proxy_api_type,
# api_version=model_params.proxy_api_version,
model=model_params.proxyllm_backend,
proxies=model_params.http_proxy,
model_alias=model_params.model_name,
context_length=max(model_params.max_context_size, 8192),
)
@property
def client(self) -> "AsyncAnthropic":
from anthropic import AsyncAnthropic
if self._client is None:
self._client = AsyncAnthropic(
api_key=self._api_key,
base_url=self._api_base,
proxies=self._proxies,
timeout=self._timeout,
)
return self._client
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
if not self._proxy_tokenizer:
self._proxy_tokenizer = ClaudeProxyTokenizer(self.client)
return self._proxy_tokenizer
@property
def default_model(self) -> str:
"""Default model name"""
model = self._model
if not model:
model = "claude-3-5-sonnet-20241022"
return model
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"stream": stream}
model = request.model or self.default_model
payload["model"] = model
# Apply claude kwargs
for k, v in self._claude_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
response = await self.client.messages.create(
max_tokens=max_tokens,
messages=messages,
**payload,
)
usage = None
finish_reason = response.stop_reason
if response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
}
response_content = response.content
if not response_content:
raise ValueError("Response content is empty")
return ModelOutput(
text=response_content[0].text,
error_code=0,
finish_reason=finish_reason,
usage=usage,
)
except Exception as e:
return ModelOutput(
text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request, stream=True)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
if "stream" in payload:
del payload["stream"]
full_text = ""
async with self.client.messages.stream(
max_tokens=max_tokens,
messages=messages,
**payload,
) as stream:
async for text in stream.text_stream:
full_text += text
usage = {
"prompt_tokens": stream.current_message_snapshot.usage.input_tokens,
"completion_tokens": stream.current_message_snapshot.usage.output_tokens,
}
yield ModelOutput(text=full_text, error_code=0, usage=usage)
except Exception as e:
yield ModelOutput(
text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self.context_length
class ClaudeProxyTokenizer(ProxyTokenizer):
def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10):
self.client = client
self.concurrency_limit = concurrency_limit
self._tiktoken_tokenizer = TiktokenProxyTokenizer()
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
# Use tiktoken to count token in local environment
return self._tiktoken_tokenizer.count_token(model_name, prompts)
def support_async(self) -> bool:
return True
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given messages.
This is relying on the claude beta API, which is not available for some users.
"""
from dbgpt.util.chat_util import run_async_tasks
tasks = []
model_name = model_name or "claude-3-5-sonnet-20241022"
for prompt in prompts:
request = ModelRequest(
model=model_name, messages=[{"role": "user", "content": prompt}]
)
tasks.append(
self.client.beta.messages.count_tokens(
model=model_name,
messages=request.messages,
)
)
results = await run_async_tasks(tasks, self.concurrency_limit)
return results

View File

@@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -20,15 +19,7 @@ async def deepseek_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: DeepseekLLMClient = cast(DeepseekLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,17 +2,11 @@ import os
from concurrent.futures import Executor
from typing import Any, Dict, Iterator, List, Optional, Tuple
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMessage, ModelOutput, ModelRequest
from dbgpt.core.interface.message import parse_model_messages
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
GEMINI_DEFAULT_MODEL = "gemini-pro"
@@ -39,15 +33,7 @@ def gemini_generate_stream(
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
client: GeminiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -1,8 +1,8 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core import ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -19,15 +19,7 @@ async def moonshot_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@@ -14,14 +14,7 @@ def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
@@ -41,3 +42,30 @@ class ProxyModel:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)
def parse_model_request(
params: Dict[str, Any], default_model: str, stream: bool = True
) -> ModelRequest:
"""Parse model request from params.
Args:
params (Dict[str, Any]): request params
default_model (str): default model name
stream (bool, optional): whether stream. Defaults to True.
"""
context = ModelRequestContext(
stream=stream,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
top_p=params.get("top_p"),
)
return request

View File

@@ -1,12 +1,12 @@
import json
import os
from concurrent.futures import Executor
from typing import AsyncIterator, Optional
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
def getlength(text):
@@ -28,20 +28,8 @@ def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: SparkLLMClient = model.proxy_llm_client
context = ModelRequestContext(
stream=True,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.generate_stream(request):
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r
@@ -141,11 +129,11 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str:
return self._model
def generate_stream(
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
) -> Iterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E

View File

@@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@@ -14,15 +14,7 @@ def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: TongyiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@@ -2,19 +2,12 @@ import json
import logging
import os
from concurrent.futures import Executor
from typing import Iterator, List, Optional
from typing import Iterator, Optional
import requests
from cachetools import TTLCache, cached
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
@@ -51,26 +44,16 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token")
def _to_wenxin_messages(messages: List[ModelMessage]):
def _to_wenxin_messages(request: ModelRequest):
"""Convert messages to wenxin compatible format
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
wenxin_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
wenxin_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
wenxin_messages.append({"role": "assistant", "content": message.content})
else:
pass
messages, system_messages = request.split_messages()
if len(system_messages) > 1:
raise ValueError("Wenxin only support one system message")
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
return messages, str_system_message
def wenxin_generate_stream(
@@ -167,7 +150,7 @@ class WenxinLLMClient(ProxyLLMClient):
"Failed to get access token. please set the correct api_key and secret key."
)
history, system_message = _to_wenxin_messages(request.get_messages())
history, system_message = _to_wenxin_messages(request)
payload = {
"messages": history,
"system": system_message,

View File

@@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@@ -19,15 +18,7 @@ async def yi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: YiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@@ -2,10 +2,10 @@ import os
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
@@ -21,15 +21,7 @@ def zhipu_generate_stream(
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
# history, systems = __convert_2_zhipu_messages(messages)
client: ZhipuLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r