feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,4 +1,13 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
from dbgpt.model.utils.chatgpt_utils import (
OpenAILLMClient,
OpenAIStreamingOperator,
MixinLLMOperator,
)
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]
__ALL__ = [
"DefaultLLMClient",
"OpenAILLMClient",
"OpenAIStreamingOperator",
"MixinLLMOperator",
]

View File

@@ -171,7 +171,7 @@ class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
Returns:
@@ -233,7 +233,7 @@ class ModelStreamSaveCacheOperator(
outputs = []
async for out in input_value:
if not llm_cache_key:
llm_cache_key = await self.current_dag_context.get_share_data(
llm_cache_key = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
outputs.append(out)
@@ -265,7 +265,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)

View File

@@ -3,11 +3,27 @@ from __future__ import annotations
import os
import logging
from dataclasses import dataclass
from abc import ABC
import importlib.metadata as metadata
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
from typing import (
List,
Dict,
Any,
Optional,
TYPE_CHECKING,
Union,
AsyncIterator,
Callable,
Awaitable,
)
from dbgpt.component import ComponentType
from dbgpt.core.operator import BaseLLM
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
if TYPE_CHECKING:
import httpx
@@ -176,13 +192,13 @@ class OpenAILLMClient(LLMClient):
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
messages = request.to_openai_messages()
payload = self._build_request(request)
payload = self._build_request(request, True)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
for r in chat_completion:
async for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
@@ -221,17 +237,74 @@ class OpenAILLMClient(LLMClient):
raise NotImplementedError()
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
"""Transform ModelOutput to openai stream format."""
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async def model_caller() -> str:
"""Read model name from share data.
In streaming mode, this transform_stream function will be executed
before parent operator(Streaming Operator is trigger by downstream Operator).
"""
return await self.current_dag_context.get_from_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_NAME
)
async for output in _to_openai_stream(input_value, None, model_caller):
yield output
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
"""Mixin class for LLM operator.
This class extends BaseOperator by adding LLM capabilities.
"""
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
)
self._llm_client = self._default_llm_client
return self._llm_client
async def _to_openai_stream(
model: str, output_iter: AsyncIterator[ModelOutput]
output_iter: AsyncIterator[ModelOutput],
model: Optional[str] = None,
model_caller: Callable[[], Union[Awaitable[str], str]] = None,
) -> AsyncIterator[str]:
"""Convert the output_iter to openai stream format.
Args:
model (str): The model name.
output_iter (AsyncIterator[ModelOutput]): The output iterator.
model (Optional[str], optional): The model name. Defaults to None.
model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None.
"""
import json
import shortuuid
import asyncio
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
@@ -245,12 +318,19 @@ async def _to_openai_stream(
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model or ""
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
finish_stream_events = []
async for model_output in output_iter:
if model_caller is not None:
if asyncio.iscoroutinefunction(model_caller):
model = await model_caller()
else:
model = model_caller()
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"