refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -178,14 +178,15 @@ class FalconChatAdapter(BaseChatAdpter):
return falcon_generate_output
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
def get_generate_stream_func(self, model_path: str):
from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
#
# class ProxyllmChatAdapter(BaseChatAdpter):
# def match(self, model_path: str):
# return "proxyllm" in model_path
#
# def get_generate_stream_func(self, model_path: str):
# from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream
#
# return proxyllm_generate_stream
class GorillaChatAdapter(BaseChatAdpter):
@@ -286,6 +287,6 @@ register_llm_model_chat_adapter(LlamaCppChatAdapter)
register_llm_model_chat_adapter(InternLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
# register_llm_model_chat_adapter(ProxyllmChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@@ -3,7 +3,7 @@ import datetime
import logging
import traceback
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, AsyncIterator, Dict
from dbgpt._private.config import Config
from dbgpt._private.pydantic import Extra
@@ -11,12 +11,13 @@ from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.operator.app_operator import (
AppChatComposerOperator,
ChatComposerInput,
build_cached_chat_operator,
)
from dbgpt.component import ComponentType
from dbgpt.core.awel import DAG, BaseOperator, InputOperator, SimpleCallDataInputSource
from dbgpt.core import LLMClient, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core.interface.message import StorageConversation
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@@ -64,6 +65,10 @@ class BaseChat(ABC):
keep_start_rounds: int = 0
keep_end_rounds: int = 0
# Some model not support system role, this config is used to control whether to
# convert system message to human message
auto_convert_message: bool = True
class Config:
"""Configuration for this pydantic object."""
@@ -124,14 +129,15 @@ class BaseChat(ABC):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
self._model_operator: BaseOperator = _build_model_operator()
self._model_stream_operator: BaseOperator = _build_model_operator(
is_stream=True, dag_name="llm_stream_model_dag"
)
# self._model_operator: BaseOperator = _build_model_operator()
# self._model_stream_operator: BaseOperator = _build_model_operator(
# is_stream=True, dag_name="llm_stream_model_dag"
# )
# In v1, we will transform the message to compatible format of specific model
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
self._message_version = chat_param.get("message_version", "v2")
self._chat_param = chat_param
class Config:
"""Configuration for this pydantic object."""
@@ -153,6 +159,27 @@ class BaseChat(ABC):
a dictionary to be formatted by prompt template
"""
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(
worker_manager, auto_convert_message=self.auto_convert_message
)
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
return await llm_task.call(call_data={"data": request})
async def call_streaming_operator(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
async for out in await llm_task.call_stream(call_data={"data": request}):
yield out
def do_action(self, prompt_response):
return prompt_response
@@ -185,7 +212,7 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response
return speak_to_user
async def __call_base(self):
async def _build_model_request(self) -> ModelRequest:
input_values = await self.generate_input_values()
# Load history
self.history_messages = self.current_message.get_history_message()
@@ -195,21 +222,6 @@ class BaseChat(ABC):
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
# TODO: support handle history message by tokens
# if self.prompt_template.template:
# metadata = {
# "template_scene": self.prompt_template.template_scene,
# "input_values": input_values,
# }
# with root_tracer.start_span(
# "BaseChat.__call_base.prompt_template.format", metadata=metadata
# ):
# current_prompt = self.prompt_template.format(**input_values)
# ### prompt context token adapt according to llm max context length
# current_prompt = await self.prompt_context_token_adapt(
# prompt=current_prompt
# )
# self.current_message.add_system_message(current_prompt)
keep_start_rounds = (
self.keep_start_rounds
@@ -219,11 +231,25 @@ class BaseChat(ABC):
keep_end_rounds = (
self.keep_end_rounds if self.prompt_template.need_historical_messages else 0
)
req_ctx = ModelRequestContext(
stream=self.prompt_template.stream_out,
user_name=self._chat_param.get("user_name"),
sys_code=self._chat_param.get("sys_code"),
chat_mode=self.chat_mode.value(),
span_id=root_tracer.get_current_span_id(),
)
node = AppChatComposerOperator(
model=self.llm_model,
temperature=float(self.prompt_template.temperature),
max_new_tokens=int(self.prompt_template.max_new_tokens),
prompt=self.prompt_template.prompt,
message_version=self._message_version,
echo=self.llm_echo,
streaming=self.prompt_template.stream_out,
keep_start_rounds=keep_start_rounds,
keep_end_rounds=keep_end_rounds,
str_history=self.prompt_template.str_history,
request_context=req_ctx,
)
node_input = {
"data": ChatComposerInput(
@@ -231,22 +257,9 @@ class BaseChat(ABC):
)
}
# llm_messages = self.generate_llm_messages()
llm_messages = await node.call(call_data=node_input)
if not CFG.NEW_SERVER_MODE:
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages))
payload = {
"model": self.llm_model,
"prompt": "",
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.llm_echo,
"version": self._message_version,
}
return payload
model_request: ModelRequest = await node.call(call_data=node_input)
model_request.context.cache_enable = self.model_cache_enable
return model_request
def stream_plugin_call(self, text):
return text
@@ -271,23 +284,19 @@ class BaseChat(ABC):
async def stream_call(self):
# TODO Retry when server connection error
payload = await self.__call_base()
payload = await self._build_model_request()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"payload request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
"BaseChat.stream_call", metadata=payload.to_dict()
)
payload["span_id"] = span.span_id
payload["model_cache_enable"] = self.model_cache_enable
payload.span_id = span.span_id
try:
async for output in await self._model_stream_operator.call_stream(
call_data={"data": payload}
):
async for output in self.call_streaming_operator(payload):
# Plugin research in result generation
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
output, self.skip_echo_len
output, 0
)
view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n")
@@ -308,19 +317,16 @@ class BaseChat(ABC):
self.current_message.end_current_round()
async def nostream_call(self):
payload = await self.__call_base()
payload = await self._build_model_request()
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=payload.to_dict()
)
logger.info(f"Request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
payload["model_cache_enable"] = self.model_cache_enable
payload.span_id = span.span_id
try:
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self._model_operator.call(
call_data={"data": payload}
)
model_output = await self.call_llm_operator(payload)
### output parse
ai_response_text = (
@@ -380,13 +386,12 @@ class BaseChat(ABC):
return self.current_ai_response()
async def get_llm_response(self):
payload = await self.__call_base()
payload = await self._build_model_request()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
payload["model_cache_enable"] = self.model_cache_enable
prompt_define_response = None
try:
model_output = await self._model_operator.call(call_data={"data": payload})
model_output = await self.call_llm_operator(payload)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
@@ -544,89 +549,3 @@ class BaseChat(ABC):
for dict_item in antv_charts
for key, value in dict_item.items()
)
def _build_model_operator(
is_stream: bool = False, dag_name: str = "llm_model_dag"
) -> BaseOperator:
"""Builds and returns a model processing workflow (DAG) operator.
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
It includes caching and branching logic to either fetch results from a cache or process
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
equivalent to::
-> model_node -> save_cached_node ->
/ \
input_node -> cache_check_branch_node ---> join_node
\ /
-> cached_node ------------------- ->
Args:
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
dag_name (str): Name of the DAG.
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
from dbgpt.core.awel import JoinOperator
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.operator.model_operator import (
CachedModelOperator,
CachedModelStreamOperator,
ModelCacheBranchOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
)
from dbgpt.storage.cache import CacheManager
# Fetch worker and cache managers from the system configuration
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager
)
# Define task names for the model and cache nodes
model_task_name = "llm_model_node"
cache_task_name = "llm_model_cache_node"
with DAG(dag_name):
# Create an input node
input_node = InputOperator(SimpleCallDataInputSource())
# Determine if the workflow should operate in streaming mode
if is_stream:
model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
cached_node = CachedModelStreamOperator(
cache_manager, task_name=cache_task_name
)
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
else:
model_node = ModelOperator(worker_manager, task_name=model_task_name)
cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
save_cached_node = ModelSaveCacheOperator(cache_manager)
# Create a branch node to decide between fetching from cache or processing with the model
cache_check_branch_node = ModelCacheBranchOperator(
cache_manager,
model_task_name="llm_model_node",
cache_task_name="llm_model_cache_node",
)
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_node = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)
# Define the workflow structure using the >> operator
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
return join_node

