mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 01:04:43 +00:00
feat(model): Add new LLMClient and new build tools (#967)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from dbgpt.core.interface.llm import (
|
||||
ModelInferenceMetrics,
|
||||
ModelRequest,
|
||||
ModelOutput,
|
||||
OpenAILLM,
|
||||
BaseLLMOperator,
|
||||
LLMClient,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
RequestBuildOperator,
|
||||
ModelMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
@@ -37,11 +40,15 @@ from dbgpt.core.interface.storage import (
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
"ModelRequest",
|
||||
"ModelOutput",
|
||||
"OpenAILLM",
|
||||
"BaseLLMOperator",
|
||||
"Operator",
|
||||
"RequestBuildOperator",
|
||||
"ModelMetadata",
|
||||
"ModelMessage",
|
||||
"LLMClient",
|
||||
"LLMOperator",
|
||||
"StreamingLLMOperator",
|
||||
"ModelMessageRoleType",
|
||||
"OnceConversation",
|
||||
"StorageConversation",
|
||||
|
@@ -211,7 +211,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
def _blocking_call_stream(
|
||||
|
@@ -130,8 +130,9 @@ async def _trigger_dag(
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
generator,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Any, Union, AsyncIterator
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import copy
|
||||
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
@@ -12,6 +12,7 @@ from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelInferenceMetrics:
|
||||
"""A class to represent metrics for assessing the inference performance of a LLM."""
|
||||
|
||||
@@ -97,6 +98,7 @@ class ModelInferenceMetrics:
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelOutput:
|
||||
"""A class to represent the output of a LLM.""" ""
|
||||
|
||||
@@ -118,6 +120,7 @@ _ModelMessageType = Union[ModelMessage, Dict[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequest:
|
||||
model: str
|
||||
"""The name of the model."""
|
||||
@@ -142,7 +145,7 @@ class ModelRequest:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
@@ -166,6 +169,110 @@ class ModelRequest:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the format of OpenAI API.
|
||||
|
||||
This function will move last user message to the end of the list.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
|
||||
]
|
||||
openai_messages = ModelRequest.to_openai_messages(messages)
|
||||
assert openai_messages == [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hi, I'm a robot."},
|
||||
{"role": "user", "content": "Who are your"},
|
||||
]
|
||||
"""
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in self.messages
|
||||
]
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelMetadata(BaseParameters):
|
||||
"""A class to represent a LLM model."""
|
||||
|
||||
model: str = field(
|
||||
metadata={"help": "Model name"},
|
||||
)
|
||||
context_length: Optional[int] = field(
|
||||
default=4096,
|
||||
metadata={"help": "Context length of model"},
|
||||
)
|
||||
chat_model: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether the model is a chat model"},
|
||||
)
|
||||
is_function_calling_model: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether the model is a function calling model"},
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Model metadata"},
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class LLMClient(ABC):
|
||||
"""An abstract class for LLM client."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The model output stream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
"""Count the number of tokens in a given prompt.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
prompt(str): The prompt.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
@@ -176,85 +283,52 @@ class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
return ModelRequest._build(self._model, input_value)
|
||||
|
||||
|
||||
class BaseLLMOperator(
|
||||
MapOperator[ModelRequest, ModelOutput],
|
||||
StreamifyAbsOperator[ModelRequest, ModelOutput],
|
||||
ABC,
|
||||
):
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class OpenAILLM(BaseLLMOperator):
|
||||
"""The operator for OpenAI LLM.
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
llm = OpenAILLM()
|
||||
model_request = ModelRequest(model="gpt-3.5-turbo", messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")])
|
||||
model_output = await llm.map(model_request)
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __int__(self):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install openai package to use OpenAILLM") from e
|
||||
import importlib.metadata as metadata
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
if not metadata.version("openai") >= "1.0.0":
|
||||
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
async def _send_request(
|
||||
self, model_request: ModelRequest, stream: Optional[bool] = False
|
||||
):
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
base_url=os.environ.get("OPENAI_API_BASE"),
|
||||
)
|
||||
messages = ModelMessage.to_openai_messages(model_request._get_messages())
|
||||
payloads = {
|
||||
"model": model_request.model,
|
||||
"stream": stream,
|
||||
}
|
||||
if model_request.temperature is not None:
|
||||
payloads["temperature"] = model_request.temperature
|
||||
if model_request.max_new_tokens:
|
||||
payloads["max_tokens"] = model_request.max_new_tokens
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
return await client.chat.completions.create(messages=messages, **payloads)
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
async def map(self, model_request: ModelRequest) -> ModelOutput:
|
||||
try:
|
||||
chat_completion = await self._send_request(model_request, stream=False)
|
||||
text = chat_completion.choices[0].message.content
|
||||
usage = chat_completion.usage.dict()
|
||||
return ModelOutput(text=text, error_code=0, usage=usage)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
async def streamify(
|
||||
self, model_request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
try:
|
||||
chat_completion = await self._send_request(model_request, stream=True)
|
||||
text = ""
|
||||
for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
content = r.choices[0].delta.content
|
||||
text += content
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
Reference in New Issue
Block a user