feat(model): Support claude proxy models (#2155)

This commit is contained in:
Fangyin Cheng
2024-11-26 19:47:28 +08:00
committed by GitHub
parent 9d8673a02f
commit 61509dc5ea
20 changed files with 508 additions and 157 deletions

View File

@@ -6,7 +6,7 @@ import logging
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Tuple, Union
from cachetools import TTLCache
@@ -394,6 +394,29 @@ class ModelRequest:
"""
return ModelMessage.messages_to_string(self.get_messages())
def split_messages(self) -> Tuple[List[Dict[str, Any]], List[str]]:
"""Split the messages.
Returns:
Tuple[List[Dict[str, Any]], List[str]]: The common messages and system
messages.
"""
messages = self.get_messages()
common_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
common_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
common_messages.append(
{"role": "assistant", "content": message.content}
)
else:
pass
return common_messages, system_messages
@dataclass
class ModelExtraMedata(BaseParameters):
@@ -861,7 +884,9 @@ class LLMClient(ABC):
raise ValueError(f"Model {model} not found")
return model_metadata
def __call__(self, *args, **kwargs) -> ModelOutput:
def __call__(
self, *args, **kwargs
) -> Coroutine[Any, Any, ModelOutput] | ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
@@ -869,22 +894,63 @@ class LLMClient(ABC):
Please do not use this method in the production environment, it is only used
for debugging.
"""
import asyncio
from dbgpt.util import get_or_create_event_loop
try:
# Check if we are in an event loop
loop = asyncio.get_running_loop()
# If we are in an event loop, use async call
if loop.is_running():
# Because we are in an async environment, but this is a sync method,
# we need to return a coroutine object for the caller to use await
return self.async_call(*args, **kwargs)
else:
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))
except RuntimeError:
# If we are not in an event loop, use sync call
loop = get_or_create_event_loop()
return loop.run_until_complete(self.async_call(*args, **kwargs))
async def async_call(self, *args, **kwargs) -> ModelOutput:
"""Return the model output asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
return await self.generate(req)
async def async_call_stream(self, *args, **kwargs) -> AsyncIterator[ModelOutput]:
"""Return the model output stream asynchronously.
Please do not use this method in the production environment, it is only used
for debugging.
"""
req = self._build_call_request(*args, **kwargs)
async for output in self.generate_stream(req): # type: ignore
yield output
def _build_call_request(self, *args, **kwargs) -> ModelRequest:
"""Build the model request for the call method."""
messages = kwargs.get("messages")
model = kwargs.get("model")
if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]
if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")
if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))
return ModelRequest.build_request(model, model_messages, **kwargs)