mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Any, Union, AsyncIterator
|
||||
import time
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import copy
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,6 +96,28 @@ class ModelInferenceMetrics:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequestContext:
|
||||
stream: Optional[bool] = False
|
||||
"""Whether to return a stream of responses."""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""The user name of the model request."""
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""The system code of the model request."""
|
||||
|
||||
conv_uid: Optional[str] = None
|
||||
"""The conversation id of the model inference."""
|
||||
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelOutput:
|
||||
@@ -145,6 +166,27 @@ class ModelRequest:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
context: Optional[ModelRequestContext] = field(
|
||||
default_factory=lambda: ModelRequestContext()
|
||||
)
|
||||
"""The context of the model inference."""
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
"""Whether to return a stream of responses."""
|
||||
return self.context and self.context.stream
|
||||
|
||||
def copy(self):
|
||||
new_request = copy.deepcopy(self)
|
||||
# Transform messages to List[ModelMessage]
|
||||
new_request.messages = list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
new_request.messages,
|
||||
)
|
||||
)
|
||||
return new_request
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
@@ -161,6 +203,17 @@ class ModelRequest:
|
||||
)
|
||||
)
|
||||
|
||||
def get_single_user_message(self) -> Optional[ModelMessage]:
|
||||
"""Get the single user message.
|
||||
|
||||
Returns:
|
||||
Optional[ModelMessage]: The single user message.
|
||||
"""
|
||||
messages = self._get_messages()
|
||||
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
||||
@staticmethod
|
||||
def _build(model: str, prompt: str, **kwargs):
|
||||
return ModelRequest(
|
||||
@@ -178,11 +231,22 @@ class ModelRequest:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
|
||||
),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content="Who are your"
|
||||
),
|
||||
]
|
||||
openai_messages = ModelRequest.to_openai_messages(messages)
|
||||
assert openai_messages == [
|
||||
@@ -272,63 +336,3 @@ class LLMClient(ABC):
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: str) -> ModelRequest:
|
||||
return ModelRequest._build(self._model, input_value)
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
Reference in New Issue
Block a user