View File

@@ -80,10 +80,6 @@ class ChatKnowledge(BaseChat):
vector_store_config=config,
)
query_rewrite = None
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
if CFG.KNOWLEDGE_SEARCH_REWRITE:
query_rewrite = QueryRewrite(
llm_client=self.llm_client,

View File

@@ -1,17 +1,36 @@
import dataclasses
from typing import Any, Dict, List, Optional
from dbgpt.core import BaseMessage, ChatPromptTemplate, ModelMessage
from dbgpt import SystemApp
from dbgpt.component import ComponentType
from dbgpt.core import (
BaseMessage,
ChatPromptTemplate,
LLMClient,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
)
from dbgpt.core.operator import (
BufferedConversationMapperOperator,
HistoryPromptBuilderOperator,
LLMBranchOperator,
)
from dbgpt.model.operator import LLMOperator, StreamingLLMOperator
from dbgpt.storage.cache.operator import (
CachedModelOperator,
CachedModelStreamOperator,
CacheManager,
ModelCacheBranchOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
)
@@ -23,7 +42,7 @@ class ChatComposerInput:
prompt_dict: Dict[str, Any]
class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]):
class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
"""App chat composer operator.
TODO: Support more history merge mode.
@@ -31,29 +50,56 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]
def __init__(
self,
model: str,
temperature: float,
max_new_tokens: int,
prompt: ChatPromptTemplate,
message_version: str = "v2",
echo: bool = False,
streaming: bool = True,
history_key: str = "chat_history",
history_merge_mode: str = "window",
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
str_history: bool = False,
request_context: ModelRequestContext = None,
**kwargs,
):
super().__init__(**kwargs)
if not request_context:
request_context = ModelRequestContext(stream=streaming)
self._prompt_template = prompt
self._history_key = history_key
self._history_merge_mode = history_merge_mode
self._keep_start_rounds = keep_start_rounds
self._keep_end_rounds = keep_end_rounds
self._str_history = str_history
self._model_name = model
self._temperature = temperature
self._max_new_tokens = max_new_tokens
self._message_version = message_version
self._echo = echo
self._streaming = streaming
self._request_context = request_context
self._sub_compose_dag = self._build_composer_dag()
async def map(self, input_value: ChatComposerInput) -> List[ModelMessage]:
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
return await end_node.call(
messages = await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
)
span_id = self._request_context.span_id
model_request = ModelRequest.build_request(
model=self._model_name,
messages=messages,
context=self._request_context,
temperature=self._temperature,
max_new_tokens=self._max_new_tokens,
span_id=span_id,
echo=self._echo,
)
return model_request
def _build_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
@@ -83,3 +129,79 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]
)
return composer_dag
def build_cached_chat_operator(
llm_client: LLMClient,
is_streaming: bool,
system_app: SystemApp,
cache_manager: Optional[CacheManager] = None,
):
"""Builds and returns a model processing workflow (DAG) operator.
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
It includes caching and branching logic to either fetch results from a cache or process
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
input_task >> cache_check_branch_task
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
cache_check_branch_task >> cache_task >> join_task
equivalent to::
-> llm_task -> save_cache_task ->
/ \
input_task -> cache_check_branch_task ---> join_task
\ /
-> cache_task ------------------- ->
Args:
llm_client (LLMClient): The LLM client for processing data using the model.
is_streaming (bool): Whether the model is a streaming model.
system_app (SystemApp): The system app.
cache_manager (CacheManager, optional): The cache manager for managing cache operations. Defaults to None.
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
# Define task names for the model and cache nodes
model_task_name = "llm_model_node"
cache_task_name = "llm_model_cache_node"
if not cache_manager:
cache_manager: CacheManager = system_app.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager
)
with DAG("dbgpt_awel_app_model_infer_with_cached") as dag:
# Create an input task
input_task = InputOperator(SimpleCallDataInputSource())
# Create a branch task to decide between fetching from cache or processing with the model
if is_streaming:
llm_task = StreamingLLMOperator(llm_client, task_name=model_task_name)
cache_task = CachedModelStreamOperator(
cache_manager, task_name=cache_task_name
)
save_cache_task = ModelStreamSaveCacheOperator(cache_manager)
else:
llm_task = LLMOperator(llm_client, task_name=model_task_name)
cache_task = CachedModelOperator(cache_manager, task_name=cache_task_name)
save_cache_task = ModelSaveCacheOperator(cache_manager)
# Create a branch node to decide between fetching from cache or processing with the model
cache_check_branch_task = ModelCacheBranchOperator(
cache_manager,
model_task_name=model_task_name,
cache_task_name=cache_task_name,
)
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_task = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)
# Define the workflow structure using the >> operator
input_task >> cache_check_branch_task
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
cache_check_branch_task >> cache_task >> join_task
return join_task