mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -178,14 +178,15 @@ class FalconChatAdapter(BaseChatAdpter):
|
||||
return falcon_generate_output
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream
|
||||
|
||||
return proxyllm_generate_stream
|
||||
#
|
||||
# class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
# def match(self, model_path: str):
|
||||
# return "proxyllm" in model_path
|
||||
#
|
||||
# def get_generate_stream_func(self, model_path: str):
|
||||
# from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream
|
||||
#
|
||||
# return proxyllm_generate_stream
|
||||
|
||||
|
||||
class GorillaChatAdapter(BaseChatAdpter):
|
||||
@@ -286,6 +287,6 @@ register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
||||
register_llm_model_chat_adapter(InternLMChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
# register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
||||
|
@@ -3,7 +3,7 @@ import datetime
|
||||
import logging
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import Extra
|
||||
@@ -11,12 +11,13 @@ from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene
|
||||
from dbgpt.app.scene.operator.app_operator import (
|
||||
AppChatComposerOperator,
|
||||
ChatComposerInput,
|
||||
build_cached_chat_operator,
|
||||
)
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core.awel import DAG, BaseOperator, InputOperator, SimpleCallDataInputSource
|
||||
from dbgpt.core import LLMClient, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core.interface.message import StorageConversation
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
@@ -64,6 +65,10 @@ class BaseChat(ABC):
|
||||
keep_start_rounds: int = 0
|
||||
keep_end_rounds: int = 0
|
||||
|
||||
# Some model not support system role, this config is used to control whether to
|
||||
# convert system message to human message
|
||||
auto_convert_message: bool = True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -124,14 +129,15 @@ class BaseChat(ABC):
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
self._model_operator: BaseOperator = _build_model_operator()
|
||||
self._model_stream_operator: BaseOperator = _build_model_operator(
|
||||
is_stream=True, dag_name="llm_stream_model_dag"
|
||||
)
|
||||
# self._model_operator: BaseOperator = _build_model_operator()
|
||||
# self._model_stream_operator: BaseOperator = _build_model_operator(
|
||||
# is_stream=True, dag_name="llm_stream_model_dag"
|
||||
# )
|
||||
|
||||
# In v1, we will transform the message to compatible format of specific model
|
||||
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
|
||||
self._message_version = chat_param.get("message_version", "v2")
|
||||
self._chat_param = chat_param
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -153,6 +159,27 @@ class BaseChat(ABC):
|
||||
a dictionary to be formatted by prompt template
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return DefaultLLMClient(
|
||||
worker_manager, auto_convert_message=self.auto_convert_message
|
||||
)
|
||||
|
||||
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
|
||||
return await llm_task.call(call_data={"data": request})
|
||||
|
||||
async def call_streaming_operator(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
|
||||
async for out in await llm_task.call_stream(call_data={"data": request}):
|
||||
yield out
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@@ -185,7 +212,7 @@ class BaseChat(ABC):
|
||||
speak_to_user = prompt_define_response
|
||||
return speak_to_user
|
||||
|
||||
async def __call_base(self):
|
||||
async def _build_model_request(self) -> ModelRequest:
|
||||
input_values = await self.generate_input_values()
|
||||
# Load history
|
||||
self.history_messages = self.current_message.get_history_message()
|
||||
@@ -195,21 +222,6 @@ class BaseChat(ABC):
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self.current_message.tokens = 0
|
||||
# TODO: support handle history message by tokens
|
||||
# if self.prompt_template.template:
|
||||
# metadata = {
|
||||
# "template_scene": self.prompt_template.template_scene,
|
||||
# "input_values": input_values,
|
||||
# }
|
||||
# with root_tracer.start_span(
|
||||
# "BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||
# ):
|
||||
# current_prompt = self.prompt_template.format(**input_values)
|
||||
# ### prompt context token adapt according to llm max context length
|
||||
# current_prompt = await self.prompt_context_token_adapt(
|
||||
# prompt=current_prompt
|
||||
# )
|
||||
# self.current_message.add_system_message(current_prompt)
|
||||
|
||||
keep_start_rounds = (
|
||||
self.keep_start_rounds
|
||||
@@ -219,11 +231,25 @@ class BaseChat(ABC):
|
||||
keep_end_rounds = (
|
||||
self.keep_end_rounds if self.prompt_template.need_historical_messages else 0
|
||||
)
|
||||
req_ctx = ModelRequestContext(
|
||||
stream=self.prompt_template.stream_out,
|
||||
user_name=self._chat_param.get("user_name"),
|
||||
sys_code=self._chat_param.get("sys_code"),
|
||||
chat_mode=self.chat_mode.value(),
|
||||
span_id=root_tracer.get_current_span_id(),
|
||||
)
|
||||
node = AppChatComposerOperator(
|
||||
model=self.llm_model,
|
||||
temperature=float(self.prompt_template.temperature),
|
||||
max_new_tokens=int(self.prompt_template.max_new_tokens),
|
||||
prompt=self.prompt_template.prompt,
|
||||
message_version=self._message_version,
|
||||
echo=self.llm_echo,
|
||||
streaming=self.prompt_template.stream_out,
|
||||
keep_start_rounds=keep_start_rounds,
|
||||
keep_end_rounds=keep_end_rounds,
|
||||
str_history=self.prompt_template.str_history,
|
||||
request_context=req_ctx,
|
||||
)
|
||||
node_input = {
|
||||
"data": ChatComposerInput(
|
||||
@@ -231,22 +257,9 @@ class BaseChat(ABC):
|
||||
)
|
||||
}
|
||||
# llm_messages = self.generate_llm_messages()
|
||||
llm_messages = await node.call(call_data=node_input)
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
|
||||
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
|
||||
llm_messages = list(map(lambda m: m.dict(), llm_messages))
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": "",
|
||||
"messages": llm_messages,
|
||||
"temperature": float(self.prompt_template.temperature),
|
||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||
"echo": self.llm_echo,
|
||||
"version": self._message_version,
|
||||
}
|
||||
return payload
|
||||
model_request: ModelRequest = await node.call(call_data=node_input)
|
||||
model_request.context.cache_enable = self.model_cache_enable
|
||||
return model_request
|
||||
|
||||
def stream_plugin_call(self, text):
|
||||
return text
|
||||
@@ -271,23 +284,19 @@ class BaseChat(ABC):
|
||||
|
||||
async def stream_call(self):
|
||||
# TODO Retry when server connection error
|
||||
payload = await self.__call_base()
|
||||
payload = await self._build_model_request()
|
||||
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"payload request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
||||
"BaseChat.stream_call", metadata=payload.to_dict()
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
payload["model_cache_enable"] = self.model_cache_enable
|
||||
payload.span_id = span.span_id
|
||||
try:
|
||||
async for output in await self._model_stream_operator.call_stream(
|
||||
call_data={"data": payload}
|
||||
):
|
||||
async for output in self.call_streaming_operator(payload):
|
||||
# Plugin research in result generation
|
||||
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
output, self.skip_echo_len
|
||||
output, 0
|
||||
)
|
||||
view_msg = self.stream_plugin_call(msg)
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
@@ -308,19 +317,16 @@ class BaseChat(ABC):
|
||||
self.current_message.end_current_round()
|
||||
|
||||
async def nostream_call(self):
|
||||
payload = await self.__call_base()
|
||||
payload = await self._build_model_request()
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.nostream_call", metadata=payload.to_dict()
|
||||
)
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
payload["model_cache_enable"] = self.model_cache_enable
|
||||
payload.span_id = span.span_id
|
||||
try:
|
||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||
model_output = await self._model_operator.call(
|
||||
call_data={"data": payload}
|
||||
)
|
||||
model_output = await self.call_llm_operator(payload)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
@@ -380,13 +386,12 @@ class BaseChat(ABC):
|
||||
return self.current_ai_response()
|
||||
|
||||
async def get_llm_response(self):
|
||||
payload = await self.__call_base()
|
||||
payload = await self._build_model_request()
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
payload["model_cache_enable"] = self.model_cache_enable
|
||||
prompt_define_response = None
|
||||
try:
|
||||
model_output = await self._model_operator.call(call_data={"data": payload})
|
||||
model_output = await self.call_llm_operator(payload)
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
@@ -544,89 +549,3 @@ class BaseChat(ABC):
|
||||
for dict_item in antv_charts
|
||||
for key, value in dict_item.items()
|
||||
)
|
||||
|
||||
|
||||
def _build_model_operator(
|
||||
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
||||
) -> BaseOperator:
|
||||
"""Builds and returns a model processing workflow (DAG) operator.
|
||||
|
||||
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
|
||||
It includes caching and branching logic to either fetch results from a cache or process
|
||||
data using the model. It supports both streaming and non-streaming modes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input_node >> cache_check_branch_node
|
||||
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||
cache_check_branch_node >> cached_node >> join_node
|
||||
|
||||
equivalent to::
|
||||
|
||||
-> model_node -> save_cached_node ->
|
||||
/ \
|
||||
input_node -> cache_check_branch_node ---> join_node
|
||||
\ /
|
||||
-> cached_node ------------------- ->
|
||||
|
||||
Args:
|
||||
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
|
||||
dag_name (str): Name of the DAG.
|
||||
|
||||
Returns:
|
||||
BaseOperator: The final operator in the constructed DAG, typically a join node.
|
||||
"""
|
||||
from dbgpt.core.awel import JoinOperator
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.operator.model_operator import (
|
||||
CachedModelOperator,
|
||||
CachedModelStreamOperator,
|
||||
ModelCacheBranchOperator,
|
||||
ModelSaveCacheOperator,
|
||||
ModelStreamSaveCacheOperator,
|
||||
)
|
||||
from dbgpt.storage.cache import CacheManager
|
||||
|
||||
# Fetch worker and cache managers from the system configuration
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||
)
|
||||
# Define task names for the model and cache nodes
|
||||
model_task_name = "llm_model_node"
|
||||
cache_task_name = "llm_model_cache_node"
|
||||
|
||||
with DAG(dag_name):
|
||||
# Create an input node
|
||||
input_node = InputOperator(SimpleCallDataInputSource())
|
||||
# Determine if the workflow should operate in streaming mode
|
||||
if is_stream:
|
||||
model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
|
||||
cached_node = CachedModelStreamOperator(
|
||||
cache_manager, task_name=cache_task_name
|
||||
)
|
||||
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
|
||||
else:
|
||||
model_node = ModelOperator(worker_manager, task_name=model_task_name)
|
||||
cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
|
||||
save_cached_node = ModelSaveCacheOperator(cache_manager)
|
||||
|
||||
# Create a branch node to decide between fetching from cache or processing with the model
|
||||
cache_check_branch_node = ModelCacheBranchOperator(
|
||||
cache_manager,
|
||||
model_task_name="llm_model_node",
|
||||
cache_task_name="llm_model_cache_node",
|
||||
)
|
||||
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
|
||||
join_node = JoinOperator(
|
||||
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||
)
|
||||
|
||||
# Define the workflow structure using the >> operator
|
||||
input_node >> cache_check_branch_node
|
||||
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||
cache_check_branch_node >> cached_node >> join_node
|
||||
|
||||
return join_node
|
||||
|
@@ -80,10 +80,6 @@ class ChatKnowledge(BaseChat):
|
||||
vector_store_config=config,
|
||||
)
|
||||
query_rewrite = None
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
query_rewrite = QueryRewrite(
|
||||
llm_client=self.llm_client,
|
||||
|
@@ -1,17 +1,36 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt.core import BaseMessage, ChatPromptTemplate, ModelMessage
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core import (
|
||||
BaseMessage,
|
||||
ChatPromptTemplate,
|
||||
LLMClient,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
BaseOperator,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
SimpleCallDataInputSource,
|
||||
)
|
||||
from dbgpt.core.operator import (
|
||||
BufferedConversationMapperOperator,
|
||||
HistoryPromptBuilderOperator,
|
||||
LLMBranchOperator,
|
||||
)
|
||||
from dbgpt.model.operator import LLMOperator, StreamingLLMOperator
|
||||
from dbgpt.storage.cache.operator import (
|
||||
CachedModelOperator,
|
||||
CachedModelStreamOperator,
|
||||
CacheManager,
|
||||
ModelCacheBranchOperator,
|
||||
ModelSaveCacheOperator,
|
||||
ModelStreamSaveCacheOperator,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +42,7 @@ class ChatComposerInput:
|
||||
prompt_dict: Dict[str, Any]
|
||||
|
||||
|
||||
class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]):
|
||||
class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
||||
"""App chat composer operator.
|
||||
|
||||
TODO: Support more history merge mode.
|
||||
@@ -31,29 +50,56 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_new_tokens: int,
|
||||
prompt: ChatPromptTemplate,
|
||||
message_version: str = "v2",
|
||||
echo: bool = False,
|
||||
streaming: bool = True,
|
||||
history_key: str = "chat_history",
|
||||
history_merge_mode: str = "window",
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
str_history: bool = False,
|
||||
request_context: ModelRequestContext = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if not request_context:
|
||||
request_context = ModelRequestContext(stream=streaming)
|
||||
self._prompt_template = prompt
|
||||
self._history_key = history_key
|
||||
self._history_merge_mode = history_merge_mode
|
||||
self._keep_start_rounds = keep_start_rounds
|
||||
self._keep_end_rounds = keep_end_rounds
|
||||
self._str_history = str_history
|
||||
self._model_name = model
|
||||
self._temperature = temperature
|
||||
self._max_new_tokens = max_new_tokens
|
||||
self._message_version = message_version
|
||||
self._echo = echo
|
||||
self._streaming = streaming
|
||||
self._request_context = request_context
|
||||
self._sub_compose_dag = self._build_composer_dag()
|
||||
|
||||
async def map(self, input_value: ChatComposerInput) -> List[ModelMessage]:
|
||||
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
return await end_node.call(
|
||||
messages = await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
)
|
||||
span_id = self._request_context.span_id
|
||||
model_request = ModelRequest.build_request(
|
||||
model=self._model_name,
|
||||
messages=messages,
|
||||
context=self._request_context,
|
||||
temperature=self._temperature,
|
||||
max_new_tokens=self._max_new_tokens,
|
||||
span_id=span_id,
|
||||
echo=self._echo,
|
||||
)
|
||||
return model_request
|
||||
|
||||
def _build_composer_dag(self) -> DAG:
|
||||
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
|
||||
@@ -83,3 +129,79 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, List[ModelMessage]]
|
||||
)
|
||||
|
||||
return composer_dag
|
||||
|
||||
|
||||
def build_cached_chat_operator(
|
||||
llm_client: LLMClient,
|
||||
is_streaming: bool,
|
||||
system_app: SystemApp,
|
||||
cache_manager: Optional[CacheManager] = None,
|
||||
):
|
||||
"""Builds and returns a model processing workflow (DAG) operator.
|
||||
|
||||
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
|
||||
It includes caching and branching logic to either fetch results from a cache or process
|
||||
data using the model. It supports both streaming and non-streaming modes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input_task >> cache_check_branch_task
|
||||
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
|
||||
cache_check_branch_task >> cache_task >> join_task
|
||||
|
||||
equivalent to::
|
||||
|
||||
-> llm_task -> save_cache_task ->
|
||||
/ \
|
||||
input_task -> cache_check_branch_task ---> join_task
|
||||
\ /
|
||||
-> cache_task ------------------- ->
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient): The LLM client for processing data using the model.
|
||||
is_streaming (bool): Whether the model is a streaming model.
|
||||
system_app (SystemApp): The system app.
|
||||
cache_manager (CacheManager, optional): The cache manager for managing cache operations. Defaults to None.
|
||||
|
||||
Returns:
|
||||
BaseOperator: The final operator in the constructed DAG, typically a join node.
|
||||
"""
|
||||
# Define task names for the model and cache nodes
|
||||
model_task_name = "llm_model_node"
|
||||
cache_task_name = "llm_model_cache_node"
|
||||
if not cache_manager:
|
||||
cache_manager: CacheManager = system_app.get_component(
|
||||
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||
)
|
||||
|
||||
with DAG("dbgpt_awel_app_model_infer_with_cached") as dag:
|
||||
# Create an input task
|
||||
input_task = InputOperator(SimpleCallDataInputSource())
|
||||
# Create a branch task to decide between fetching from cache or processing with the model
|
||||
if is_streaming:
|
||||
llm_task = StreamingLLMOperator(llm_client, task_name=model_task_name)
|
||||
cache_task = CachedModelStreamOperator(
|
||||
cache_manager, task_name=cache_task_name
|
||||
)
|
||||
save_cache_task = ModelStreamSaveCacheOperator(cache_manager)
|
||||
else:
|
||||
llm_task = LLMOperator(llm_client, task_name=model_task_name)
|
||||
cache_task = CachedModelOperator(cache_manager, task_name=cache_task_name)
|
||||
save_cache_task = ModelSaveCacheOperator(cache_manager)
|
||||
|
||||
# Create a branch node to decide between fetching from cache or processing with the model
|
||||
cache_check_branch_task = ModelCacheBranchOperator(
|
||||
cache_manager,
|
||||
model_task_name=model_task_name,
|
||||
cache_task_name=cache_task_name,
|
||||
)
|
||||
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
|
||||
join_task = JoinOperator(
|
||||
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||
)
|
||||
|
||||
# Define the workflow structure using the >> operator
|
||||
input_task >> cache_check_branch_task
|
||||
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
|
||||
cache_check_branch_task >> cache_task >> join_task
|
||||
return join_task
|
||||
|
Reference in New Issue
Block a user