feat(model): Support claude proxy models (#2155)

This commit is contained in:
Fangyin Cheng 2024-11-26 19:47:28 +08:00 committed by GitHub
parent 9d8673a02f
commit 61509dc5ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 508 additions and 157 deletions

View File

@ -131,6 +131,15 @@ class Config(metaclass=Singleton):
os.environ["deepseek_proxyllm_api_base"] = os.getenv(
"DEEPSEEK_API_BASE", "https://api.deepseek.com/v1"
)
self.claude_proxy_api_key = os.getenv("ANTHROPIC_API_KEY")
if self.claude_proxy_api_key:
os.environ["claude_proxyllm_proxy_api_key"] = self.claude_proxy_api_key
os.environ["claude_proxyllm_proxyllm_backend"] = os.getenv(
"ANTHROPIC_MODEL_VERSION", "claude-3-5-sonnet-20241022"
)
os.environ["claude_proxyllm_api_base"] = os.getenv(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

View File

@ -6,7 +6,7 @@ import logging
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union
from cachetools import TTLCache
@ -394,6 +394,29 @@ class ModelRequest:
"""
return ModelMessage.messages_to_string(self.get_messages())
def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]:
"""Split the messages.
Returns:
Tuple[List[Dict[str, Any]], List[str]]: The common messages and system
messages.
"""
messages = self.get_messages()
common_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
common_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
common_messages.append(
{"role": "assistant", "content": message.content}
)
else:
pass
return common_messages, system_messages
@dataclass
class ModelExtraMedata(BaseParameters):
@ -861,7 +884,9 @@ class LLMClient(ABC):
raise ValueError(f"Model {model} not found")
return model_metadata
def __call__(self, *args, **kwargs) -> ModelOutput:
def __call__(
self, *args, **kwargs
) -> Coroutine[Any, Any, ModelOutput] | ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
@ -869,22 +894,63 @@ class LLMClient(ABC):
Please do not use this method in the production environment, it is only used
for debugging.
"""
import asyncio
from dbgpt.util import get_or_create_event_loop
try:
# Check if we are in an event loop
loop = asyncio.get_running_loop()
# If we are in an event loop, use async call
if loop.is_running():
# Because we are in an async environment, but this is a sync method,
# we need to return a coroutine object for the caller to use await
return self.async_call(*args, **kwargs)
else:
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))
except RuntimeError:
# If we are not in an event loop, use sync call
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))
async def async_call(self, *args, **kwargs) -> ModelOutput:
"""Return the model output asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
return await self.generate(req)
async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]:
"""Return the model output stream asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
async for output in self.generate_stream(req): # type: ignore
yield output
def _build_call_request(self, *args, **kwargs) -> ModelRequest:
"""Build the model request for the call method."""
messages = kwargs.get("messages")
model = kwargs.get("model")
if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]
if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")
if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))
return ModelRequest.build_request(model, model_messages, **kwargs)

View File

