mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
feat(model): Support claude proxy models (#2155)
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
@@ -394,6 +394,29 @@ class ModelRequest:
|
||||
"""
|
||||
return ModelMessage.messages_to_string(self.get_messages())
|
||||
|
||||
def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""Split the messages.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[str]]: The common messages and system
|
||||
messages.
|
||||
"""
|
||||
messages = self.get_messages()
|
||||
common_messages = []
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
common_messages.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
common_messages.append(
|
||||
{"role": "assistant", "content": message.content}
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return common_messages, system_messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
@@ -861,7 +884,9 @@ class LLMClient(ABC):
|
||||
raise ValueError(f"Model {model} not found")
|
||||
return model_metadata
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ModelOutput:
|
||||
def __call__(
|
||||
self, *args, **kwargs
|
||||
) -> Coroutine[Any, Any, ModelOutput] | ModelOutput:
|
||||
"""Return the model output.
|
||||
|
||||
Call the LLM client to generate the response for the given message.
|
||||
@@ -869,22 +894,63 @@ class LLMClient(ABC):
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
|
||||
try:
|
||||
# Check if we are in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we are in an event loop, use async call
|
||||
if loop.is_running():
|
||||
# Because we are in an async environment, but this is a sync method,
|
||||
# we need to return a coroutine object for the caller to use await
|
||||
return self.async_call(*args, **kwargs)
|
||||
else:
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.async_call(*args, **kwargs))
|
||||
except RuntimeError:
|
||||
# If we are not in an event loop, use sync call
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.async_call(*args, **kwargs))
|
||||
|
||||
async def async_call(self, *args, **kwargs) -> ModelOutput:
|
||||
"""Return the model output asynchronously.
|
||||
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
req = self._build_call_request(*args, **kwargs)
|
||||
return await self.generate(req)
|
||||
|
||||
async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]:
|
||||
"""Return the model output stream asynchronously.
|
||||
|
||||
Please do not use this method in the production environment, it is only used
|
||||
for debugging.
|
||||
"""
|
||||
req = self._build_call_request(*args, **kwargs)
|
||||
async for output in self.generate_stream(req): # type: ignore
|
||||
yield output
|
||||
|
||||
def _build_call_request(self, *args, **kwargs) -> ModelRequest:
|
||||
"""Build the model request for the call method."""
|
||||
messages = kwargs.get("messages")
|
||||
model = kwargs.get("model")
|
||||
|
||||
if messages:
|
||||
del kwargs["messages"]
|
||||
model_messages = ModelMessage.from_openai_messages(messages)
|
||||
else:
|
||||
model_messages = [ModelMessage.build_human_message(args[0])]
|
||||
|
||||
if not model:
|
||||
if hasattr(self, "default_model"):
|
||||
model = getattr(self, "default_model")
|
||||
else:
|
||||
raise ValueError("The default model is not set")
|
||||
|
||||
if "model" in kwargs:
|
||||
del kwargs["model"]
|
||||
req = ModelRequest.build_request(model, model_messages, **kwargs)
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.generate(req))
|
||||
|
||||
return ModelRequest.build_request(model, model_messages, **kwargs)
|
||||
|
Reference in New Issue
Block a user