feat(model): Add new LLMClient and new build tools (#967)

This commit is contained in:
Fangyin Cheng
2023-12-23 16:33:01 +08:00
committed by GitHub
parent 12234ae258
commit 0c46c339ca
30 changed files with 1072 additions and 133 deletions

View File

@@ -1,9 +1,12 @@
from dbgpt.core.interface.llm import (
ModelInferenceMetrics,
ModelRequest,
ModelOutput,
OpenAILLM,
BaseLLMOperator,
LLMClient,
LLMOperator,
StreamingLLMOperator,
RequestBuildOperator,
ModelMetadata,
)
from dbgpt.core.interface.message import (
ModelMessage,
@@ -37,11 +40,15 @@ from dbgpt.core.interface.storage import (
__ALL__ = [
"ModelInferenceMetrics",
"ModelRequest",
"ModelOutput",
"OpenAILLM",
"BaseLLMOperator",
"Operator",
"RequestBuildOperator",
"ModelMetadata",
"ModelMessage",
"LLMClient",
"LLMOperator",
"StreamingLLMOperator",
"ModelMessageRoleType",
"OnceConversation",
"StorageConversation",

View File

@@ -211,7 +211,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True
)
return out_ctx.current_task_context.task_output.output_stream
def _blocking_call_stream(

View File

@@ -130,8 +130,9 @@ async def _trigger_dag(
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
generator = await end_node.call_stream(call_data={"data": body})
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
generator,
headers=headers,
media_type=media_type,
)

View File

@@ -1,10 +1,10 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Any, Union, AsyncIterator
import time
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
import copy
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
@@ -12,6 +12,7 @@ from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
@dataclass
@PublicAPI(stability="beta")
class ModelInferenceMetrics:
"""A class to represent metrics for assessing the inference performance of a LLM."""
@@ -97,6 +98,7 @@ class ModelInferenceMetrics:
@dataclass
@PublicAPI(stability="beta")
class ModelOutput:
"""A class to represent the output of a LLM.""" ""
@@ -118,6 +120,7 @@ _ModelMessageType = Union[ModelMessage, Dict[str, Any]]
@dataclass
@PublicAPI(stability="beta")
class ModelRequest:
model: str
"""The name of the model."""
@@ -142,7 +145,7 @@ class ModelRequest:
span_id: Optional[str] = None
"""The span id of the model inference."""
def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
@@ -166,6 +169,110 @@ class ModelRequest:
**kwargs,
)
def to_openai_messages(self) -> List[Dict[str, Any]]:
"""Convert the messages to the format of OpenAI API.
This function will move last user message to the end of the list.
Returns:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Examples:
.. code-block:: python
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
]
openai_messages = ModelRequest.to_openai_messages(messages)
assert openai_messages == [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hi, I'm a robot."},
{"role": "user", "content": "Who are your"},
]
"""
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in self.messages
]
return ModelMessage.to_openai_messages(messages)
@dataclass
@PublicAPI(stability="beta")
class ModelMetadata(BaseParameters):
"""A class to represent a LLM model."""
model: str = field(
metadata={"help": "Model name"},
)
context_length: Optional[int] = field(
default=4096,
metadata={"help": "Context length of model"},
)
chat_model: Optional[bool] = field(
default=True,
metadata={"help": "Whether the model is a chat model"},
)
is_function_calling_model: Optional[bool] = field(
default=False,
metadata={"help": "Whether the model is a function calling model"},
)
metadata: Optional[Dict[str, Any]] = field(
default_factory=dict,
metadata={"help": "Model metadata"},
)
@PublicAPI(stability="beta")
class LLMClient(ABC):
"""An abstract class for LLM client."""
@abstractmethod
async def generate(self, request: ModelRequest) -> ModelOutput:
"""Generate a response for a given model request.
Args:
request(ModelRequest): The model request.
Returns:
ModelOutput: The model output.
"""
@abstractmethod
async def generate_stream(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request.
Args:
request(ModelRequest): The model request.
Returns:
AsyncIterator[ModelOutput]: The model output stream.
"""
@abstractmethod
async def models(self) -> List[ModelMetadata]:
"""Get all the models.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
@abstractmethod
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
Args:
model(str): The model name.
prompt(str): The prompt.
Returns:
int: The number of tokens.
"""
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
def __init__(self, model: str, **kwargs):
@@ -176,85 +283,52 @@ class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
return ModelRequest._build(self._model, input_value)
class BaseLLMOperator(
MapOperator[ModelRequest, ModelOutput],
StreamifyAbsOperator[ModelRequest, ModelOutput],
ABC,
):
class BaseLLM:
"""The abstract operator for a LLM."""
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@PublicAPI(stability="beta")
class OpenAILLM(BaseLLMOperator):
"""The operator for OpenAI LLM.
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
if not self._llm_client:
raise ValueError("llm_client is not set")
return self._llm_client
Examples:
.. code-block:: python
llm = OpenAILLM()
model_request = ModelRequest(model="gpt-3.5-turbo", messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")])
model_output = await llm.map(model_request)
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate a no streaming response.
"""
def __int__(self):
try:
import openai
except ImportError as e:
raise ImportError("Please install openai package to use OpenAILLM") from e
import importlib.metadata as metadata
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
if not metadata.version("openai") >= "1.0.0":
raise ImportError("Please upgrade openai package to version 1.0.0 or above")
async def map(self, request: ModelRequest) -> ModelOutput:
return await self.llm_client.generate(request)
async def _send_request(
self, model_request: ModelRequest, stream: Optional[bool] = False
):
import os
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=os.environ.get("OPENAI_API_BASE"),
)
messages = ModelMessage.to_openai_messages(model_request._get_messages())
payloads = {
"model": model_request.model,
"stream": stream,
}
if model_request.temperature is not None:
payloads["temperature"] = model_request.temperature
if model_request.max_new_tokens:
payloads["max_tokens"] = model_request.max_new_tokens
class StreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
return await client.chat.completions.create(messages=messages, **payloads)
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
async def map(self, model_request: ModelRequest) -> ModelOutput:
try:
chat_completion = await self._send_request(model_request, stream=False)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
This operator will generate streaming response.
"""
async def streamify(
self, model_request: ModelRequest
) -> AsyncIterator[ModelOutput]:
try:
chat_completion = await self._send_request(model_request, stream=True)
text = ""
for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
async for output in self.llm_client.generate_stream(request):
yield output