@ -97,6 +97,26 @@ class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter):
return chatgpt_generate_stream
class ClaudeProxyLLMModelAdapter(ProxyLLMModelAdapter):
def support_async(self) -> bool:
return True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "claude_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
return ClaudeLLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.claude import claude_generate_stream
return claude_generate_stream
class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "tongyi_proxyllm"
@ -320,6 +340,7 @@ class DeepseekProxyLLMModelAdapter(ProxyLLMModelAdapter):
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(ClaudeProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(OllamaLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, Dict, Iterator, List, Optional
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelOutput
@ -113,7 +113,9 @@ class WorkerManager(ABC):
"""Shutdown model instance"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod

View File

@ -9,7 +9,7 @@ import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Iterator
from typing import AsyncIterator, Awaitable, Callable, Iterator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
@ -327,7 +327,7 @@ class LocalWorkerManager(WorkerManager):
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""Generate stream result, chat scene"""
with root_tracer.start_span(
"WorkerManager.generate_stream", params.get("span_id")
@ -693,7 +693,9 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
async def generate_stream(
self, params: Dict, **kwargs
) -> AsyncIterator[ModelOutput]:
async for output in self.worker_manager.generate_stream(params, **kwargs):
yield output

View File

@ -1,9 +1,25 @@
"""Proxy models."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.spark import SparkLLMClient
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
from dbgpt.model.proxy.llms.yi import YiLLMClient
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient
def __lazy_import(name):
module_path = {
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
@ -28,6 +44,7 @@ def __getattr__(name):
__all__ = [
"OpenAILLMClient",
"ClaudeLLMClient",
"GeminiLLMClient",
"TongyiLLMClient",
"ZhipuLLMClient",

View File

@ -34,6 +34,25 @@ class ProxyTokenizer(ABC):
List[int]: token count, -1 if failed
"""
def support_async(self) -> bool:
"""Check if the tokenizer supports asynchronous counting token.
Returns:
bool: True if supports, False otherwise
"""
return False
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given prompts asynchronously.
Args:
model_name (str): model name
prompts (List[str]): prompts to count token
Returns:
List[int]: token count, -1 if failed
"""
raise NotImplementedError()
class TiktokenProxyTokenizer(ProxyTokenizer):
def __init__(self):
@ -92,7 +111,7 @@ class ProxyLLMClient(LLMClient):
self.model_names = model_names
self.context_length = context_length
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
self._proxy_tokenizer = proxy_tokenizer
def __getstate__(self):
"""Customize the serialization of the object"""
@ -105,6 +124,17 @@ class ProxyLLMClient(LLMClient):
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
"""Get proxy tokenizer
Returns:
ProxyTokenizer: proxy tokenizer
"""
if not self._proxy_tokenizer:
self._proxy_tokenizer = TiktokenProxyTokenizer()
return self._proxy_tokenizer
@classmethod
@abstractmethod
def new_client(
@ -257,6 +287,9 @@ class ProxyLLMClient(LLMClient):
Returns:
int: token count, -1 if failed
"""
if self.proxy_tokenizer.support_async():
cnts = await self.proxy_tokenizer.count_token_async(model, [prompt])
return cnts[0]
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)

View File

@ -5,17 +5,11 @@ import logging
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt.core import (
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
from dbgpt.util.i18n_utils import _
@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: OpenAILLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient):
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(

View File

@ -1,7 +1,275 @@
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
import logging
import os
from concurrent.futures import Executor
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import (
ProxyLLMClient,
ProxyTokenizer,
TiktokenProxyTokenizer,
)
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
if TYPE_CHECKING:
from anthropic import AsyncAnthropic, ProxiesTypes
logger = logging.getLogger(__name__)
def claude_generate_stream(
async def claude_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "claude LLM was not supported!"
) -> AsyncIterator[ModelOutput]:
client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r
class ClaudeLLMClient(ProxyLLMClient):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "claude_proxyllm",
context_length: Optional[int] = 8192,
client: Optional["AsyncAnthropic"] = None,
claude_kwargs: Optional[Dict[str, Any]] = None,
proxy_tokenizer: Optional[ProxyTokenizer] = None,
):
try:
import anthropic
except ImportError as exc:
raise ValueError(
"Could not import python package: anthropic "
"Please install anthropic by command `pip install anthropic"
) from exc
if not model:
model = "claude-3-5-sonnet-20241022"
self._client = client
self._model = model
self._api_key = api_key
self._api_base = api_base or os.environ.get(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)
self._proxies = proxies
self._timeout = timeout
self._claude_kwargs = claude_kwargs or {}
self._model_alias = model_alias
self._proxy_tokenizer = proxy_tokenizer
super().__init__(
model_names=[model_alias],
context_length=context_length,
proxy_tokenizer=proxy_tokenizer,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ClaudeProxyLLMClient":
return cls(
api_key=model_params.proxy_api_key,
api_base=model_params.proxy_api_base,
# api_type=model_params.proxy_api_type,
# api_version=model_params.proxy_api_version,
model=model_params.proxyllm_backend,
proxies=model_params.http_proxy,
model_alias=model_params.model_name,
context_length=max(model_params.max_context_size, 8192),
)
@property
def client(self) -> "AsyncAnthropic":
from anthropic import AsyncAnthropic
if self._client is None:
self._client = AsyncAnthropic(
api_key=self._api_key,
base_url=self._api_base,
proxies=self._proxies,
timeout=self._timeout,
)
return self._client
@property
def proxy_tokenizer(self) -> ProxyTokenizer:
if not self._proxy_tokenizer:
self._proxy_tokenizer = ClaudeProxyTokenizer(self.client)
return self._proxy_tokenizer
@property
def default_model(self) -> str:
"""Default model name"""
model = self._model
if not model:
model = "claude-3-5-sonnet-20241022"
return model
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"stream": stream}
model = request.model or self.default_model
payload["model"] = model
# Apply claude kwargs
for k, v in self._claude_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
if request.top_p:
payload["top_p"] = request.top_p
return payload
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
response = await self.client.messages.create(
max_tokens=max_tokens,
messages=messages,
**payload,
)
usage = None
finish_reason = response.stop_reason
if response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
}
response_content = response.content
if not response_content:
raise ValueError("Response content is empty")
return ModelOutput(
text=response_content[0].text,
error_code=0,
finish_reason=finish_reason,
usage=usage,
)
except Exception as e:
return ModelOutput(
text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages, system_messages = request.split_messages()
payload = self._build_request(request, stream=True)
logger.info(
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
)
try:
if len(system_messages) > 1:
raise ValueError("Claude only supports single system message")
if system_messages:
payload["system"] = system_messages[0]
if "max_tokens" not in payload:
max_tokens = 1024
else:
max_tokens = payload["max_tokens"]
del payload["max_tokens"]
if "stream" in payload:
del payload["stream"]
full_text = ""
async with self.client.messages.stream(
max_tokens=max_tokens,
messages=messages,
**payload,
) as stream:
async for text in stream.text_stream:
full_text += text
usage = {
"prompt_tokens": stream.current_message_snapshot.usage.input_tokens,
"completion_tokens": stream.current_message_snapshot.usage.output_tokens,
}
yield ModelOutput(text=full_text, error_code=0, usage=usage)
except Exception as e:
yield ModelOutput(
text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self.context_length
class ClaudeProxyTokenizer(ProxyTokenizer):
def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10):
self.client = client
self.concurrency_limit = concurrency_limit
self._tiktoken_tokenizer = TiktokenProxyTokenizer()
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
# Use tiktoken to count token in local environment
return self._tiktoken_tokenizer.count_token(model_name, prompts)
def support_async(self) -> bool:
return True
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given messages.
This is relying on the claude beta API, which is not available for some users.
"""
from dbgpt.util.chat_util import run_async_tasks
tasks = []
model_name = model_name or "claude-3-5-sonnet-20241022"
for prompt in prompts:
request = ModelRequest(
model=model_name, messages=[{"role": "user", "content": prompt}]
)
tasks.append(
self.client.beta.messages.count_tokens(
model=model_name,
messages=request.messages,
)
)
results = await run_async_tasks(tasks, self.concurrency_limit)
return results

View File

@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@ -20,15 +19,7 @@ async def deepseek_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: DeepseekLLMClient = cast(DeepseekLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@ -2,17 +2,11 @@ import os
from concurrent.futures import Executor
from typing import Any, Dict, Iterator, List, Optional, Tuple
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelMessage, ModelOutput, ModelRequest
from dbgpt.core.interface.message import parse_model_messages
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
GEMINI_DEFAULT_MODEL = "gemini-pro"
@ -39,15 +33,7 @@ def gemini_generate_stream(
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
client: GeminiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@ -1,8 +1,8 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core import ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@ -19,15 +19,7 @@ async def moonshot_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@ -14,14 +14,7 @@ def ollama_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
client: OllamaLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
@ -41,3 +42,30 @@ class ProxyModel:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)
def parse_model_request(
params: Dict[str, Any], default_model: str, stream: bool = True
) -> ModelRequest:
"""Parse model request from params.
Args:
params (Dict[str, Any]): request params
default_model (str): default model name
stream (bool, optional): whether stream. Defaults to True.
"""
context = ModelRequestContext(
stream=stream,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
top_p=params.get("top_p"),
)
return request

View File

@ -1,12 +1,12 @@
import json
import os
from concurrent.futures import Executor
from typing import AsyncIterator, Optional
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
def getlength(text):
@ -28,20 +28,8 @@ def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: SparkLLMClient = model.proxy_llm_client
context = ModelRequestContext(
stream=True,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.generate_stream(request):
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r
@ -141,11 +129,11 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str:
return self._model
def generate_stream(
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
) -> Iterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E

View File

@ -2,10 +2,10 @@ import logging
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
logger = logging.getLogger(__name__)
@ -14,15 +14,7 @@ def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: TongyiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@ -2,19 +2,12 @@ import json
import logging
import os
from concurrent.futures import Executor
from typing import Iterator, List, Optional
from typing import Iterator, Optional
import requests
from cachetools import TTLCache, cached
from dbgpt.core import (
MessageConverter,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
@ -51,26 +44,16 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token")
def _to_wenxin_messages(messages: List[ModelMessage]):
def _to_wenxin_messages(request: ModelRequest):
"""Convert messages to wenxin compatible format
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
wenxin_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
wenxin_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
wenxin_messages.append({"role": "assistant", "content": message.content})
else:
pass
messages, system_messages = request.split_messages()
if len(system_messages) > 1:
raise ValueError("Wenxin only support one system message")
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
return messages, str_system_message
def wenxin_generate_stream(
@ -167,7 +150,7 @@ class WenxinLLMClient(ProxyLLMClient):
"Failed to get access token. please set the correct api_key and secret key."
)
history, system_message = _to_wenxin_messages(request.get_messages())
history, system_message = _to_wenxin_messages(request)
payload = {
"messages": history,
"system": system_message,

View File

@ -1,8 +1,7 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from dbgpt.core import ModelRequest, ModelRequestContext
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from .chatgpt import OpenAILLMClient
@ -19,15 +18,7 @@ async def yi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: YiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r

View File

@ -2,10 +2,10 @@ import os
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
@ -21,15 +21,7 @@ def zhipu_generate_stream(
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
# history, systems = __convert_2_zhipu_messages(messages)
client: ZhipuLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
request = parse_model_request(params, client.default_model, stream=True)
for r in client.sync_generate_stream(request):
yield r

View File

@ -670,6 +670,13 @@ def openai_requires():
setup_spec.extras["openai"] += setup_spec.extras["rag"]
def proxy_requires():
"""
pip install "dbgpt[proxy]"
"""
setup_spec.extras["proxy"] = setup_spec.extras["openai"] + ["anthropic"]
def gpt4all_requires():
"""
pip install "dbgpt[gpt4all]"
@ -727,6 +734,7 @@ def default_requires():
setup_spec.extras["default"] += setup_spec.extras["datasource"]
setup_spec.extras["default"] += setup_spec.extras["torch"]
setup_spec.extras["default"] += setup_spec.extras["cache"]
setup_spec.extras["default"] += setup_spec.extras["proxy"]
setup_spec.extras["default"] += setup_spec.extras["code"]
if INCLUDE_QUANTIZATION:
# Add quantization extra to default, default is True
@ -763,6 +771,7 @@ cache_requires()
observability_requires()
openai_requires()
proxy_requires()
# must be last
default_requires()
all_requires()