mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from dbgpt.util.i18n_utils import _
|
||||
name="auto_convert_message",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
default=True,
|
||||
description=_(
|
||||
"Whether to auto convert the messages that are not supported "
|
||||
"by the LLM to a compatible format"
|
||||
@@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient):
|
||||
|
||||
Args:
|
||||
worker_manager (WorkerManager): worker manager instance.
|
||||
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False.
|
||||
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_manager: Optional[WorkerManager] = None,
|
||||
auto_convert_message: bool = False,
|
||||
auto_convert_message: bool = True,
|
||||
):
|
||||
self._worker_manager = worker_manager
|
||||
self._auto_covert_message = auto_convert_message
|
||||
@@ -128,7 +128,7 @@ class DefaultLLMClient(LLMClient):
|
||||
name="auto_convert_message",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
default=True,
|
||||
description=_(
|
||||
"Whether to auto convert the messages that are not supported "
|
||||
"by the LLM to a compatible format"
|
||||
@@ -158,7 +158,7 @@ class RemoteLLMClient(DefaultLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
controller_address: str = "http://127.0.0.1:8000",
|
||||
auto_convert_message: bool = False,
|
||||
auto_convert_message: bool = True,
|
||||
):
|
||||
"""Initialize the RemoteLLMClient."""
|
||||
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager
|
||||
|
@@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
This class extends BaseOperator by adding LLM capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
def __init__(
|
||||
self,
|
||||
default_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(default_client, save_model_output=save_model_output)
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
@@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(llm_client, save_model_output=save_model_output)
|
||||
BaseLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
@@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(llm_client, save_model_output=save_model_output)
|
||||
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
@@ -94,6 +94,17 @@ class ProxyLLMClient(LLMClient):
|
||||
self.executor = executor or ThreadPoolExecutor()
|
||||
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
|
||||
|
||||
def __getstate__(self):
|
||||
"""Customize the serialization of the object"""
|
||||
state = self.__dict__.copy()
|
||||
state.pop("executor")
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Customize the deserialization of the object"""
|
||||
self.__dict__.update(state)
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def new_client(
|
||||
|
@@ -16,7 +16,13 @@ from typing import (
|
||||
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.core.awel import TransformStreamAbsOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.core.operators import BaseLLM
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
|
||||
),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
async def transform_stream(self, model_output: AsyncIterator[ModelOutput]):
|
||||
|
Reference in New Issue
Block a user