feat(core): Support higher-order operators (#1984)

Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
Fangyin Cheng
2024-09-09 10:15:37 +08:00
committed by GitHub
parent f6d5fc4595
commit 65c875db20
62 changed files with 6281 additions and 386 deletions

View File

@@ -27,7 +27,7 @@ from dbgpt.util.i18n_utils import _
name="auto_convert_message",
type=bool,
optional=True,
default=False,
default=True,
description=_(
"Whether to auto convert the messages that are not supported "
"by the LLM to a compatible format"
@@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient):
Args:
worker_manager (WorkerManager): worker manager instance.
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False.
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True.
"""
def __init__(
self,
worker_manager: Optional[WorkerManager] = None,
auto_convert_message: bool = False,
auto_convert_message: bool = True,
):
self._worker_manager = worker_manager
self._auto_covert_message = auto_convert_message
@@ -128,7 +128,7 @@ class DefaultLLMClient(LLMClient):
name="auto_convert_message",
type=bool,
optional=True,
default=False,
default=True,
description=_(
"Whether to auto convert the messages that are not supported "
"by the LLM to a compatible format"
@@ -158,7 +158,7 @@ class RemoteLLMClient(DefaultLLMClient):
def __init__(
self,
controller_address: str = "http://127.0.0.1:8000",
auto_convert_message: bool = False,
auto_convert_message: bool = True,
):
"""Initialize the RemoteLLMClient."""
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager

View File

@@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
This class extends BaseOperator by adding LLM capabilities.
"""
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
def __init__(
self,
default_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(default_client, save_model_output=save_model_output)
@property
def llm_client(self) -> LLMClient:
@@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator):
],
)
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(llm_client, save_model_output=save_model_output)
BaseLLMOperator.__init__(self, llm_client, **kwargs)
@@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
],
)
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(llm_client, save_model_output=save_model_output)
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)

View File

@@ -94,6 +94,17 @@ class ProxyLLMClient(LLMClient):
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
def __getstate__(self):
"""Customize the serialization of the object"""
state = self.__dict__.copy()
state.pop("executor")
return state
def __setstate__(self, state):
"""Customize the deserialization of the object"""
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()
@classmethod
@abstractmethod
def new_client(

View File

@@ -16,7 +16,13 @@ from typing import (
from dbgpt._private.pydantic import model_to_json
from dbgpt.core.awel import TransformStreamAbsOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.operators import BaseLLM
from dbgpt.util.i18n_utils import _
@@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
async def transform_stream(self, model_output: AsyncIterator[ModelOutput]):