mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -384,7 +384,7 @@ class DAGContext:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: Optional[str] = None
|
||||
self, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
@@ -407,7 +407,7 @@ class DAGContext:
|
||||
return self.get_from_share_data(_build_task_key(task_name, key))
|
||||
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
|
||||
self, task_name: str, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
|
||||
@@ -415,7 +415,7 @@ class DAGContext:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
|
||||
overwrite (bool): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
|
@@ -46,7 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node: "BaseOperator",
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
@@ -54,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
"""
|
||||
@@ -190,7 +190,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
@@ -230,7 +232,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, dag_ctx=dag_ctx
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
|
@@ -9,6 +9,12 @@ from .base import BaseOperator
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.streamify(self.streamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
self.streamify
|
||||
)
|
||||
|
@@ -76,12 +76,12 @@ def _save_call_data(
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
logger.debug(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
|
@@ -19,24 +19,24 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node: BaseOperator,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
if not dag_ctx:
|
||||
if not exist_dag_ctx:
|
||||
# Create DAG context
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
else:
|
||||
node_outputs = dag_ctx._node_to_outputs
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
# Share node output with exist dag context
|
||||
node_outputs = exist_dag_ctx._node_to_outputs
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
|
||||
|
@@ -127,6 +127,11 @@ class ModelRequestContext:
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""The request id of the model inference."""
|
||||
cache_enable: Optional[bool] = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
@@ -171,7 +176,7 @@ class ModelRequest:
|
||||
"""The stop token ids of the model inference."""
|
||||
context_len: Optional[int] = None
|
||||
"""The context length of the model inference."""
|
||||
echo: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
"""Whether to echo the input messages."""
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
@@ -203,7 +208,12 @@ class ModelRequest:
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
)
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v}
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
|
||||
|
||||
def to_trace_metadata(self):
|
||||
metadata = self.to_dict()
|
||||
metadata["prompt"] = self.messages_to_string()
|
||||
return metadata
|
||||
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
@@ -234,10 +244,13 @@ class ModelRequest:
|
||||
def build_request(
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
|
||||
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
echo: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
if not context:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
@@ -250,6 +263,7 @@ class ModelRequest:
|
||||
model=model,
|
||||
messages=messages,
|
||||
context=context,
|
||||
echo=echo,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -261,14 +275,22 @@ class ModelRequest:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the format of OpenAI API.
|
||||
def to_common_messages(
|
||||
self, support_system_role: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the common format(like OpenAI API).
|
||||
|
||||
This function will move last user message to the end of the list.
|
||||
|
||||
Args:
|
||||
support_system_role (bool): Whether to support system role
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message role is not supported
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -298,7 +320,17 @@ class ModelRequest:
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in self.messages
|
||||
]
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
return ModelMessage.to_common_messages(
|
||||
messages, support_system_role=support_system_role
|
||||
)
|
||||
|
||||
def messages_to_string(self) -> str:
|
||||
"""Convert the messages to string.
|
||||
|
||||
Returns:
|
||||
str: The messages in string format.
|
||||
"""
|
||||
return ModelMessage.messages_to_string(self.get_messages())
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -478,7 +510,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message system message conversion")
|
||||
return messages
|
||||
if model_metadata.ext_metadata.support_system_message:
|
||||
if not model_metadata.ext_metadata.support_system_message:
|
||||
# 3. Convert the messages to no system message
|
||||
return self.convert_to_no_system_message(messages, model_metadata)
|
||||
return messages
|
||||
|
@@ -197,15 +197,24 @@ class ModelMessage(BaseModel):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def to_openai_messages(
|
||||
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
|
||||
def to_common_messages(
|
||||
messages: List["ModelMessage"],
|
||||
convert_to_compatible_format: bool = False,
|
||||
support_system_role: bool = True,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Convert to OpenAI message format and
|
||||
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
"""Convert to common message format(e.g. OpenAI message format) and
|
||||
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
convert_to_compatible_format (bool): Whether to convert to compatible format
|
||||
support_system_role (bool): Whether to support system role
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The common messages
|
||||
|
||||
Raises:
|
||||
ValueError: If the message role is not supported
|
||||
"""
|
||||
history = []
|
||||
# Add history conversation
|
||||
@@ -213,6 +222,8 @@ class ModelMessage(BaseModel):
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
if not support_system_role:
|
||||
raise ValueError("Current model not support system role")
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
@@ -250,6 +261,18 @@ class ModelMessage(BaseModel):
|
||||
|
||||
return str_msg
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
return _messages_to_str(messages)
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||
@@ -264,7 +287,7 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||
|
||||
|
||||
def _messages_to_str(
|
||||
messages: List[BaseMessage],
|
||||
messages: List[Union[BaseMessage, ModelMessage]],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
@@ -272,7 +295,7 @@ def _messages_to_str(
|
||||
"""Convert messages to str
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
messages (List[Union[BaseMessage, ModelMessage]]): The messages
|
||||
human_prefix (str): The human prefix
|
||||
ai_prefix (str): The ai prefix
|
||||
system_prefix (str): The system prefix
|
||||
@@ -291,6 +314,8 @@ def _messages_to_str(
|
||||
role = system_prefix
|
||||
elif isinstance(message, ViewMessage):
|
||||
pass
|
||||
elif isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {message}")
|
||||
if role:
|
||||
|
@@ -44,7 +44,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
model_context = data.get("model_context")
|
||||
has_echo = True
|
||||
has_echo = False
|
||||
if model_context and "prompt_echo_len_char" in model_context:
|
||||
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
||||
has_echo = bool(model_context.get("echo", False))
|
||||
|
@@ -421,13 +421,13 @@ def test_parse_model_messages_multiple_system_messages():
|
||||
def test_to_openai_messages(
|
||||
human_model_message, ai_model_message, system_model_message
|
||||
):
|
||||
none_messages = ModelMessage.to_openai_messages([])
|
||||
none_messages = ModelMessage.to_common_messages([])
|
||||
assert none_messages == []
|
||||
|
||||
single_messages = ModelMessage.to_openai_messages([human_model_message])
|
||||
single_messages = ModelMessage.to_common_messages([human_model_message])
|
||||
assert single_messages == [{"role": "user", "content": human_model_message.content}]
|
||||
|
||||
normal_messages = ModelMessage.to_openai_messages(
|
||||
normal_messages = ModelMessage.to_common_messages(
|
||||
[
|
||||
system_model_message,
|
||||
human_model_message,
|
||||
@@ -446,7 +446,7 @@ def test_to_openai_messages(
|
||||
def test_to_openai_messages_convert_to_compatible_format(
|
||||
human_model_message, ai_model_message, system_model_message
|
||||
):
|
||||
shuffle_messages = ModelMessage.to_openai_messages(
|
||||
shuffle_messages = ModelMessage.to_common_messages(
|
||||
[
|
||||
system_model_message,
|
||||
human_model_message,
|
||||
|
Reference in New Issue
Block a user