mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
276 lines
9.9 KiB
Python
276 lines
9.9 KiB
Python
import logging
|
|
import os
|
|
from concurrent.futures import Executor
|
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast
|
|
|
|
from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest
|
|
from dbgpt.model.parameter import ProxyModelParameters
|
|
from dbgpt.model.proxy.base import (
|
|
ProxyLLMClient,
|
|
ProxyTokenizer,
|
|
TiktokenProxyTokenizer,
|
|
)
|
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
|
|
|
if TYPE_CHECKING:
|
|
from anthropic import AsyncAnthropic, ProxiesTypes
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def claude_generate_stream(
|
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
|
) -> AsyncIterator[ModelOutput]:
|
|
client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client)
|
|
request = parse_model_request(params, client.default_model, stream=True)
|
|
async for r in client.generate_stream(request):
|
|
yield r
|
|
|
|
|
|
class ClaudeLLMClient(ProxyLLMClient):
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
proxies: Optional["ProxiesTypes"] = None,
|
|
timeout: Optional[int] = 240,
|
|
model_alias: Optional[str] = "claude_proxyllm",
|
|
context_length: Optional[int] = 8192,
|
|
client: Optional["AsyncAnthropic"] = None,
|
|
claude_kwargs: Optional[Dict[str, Any]] = None,
|
|
proxy_tokenizer: Optional[ProxyTokenizer] = None,
|
|
):
|
|
try:
|
|
import anthropic
|
|
except ImportError as exc:
|
|
raise ValueError(
|
|
"Could not import python package: anthropic "
|
|
"Please install anthropic by command `pip install anthropic"
|
|
) from exc
|
|
if not model:
|
|
model = "claude-3-5-sonnet-20241022"
|
|
self._client = client
|
|
self._model = model
|
|
self._api_key = api_key
|
|
self._api_base = api_base or os.environ.get(
|
|
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
|
|
)
|
|
self._proxies = proxies
|
|
self._timeout = timeout
|
|
self._claude_kwargs = claude_kwargs or {}
|
|
self._model_alias = model_alias
|
|
self._proxy_tokenizer = proxy_tokenizer
|
|
|
|
super().__init__(
|
|
model_names=[model_alias],
|
|
context_length=context_length,
|
|
proxy_tokenizer=proxy_tokenizer,
|
|
)
|
|
|
|
@classmethod
|
|
def new_client(
|
|
cls,
|
|
model_params: ProxyModelParameters,
|
|
default_executor: Optional[Executor] = None,
|
|
) -> "ClaudeProxyLLMClient":
|
|
return cls(
|
|
api_key=model_params.proxy_api_key,
|
|
api_base=model_params.proxy_api_base,
|
|
# api_type=model_params.proxy_api_type,
|
|
# api_version=model_params.proxy_api_version,
|
|
model=model_params.proxyllm_backend,
|
|
proxies=model_params.http_proxy,
|
|
model_alias=model_params.model_name,
|
|
context_length=max(model_params.max_context_size, 8192),
|
|
)
|
|
|
|
@property
|
|
def client(self) -> "AsyncAnthropic":
|
|
from anthropic import AsyncAnthropic
|
|
|
|
if self._client is None:
|
|
self._client = AsyncAnthropic(
|
|
api_key=self._api_key,
|
|
base_url=self._api_base,
|
|
proxies=self._proxies,
|
|
timeout=self._timeout,
|
|
)
|
|
return self._client
|
|
|
|
@property
|
|
def proxy_tokenizer(self) -> ProxyTokenizer:
|
|
if not self._proxy_tokenizer:
|
|
self._proxy_tokenizer = ClaudeProxyTokenizer(self.client)
|
|
return self._proxy_tokenizer
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
"""Default model name"""
|
|
model = self._model
|
|
if not model:
|
|
model = "claude-3-5-sonnet-20241022"
|
|
return model
|
|
|
|
def _build_request(
|
|
self, request: ModelRequest, stream: Optional[bool] = False
|
|
) -> Dict[str, Any]:
|
|
payload = {"stream": stream}
|
|
model = request.model or self.default_model
|
|
payload["model"] = model
|
|
# Apply claude kwargs
|
|
for k, v in self._claude_kwargs.items():
|
|
payload[k] = v
|
|
if request.temperature:
|
|
payload["temperature"] = request.temperature
|
|
if request.max_new_tokens:
|
|
payload["max_tokens"] = request.max_new_tokens
|
|
if request.stop:
|
|
payload["stop"] = request.stop
|
|
if request.top_p:
|
|
payload["top_p"] = request.top_p
|
|
return payload
|
|
|
|
async def generate(
|
|
self,
|
|
request: ModelRequest,
|
|
message_converter: Optional[MessageConverter] = None,
|
|
) -> ModelOutput:
|
|
request = self.local_covert_message(request, message_converter)
|
|
messages, system_messages = request.split_messages()
|
|
payload = self._build_request(request)
|
|
logger.info(
|
|
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
|
|
)
|
|
try:
|
|
if len(system_messages) > 1:
|
|
raise ValueError("Claude only supports single system message")
|
|
if system_messages:
|
|
payload["system"] = system_messages[0]
|
|
if "max_tokens" not in payload:
|
|
max_tokens = 1024
|
|
else:
|
|
max_tokens = payload["max_tokens"]
|
|
del payload["max_tokens"]
|
|
response = await self.client.messages.create(
|
|
max_tokens=max_tokens,
|
|
messages=messages,
|
|
**payload,
|
|
)
|
|
usage = None
|
|
finish_reason = response.stop_reason
|
|
if response.usage:
|
|
usage = {
|
|
"prompt_tokens": response.usage.input_tokens,
|
|
"completion_tokens": response.usage.output_tokens,
|
|
}
|
|
response_content = response.content
|
|
if not response_content:
|
|
raise ValueError("Response content is empty")
|
|
return ModelOutput(
|
|
text=response_content[0].text,
|
|
error_code=0,
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
)
|
|
except Exception as e:
|
|
return ModelOutput(
|
|
text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}",
|
|
error_code=1,
|
|
)
|
|
|
|
async def generate_stream(
|
|
self,
|
|
request: ModelRequest,
|
|
message_converter: Optional[MessageConverter] = None,
|
|
) -> AsyncIterator[ModelOutput]:
|
|
request = self.local_covert_message(request, message_converter)
|
|
messages, system_messages = request.split_messages()
|
|
payload = self._build_request(request, stream=True)
|
|
logger.info(
|
|
f"Send request to claude, payload: {payload}\n\n messages:\n{messages}"
|
|
)
|
|
try:
|
|
if len(system_messages) > 1:
|
|
raise ValueError("Claude only supports single system message")
|
|
if system_messages:
|
|
payload["system"] = system_messages[0]
|
|
if "max_tokens" not in payload:
|
|
max_tokens = 1024
|
|
else:
|
|
max_tokens = payload["max_tokens"]
|
|
del payload["max_tokens"]
|
|
if "stream" in payload:
|
|
del payload["stream"]
|
|
full_text = ""
|
|
async with self.client.messages.stream(
|
|
max_tokens=max_tokens,
|
|
messages=messages,
|
|
**payload,
|
|
) as stream:
|
|
async for text in stream.text_stream:
|
|
full_text += text
|
|
usage = {
|
|
"prompt_tokens": stream.current_message_snapshot.usage.input_tokens,
|
|
"completion_tokens": stream.current_message_snapshot.usage.output_tokens,
|
|
}
|
|
yield ModelOutput(text=full_text, error_code=0, usage=usage)
|
|
except Exception as e:
|
|
yield ModelOutput(
|
|
text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}",
|
|
error_code=1,
|
|
)
|
|
|
|
async def models(self) -> List[ModelMetadata]:
|
|
model_metadata = ModelMetadata(
|
|
model=self._model_alias,
|
|
context_length=await self.get_context_length(),
|
|
)
|
|
return [model_metadata]
|
|
|
|
async def get_context_length(self) -> int:
|
|
"""Get the context length of the model.
|
|
|
|
Returns:
|
|
int: The context length.
|
|
# TODO: This is a temporary solution. We should have a better way to get the context length.
|
|
eg. get real context length from the openai api.
|
|
"""
|
|
return self.context_length
|
|
|
|
|
|
class ClaudeProxyTokenizer(ProxyTokenizer):
|
|
def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10):
|
|
self.client = client
|
|
self.concurrency_limit = concurrency_limit
|
|
self._tiktoken_tokenizer = TiktokenProxyTokenizer()
|
|
|
|
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
|
|
# Use tiktoken to count token in local environment
|
|
return self._tiktoken_tokenizer.count_token(model_name, prompts)
|
|
|
|
def support_async(self) -> bool:
|
|
return True
|
|
|
|
async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]:
|
|
"""Count token of given messages.
|
|
|
|
This is relying on the claude beta API, which is not available for some users.
|
|
"""
|
|
from dbgpt.util.chat_util import run_async_tasks
|
|
|
|
tasks = []
|
|
model_name = model_name or "claude-3-5-sonnet-20241022"
|
|
for prompt in prompts:
|
|
request = ModelRequest(
|
|
model=model_name, messages=[{"role": "user", "content": prompt}]
|
|
)
|
|
tasks.append(
|
|
self.client.beta.messages.count_tokens(
|
|
model=model_name,
|
|
messages=request.messages,
|
|
)
|
|
)
|
|
results = await run_async_tasks(tasks, self.concurrency_limit)
|
|
return results
|