feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,14 +1,13 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Any, Union, AsyncIterator
import time
from dataclasses import dataclass, asdict, field
import copy
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
@dataclass
@@ -97,6 +96,28 @@ class ModelInferenceMetrics:
return asdict(self)
@dataclass
@PublicAPI(stability="beta")
class ModelRequestContext:
stream: Optional[bool] = False
"""Whether to return a stream of responses."""
user_name: Optional[str] = None
"""The user name of the model request."""
sys_code: Optional[str] = None
"""The system code of the model request."""
conv_uid: Optional[str] = None
"""The conversation id of the model inference."""
span_id: Optional[str] = None
"""The span id of the model inference."""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
@dataclass
@PublicAPI(stability="beta")
class ModelOutput:
@@ -145,6 +166,27 @@ class ModelRequest:
span_id: Optional[str] = None
"""The span id of the model inference."""
context: Optional[ModelRequestContext] = field(
default_factory=lambda: ModelRequestContext()
)
"""The context of the model inference."""
@property
def stream(self) -> bool:
"""Whether to return a stream of responses."""
return self.context and self.context.stream
def copy(self):
new_request = copy.deepcopy(self)
# Transform messages to List[ModelMessage]
new_request.messages = list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
new_request.messages,
)
)
return new_request
def to_dict(self) -> Dict[str, Any]:
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
@@ -161,6 +203,17 @@ class ModelRequest:
)
)
def get_single_user_message(self) -> Optional[ModelMessage]:
"""Get the single user message.
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self._get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]
@staticmethod
def _build(model: str, prompt: str, **kwargs):
return ModelRequest(
@@ -178,11 +231,22 @@ class ModelRequest:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Examples:
.. code-block:: python
from dbgpt.core.interface.message import (
ModelMessage,
ModelMessageRoleType,
)
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
ModelMessage(
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
),
ModelMessage(
role=ModelMessageRoleType.HUMAN, content="Who are your"
),
]
openai_messages = ModelRequest.to_openai_messages(messages)
assert openai_messages == [
@@ -272,63 +336,3 @@ class LLMClient(ABC):
Returns:
int: The number of tokens.
"""
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
def __init__(self, model: str, **kwargs):
self._model = model
super().__init__(**kwargs)
async def map(self, input_value: str) -> ModelRequest:
return ModelRequest._build(self._model, input_value)
class BaseLLM:
"""The abstract operator for a LLM."""
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
if not self._llm_client:
raise ValueError("llm_client is not set")
return self._llm_client
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate a no streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
return await self.llm_client.generate(request)
class StreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
async for output in self.llm_client.generate_stream(request):
yield output