mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-22 10:08:34 +00:00
feat(model): Support claude proxy models (#2155)
This commit is contained in:
parent
9d8673a02f
commit
61509dc5ea
@ -131,6 +131,15 @@ class Config(metaclass=Singleton):
|
||||
os.environ["deepseek_proxyllm_api_base"] = os.getenv(
|
||||
"DEEPSEEK_API_BASE", "https://api.deepseek.com/v1"
|
||||
)
|
||||
self.claude_proxy_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.claude_proxy_api_key:
|
||||
os.environ["claude_proxyllm_proxy_api_key"] = self.claude_proxy_api_key
|
||||
os.environ["claude_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"ANTHROPIC_MODEL_VERSION", "claude-3-5-sonnet-20241022"
|
||||
)
|
||||
os.environ["claude_proxyllm_api_base"] = os.getenv(
|
||||
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
|
||||
)
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
@ -394,6 +394,29 @@ class ModelRequest:
|
||||
"""
|
||||
return ModelMessage.messages_to_string(self.get_messages())
|
||||
|
||||
def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""Split the messages.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[str]]: The common messages and system
|
||||
messages.
|
||||
"""
|
||||
messages = self.get_messages()
|
||||
common_messages = []
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
common_messages.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
common_messages.append(
|
||||
{"role": "assistant", "content": message.content}
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return common_messages, system_messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
@ -861,7 +884,9 @@ class LLMClient(ABC):
|
||||
raise ValueError(f"Model {model} not found")
|
||||
return model_metadata
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ModelOutput:
|
||||
def __call__(
|
||||
self, *args, **kwargs
|
||||
) -> Coroutine[Any, Any, ModelOutput] | ModelOutput:
|
||||
"""Return the model output.
|
||||
|
||||
Call the LLM client to generate the response for the given message.
|
||||
@ -869,22 +894,63 @@ class LLMClient(ABC):
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
|
||||
try:
|
||||
# Check if we are in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we are in an event loop, use async call
|
||||
if loop.is_running():
|
||||
# Because we are in an async environment, but this is a sync method,
|
||||
# we need to return a coroutine object for the caller to use await
|
||||
return self.async_call(*args, **kwargs)
|
||||
else:
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.async_call(*args, **kwargs))
|
||||
except RuntimeError:
|
||||
# If we are not in an event loop, use sync call
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.async_call(*args, **kwargs))
|
||||
|
||||
async def async_call(self, *args, **kwargs) -> ModelOutput:
|
||||
"""Return the model output asynchronously.
|
||||
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
req = self._build_call_request(*args, **kwargs)
|
||||
return await self.generate(req)
|
||||
|
||||
async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]:
|
||||
"""Return the model output stream asynchronously.
|
||||
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
req = self._build_call_request(*args, **kwargs)
|
||||
async for output in self.generate_stream(req): # type: ignore
|
||||
yield output
|
||||
|
||||
def _build_call_request(self, *args, **kwargs) -> ModelRequest:
|
||||
"""Build the model request for the call method."""
|
||||
messages = kwargs.get("messages")
|
||||
model = kwargs.get("model")
|
||||
|
||||
if messages:
|
||||
del kwargs["messages"]
|
||||
model_messages = ModelMessage.from_openai_messages(messages)
|
||||
else:
|
||||
model_messages = [ModelMessage.build_human_message(args[0])]
|
||||
|
||||
if not model:
|
||||
if hasattr(self, "default_model"):
|
||||
model = getattr(self, "default_model")
|
||||
else:
|
||||
raise ValueError("The default model is not set")
|
||||
|
||||
if "model" in kwargs:
|
||||
del kwargs["model"]
|
||||
req = ModelRequest.build_request(model, model_messages, **kwargs)
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.generate(req))
|
||||
|
||||
return ModelRequest.build_request(model, model_messages, **kwargs)
|
||||
|
@ -97,6 +97,26 @@ class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
return chatgpt_generate_stream
|
||||
|
||||
|
||||
class ClaudeProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "claude_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
|
||||
|
||||
return ClaudeLLMClient
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.claude import claude_generate_stream
|
||||
|
||||
return claude_generate_stream
|
||||
|
||||
|
||||
class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "tongyi_proxyllm"
|
||||
@ -320,6 +340,7 @@ class DeepseekProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
|
||||
|
||||
register_model_adapter(OpenAIProxyLLMModelAdapter)
|
||||
register_model_adapter(ClaudeProxyLLMModelAdapter)
|
||||
register_model_adapter(TongyiProxyLLMModelAdapter)
|
||||
register_model_adapter(OllamaLLMModelAdapter)
|
||||
register_model_adapter(ZhipuProxyLLMModelAdapter)
|
||||
|
@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Iterator, List, Optional
|
||||
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import ModelMetadata, ModelOutput
|
||||
@ -113,7 +113,9 @@ class WorkerManager(ABC):
|
||||
"""Shutdown model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
async def generate_stream(
|
||||
self, params: Dict, **kwargs
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -9,7 +9,7 @@ import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Iterator
|
||||
from typing import AsyncIterator, Awaitable, Callable, Iterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -327,7 +327,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
async def generate_stream(
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
) -> Iterator[ModelOutput]:
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.generate_stream", params.get("span_id")
|
||||
@ -693,7 +693,9 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
async def generate_stream(
|
||||
self, params: Dict, **kwargs
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.worker_manager.generate_stream(params, **kwargs):
|
||||
yield output
|
||||
|
||||
|
@ -1,9 +1,25 @@
|
||||
"""Proxy models."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.claude import ClaudeLLMClient
|
||||
from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient
|
||||
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
|
||||
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
|
||||
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
|
||||
from dbgpt.model.proxy.llms.spark import SparkLLMClient
|
||||
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
|
||||
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
|
||||
from dbgpt.model.proxy.llms.yi import YiLLMClient
|
||||
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient
|
||||
|
||||
|
||||
def __lazy_import(name):
|
||||
module_path = {
|
||||
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
|
||||
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
|
||||
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
|
||||
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
|
||||
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
|
||||
@ -28,6 +44,7 @@ def __getattr__(name):
|
||||
|
||||
__all__ = [
|
||||
"OpenAILLMClient",
|
||||
"ClaudeLLMClient",
|
||||
"GeminiLLMClient",
|
||||
"TongyiLLMClient",
|
||||
"ZhipuLLMClient",
|
||||
|
@ -34,6 +34,25 @@ class ProxyTokenizer(ABC):
|
||||
List[int]: token count, -1 if failed
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Check if the tokenizer supports asynchronous counting token.
|
||||
|
||||
Returns:
|
||||
bool: True if supports, False otherwise
|
||||
"""
|
||||
return False
|
||||
|
||||
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
|
||||
"""Count token of given prompts asynchronously.
|
||||
Args:
|
||||
model_name (str): model name
|
||||
prompts (List[str]): prompts to count token
|
||||
|
||||
Returns:
|
||||
List[int]: token count, -1 if failed
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TiktokenProxyTokenizer(ProxyTokenizer):
|
||||
def __init__(self):
|
||||
@ -92,7 +111,7 @@ class ProxyLLMClient(LLMClient):
|
||||
self.model_names = model_names
|
||||
self.context_length = context_length
|
||||
self.executor = executor or ThreadPoolExecutor()
|
||||
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
|
||||
self._proxy_tokenizer = proxy_tokenizer
|
||||
|
||||
def __getstate__(self):
|
||||
"""Customize the serialization of the object"""
|
||||
@ -105,6 +124,17 @@ class ProxyLLMClient(LLMClient):
|
||||
self.__dict__.update(state)
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
@property
|
||||
def proxy_tokenizer(self) -> ProxyTokenizer:
|
||||
"""Get proxy tokenizer
|
||||
|
||||
Returns:
|
||||
ProxyTokenizer: proxy tokenizer
|
||||
"""
|
||||
if not self._proxy_tokenizer:
|
||||
self._proxy_tokenizer = TiktokenProxyTokenizer()
|
||||
return self._proxy_tokenizer
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def new_client(
|
||||
@ -257,6 +287,9 @@ class ProxyLLMClient(LLMClient):
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
if self.proxy_tokenizer.support_async():
|
||||
cnts = await self.proxy_tokenizer.count_token_async(model, [prompt])
|
||||
return cnts[0]
|
||||
counts = await blocking_func_to_async(
|
||||
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
|
||||
)
|
||||
|
@ -5,17 +5,11 @@ import logging
|
||||
from concurrent.futures import Executor
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMetadata,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@ -32,15 +26,7 @@ async def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: OpenAILLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
payload["max_tokens"] = request.max_new_tokens
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.top_p:
|
||||
payload["top_p"] = request.top_p
|
||||
return payload
|
||||
|
||||
async def generate(
|
||||
|
@ -1,7 +1,275 @@
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import (
|
||||
ProxyLLMClient,
|
||||
ProxyTokenizer,
|
||||
TiktokenProxyTokenizer,
|
||||
)
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic import AsyncAnthropic, ProxiesTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def claude_generate_stream(
|
||||
async def claude_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
yield "claude LLM was not supported!"
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
class ClaudeLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = "claude_proxyllm",
|
||||
context_length: Optional[int] = 8192,
|
||||
client: Optional["AsyncAnthropic"] = None,
|
||||
claude_kwargs: Optional[Dict[str, Any]] = None,
|
||||
proxy_tokenizer: Optional[ProxyTokenizer] = None,
|
||||
):
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: anthropic "
|
||||
"Please install anthropic by command `pip install anthropic"
|
||||
) from exc
|
||||
if not model:
|
||||
model = "claude-3-5-sonnet-20241022"
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._api_key = api_key
|
||||
self._api_base = api_base or os.environ.get(
|
||||
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
|
||||
)
|
||||
self._proxies = proxies
|
||||
self._timeout = timeout
|
||||
self._claude_kwargs = claude_kwargs or {}
|
||||
self._model_alias = model_alias
|
||||
self._proxy_tokenizer = proxy_tokenizer
|
||||
|
||||
super().__init__(
|
||||
model_names=[model_alias],
|
||||
context_length=context_length,
|
||||
proxy_tokenizer=proxy_tokenizer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "ClaudeProxyLLMClient":
|
||||
return cls(
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_base=model_params.proxy_api_base,
|
||||
# api_type=model_params.proxy_api_type,
|
||||
# api_version=model_params.proxy_api_version,
|
||||
model=model_params.proxyllm_backend,
|
||||
proxies=model_params.http_proxy,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=max(model_params.max_context_size, 8192),
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> "AsyncAnthropic":
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
if self._client is None:
|
||||
self._client = AsyncAnthropic(
|
||||
api_key=self._api_key,
|
||||
base_url=self._api_base,
|
||||
proxies=self._proxies,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def proxy_tokenizer(self) -> ProxyTokenizer:
|
||||
if not self._proxy_tokenizer:
|
||||
self._proxy_tokenizer = ClaudeProxyTokenizer(self.client)
|
||||
return self._proxy_tokenizer
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
"""Default model name"""
|
||||
model = self._model
|
||||
if not model:
|
||||
model = "claude-3-5-sonnet-20241022"
|
||||
return model
|
||||
|
||||
def _build_request(
|
||||
self, request: ModelRequest, stream: Optional[bool] = False
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"stream": stream}
|
||||
model = request.model or self.default_model
|
||||
payload["model"] = model
|
||||
# Apply claude kwargs
|
||||
for k, v in self._claude_kwargs.items():
|
||||
payload[k] = v
|
||||
if request.temperature:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_new_tokens:
|
||||
payload["max_tokens"] = request.max_new_tokens
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.top_p:
|
||||
payload["top_p"] = request.top_p
|
||||
return payload
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages, system_messages = request.split_messages()
|
||||
payload = self._build_request(request)
|
||||
logger.info(
|
||||
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
|
||||
)
|
||||
try:
|
||||
if len(system_messages) > 1:
|
||||
raise ValueError("Claude only supports single system message")
|
||||
if system_messages:
|
||||
payload["system"] = system_messages[0]
|
||||
if "max_tokens" not in payload:
|
||||
max_tokens = 1024
|
||||
else:
|
||||
max_tokens = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
response = await self.client.messages.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
**payload,
|
||||
)
|
||||
usage = None
|
||||
finish_reason = response.stop_reason
|
||||
if response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
}
|
||||
response_content = response.content
|
||||
if not response_content:
|
||||
raise ValueError("Response content is empty")
|
||||
return ModelOutput(
|
||||
text=response_content[0].text,
|
||||
error_code=0,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages, system_messages = request.split_messages()
|
||||
payload = self._build_request(request, stream=True)
|
||||
logger.info(
|
||||
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
|
||||
)
|
||||
try:
|
||||
if len(system_messages) > 1:
|
||||
raise ValueError("Claude only supports single system message")
|
||||
if system_messages:
|
||||
payload["system"] = system_messages[0]
|
||||
if "max_tokens" not in payload:
|
||||
max_tokens = 1024
|
||||
else:
|
||||
max_tokens = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
if "stream" in payload:
|
||||
del payload["stream"]
|
||||
full_text = ""
|
||||
async with self.client.messages.stream(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
**payload,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
full_text += text
|
||||
usage = {
|
||||
"prompt_tokens": stream.current_message_snapshot.usage.input_tokens,
|
||||
"completion_tokens": stream.current_message_snapshot.usage.output_tokens,
|
||||
}
|
||||
yield ModelOutput(text=full_text, error_code=0, usage=usage)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
model_metadata = ModelMetadata(
|
||||
model=self._model_alias,
|
||||
context_length=await self.get_context_length(),
|
||||
)
|
||||
return [model_metadata]
|
||||
|
||||
async def get_context_length(self) -> int:
|
||||
"""Get the context length of the model.
|
||||
|
||||
Returns:
|
||||
int: The context length.
|
||||
# TODO: This is a temporary solution. We should have a better way to get the context length.
|
||||
eg. get real context length from the openai api.
|
||||
"""
|
||||
return self.context_length
|
||||
|
||||
|
||||
class ClaudeProxyTokenizer(ProxyTokenizer):
|
||||
def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10):
|
||||
self.client = client
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self._tiktoken_tokenizer = TiktokenProxyTokenizer()
|
||||
|
||||
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
|
||||
# Use tiktoken to count token in local environment
|
||||
return self._tiktoken_tokenizer.count_token(model_name, prompts)
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
|
||||
"""Count token of given messages.
|
||||
|
||||
This is relying on the claude beta API, which is not available for some users.
|
||||
"""
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
tasks = []
|
||||
model_name = model_name or "claude-3-5-sonnet-20241022"
|
||||
for prompt in prompts:
|
||||
request = ModelRequest(
|
||||
model=model_name, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
tasks.append(
|
||||
self.client.beta.messages.count_tokens(
|
||||
model=model_name,
|
||||
messages=request.messages,
|
||||
)
|
||||
)
|
||||
results = await run_async_tasks(tasks, self.concurrency_limit)
|
||||
return results
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
from .chatgpt import OpenAILLMClient
|
||||
|
||||
@ -20,15 +19,7 @@ async def deepseek_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: DeepseekLLMClient = cast(DeepseekLLMClient, model.proxy_llm_client)
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -2,17 +2,11 @@ import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMessage,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core import MessageConverter, ModelMessage, ModelOutput, ModelRequest
|
||||
from dbgpt.core.interface.message import parse_model_messages
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
GEMINI_DEFAULT_MODEL = "gemini-pro"
|
||||
|
||||
@ -39,15 +33,7 @@ def gemini_generate_stream(
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
client: GeminiLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core import ModelRequestContext
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
from .chatgpt import OpenAILLMClient
|
||||
|
||||
@ -19,15 +19,7 @@ async def moonshot_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client)
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -2,10 +2,10 @@ import logging
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,14 +14,7 @@ def ollama_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=4096
|
||||
):
|
||||
client: OllamaLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||
@ -41,3 +42,30 @@ class ProxyModel:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return self._tokenizer.count_token(messages, model_name)
|
||||
|
||||
|
||||
def parse_model_request(
|
||||
params: Dict[str, Any], default_model: str, stream: bool = True
|
||||
) -> ModelRequest:
|
||||
"""Parse model request from params.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): request params
|
||||
default_model (str): default model name
|
||||
stream (bool, optional): whether stream. Defaults to True.
|
||||
"""
|
||||
context = ModelRequestContext(
|
||||
stream=stream,
|
||||
user_name=params.get("user_name"),
|
||||
request_id=params.get("request_id"),
|
||||
)
|
||||
request = ModelRequest.build_request(
|
||||
default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
top_p=params.get("top_p"),
|
||||
)
|
||||
return request
|
||||
|
@ -1,12 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import AsyncIterator, Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
|
||||
def getlength(text):
|
||||
@ -28,20 +28,8 @@ def spark_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: SparkLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(
|
||||
stream=True,
|
||||
user_name=params.get("user_name"),
|
||||
request_id=params.get("request_id"),
|
||||
)
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
for r in client.generate_stream(request):
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
@ -141,11 +129,11 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def generate_stream(
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
) -> Iterator[ModelOutput]:
|
||||
"""
|
||||
reference:
|
||||
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
|
||||
|
@ -2,10 +2,10 @@ import logging
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,15 +14,7 @@ def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: TongyiLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -2,19 +2,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, List, Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import requests
|
||||
from cachetools import TTLCache, cached
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageConverter,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
@ -51,26 +44,16 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
return res.json().get("access_token")
|
||||
|
||||
|
||||
def _to_wenxin_messages(messages: List[ModelMessage]):
|
||||
def _to_wenxin_messages(request: ModelRequest):
|
||||
"""Convert messages to wenxin compatible format
|
||||
|
||||
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
|
||||
"""
|
||||
wenxin_messages = []
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
wenxin_messages.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
wenxin_messages.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
messages, system_messages = request.split_messages()
|
||||
if len(system_messages) > 1:
|
||||
raise ValueError("Wenxin only support one system message")
|
||||
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||
return wenxin_messages, str_system_message
|
||||
return messages, str_system_message
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
@ -167,7 +150,7 @@ class WenxinLLMClient(ProxyLLMClient):
|
||||
"Failed to get access token. please set the correct api_key and secret key."
|
||||
)
|
||||
|
||||
history, system_message = _to_wenxin_messages(request.get_messages())
|
||||
history, system_message = _to_wenxin_messages(request)
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system_message,
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from dbgpt.core import ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
from .chatgpt import OpenAILLMClient
|
||||
|
||||
@ -19,15 +18,7 @@ async def yi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
client: YiLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
async for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
@ -2,10 +2,10 @@ import os
|
||||
from concurrent.futures import Executor
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
@ -21,15 +21,7 @@ def zhipu_generate_stream(
|
||||
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
# history, systems = __convert_2_zhipu_messages(messages)
|
||||
client: ZhipuLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
request = parse_model_request(params, client.default_model, stream=True)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
9
setup.py
9
setup.py
@ -670,6 +670,13 @@ def openai_requires():
|
||||
setup_spec.extras["openai"] += setup_spec.extras["rag"]
|
||||
|
||||
|
||||
def proxy_requires():
|
||||
"""
|
||||
pip install "dbgpt[proxy]"
|
||||
"""
|
||||
setup_spec.extras["proxy"] = setup_spec.extras["openai"] + ["anthropic"]
|
||||
|
||||
|
||||
def gpt4all_requires():
|
||||
"""
|
||||
pip install "dbgpt[gpt4all]"
|
||||
@ -727,6 +734,7 @@ def default_requires():
|
||||
setup_spec.extras["default"] += setup_spec.extras["datasource"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["cache"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["proxy"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["code"]
|
||||
if INCLUDE_QUANTIZATION:
|
||||
# Add quantization extra to default, default is True
|
||||
@ -763,6 +771,7 @@ cache_requires()
|
||||
observability_requires()
|
||||
|
||||
openai_requires()
|
||||
proxy_requires()
|
||||
# must be last
|
||||
default_requires()
|
||||
all_requires()
|
||||
|
Loading…
Reference in New Issue
Block a user