diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index e88779d18..a7876824b 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -131,6 +131,15 @@ class Config(metaclass=Singleton): os.environ["deepseek_proxyllm_api_base"] = os.getenv( "DEEPSEEK_API_BASE", "https://api.deepseek.com/v1" ) + self.claude_proxy_api_key = os.getenv("ANTHROPIC_API_KEY") + if self.claude_proxy_api_key: + os.environ["claude_proxyllm_proxy_api_key"] = self.claude_proxy_api_key + os.environ["claude_proxyllm_proxyllm_backend"] = os.getenv( + "ANTHROPIC_MODEL_VERSION", "claude-3-5-sonnet-20241022" + ) + os.environ["claude_proxyllm_api_base"] = os.getenv( + "ANTHROPIC_BASE_URL", "https://api.anthropic.com" + ) self.proxy_server_url = os.getenv("PROXY_SERVER_URL") diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index edac25b67..a6b3031e4 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -6,7 +6,7 @@ import logging import time from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field -from typing import Any, AsyncIterator, Dict, List, Optional, Union +from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union from cachetools import TTLCache @@ -394,6 +394,29 @@ class ModelRequest: """ return ModelMessage.messages_to_string(self.get_messages()) + def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]: + """Split the messages. + + Returns: + Tuple[List[Dict[str, Any]], List[str]]: The common messages and system + messages. + """ + messages = self.get_messages() + common_messages = [] + system_messages = [] + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + common_messages.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + system_messages.append(message.content) + elif message.role == ModelMessageRoleType.AI: + common_messages.append( + {"role": "assistant", "content": message.content} + ) + else: + pass + return common_messages, system_messages + @dataclass class ModelExtraMedata(BaseParameters): @@ -861,7 +884,9 @@ class LLMClient(ABC): raise ValueError(f"Model {model} not found") return model_metadata - def __call__(self, *args, **kwargs) -> ModelOutput: + def __call__( + self, *args, **kwargs + ) -> Coroutine[Any, Any, ModelOutput] | ModelOutput: """Return the model output. Call the LLM client to generate the response for the given message. @@ -869,22 +894,63 @@ class LLMClient(ABC): Please do not use this method in the production environment, it is only used for debugging. """ + import asyncio + from dbgpt.util import get_or_create_event_loop + try: + # Check if we are in an event loop + loop = asyncio.get_running_loop() + # If we are in an event loop, use async call + if loop.is_running(): + # Because we are in an async environment, but this is a sync method, + # we need to return a coroutine object for the caller to use await + return self.async_call(*args, **kwargs) + else: + loop = get_or_create_event_loop() + return loop.run_until_complete(self.async_call(*args, **kwargs)) + except RuntimeError: + # If we are not in an event loop, use sync call + loop = get_or_create_event_loop() + return loop.run_until_complete(self.async_call(*args, **kwargs)) + + async def async_call(self, *args, **kwargs) -> ModelOutput: + """Return the model output asynchronously. + + Please do not use this method in the production environment, it is only used + for debugging. + """ + req = self._build_call_request(*args, **kwargs) + return await self.generate(req) + + async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]: + """Return the model output stream asynchronously. + + Please do not use this method in the production environment, it is only used + for debugging. + """ + req = self._build_call_request(*args, **kwargs) + async for output in self.generate_stream(req): # type: ignore + yield output + + def _build_call_request(self, *args, **kwargs) -> ModelRequest: + """Build the model request for the call method.""" messages = kwargs.get("messages") model = kwargs.get("model") + if messages: del kwargs["messages"] model_messages = ModelMessage.from_openai_messages(messages) else: model_messages = [ModelMessage.build_human_message(args[0])] + if not model: if hasattr(self, "default_model"): model = getattr(self, "default_model") else: raise ValueError("The default model is not set") + if "model" in kwargs: del kwargs["model"] - req = ModelRequest.build_request(model, model_messages, **kwargs) - loop = get_or_create_event_loop() - return loop.run_until_complete(self.generate(req)) + + return ModelRequest.build_request(model, model_messages, **kwargs) diff --git a/dbgpt/model/adapter/proxy_adapter.py b/dbgpt/model/adapter/proxy_adapter.py index a84f4f330..d211393b0 100644 --- a/dbgpt/model/adapter/proxy_adapter.py +++ b/dbgpt/model/adapter/proxy_adapter.py @@ -97,6 +97,26 @@ class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter): return chatgpt_generate_stream +class ClaudeProxyLLMModelAdapter(ProxyLLMModelAdapter): + def support_async(self) -> bool: + return True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "claude_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.claude import ClaudeLLMClient + + return ClaudeLLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.claude import claude_generate_stream + + return claude_generate_stream + + class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter): def do_match(self, lower_model_name_or_path: Optional[str] = None): return lower_model_name_or_path == "tongyi_proxyllm" @@ -320,6 +340,7 @@ class DeepseekProxyLLMModelAdapter(ProxyLLMModelAdapter): register_model_adapter(OpenAIProxyLLMModelAdapter) +register_model_adapter(ClaudeProxyLLMModelAdapter) register_model_adapter(TongyiProxyLLMModelAdapter) register_model_adapter(OllamaLLMModelAdapter) register_model_adapter(ZhipuProxyLLMModelAdapter) diff --git a/dbgpt/model/cluster/manager_base.py b/dbgpt/model/cluster/manager_base.py index 37b09b0fd..3a4594d56 100644 --- a/dbgpt/model/cluster/manager_base.py +++ b/dbgpt/model/cluster/manager_base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass from datetime import datetime -from typing import Callable, Dict, Iterator, List, Optional +from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.core import ModelMetadata, ModelOutput @@ -113,7 +113,9 @@ class WorkerManager(ABC): """Shutdown model instance""" @abstractmethod - async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: + async def generate_stream( + self, params: Dict, **kwargs + ) -> AsyncIterator[ModelOutput]: """Generate stream result, chat scene""" @abstractmethod diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 32714a303..022854a85 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -9,7 +9,7 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict -from typing import Awaitable, Callable, Iterator +from typing import AsyncIterator, Awaitable, Callable, Iterator from fastapi import APIRouter from fastapi.responses import StreamingResponse @@ -327,7 +327,7 @@ class LocalWorkerManager(WorkerManager): async def generate_stream( self, params: Dict, async_wrapper=None, **kwargs - ) -> Iterator[ModelOutput]: + ) -> AsyncIterator[ModelOutput]: """Generate stream result, chat scene""" with root_tracer.start_span( "WorkerManager.generate_stream", params.get("span_id") @@ -693,7 +693,9 @@ class WorkerManagerAdapter(WorkerManager): worker_type, model_name, healthy_only ) - async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]: + async def generate_stream( + self, params: Dict, **kwargs + ) -> AsyncIterator[ModelOutput]: async for output in self.worker_manager.generate_stream(params, **kwargs): yield output diff --git a/dbgpt/model/proxy/__init__.py b/dbgpt/model/proxy/__init__.py index e3d3a21b4..716fa3475 100644 --- a/dbgpt/model/proxy/__init__.py +++ b/dbgpt/model/proxy/__init__.py @@ -1,9 +1,25 @@ """Proxy models.""" +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient + from dbgpt.model.proxy.llms.claude import ClaudeLLMClient + from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient + from dbgpt.model.proxy.llms.gemini import GeminiLLMClient + from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient + from dbgpt.model.proxy.llms.ollama import OllamaLLMClient + from dbgpt.model.proxy.llms.spark import SparkLLMClient + from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient + from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient + from dbgpt.model.proxy.llms.yi import YiLLMClient + from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient + def __lazy_import(name): module_path = { "OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt", + "ClaudeLLMClient": "dbgpt.model.proxy.llms.claude", "GeminiLLMClient": "dbgpt.model.proxy.llms.gemini", "SparkLLMClient": "dbgpt.model.proxy.llms.spark", "TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi", @@ -28,6 +44,7 @@ def __getattr__(name): __all__ = [ "OpenAILLMClient", + "ClaudeLLMClient", "GeminiLLMClient", "TongyiLLMClient", "ZhipuLLMClient", diff --git a/dbgpt/model/proxy/base.py b/dbgpt/model/proxy/base.py index 2a1a3b6b8..ff135a616 100644 --- a/dbgpt/model/proxy/base.py +++ b/dbgpt/model/proxy/base.py @@ -34,6 +34,25 @@ class ProxyTokenizer(ABC): List[int]: token count, -1 if failed """ + def support_async(self) -> bool: + """Check if the tokenizer supports asynchronous counting token. + + Returns: + bool: True if supports, False otherwise + """ + return False + + async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]: + """Count token of given prompts asynchronously. + Args: + model_name (str): model name + prompts (List[str]): prompts to count token + + Returns: + List[int]: token count, -1 if failed + """ + raise NotImplementedError() + class TiktokenProxyTokenizer(ProxyTokenizer): def __init__(self): @@ -92,7 +111,7 @@ class ProxyLLMClient(LLMClient): self.model_names = model_names self.context_length = context_length self.executor = executor or ThreadPoolExecutor() - self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer() + self._proxy_tokenizer = proxy_tokenizer def __getstate__(self): """Customize the serialization of the object""" @@ -105,6 +124,17 @@ class ProxyLLMClient(LLMClient): self.__dict__.update(state) self.executor = ThreadPoolExecutor() + @property + def proxy_tokenizer(self) -> ProxyTokenizer: + """Get proxy tokenizer + + Returns: + ProxyTokenizer: proxy tokenizer + """ + if not self._proxy_tokenizer: + self._proxy_tokenizer = TiktokenProxyTokenizer() + return self._proxy_tokenizer + @classmethod @abstractmethod def new_client( @@ -257,6 +287,9 @@ class ProxyLLMClient(LLMClient): Returns: int: token count, -1 if failed """ + if self.proxy_tokenizer.support_async(): + cnts = await self.proxy_tokenizer.count_token_async(model, [prompt]) + return cnts[0] counts = await blocking_func_to_async( self.executor, self.proxy_tokenizer.count_token, model, [prompt] ) diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 1b2c2135a..fd8b2e92f 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -5,17 +5,11 @@ import logging from concurrent.futures import Executor from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union -from dbgpt.core import ( - MessageConverter, - ModelMetadata, - ModelOutput, - ModelRequest, - ModelRequestContext, -) +from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request from dbgpt.model.utils.chatgpt_utils import OpenAIParameters from dbgpt.util.i18n_utils import _ @@ -32,15 +26,7 @@ async def chatgpt_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: OpenAILLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) async for r in client.generate_stream(request): yield r @@ -191,6 +177,8 @@ class OpenAILLMClient(ProxyLLMClient): payload["max_tokens"] = request.max_new_tokens if request.stop: payload["stop"] = request.stop + if request.top_p: + payload["top_p"] = request.top_p return payload async def generate( diff --git a/dbgpt/model/proxy/llms/claude.py b/dbgpt/model/proxy/llms/claude.py index 0d86e7937..e3355b2b8 100644 --- a/dbgpt/model/proxy/llms/claude.py +++ b/dbgpt/model/proxy/llms/claude.py @@ -1,7 +1,275 @@ -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +import logging +import os +from concurrent.futures import Executor +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, cast + +from dbgpt.core import MessageConverter, ModelMetadata, ModelOutput, ModelRequest +from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.proxy.base import ( + ProxyLLMClient, + ProxyTokenizer, + TiktokenProxyTokenizer, +) +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request + +if TYPE_CHECKING: + from anthropic import AsyncAnthropic, ProxiesTypes + +logger = logging.getLogger(__name__) -def claude_generate_stream( +async def claude_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 -): - yield "claude LLM was not supported!" +) -> AsyncIterator[ModelOutput]: + client: ClaudeLLMClient = cast(ClaudeLLMClient, model.proxy_llm_client) + request = parse_model_request(params, client.default_model, stream=True) + async for r in client.generate_stream(request): + yield r + + +class ClaudeLLMClient(ProxyLLMClient): + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model: Optional[str] = None, + proxies: Optional["ProxiesTypes"] = None, + timeout: Optional[int] = 240, + model_alias: Optional[str] = "claude_proxyllm", + context_length: Optional[int] = 8192, + client: Optional["AsyncAnthropic"] = None, + claude_kwargs: Optional[Dict[str, Any]] = None, + proxy_tokenizer: Optional[ProxyTokenizer] = None, + ): + try: + import anthropic + except ImportError as exc: + raise ValueError( + "Could not import python package: anthropic " + "Please install anthropic by command `pip install anthropic" + ) from exc + if not model: + model = "claude-3-5-sonnet-20241022" + self._client = client + self._model = model + self._api_key = api_key + self._api_base = api_base or os.environ.get( + "ANTHROPIC_BASE_URL", "https://api.anthropic.com" + ) + self._proxies = proxies + self._timeout = timeout + self._claude_kwargs = claude_kwargs or {} + self._model_alias = model_alias + self._proxy_tokenizer = proxy_tokenizer + + super().__init__( + model_names=[model_alias], + context_length=context_length, + proxy_tokenizer=proxy_tokenizer, + ) + + @classmethod + def new_client( + cls, + model_params: ProxyModelParameters, + default_executor: Optional[Executor] = None, + ) -> "ClaudeProxyLLMClient": + return cls( + api_key=model_params.proxy_api_key, + api_base=model_params.proxy_api_base, + # api_type=model_params.proxy_api_type, + # api_version=model_params.proxy_api_version, + model=model_params.proxyllm_backend, + proxies=model_params.http_proxy, + model_alias=model_params.model_name, + context_length=max(model_params.max_context_size, 8192), + ) + + @property + def client(self) -> "AsyncAnthropic": + from anthropic import AsyncAnthropic + + if self._client is None: + self._client = AsyncAnthropic( + api_key=self._api_key, + base_url=self._api_base, + proxies=self._proxies, + timeout=self._timeout, + ) + return self._client + + @property + def proxy_tokenizer(self) -> ProxyTokenizer: + if not self._proxy_tokenizer: + self._proxy_tokenizer = ClaudeProxyTokenizer(self.client) + return self._proxy_tokenizer + + @property + def default_model(self) -> str: + """Default model name""" + model = self._model + if not model: + model = "claude-3-5-sonnet-20241022" + return model + + def _build_request( + self, request: ModelRequest, stream: Optional[bool] = False + ) -> Dict[str, Any]: + payload = {"stream": stream} + model = request.model or self.default_model + payload["model"] = model + # Apply claude kwargs + for k, v in self._claude_kwargs.items(): + payload[k] = v + if request.temperature: + payload["temperature"] = request.temperature + if request.max_new_tokens: + payload["max_tokens"] = request.max_new_tokens + if request.stop: + payload["stop"] = request.stop + if request.top_p: + payload["top_p"] = request.top_p + return payload + + async def generate( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> ModelOutput: + request = self.local_covert_message(request, message_converter) + messages, system_messages = request.split_messages() + payload = self._build_request(request) + logger.info( + f"Send request to claude, payload: {payload}\n\n messages:\n{messages}" + ) + try: + if len(system_messages) > 1: + raise ValueError("Claude only supports single system message") + if system_messages: + payload["system"] = system_messages[0] + if "max_tokens" not in payload: + max_tokens = 1024 + else: + max_tokens = payload["max_tokens"] + del payload["max_tokens"] + response = await self.client.messages.create( + max_tokens=max_tokens, + messages=messages, + **payload, + ) + usage = None + finish_reason = response.stop_reason + if response.usage: + usage = { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + } + response_content = response.content + if not response_content: + raise ValueError("Response content is empty") + return ModelOutput( + text=response_content[0].text, + error_code=0, + finish_reason=finish_reason, + usage=usage, + ) + except Exception as e: + return ModelOutput( + text=f"**Claude Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) + + async def generate_stream( + self, + request: ModelRequest, + message_converter: Optional[MessageConverter] = None, + ) -> AsyncIterator[ModelOutput]: + request = self.local_covert_message(request, message_converter) + messages, system_messages = request.split_messages() + payload = self._build_request(request, stream=True) + logger.info( + f"Send request to claude, payload: {payload}\n\n messages:\n{messages}" + ) + try: + if len(system_messages) > 1: + raise ValueError("Claude only supports single system message") + if system_messages: + payload["system"] = system_messages[0] + if "max_tokens" not in payload: + max_tokens = 1024 + else: + max_tokens = payload["max_tokens"] + del payload["max_tokens"] + if "stream" in payload: + del payload["stream"] + full_text = "" + async with self.client.messages.stream( + max_tokens=max_tokens, + messages=messages, + **payload, + ) as stream: + async for text in stream.text_stream: + full_text += text + usage = { + "prompt_tokens": stream.current_message_snapshot.usage.input_tokens, + "completion_tokens": stream.current_message_snapshot.usage.output_tokens, + } + yield ModelOutput(text=full_text, error_code=0, usage=usage) + except Exception as e: + yield ModelOutput( + text=f"**Claude Generate Stream Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) + + async def models(self) -> List[ModelMetadata]: + model_metadata = ModelMetadata( + model=self._model_alias, + context_length=await self.get_context_length(), + ) + return [model_metadata] + + async def get_context_length(self) -> int: + """Get the context length of the model. + + Returns: + int: The context length. + # TODO: This is a temporary solution. We should have a better way to get the context length. + eg. get real context length from the openai api. + """ + return self.context_length + + +class ClaudeProxyTokenizer(ProxyTokenizer): + def __init__(self, client: "AsyncAnthropic", concurrency_limit: int = 10): + self.client = client + self.concurrency_limit = concurrency_limit + self._tiktoken_tokenizer = TiktokenProxyTokenizer() + + def count_token(self, model_name: str, prompts: List[str]) -> List[int]: + # Use tiktoken to count token in local environment + return self._tiktoken_tokenizer.count_token(model_name, prompts) + + def support_async(self) -> bool: + return True + + async def count_token_async(self, model_name: str, prompts: List[str]) -> List[int]: + """Count token of given messages. + + This is relying on the claude beta API, which is not available for some users. + """ + from dbgpt.util.chat_util import run_async_tasks + + tasks = [] + model_name = model_name or "claude-3-5-sonnet-20241022" + for prompt in prompts: + request = ModelRequest( + model=model_name, messages=[{"role": "user", "content": prompt}] + ) + tasks.append( + self.client.beta.messages.count_tokens( + model=model_name, + messages=request.messages, + ) + ) + results = await run_async_tasks(tasks, self.concurrency_limit) + return results diff --git a/dbgpt/model/proxy/llms/deepseek.py b/dbgpt/model/proxy/llms/deepseek.py index 6823acc51..0e93580b7 100644 --- a/dbgpt/model/proxy/llms/deepseek.py +++ b/dbgpt/model/proxy/llms/deepseek.py @@ -1,8 +1,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast -from dbgpt.core import ModelRequest, ModelRequestContext -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request from .chatgpt import OpenAILLMClient @@ -20,15 +19,7 @@ async def deepseek_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: DeepseekLLMClient = cast(DeepseekLLMClient, model.proxy_llm_client) - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index f37f0b2d2..043ef4cd6 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -2,17 +2,11 @@ import os from concurrent.futures import Executor from typing import Any, Dict, Iterator, List, Optional, Tuple -from dbgpt.core import ( - MessageConverter, - ModelMessage, - ModelOutput, - ModelRequest, - ModelRequestContext, -) +from dbgpt.core import MessageConverter, ModelMessage, ModelOutput, ModelRequest from dbgpt.core.interface.message import parse_model_messages from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request GEMINI_DEFAULT_MODEL = "gemini-pro" @@ -39,15 +33,7 @@ def gemini_generate_stream( model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") client: GeminiLLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/moonshot.py b/dbgpt/model/proxy/llms/moonshot.py index ecf6474fd..0412e1e61 100644 --- a/dbgpt/model/proxy/llms/moonshot.py +++ b/dbgpt/model/proxy/llms/moonshot.py @@ -1,8 +1,8 @@ import os from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast -from dbgpt.core import ModelRequest, ModelRequestContext -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.core import ModelRequestContext +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request from .chatgpt import OpenAILLMClient @@ -19,15 +19,7 @@ async def moonshot_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: MoonshotLLMClient = cast(MoonshotLLMClient, model.proxy_llm_client) - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/ollama.py b/dbgpt/model/proxy/llms/ollama.py index d48a6b1b2..0f440c002 100644 --- a/dbgpt/model/proxy/llms/ollama.py +++ b/dbgpt/model/proxy/llms/ollama.py @@ -2,10 +2,10 @@ import logging from concurrent.futures import Executor from typing import Iterator, Optional -from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request logger = logging.getLogger(__name__) @@ -14,14 +14,7 @@ def ollama_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=4096 ): client: OllamaLLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - ) + request = parse_model_request(params, client.default_model, stream=True) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/proxy_model.py b/dbgpt/model/proxy/llms/proxy_model.py index 3ee3c67fd..7d0204af3 100644 --- a/dbgpt/model/proxy/llms/proxy_model.py +++ b/dbgpt/model/proxy/llms/proxy_model.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from dbgpt.core import ModelRequest, ModelRequestContext from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper @@ -41,3 +42,30 @@ class ProxyModel: int: token count, -1 if failed """ return self._tokenizer.count_token(messages, model_name) + + +def parse_model_request( + params: Dict[str, Any], default_model: str, stream: bool = True +) -> ModelRequest: + """Parse model request from params. + + Args: + params (Dict[str, Any]): request params + default_model (str): default model name + stream (bool, optional): whether stream. Defaults to True. + """ + context = ModelRequestContext( + stream=stream, + user_name=params.get("user_name"), + request_id=params.get("request_id"), + ) + request = ModelRequest.build_request( + default_model, + messages=params["messages"], + temperature=params.get("temperature"), + context=context, + max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), + top_p=params.get("top_p"), + ) + return request diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index cd71f74d1..bc14c8b9c 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -1,12 +1,12 @@ import json import os from concurrent.futures import Executor -from typing import AsyncIterator, Optional +from typing import Iterator, Optional -from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request def getlength(text): @@ -28,20 +28,8 @@ def spark_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: SparkLLMClient = model.proxy_llm_client - context = ModelRequestContext( - stream=True, - user_name=params.get("user_name"), - request_id=params.get("request_id"), - ) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) - for r in client.generate_stream(request): + request = parse_model_request(params, client.default_model, stream=True) + for r in client.sync_generate_stream(request): yield r @@ -141,11 +129,11 @@ class SparkLLMClient(ProxyLLMClient): def default_model(self) -> str: return self._model - def generate_stream( + def sync_generate_stream( self, request: ModelRequest, message_converter: Optional[MessageConverter] = None, - ) -> AsyncIterator[ModelOutput]: + ) -> Iterator[ModelOutput]: """ reference: https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index 0709f3fec..75d18fec6 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -2,10 +2,10 @@ import logging from concurrent.futures import Executor from typing import Iterator, Optional -from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request logger = logging.getLogger(__name__) @@ -14,15 +14,7 @@ def tongyi_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: TongyiLLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py index 8d797aeae..73f206592 100644 --- a/dbgpt/model/proxy/llms/wenxin.py +++ b/dbgpt/model/proxy/llms/wenxin.py @@ -2,19 +2,12 @@ import json import logging import os from concurrent.futures import Executor -from typing import Iterator, List, Optional +from typing import Iterator, Optional import requests from cachetools import TTLCache, cached -from dbgpt.core import ( - MessageConverter, - ModelMessage, - ModelMessageRoleType, - ModelOutput, - ModelRequest, - ModelRequestContext, -) +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel @@ -51,26 +44,16 @@ def _build_access_token(api_key: str, secret_key: str) -> str: return res.json().get("access_token") -def _to_wenxin_messages(messages: List[ModelMessage]): +def _to_wenxin_messages(request: ModelRequest): """Convert messages to wenxin compatible format See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 """ - wenxin_messages = [] - system_messages = [] - for message in messages: - if message.role == ModelMessageRoleType.HUMAN: - wenxin_messages.append({"role": "user", "content": message.content}) - elif message.role == ModelMessageRoleType.SYSTEM: - system_messages.append(message.content) - elif message.role == ModelMessageRoleType.AI: - wenxin_messages.append({"role": "assistant", "content": message.content}) - else: - pass + messages, system_messages = request.split_messages() if len(system_messages) > 1: raise ValueError("Wenxin only support one system message") str_system_message = system_messages[0] if len(system_messages) > 0 else "" - return wenxin_messages, str_system_message + return messages, str_system_message def wenxin_generate_stream( @@ -167,7 +150,7 @@ class WenxinLLMClient(ProxyLLMClient): "Failed to get access token. please set the correct api_key and secret key." ) - history, system_message = _to_wenxin_messages(request.get_messages()) + history, system_message = _to_wenxin_messages(request) payload = { "messages": history, "system": system_message, diff --git a/dbgpt/model/proxy/llms/yi.py b/dbgpt/model/proxy/llms/yi.py index 990b1e489..f0aa24934 100644 --- a/dbgpt/model/proxy/llms/yi.py +++ b/dbgpt/model/proxy/llms/yi.py @@ -1,8 +1,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, Optional, Union -from dbgpt.core import ModelRequest, ModelRequestContext -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request from .chatgpt import OpenAILLMClient @@ -19,15 +18,7 @@ async def yi_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): client: YiLLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 294d8d1a7..e96e19b3b 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -2,10 +2,10 @@ import os from concurrent.futures import Executor from typing import Iterator, Optional -from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext +from dbgpt.core import MessageConverter, ModelOutput, ModelRequest from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.proxy.base import ProxyLLMClient -from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request CHATGLM_DEFAULT_MODEL = "chatglm_pro" @@ -21,15 +21,7 @@ def zhipu_generate_stream( # convert_to_compatible_format = params.get("convert_to_compatible_format", False) # history, systems = __convert_2_zhipu_messages(messages) client: ZhipuLLMClient = model.proxy_llm_client - context = ModelRequestContext(stream=True, user_name=params.get("user_name")) - request = ModelRequest.build_request( - client.default_model, - messages=params["messages"], - temperature=params.get("temperature"), - context=context, - max_new_tokens=params.get("max_new_tokens"), - stop=params.get("stop"), - ) + request = parse_model_request(params, client.default_model, stream=True) for r in client.sync_generate_stream(request): yield r diff --git a/setup.py b/setup.py index b5e092466..14b3ba720 100644 --- a/setup.py +++ b/setup.py @@ -670,6 +670,13 @@ def openai_requires(): setup_spec.extras["openai"] += setup_spec.extras["rag"] +def proxy_requires(): + """ + pip install "dbgpt[proxy]" + """ + setup_spec.extras["proxy"] = setup_spec.extras["openai"] + ["anthropic"] + + def gpt4all_requires(): """ pip install "dbgpt[gpt4all]" @@ -727,6 +734,7 @@ def default_requires(): setup_spec.extras["default"] += setup_spec.extras["datasource"] setup_spec.extras["default"] += setup_spec.extras["torch"] setup_spec.extras["default"] += setup_spec.extras["cache"] + setup_spec.extras["default"] += setup_spec.extras["proxy"] setup_spec.extras["default"] += setup_spec.extras["code"] if INCLUDE_QUANTIZATION: # Add quantization extra to default, default is True @@ -763,6 +771,7 @@ cache_requires() observability_requires() openai_requires() +proxy_requires() # must be last default_requires() all_requires()