mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,4 +1,13 @@
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||
from dbgpt.model.utils.chatgpt_utils import (
|
||||
OpenAILLMClient,
|
||||
OpenAIStreamingOperator,
|
||||
MixinLLMOperator,
|
||||
)
|
||||
|
||||
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]
|
||||
__ALL__ = [
|
||||
"DefaultLLMClient",
|
||||
"OpenAILLMClient",
|
||||
"OpenAIStreamingOperator",
|
||||
"MixinLLMOperator",
|
||||
]
|
||||
|
@@ -171,7 +171,7 @@ class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
|
||||
self._model_task_name = model_task_name
|
||||
self._cache_task_name = cache_task_name
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
@@ -233,7 +233,7 @@ class ModelStreamSaveCacheOperator(
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
llm_cache_key = await self.current_dag_context.get_share_data(
|
||||
llm_cache_key = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
outputs.append(out)
|
||||
@@ -265,7 +265,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
Returns:
|
||||
ModelOutput: The same input model output.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||
|
@@ -3,11 +3,27 @@ from __future__ import annotations
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC
|
||||
import importlib.metadata as metadata
|
||||
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
|
||||
from typing import (
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Awaitable,
|
||||
)
|
||||
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core.operator import BaseLLM
|
||||
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
|
||||
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
|
||||
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
@@ -176,13 +192,13 @@ class OpenAILLMClient(LLMClient):
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
messages = request.to_openai_messages()
|
||||
payload = self._build_request(request)
|
||||
payload = self._build_request(request, True)
|
||||
try:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = ""
|
||||
for r in chat_completion:
|
||||
async for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
@@ -221,17 +237,74 @@ class OpenAILLMClient(LLMClient):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
"""Transform ModelOutput to openai stream format."""
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
async def model_caller() -> str:
|
||||
"""Read model name from share data.
|
||||
In streaming mode, this transform_stream function will be executed
|
||||
before parent operator(Streaming Operator is trigger by downstream Operator).
|
||||
"""
|
||||
return await self.current_dag_context.get_from_share_data(
|
||||
BaseLLM.SHARE_DATA_KEY_MODEL_NAME
|
||||
)
|
||||
|
||||
async for output in _to_openai_stream(input_value, None, model_caller):
|
||||
yield output
|
||||
|
||||
|
||||
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
"""Mixin class for LLM operator.
|
||||
|
||||
This class extends BaseOperator by adding LLM capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
self._default_llm_client = default_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
if worker_manager_factory:
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
else:
|
||||
if self._default_llm_client is None:
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
|
||||
self._default_llm_client = OpenAILLMClient()
|
||||
logger.info(
|
||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
||||
)
|
||||
self._llm_client = self._default_llm_client
|
||||
return self._llm_client
|
||||
|
||||
|
||||
async def _to_openai_stream(
|
||||
model: str, output_iter: AsyncIterator[ModelOutput]
|
||||
output_iter: AsyncIterator[ModelOutput],
|
||||
model: Optional[str] = None,
|
||||
model_caller: Callable[[], Union[Awaitable[str], str]] = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Convert the output_iter to openai stream format.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
output_iter (AsyncIterator[ModelOutput]): The output iterator.
|
||||
model (Optional[str], optional): The model name. Defaults to None.
|
||||
model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None.
|
||||
"""
|
||||
import json
|
||||
import shortuuid
|
||||
import asyncio
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
@@ -245,12 +318,19 @@ async def _to_openai_stream(
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model or ""
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
finish_stream_events = []
|
||||
async for model_output in output_iter:
|
||||
if model_caller is not None:
|
||||
if asyncio.iscoroutinefunction(model_caller):
|
||||
model = await model_caller()
|
||||
else:
|
||||
model = model_caller()
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
|
Reference in New Issue
Block a user