feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

6
.gitignore vendored
View File

@@ -179,4 +179,8 @@ thirdparty
# typescript # typescript
*.tsbuildinfo *.tsbuildinfo
/web/next-env.d.ts /web/next-env.d.ts
# Ignore awel DAG visualization files
/examples/**/*.gv
/examples/**/*.gv.pdf

View File

@@ -67,6 +67,11 @@ pre-commit: fmt test ## Run formatting and unit tests before committing
test: $(VENV)/.testenv ## Run unit tests test: $(VENV)/.testenv ## Run unit tests
$(VENV_BIN)/pytest dbgpt $(VENV_BIN)/pytest dbgpt
.PHONY: test-doc
test-doc: $(VENV)/.testenv ## Run doctests
# -k "not test_" skips tests that are not doctests.
$(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core
.PHONY: coverage .PHONY: coverage
coverage: setup ## Run tests and report coverage coverage: setup ## Run tests and report coverage
$(VENV_BIN)/pytest dbgpt --cov=dbgpt $(VENV_BIN)/pytest dbgpt --cov=dbgpt

View File

@@ -102,6 +102,11 @@ class BaseChat(ABC):
is_stream=True, dag_name="llm_stream_model_dag" is_stream=True, dag_name="llm_stream_model_dag"
) )
# Get the message version, default is v1 in app
# In v1, we will transform the message to compatible format of specific model
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
self._message_version = chat_param.get("message_version", "v1")
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@@ -185,6 +190,7 @@ class BaseChat(ABC):
"temperature": float(self.prompt_template.temperature), "temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens), "max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.llm_echo, "echo": self.llm_echo,
"version": self._message_version,
} }
return payload return payload

View File

@@ -6,7 +6,10 @@ from dbgpt.core.interface.cache import (
CacheValue, CacheValue,
) )
from dbgpt.core.interface.llm import ( from dbgpt.core.interface.llm import (
DefaultMessageConverter,
LLMClient, LLMClient,
MessageConverter,
ModelExtraMedata,
ModelInferenceMetrics, ModelInferenceMetrics,
ModelMetadata, ModelMetadata,
ModelOutput, ModelOutput,
@@ -14,19 +17,28 @@ from dbgpt.core.interface.llm import (
ModelRequestContext, ModelRequestContext,
) )
from dbgpt.core.interface.message import ( from dbgpt.core.interface.message import (
AIMessage,
BaseMessage,
ConversationIdentifier, ConversationIdentifier,
HumanMessage,
MessageIdentifier, MessageIdentifier,
MessageStorageItem, MessageStorageItem,
ModelMessage, ModelMessage,
ModelMessageRoleType, ModelMessageRoleType,
OnceConversation, OnceConversation,
StorageConversation, StorageConversation,
SystemMessage,
) )
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.prompt import ( from dbgpt.core.interface.prompt import (
BasePromptTemplate,
ChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
PromptManager, PromptManager,
PromptTemplate, PromptTemplate,
StoragePromptTemplate, StoragePromptTemplate,
SystemPromptTemplate,
) )
from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.storage import ( from dbgpt.core.interface.storage import (
@@ -49,14 +61,26 @@ __ALL__ = [
"ModelMessage", "ModelMessage",
"LLMClient", "LLMClient",
"ModelMessageRoleType", "ModelMessageRoleType",
"ModelExtraMedata",
"MessageConverter",
"DefaultMessageConverter",
"OnceConversation", "OnceConversation",
"StorageConversation", "StorageConversation",
"BaseMessage",
"SystemMessage",
"AIMessage",
"HumanMessage",
"MessageStorageItem", "MessageStorageItem",
"ConversationIdentifier", "ConversationIdentifier",
"MessageIdentifier", "MessageIdentifier",
"PromptTemplate", "PromptTemplate",
"PromptManager", "PromptManager",
"StoragePromptTemplate", "StoragePromptTemplate",
"BasePromptTemplate",
"ChatPromptTemplate",
"MessagesPlaceholder",
"SystemPromptTemplate",
"HumanPromptTemplate",
"BaseOutputParser", "BaseOutputParser",
"SQLOutputParser", "SQLOutputParser",
"Serializable", "Serializable",

View File

@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
""" """
import logging
from typing import List, Optional from typing import List, Optional
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@@ -39,6 +40,8 @@ from .task.task_impl import (
) )
from .trigger.http_trigger import HttpTrigger from .trigger.http_trigger import HttpTrigger
logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
"initialize_awel", "initialize_awel",
"DAGContext", "DAGContext",
@@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
def setup_dev_environment( def setup_dev_environment(
dags: List[DAG], dags: List[DAG],
host: Optional[str] = "0.0.0.0", host: Optional[str] = "127.0.0.1",
port: Optional[int] = 5555, port: Optional[int] = 5555,
logging_level: Optional[str] = None, logging_level: Optional[str] = None,
logger_filename: Optional[str] = None, logger_filename: Optional[str] = None,
show_dag_graph: Optional[bool] = True,
) -> None: ) -> None:
"""Setup a development environment for AWEL. """Setup a development environment for AWEL.
Just using in development environment, not production environment. Just using in development environment, not production environment.
Args:
dags (List[DAG]): The DAGs.
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
port (Optional[int], optional): The port. Defaults to 5555.
logging_level (Optional[str], optional): The logging level. Defaults to None.
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
If True, the DAG graph will be saved to a file and open it automatically.
""" """
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -118,6 +131,15 @@ def setup_dev_environment(
system_app.register_instance(trigger_manager) system_app.register_instance(trigger_manager)
for dag in dags: for dag in dags:
if show_dag_graph:
try:
dag_graph_file = dag.visualize_dag()
if dag_graph_file:
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
except Exception as e:
logger.warning(
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
)
for trigger in dag.trigger_nodes: for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger) trigger_manager.register_trigger(trigger)
trigger_manager.after_register() trigger_manager.after_register()

View File

@@ -6,8 +6,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
from concurrent.futures import Executor from concurrent.futures import Executor
from functools import cache from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@@ -177,7 +176,10 @@ class DAGLifecycle:
pass pass
async def after_dag_end(self): async def after_dag_end(self):
"""The callback after DAG end""" """The callback after DAG end,
This method may be called multiple times, please make sure it is idempotent.
"""
pass pass
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._downstream.append(node) self._downstream.append(node)
node._upstream.append(self) node._upstream.append(self)
def __repr__(self):
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
if self.node_name:
return f"{cls_name}(node_name={self.node_name})"
else:
return f"{cls_name}"
def __str__(self):
return self.__repr__()
def _build_task_key(task_name: str, key: str) -> str: def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}" return f"{task_name}___$$$$$$___{key}"
@@ -496,6 +512,15 @@ class DAG:
tasks.append(node.after_dag_end()) tasks.append(node.after_dag_end())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self): def __enter__(self):
DAGVar.enter_dag(self) DAGVar.enter_dag(self)
return self return self
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
for node in stream_nodes: for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream)) nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes return nodes
def _print_format_dag_tree(dag: DAG) -> None:
for node in dag.root_nodes:
_print_dag(node)
def _print_dag(
node: DAGNode,
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
):
if level_dict is None:
level_dict = {}
connector = " -> " if level != 0 else ""
new_prefix = prefix
if last:
if level != 0:
new_prefix += " "
print(prefix + connector + str(node))
else:
if level != 0:
new_prefix += "| "
print(prefix + connector + str(node))
level_dict[level] = level_dict.get(level, 0) + 1
num_children = len(node.downstream)
for i, child in enumerate(node.downstream):
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
def _print_node(node: DAGNode, level: int) -> None:
print(f"{level_sep * level}{node}")
_apply_root_node(root_nodes, _print_node)
def _apply_root_node(
root_nodes: List[DAGNode],
func: Callable[[DAGNode, int], None],
) -> None:
for dag_node in root_nodes:
_handle_dag_nodes(False, 0, dag_node, func)
def _handle_dag_nodes(
is_down_to_up: bool,
level: int,
dag_node: DAGNode,
func: Callable[[DAGNode, int], None],
):
if not dag_node:
return
func(dag_node, level)
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
level += 1
for node in stream_nodes:
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None
dot = Digraph(name=dag.dag_id)
# Record the added edges to avoid adding duplicate edges
added_edges = set()
def add_edges(node: DAGNode):
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)
for root in dag.root_nodes:
add_edges(root)
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR
return dot.render(filename, view=view, **kwargs)

View File

@@ -46,6 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
node: "BaseOperator", node: "BaseOperator",
call_data: Optional[CALL_DATA] = None, call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False, streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
) -> DAGContext: ) -> DAGContext:
"""Execute the workflow starting from a given operator. """Execute the workflow starting from a given operator.
@@ -53,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
node (RunnableDAGNode): The starting node of the workflow to be executed. node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call. streaming_call (bool): Whether the call is a streaming call.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns: Returns:
DAGContext: The context after executing the workflow, containing the final state and data. DAGContext: The context after executing the workflow, containing the final state and data.
""" """
@@ -174,18 +175,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
TaskOutput[OUT]: The task output after this node has been run. TaskOutput[OUT]: The task output after this node has been run.
""" """
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT: async def call(
self,
call_data: Optional[CALL_DATA] = None,
dag_ctx: Optional[DAGContext] = None,
) -> OUT:
"""Execute the node and return the output. """Execute the node and return the output.
This method is a high-level wrapper for executing the node. This method is a high-level wrapper for executing the node.
Args: Args:
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns: Returns:
OUT: The output of the node after execution. OUT: The output of the node after execution.
""" """
out_ctx = await self._runner.execute_workflow(self, call_data) out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
return out_ctx.current_task_context.task_output.output return out_ctx.current_task_context.task_output.output
def _blocking_call( def _blocking_call(
@@ -209,7 +214,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return loop.run_until_complete(self.call(call_data)) return loop.run_until_complete(self.call(call_data))
async def call_stream( async def call_stream(
self, call_data: Optional[CALL_DATA] = None self,
call_data: Optional[CALL_DATA] = None,
dag_ctx: Optional[DAGContext] = None,
) -> AsyncIterator[OUT]: ) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream. """Execute the node and return the output as a stream.
@@ -217,12 +224,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args: Args:
call_data (CALL_DATA): The data pass to root operator node. call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns: Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream. AsyncIterator[OUT]: An asynchronous iterator over the output stream.
""" """
out_ctx = await self._runner.execute_workflow( out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True self, call_data, streaming_call=True, dag_ctx=dag_ctx
) )
return out_ctx.current_task_context.task_output.output_stream return out_ctx.current_task_context.task_output.output_stream

View File

@@ -19,17 +19,21 @@ class DefaultWorkflowRunner(WorkflowRunner):
node: BaseOperator, node: BaseOperator,
call_data: Optional[CALL_DATA] = None, call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False, streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
) -> DAGContext: ) -> DAGContext:
# Save node output # Save node output
# dag = node.dag # dag = node.dag
node_outputs: Dict[str, TaskContext] = {}
job_manager = JobManager.build_from_end_node(node, call_data) job_manager = JobManager.build_from_end_node(node, call_data)
# Create DAG context if not dag_ctx:
dag_ctx = DAGContext( # Create DAG context
streaming_call=streaming_call, node_outputs: Dict[str, TaskContext] = {}
node_to_outputs=node_outputs, dag_ctx = DAGContext(
node_name_to_ids=job_manager._node_name_to_ids, streaming_call=streaming_call,
) node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
else:
node_outputs = dag_ctx._node_to_outputs
logger.info( logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
) )

View File

@@ -1,14 +1,21 @@
import collections
import copy import copy
import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo from dbgpt.util.model_utils import GPUInfo
logger = logging.getLogger(__name__)
@dataclass @dataclass
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
@@ -223,6 +230,29 @@ class ModelRequest:
raise ValueError("The messages is not a single user message") raise ValueError("The messages is not a single user message")
return messages[0] return messages[0]
@staticmethod
def build_request(
model: str,
messages: List[ModelMessage],
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
stream: Optional[bool] = False,
**kwargs,
):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
context = ModelRequestContext(**context_dict)
return ModelRequest(
model=model,
messages=messages,
context=context,
**kwargs,
)
@staticmethod @staticmethod
def _build(model: str, prompt: str, **kwargs): def _build(model: str, prompt: str, **kwargs):
return ModelRequest( return ModelRequest(
@@ -271,6 +301,43 @@ class ModelRequest:
return ModelMessage.to_openai_messages(messages) return ModelMessage.to_openai_messages(messages)
@dataclass
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: Optional[List[str]] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
],
metadata={"help": "The roles of the prompt"},
)
prompt_sep: Optional[str] = field(
default="\n",
metadata={"help": "The separator of the prompt between multiple rounds"},
)
# You can see the chat template in your model repo tokenizer config,
# typically in the tokenizer_config.json
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@property
def support_system_message(self) -> bool:
"""Whether the model supports system message.
Returns:
bool: Whether the model supports system message.
"""
return ModelMessageRoleType.SYSTEM in self.prompt_roles
@dataclass @dataclass
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
class ModelMetadata(BaseParameters): class ModelMetadata(BaseParameters):
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
default_factory=dict, default_factory=dict,
metadata={"help": "Model metadata"}, metadata={"help": "Model metadata"},
) )
ext_metadata: Optional[ModelExtraMedata] = field(
default_factory=ModelExtraMedata,
metadata={"help": "Model extra metadata"},
)
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert the messages
to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
>>> class RemoveSystemMessageConverter(MessageConverter):
... def convert(
... self,
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user message.
... system_message = None
... other_messages = []
... sep = "\\n"
... for message in messages:
... if message.role == ModelMessageRoleType.SYSTEM:
... system_message = message
... else:
... other_messages.append(message)
... if system_message and other_messages:
... other_messages[-1].content = (
... system_message.content + sep + other_messages[-1].content
... )
... return other_messages
...
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
... ]
>>> converter = RemoveSystemMessageConverter()
>>> converted_messages = converter.convert(messages, None)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
@abstractmethod
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
self._prompt_sep = prompt_sep
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
There are three steps to convert the messages:
1. Just keep system, human and AI messages
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support system message
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
# 1. Just keep system, human and AI messages
messages = list(filter(lambda m: m.pass_to_model, messages))
# 2. Move the last user's message to the end of the list
messages = self.move_last_user_message_to_end(messages)
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message system message conversion")
return messages
if model_metadata.ext_metadata.support_system_message:
# 3. Convert the messages to no system message
return self.convert_to_no_system_message(messages, model_metadata)
return messages
def convert_to_no_system_message(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import (
... DefaultMessageConverter,
... ModelMetadata,
... )
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> model_metadata = ModelMetadata(model="test")
>>> converted_messages = converter.convert_to_no_system_message(
... messages, model_metadata
... )
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message conversion")
return messages
ext_metadata = model_metadata.ext_metadata
system_messages = []
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]:
result_messages.append(message)
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
system_message_str = None
if len(system_messages) > 1:
logger.warning("Your system messages have more than one message")
system_message_str = prompt_sep.join([m.content for m in system_messages])
elif len(system_messages) == 1:
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
return result_messages
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> converted_messages = converter.move_last_user_message_to_end(messages)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ]
Args:
messages(List[ModelMessage]): The messages.
Returns:
List[ModelMessage]: The converted messages.
"""
last_user_input_index = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].role == ModelMessageRoleType.HUMAN:
last_user_input_index = i
break
if last_user_input_index is not None:
last_user_input = messages.pop(last_user_input_index)
messages.append(last_user_input)
return messages
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
class LLMClient(ABC): class LLMClient(ABC):
"""An abstract class for LLM client.""" """An abstract class for LLM client."""
# Cache the model metadata for 60 seconds
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
@property
def cache(self) -> collections.abc.MutableMapping:
"""The cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
collections.abc.MutableMapping: The cache object.
"""
return self._MODEL_CACHE_
@abstractmethod @abstractmethod
async def generate(self, request: ModelRequest) -> ModelOutput: async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate a response for a given model request. """Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args: Args:
request(ModelRequest): The model request. request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns: Returns:
ModelOutput: The model output. ModelOutput: The model output.
@@ -315,12 +658,18 @@ class LLMClient(ABC):
@abstractmethod @abstractmethod
async def generate_stream( async def generate_stream(
self, request: ModelRequest self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]: ) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request. """Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args: Args:
request(ModelRequest): The model request. request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns: Returns:
AsyncIterator[ModelOutput]: The model output stream. AsyncIterator[ModelOutput]: The model output stream.
@@ -345,3 +694,65 @@ class LLMClient(ABC):
Returns: Returns:
int: The number of tokens. int: The number of tokens.
""" """
async def covert_message(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelRequest: The converted model request.
"""
if not message_converter:
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.messages, model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the llm server.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
key = "____$llm_client_models$____"
if key not in self.cache:
models = await self.models()
self.cache[key] = models
for model in models:
model_metadata_key = (
f"____$llm_client_models_metadata_{model.model}$____"
)
self.cache[model_metadata_key] = model
return self.cache[key]
async def get_model_metadata(self, model: str) -> ModelMetadata:
"""Get the model metadata.
Args:
model(str): The model name.
Returns:
ModelMetadata: The model metadata.
Raises:
ValueError: If the model is not found.
"""
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
if model_metadata_key not in self.cache:
await self.cached_models()
model_metadata = self.cache.get(model_metadata_key)
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.storage import ( from dbgpt.core.interface.storage import (
InMemoryStorage, InMemoryStorage,
ResourceIdentifier, ResourceIdentifier,
@@ -114,6 +113,50 @@ class ModelMessage(BaseModel):
content: str content: str
round_index: Optional[int] = 0 round_index: Optional[int] = 0
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
The view message will not be passed to the model
Returns:
bool: Whether the message will be passed to the model
"""
return self.role in [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]
@staticmethod
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
result = []
for message in messages:
content, round_index = message.content, message.round_index
if isinstance(message, HumanMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.HUMAN,
content=content,
round_index=round_index,
)
)
elif isinstance(message, AIMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.AI,
content=content,
round_index=round_index,
)
)
elif isinstance(message, SystemMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content=message.content
)
)
return result
@staticmethod @staticmethod
def from_openai_messages( def from_openai_messages(
messages: Union[str, List[Dict[str, str]]] messages: Union[str, List[Dict[str, str]]]
@@ -142,9 +185,15 @@ class ModelMessage(BaseModel):
return result return result
@staticmethod @staticmethod
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: def to_openai_messages(
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and """Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
convert_to_compatible_format (bool): Whether to convert to compatible format
""" """
history = [] history = []
# Add history conversation # Add history conversation
@@ -157,15 +206,16 @@ class ModelMessage(BaseModel):
history.append({"role": "assistant", "content": message.content}) history.append({"role": "assistant", "content": message.content})
else: else:
pass pass
# Move the last user's information to the end if convert_to_compatible_format:
last_user_input_index = None # Move the last user's information to the end
for i in range(len(history) - 1, -1, -1): last_user_input_index = None
if history[i]["role"] == "user": for i in range(len(history) - 1, -1, -1):
last_user_input_index = i if history[i]["role"] == "user":
break last_user_input_index = i
if last_user_input_index: break
last_user_input = history.pop(last_user_input_index) if last_user_input_index:
history.append(last_user_input) last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
return history return history
@staticmethod @staticmethod
@@ -189,8 +239,8 @@ class ModelMessage(BaseModel):
return str_msg return str_msg
_SingleRoundMessage = List[ModelMessage] _SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]] _MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
def _message_to_dict(message: BaseMessage) -> Dict: def _message_to_dict(message: BaseMessage) -> Dict:
@@ -338,7 +388,8 @@ class OnceConversation:
"""Start a new round of conversation """Start a new round of conversation
Example: Example:
>>> conversation = OnceConversation()
>>> conversation = OnceConversation("chat_normal")
>>> # The chat order will be 0, then we start a new round of conversation >>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0 >>> assert conversation.chat_order == 0
>>> conversation.start_new_round() >>> conversation.start_new_round()
@@ -585,6 +636,28 @@ class OnceConversation:
) )
return messages return messages
def get_history_message(
self, include_system_message: bool = False
) -> List[BaseMessage]:
"""Get the history message
Not include the system messages.
Args:
include_system_message (bool): Whether to include the system message
Returns:
List[BaseMessage]: The history messages
"""
messages = []
for message in self.messages:
if message.pass_to_model:
if include_system_message:
messages.append(message)
elif message.type != "system":
messages.append(message)
return messages
class ConversationIdentifier(ResourceIdentifier): class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier""" """Conversation identifier"""

View File

@@ -0,0 +1,114 @@
import dataclasses
from typing import Any, Dict, List, Optional
from dbgpt.core import (
ChatPromptTemplate,
MessageStorageItem,
ModelMessage,
ModelRequest,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
)
from dbgpt.core.interface.operator.prompt_operator import HistoryPromptBuilderOperator
from .message_operator import (
BufferedConversationMapperOperator,
ChatHistoryLoadType,
PreChatHistoryLoadOperator,
)
@dataclasses.dataclass
class ChatComposerInput:
"""The composer input."""
prompt_dict: Dict[str, Any]
model_dict: Dict[str, Any]
context: ChatHistoryLoadType
class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
"""The chat history prompt composer operator.
For simple use, you can use this operator to compose the chat history prompt.
"""
def __init__(
self,
prompt_template: ChatPromptTemplate,
history_key: str = "chat_history",
last_k_round: int = 2,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._prompt_template = prompt_template
self._history_key = history_key
self._last_k_round = last_k_round
self._storage = storage
self._message_storage = message_storage
self._sub_compose_dag = self._build_composer_dag()
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
return await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
)
def _build_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_chat_history_prompt_composer") as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# Load and store chat history, default use InMemoryStorage.
chat_history_load_task = PreChatHistoryLoadOperator(
storage=self._storage, message_storage=self._message_storage
)
# History transform task, here we keep last 5 round messages
history_transform_task = BufferedConversationMapperOperator(
last_k_round=self._last_k_round
)
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template, history_key=self._history_key
)
model_request_build_task = JoinOperator(self._build_model_request)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x.context)
>> chat_history_load_task
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x.prompt_dict)
>> history_prompt_build_task
)
history_prompt_build_task >> model_request_build_task
(
input_task
>> MapOperator(lambda x: x.model_dict)
>> model_request_build_task
)
return composer_dag
def _build_model_request(
self, messages: List[ModelMessage], model_dict: Dict[str, Any]
) -> ModelRequest:
return ModelRequest.build_request(messages=messages, **model_dict)
async def after_dag_end(self):
# Should call after_dag_end() of sub dag
await self._sub_compose_dag._after_dag_end()

View File

@@ -1,11 +1,12 @@
import dataclasses import dataclasses
from abc import ABC from abc import ABC
from typing import Any, AsyncIterator, Dict, Optional, Union from typing import Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import ( from dbgpt.core.awel import (
BranchFunc, BranchFunc,
BranchOperator, BranchOperator,
DAGContext,
MapOperator, MapOperator,
StreamifyAbsOperator, StreamifyAbsOperator,
) )
@@ -22,20 +23,30 @@ RequestInput = Union[
str, str,
Dict[str, Any], Dict[str, Any],
BaseModel, BaseModel,
ModelMessage,
List[ModelMessage],
] ]
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC): class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
"""Build the model request from the input value."""
def __init__(self, model: Optional[str] = None, **kwargs): def __init__(self, model: Optional[str] = None, **kwargs):
self._model = model self._model = model
super().__init__(**kwargs) super().__init__(**kwargs)
async def map(self, input_value: RequestInput) -> ModelRequest: async def map(self, input_value: RequestInput) -> ModelRequest:
req_dict = {} req_dict = {}
if not input_value:
raise ValueError("input_value is not set")
if isinstance(input_value, str): if isinstance(input_value, str):
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]} req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
elif isinstance(input_value, dict): elif isinstance(input_value, dict):
req_dict = input_value req_dict = input_value
elif isinstance(input_value, ModelMessage):
req_dict = {"messages": [input_value]}
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
req_dict = {"messages": input_value}
elif dataclasses.is_dataclass(input_value): elif dataclasses.is_dataclass(input_value):
req_dict = dataclasses.asdict(input_value) req_dict = dataclasses.asdict(input_value)
elif isinstance(input_value, BaseModel): elif isinstance(input_value, BaseModel):
@@ -76,6 +87,7 @@ class BaseLLM:
"""The abstract operator for a LLM.""" """The abstract operator for a LLM."""
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client self._llm_client = llm_client
@@ -87,8 +99,16 @@ class BaseLLM:
raise ValueError("llm_client is not set") raise ValueError("llm_client is not set")
return self._llm_client return self._llm_client
async def save_model_output(
self, current_dag_context: DAGContext, model_output: ModelOutput
) -> None:
"""Save the model output to the share data."""
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM. """The operator for a LLM.
Args: Args:
@@ -105,10 +125,12 @@ class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
await self.current_dag_context.save_to_share_data( await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model self.SHARE_DATA_KEY_MODEL_NAME, request.model
) )
return await self.llm_client.generate(request) model_output = await self.llm_client.generate(request)
await self.save_model_output(self.current_dag_context, model_output)
return model_output
class StreamingLLMOperator( class BaseStreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
): ):
"""The streaming operator for a LLM. """The streaming operator for a LLM.
@@ -127,8 +149,12 @@ class StreamingLLMOperator(
await self.current_dag_context.save_to_share_data( await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model self.SHARE_DATA_KEY_MODEL_NAME, request.model
) )
model_output = None
async for output in self.llm_client.generate_stream(request): async for output in self.llm_client.generate_stream(request):
model_output = output
yield output yield output
if model_output:
await self.save_model_output(self.current_dag_context, model_output)
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]): class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):

View File

@@ -1,19 +1,17 @@
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC
from typing import Any, AsyncIterator, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import ( from dbgpt.core import (
MessageStorageItem, MessageStorageItem,
ModelMessage, ModelMessage,
ModelMessageRoleType, ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext, ModelRequestContext,
StorageConversation, StorageConversation,
StorageInterface, StorageInterface,
) )
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator from dbgpt.core.awel import BaseOperator, MapOperator
from dbgpt.core.interface.message import _MultiRoundMessageMapper from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
class BaseConversationOperator(BaseOperator, ABC): class BaseConversationOperator(BaseOperator, ABC):
@@ -21,32 +19,41 @@ class BaseConversationOperator(BaseOperator, ABC):
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation" SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request" SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT = "share_data_key_model_request_context"
_check_storage: bool = True
def __init__( def __init__(
self, self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None, storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
check_storage: bool = True,
**kwargs, **kwargs,
): ):
self._check_storage = check_storage
super().__init__(**kwargs) super().__init__(**kwargs)
self._storage = storage self._storage = storage
self._message_storage = message_storage self._message_storage = message_storage
@property @property
def storage(self) -> StorageInterface[StorageConversation, Any]: def storage(self) -> Optional[StorageInterface[StorageConversation, Any]]:
"""Return the LLM client.""" """Return the LLM client."""
if not self._storage: if not self._storage:
raise ValueError("Storage is not set") if self._check_storage:
raise ValueError("Storage is not set")
return None
return self._storage return self._storage
@property @property
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]: def message_storage(self) -> Optional[StorageInterface[MessageStorageItem, Any]]:
"""Return the LLM client.""" """Return the LLM client."""
if not self._message_storage: if not self._message_storage:
raise ValueError("Message storage is not set") if self._check_storage:
raise ValueError("Message storage is not set")
return None
return self._message_storage return self._message_storage
async def get_storage_conversation(self) -> StorageConversation: async def get_storage_conversation(self) -> Optional[StorageConversation]:
"""Get the storage conversation from share data. """Get the storage conversation from share data.
Returns: Returns:
@@ -58,104 +65,11 @@ class BaseConversationOperator(BaseOperator, ABC):
) )
) )
if not storage_conv: if not storage_conv:
raise ValueError("Storage conversation is not set") if self._check_storage:
raise ValueError("Storage conversation is not set")
return None
return storage_conv return storage_conv
async def get_model_request(self) -> ModelRequest:
"""Get the model request from share data.
Returns:
ModelRequest: The model request.
"""
model_request: ModelRequest = (
await self.current_dag_context.get_from_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST
)
)
if not model_request:
raise ValueError("Model request is not set")
return model_request
class PreConversationOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
"""The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
and they can store in different storage(for high performance).
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
Args:
input_value (ModelRequest): The input value.
Returns:
ModelRequest: The mapped ModelRequest.
"""
if input_value.context is None:
input_value.context = ModelRequestContext()
if not input_value.context.conv_uid:
input_value.context.conv_uid = str(uuid.uuid4())
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
StorageConversation,
conv_uid=input_value.context.conv_uid,
chat_mode=chat_mode,
user_name=input_value.context.user_name,
sys_code=input_value.context.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
)
input_messages = input_value.get_messages()
await self.save_to_storage(storage_conv, input_messages)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
)
return input_value
async def save_to_storage(
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
) -> None:
"""Save the messages to storage.
Args:
storage_conv (StorageConversation): The storage conversation.
input_messages (List[ModelMessage]): The input messages.
"""
# check first
self.check_messages(input_messages)
storage_conv.start_new_round()
for message in input_messages:
if message.role == ModelMessageRoleType.HUMAN:
storage_conv.add_user_message(message.content)
else:
storage_conv.add_system_message(message.content)
def check_messages(self, messages: List[ModelMessage]) -> None: def check_messages(self, messages: List[ModelMessage]) -> None:
"""Check the messages. """Check the messages.
@@ -174,164 +88,147 @@ class PreConversationOperator(
]: ]:
raise ValueError(f"Message role {message.role} is not supported") raise ValueError(f"Message role {message.role} is not supported")
async def after_dag_end(self):
"""The callback after DAG end""" ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]]
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
# TODO dont save if the conversation has some internal error
storage_conv.end_current_round()
class PostConversationOperator( class PreChatHistoryLoadOperator(
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput] BaseConversationOperator, MapOperator[ChatHistoryLoadType, List[BaseMessage]]
): ):
def __init__(self, **kwargs): """The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
and they can store in different storage(for high performance).
This operator just load the conversation and messages from storage.
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
include_system_message: bool = False,
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs) MapOperator.__init__(self, **kwargs)
self._include_system_message = include_system_message
async def map(self, input_value: ModelOutput) -> ModelOutput: async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]:
"""Map the input value to a ModelOutput. """Map the input value to a ModelRequest.
Args: Args:
input_value (ModelOutput): The input value. input_value (ChatHistoryLoadType): The input value.
Returns: Returns:
ModelOutput: The mapped ModelOutput. List[BaseMessage]: The messages stored in the storage.
""" """
# Get the storage conversation from share data if not input_value:
storage_conv: StorageConversation = await self.get_storage_conversation() raise ValueError("Model request context can't be None")
storage_conv.add_ai_message(input_value.text) if isinstance(input_value, dict):
return input_value input_value = ModelRequestContext(**input_value)
if not input_value.conv_uid:
input_value.conv_uid = str(uuid.uuid4())
if not input_value.extra:
input_value.extra = {}
chat_mode = input_value.chat_mode
class PostStreamingConversationOperator( # Create a new storage conversation, this will load the conversation from storage, so we must do this async
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput] storage_conv: StorageConversation = await self.blocking_func_to_async(
): StorageConversation,
def __init__(self, **kwargs): conv_uid=input_value.conv_uid,
TransformStreamAbsOperator.__init__(self, **kwargs) chat_mode=chat_mode,
user_name=input_value.user_name,
sys_code=input_value.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
)
async def transform_stream( # Save the storage conversation to share data, for the child operators
self, input_value: AsyncIterator[ModelOutput] await self.current_dag_context.save_to_share_data(
) -> ModelOutput: self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
"""Transform the input value to a ModelOutput. )
await self.current_dag_context.save_to_share_data(
Args: self.SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT, input_value
input_value (ModelOutput): The input value. )
# Get history messages from storage
Returns: history_messages: List[BaseMessage] = storage_conv.get_history_message(
ModelOutput: The transformed ModelOutput. include_system_message=self._include_system_message
""" )
full_text = "" return history_messages
async for model_output in input_value:
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
full_text = model_output.text
yield model_output
# Get the storage conversation from share data
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv.add_ai_message(full_text)
class ConversationMapperOperator( class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
): ):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs): def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
MapOperator.__init__(self, **kwargs) MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper self._message_mapper = message_mapper
async def map(self, input_value: ModelRequest) -> ModelRequest: async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
"""Map the input value to a ModelRequest. return self.map_messages(input_value)
Args: def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
input_value (ModelRequest): The input value. messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
Returns:
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
messages (List[ModelMessage]): The input messages.
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages messages
) )
message_mapper = self._message_mapper or self.map_multi_round_messages message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round) return message_mapper(messages_by_round)
def map_multi_round_messages( def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]] self, messages_by_round: List[List[BaseMessage]]
) -> List[ModelMessage]: ) -> List[BaseMessage]:
"""Map multi round messages to a list of ModelMessage """Map multi round messages to a list of BaseMessage.
By default, just merge all multi round messages to a list of ModelMessage according origin order. By default, just merge all multi round messages to a list of BaseMessage according origin order.
And you can overwrite this method to implement your own logic. And you can overwrite this method to implement your own logic.
Examples: Examples:
Merge multi round messages to a list of ModelMessage according origin order. Merge multi round messages to a list of BaseMessage according origin order.
.. code-block:: python >>> from dbgpt.core.interface.message import (
... AIMessage,
... HumanMessage,
... SystemMessage,
... )
>>> messages_by_round = [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="What's the error?", round_index=2),
... AIMessage(content="Just a joke.", round_index=2),
... ],
... ]
>>> operator = ConversationMapperOperator()
>>> messages = operator.map_multi_round_messages(messages_by_round)
>>> assert messages == [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... HumanMessage(content="What's the error?", round_index=2),
... AIMessage(content="Just a joke.", round_index=2),
... ]
import asyncio Map multi round messages to a list of BaseMessage just keep the last one round.
from dbgpt.core.operator import ConversationMapperOperator
messages_by_round = [ >>> class MyMapper(ConversationMapperOperator):
[ ... def __init__(self, **kwargs):
ModelMessage(role="human", content="Hi", round_index=1), ... super().__init__(**kwargs)
ModelMessage(role="ai", content="Hello!", round_index=1), ...
], ... def map_multi_round_messages(
[ ... self, messages_by_round: List[List[BaseMessage]]
ModelMessage(role="system", content="Error 404", round_index=2), ... ) -> List[BaseMessage]:
ModelMessage( ... return messages_by_round[-1]
role="human", content="What's the error?", round_index=2 ...
), >>> operator = MyMapper()
ModelMessage(role="ai", content="Just a joke.", round_index=2), >>> messages = operator.map_multi_round_messages(messages_by_round)
], >>> assert messages == [
[ ... HumanMessage(content="What's the error?", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3), ... AIMessage(content="Just a joke.", round_index=2),
], ... ]
]
operator = ConversationMapperOperator()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(
role="human", content="What's the error?", round_index=2
),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
Map multi round messages to a list of ModelMessage just keep the last one round.
.. code-block:: python
class MyMapper(ConversationMapperOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
return messages_by_round[-1]
operator = MyMapper()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Funny!", round_index=3),
]
Args: Args:
""" """
@@ -340,17 +237,17 @@ class ConversationMapperOperator(
return sum(messages_by_round, []) return sum(messages_by_round, [])
def _split_messages_by_round( def _split_messages_by_round(
self, messages: List[ModelMessage] self, messages: List[BaseMessage]
) -> List[List[ModelMessage]]: ) -> List[List[BaseMessage]]:
"""Split the messages by round index. """Split the messages by round index.
Args: Args:
messages (List[ModelMessage]): The input messages. messages (List[BaseMessage]): The messages.
Returns: Returns:
List[List[ModelMessage]]: The split messages. List[List[BaseMessage]]: The messages split by round.
""" """
messages_by_round: List[List[ModelMessage]] = [] messages_by_round: List[List[BaseMessage]] = []
last_round_index = 0 last_round_index = 0
for message in messages: for message in messages:
if not message.round_index: if not message.round_index:
@@ -366,7 +263,7 @@ class ConversationMapperOperator(
class BufferedConversationMapperOperator(ConversationMapperOperator): class BufferedConversationMapperOperator(ConversationMapperOperator):
"""The buffered conversation mapper operator. """The buffered conversation mapper operator.
This Operator must be used after the PreConversationOperator, This Operator must be used after the PreChatHistoryLoadOperator,
and it will map the messages in the storage conversation. and it will map the messages in the storage conversation.
Examples: Examples:
@@ -419,8 +316,8 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
if message_mapper: if message_mapper:
def new_message_mapper( def new_message_mapper(
messages_by_round: List[List[ModelMessage]], messages_by_round: List[List[BaseMessage]],
) -> List[ModelMessage]: ) -> List[BaseMessage]:
# Apply keep k round messages first, then apply the custom message mapper # Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round) messages_by_round = self._keep_last_round_messages(messages_by_round)
return message_mapper(messages_by_round) return message_mapper(messages_by_round)
@@ -428,23 +325,23 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
else: else:
def new_message_mapper( def new_message_mapper(
messages_by_round: List[List[ModelMessage]], messages_by_round: List[List[BaseMessage]],
) -> List[ModelMessage]: ) -> List[BaseMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round) messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, []) return sum(messages_by_round, [])
super().__init__(new_message_mapper, **kwargs) super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages( def _keep_last_round_messages(
self, messages_by_round: List[List[ModelMessage]] self, messages_by_round: List[List[BaseMessage]]
) -> List[List[ModelMessage]]: ) -> List[List[BaseMessage]]:
"""Keep the last k round messages. """Keep the last k round messages.
Args: Args:
messages_by_round (List[List[ModelMessage]]): The messages by round. messages_by_round (List[List[BaseMessage]]): The messages by round.
Returns: Returns:
List[List[ModelMessage]]: The latest round messages. List[List[BaseMessage]]: The latest round messages.
""" """
index = self._last_k_round + 1 index = self._last_k_round + 1
return messages_by_round[-index:] return messages_by_round[-index:]

View File

@@ -0,0 +1,255 @@
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import (
BasePromptTemplate,
ChatPromptTemplate,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
StorageConversation,
)
from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.interface.message import BaseMessage
from dbgpt.core.interface.operator.llm_operator import BaseLLM
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
from dbgpt.util.function_utils import rearrange_args_by_type
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs):
super().__init__(check_storage=check_storage, **kwargs)
async def format_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
"""Format the prompt.
Args:
prompt (ChatPromptTemplate): The prompt.
prompt_dict (Dict[str, Any]): The prompt dict.
Returns:
List[ModelMessage]: The formatted prompt.
"""
kwargs = {}
kwargs.update(prompt_dict)
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs)
messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(messages)
return messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
"""Start a new round conversation.
Args:
messages (List[ModelMessage]): The messages.
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
storage_conv: StorageConversation = await self.get_storage_conversation()
if not storage_conv:
return
# Start new round
storage_conv.start_new_round()
storage_conv.add_user_message(lass_user_message)
async def after_dag_end(self):
"""The callback after DAG end"""
# TODO remove this to start_new_round()
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
if not storage_conv:
return
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
)
if model_output:
# Save model output message to storage
storage_conv.add_ai_message(model_output.text)
# End current conversation round and flush to storage
storage_conv.end_current_round()
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
class PromptBuilderOperator(
BasePromptBuilderOperator, MapOperator[Dict[str, Any], List[ModelMessage]]
):
"""The operator to build the prompt with static prompt.
Examples:
.. code-block:: python
import asyncio
from dbgpt.core.awel import DAG
from dbgpt.core import (
ModelMessage,
HumanMessage,
SystemMessage,
HumanPromptTemplate,
SystemPromptTemplate,
ChatPromptTemplate,
)
from dbgpt.core.operator import PromptBuilderOperator
with DAG("prompt_test") as dag:
str_prompt = PromptBuilderOperator(
"Please write a {dialect} SQL count the length of a field"
)
tp_prompt = PromptBuilderOperator(
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
)
)
chat_prompt = PromptBuilderOperator(
ChatPromptTemplate(
messages=[
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
)
]
)
)
with_sys_prompt = PromptBuilderOperator(
ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(
"You are a {dialect} SQL expert"
),
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
),
],
)
)
single_input = {"data": {"dialect": "mysql"}}
single_expected_messages = [
ModelMessage(
content="Please write a mysql SQL count the length of a field",
role="human",
)
]
with_sys_expected_messages = [
ModelMessage(content="You are a mysql SQL expert", role="system"),
ModelMessage(
content="Please write a mysql SQL count the length of a field",
role="human",
),
]
assert (
asyncio.run(str_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(tp_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(chat_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(with_sys_prompt.call(call_data=single_input))
== with_sys_expected_messages
)
"""
def __init__(self, prompt: PromptTemplateType, **kwargs):
if isinstance(prompt, str):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt)]
)
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
prompt, ChatPromptTemplate
):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt.template)]
)
elif isinstance(prompt, MessageType):
prompt = ChatPromptTemplate(messages=[prompt])
self._prompt = prompt
super().__init__(check_storage=False, **kwargs)
MapOperator.__init__(self, map_function=self.merge_prompt, **kwargs)
@rearrange_args_by_type
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
return await self.format_prompt(self._prompt, prompt_dict)
class DynamicPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
"""The operator to build the prompt with dynamic prompt.
The prompt template is dynamic, and it created by parent operator.
"""
def __init__(self, **kwargs):
super().__init__(check_storage=False, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
@rearrange_args_by_type
async def merge_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
return await self.format_prompt(prompt, prompt_dict)
class HistoryPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
def __init__(
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
):
self._prompt = prompt
self._history_key = history_key
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
return await self.format_prompt(self._prompt, prompt_dict)
class HistoryDynamicPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
"""The operator to build the prompt with dynamic prompt.
The prompt template is dynamic, and it created by parent operator.
"""
def __init__(self, history_key: Optional[str] = None, **kwargs):
self._history_key = history_key
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self,
prompt: ChatPromptTemplate,
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
return await self.format_prompt(prompt, prompt_dict)

View File

@@ -1,11 +1,14 @@
from __future__ import annotations
import dataclasses import dataclasses
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt.core._private.example_base import ExampleSelector from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.awel import MapOperator from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.output_parser import BaseOutputParser from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.storage import ( from dbgpt.core.interface.storage import (
InMemoryStorage, InMemoryStorage,
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
} }
class PromptTemplate(BaseModel, ABC): class BasePromptTemplate(BaseModel):
input_variables: List[str] input_variables: List[str]
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""
template: Optional[str]
"""The prompt template."""
template_format: Optional[str] = "f-string"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
self.template, **kwargs
)
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BasePromptTemplate:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str] template_scene: Optional[str]
template_define: Optional[str] template_define: Optional[str]
"""this template define""" """this template define"""
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""strict template will check template args""" """strict template will check template args"""
template_is_strict: bool = True template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
self.template_is_strict self.template_is_strict
)(self.template, **kwargs) )(self.template, **kwargs)
@staticmethod
def from_template(template: str) -> "PromptTemplateOperator": class BaseChatPromptTemplate(BaseModel, ABC):
prompt: BasePromptTemplate
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects."""
return self.prompt.input_variables
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string.""" """Create a prompt template from a template string."""
return PromptTemplateOperator( prompt = BasePromptTemplate.from_template(template, template_format)
PromptTemplate(template=template, input_variables=[]) return cls(prompt=prompt, **kwargs)
)
class SystemPromptTemplate(BaseChatPromptTemplate):
"""The system prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [SystemMessage(content=content)]
class HumanPromptTemplate(BaseChatPromptTemplate):
"""The human prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [HumanMessage(content=content)]
class MessagesPlaceholder(BaseChatPromptTemplate):
"""The messages placeholder template.
Mostly used for the chat history.
"""
variable_name: str
prompt: BasePromptTemplate = None
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
messages = kwargs.get(self.variable_name, [])
if not isinstance(messages, list):
raise ValueError(
f"Unsupported messages type: {type(messages)}, should be list."
)
for message in messages:
if not isinstance(message, BaseMessage):
raise ValueError(
f"Unsupported message type: {type(message)}, should be BaseMessage."
)
return messages
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects.
Returns:
List[str]: The input variables.
"""
return [self.variable_name]
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
class ChatPromptTemplate(BasePromptTemplate):
messages: List[MessageType]
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
result_messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
result_messages.append(message)
elif isinstance(message, BaseChatPromptTemplate):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
elif isinstance(message, MessagesPlaceholder):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@root_validator(pre=True)
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:
input_variables = set()
for message in messages:
if isinstance(message, BaseChatPromptTemplate):
input_variables.update(message.input_variables)
values["input_variables"] = sorted(input_variables)
return values
@dataclasses.dataclass @dataclasses.dataclass
@@ -547,10 +677,36 @@ class PromptManager:
self.storage.delete(identifier) self.storage.delete(identifier)
class PromptTemplateOperator(MapOperator[Dict, str]): def _get_string_template_vars(template_str: str) -> Set[str]:
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any): """Get template variables from a template string."""
super().__init__(**kwargs) variables = set()
self._prompt_template = prompt_template formatter = Formatter()
async def map(self, input_value: Dict) -> str: for _, variable_name, _, _ in formatter.parse(template_str):
return self._prompt_template.format(**input_value) if variable_name:
variables.add(variable_name)
return variables
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
from jinja2 import Environment, meta
env = Environment()
ast = env.parse(template_str)
variables = meta.find_undeclared_variables(ast)
return variables
def get_template_vars(
template_str: str, template_format: str = "f-string"
) -> List[str]:
"""Get template variables from a template string."""
if template_format == "f-string":
result = _get_string_template_vars(template_str)
elif template_format == "jinja2":
result = _get_jinja2_template_vars(template_str)
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(result)

View File

@@ -413,13 +413,18 @@ def test_to_openai_messages(
{"role": "user", "content": human_model_message.content}, {"role": "user", "content": human_model_message.content},
] ]
def test_to_openai_messages_convert_to_compatible_format(
human_model_message, ai_model_message, system_model_message
):
shuffle_messages = ModelMessage.to_openai_messages( shuffle_messages = ModelMessage.to_openai_messages(
[ [
system_model_message, system_model_message,
human_model_message, human_model_message,
human_model_message, human_model_message,
ai_model_message, ai_model_message,
] ],
convert_to_compatible_format=True,
) )
assert shuffle_messages == [ assert shuffle_messages == [
{"role": "system", "content": system_model_message.content}, {"role": "system", "content": system_model_message.content},

View File

@@ -99,12 +99,6 @@ class TestPromptTemplate:
formatted_output = prompt.format(response="hello") formatted_output = prompt.format(response="hello")
assert "Response: " in formatted_output assert "Response: " in formatted_output
def test_from_template(self):
template_str = "Hello {name}"
prompt = PromptTemplate.from_template(template_str)
assert prompt._prompt_template.template == template_str
assert prompt._prompt_template.input_variables == []
def test_format_missing_variable(self): def test_format_missing_variable(self):
template_str = "Hello {name}" template_str = "Hello {name}"
prompt = PromptTemplate( prompt = PromptTemplate(

View File

@@ -1,31 +1,41 @@
from dbgpt.core.interface.operator.composer_operator import (
ChatComposerInput,
ChatHistoryPromptComposerOperator,
)
from dbgpt.core.interface.operator.llm_operator import ( from dbgpt.core.interface.operator.llm_operator import (
BaseLLM, BaseLLM,
BaseLLMOperator,
BaseStreamingLLMOperator,
LLMBranchOperator, LLMBranchOperator,
LLMOperator, RequestBuilderOperator,
RequestBuildOperator,
StreamingLLMOperator,
) )
from dbgpt.core.interface.operator.message_operator import ( from dbgpt.core.interface.operator.message_operator import (
BaseConversationOperator, BaseConversationOperator,
BufferedConversationMapperOperator, BufferedConversationMapperOperator,
ConversationMapperOperator, ConversationMapperOperator,
PostConversationOperator, PreChatHistoryLoadOperator,
PostStreamingConversationOperator, )
PreConversationOperator, from dbgpt.core.interface.operator.prompt_operator import (
DynamicPromptBuilderOperator,
HistoryDynamicPromptBuilderOperator,
HistoryPromptBuilderOperator,
PromptBuilderOperator,
) )
from dbgpt.core.interface.prompt import PromptTemplateOperator
__ALL__ = [ __ALL__ = [
"BaseLLM", "BaseLLM",
"LLMBranchOperator", "LLMBranchOperator",
"LLMOperator", "BaseLLMOperator",
"RequestBuildOperator", "RequestBuilderOperator",
"StreamingLLMOperator", "BaseStreamingLLMOperator",
"BaseConversationOperator", "BaseConversationOperator",
"BufferedConversationMapperOperator", "BufferedConversationMapperOperator",
"ConversationMapperOperator", "ConversationMapperOperator",
"PostConversationOperator", "PreChatHistoryLoadOperator",
"PostStreamingConversationOperator", "PromptBuilderOperator",
"PreConversationOperator", "DynamicPromptBuilderOperator",
"PromptTemplateOperator", "HistoryPromptBuilderOperator",
"HistoryDynamicPromptBuilderOperator",
"ChatComposerInput",
"ChatHistoryPromptComposerOperator",
] ]

View File

@@ -1,13 +1,7 @@
from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.chatgpt_utils import ( from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
OpenAILLMClient,
OpenAIStreamingOperator,
MixinLLMOperator,
)
__ALL__ = [ __ALL__ = [
"DefaultLLMClient", "DefaultLLMClient",
"OpenAILLMClient", "OpenAILLMClient",
"OpenAIStreamingOperator",
"MixinLLMOperator",
] ]

View File

@@ -152,7 +152,7 @@ class LLMModelAdapter(ABC):
return "\n" return "\n"
def transform_model_messages( def transform_model_messages(
self, messages: List[ModelMessage] self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""Transform the model messages """Transform the model messages
@@ -174,15 +174,19 @@ class LLMModelAdapter(ABC):
] ]
Args: Args:
messages (List[ModelMessage]): The model messages messages (List[ModelMessage]): The model messages
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
Returns: Returns:
List[Dict[str, str]]: The transformed model messages List[Dict[str, str]]: The transformed model messages
""" """
logger.info(f"support_system_message: {self.support_system_message}") logger.info(f"support_system_message: {self.support_system_message}")
if not self.support_system_message: if not self.support_system_message and convert_to_compatible_format:
# We will not do any transform in the future
return self._transform_to_no_system_messages(messages) return self._transform_to_no_system_messages(messages)
else: else:
return ModelMessage.to_openai_messages(messages) return ModelMessage.to_openai_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
def _transform_to_no_system_messages( def _transform_to_no_system_messages(
self, messages: List[ModelMessage] self, messages: List[ModelMessage]
@@ -237,6 +241,7 @@ class LLMModelAdapter(ABC):
messages: List[ModelMessage], messages: List[ModelMessage],
tokenizer: Any, tokenizer: Any,
prompt_template: str = None, prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]: ) -> Optional[str]:
"""Get the string prompt from the given parameters and messages """Get the string prompt from the given parameters and messages
@@ -247,6 +252,7 @@ class LLMModelAdapter(ABC):
messages (List[ModelMessage]): The model messages messages (List[ModelMessage]): The model messages
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
prompt_template (str, optional): The prompt template. Defaults to None. prompt_template (str, optional): The prompt template. Defaults to None.
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
Returns: Returns:
Optional[str]: The string prompt Optional[str]: The string prompt
@@ -262,6 +268,7 @@ class LLMModelAdapter(ABC):
model_context: Dict, model_context: Dict,
prompt_template: str = None, prompt_template: str = None,
): ):
convert_to_compatible_format = params.get("convert_to_compatible_format")
conv: ConversationAdapter = self.get_default_conv_template( conv: ConversationAdapter = self.get_default_conv_template(
model_name, model_path model_name, model_path
) )
@@ -277,6 +284,72 @@ class LLMModelAdapter(ABC):
return None, None, None return None, None, None
conv = conv.copy() conv = conv.copy()
if convert_to_compatible_format:
# In old version, we will convert the messages to compatible format
conv = self._set_conv_converted_messages(conv, messages)
else:
# In new version, we will use the messages directly
conv = self._set_conv_messages(conv, messages)
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def _set_conv_messages(
self, conv: ConversationAdapter, messages: List[ModelMessage]
) -> ConversationAdapter:
"""Set the messages to the conversation template
Args:
conv (ConversationAdapter): The conversation template
messages (List[ModelMessage]): The model messages
Returns:
ConversationAdapter: The conversation template with messages
"""
system_messages = []
for message in messages:
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if len(system_messages) > 1:
raise ValueError(
f"Your system messages have more than one message: {system_messages}"
)
if system_messages:
conv.set_system_message(system_messages[0])
return conv
def _set_conv_converted_messages(
self, conv: ConversationAdapter, messages: List[ModelMessage]
) -> ConversationAdapter:
"""Set the messages to the conversation template
In the old version, we will convert the messages to compatible format.
This method will be deprecated in the future.
Args:
conv (ConversationAdapter): The conversation template
messages (List[ModelMessage]): The model messages
Returns:
ConversationAdapter: The conversation template with messages
"""
system_messages = [] system_messages = []
user_messages = [] user_messages = []
ai_messages = [] ai_messages = []
@@ -295,10 +368,8 @@ class LLMModelAdapter(ABC):
# Support for multiple system messages # Support for multiple system messages
system_messages.append(content) system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN: elif role == ModelMessageRoleType.HUMAN:
# conv.append_message(conv.roles[0], content)
user_messages.append(content) user_messages.append(content)
elif role == ModelMessageRoleType.AI: elif role == ModelMessageRoleType.AI:
# conv.append_message(conv.roles[1], content)
ai_messages.append(content) ai_messages.append(content)
else: else:
raise ValueError(f"Unknown role: {role}") raise ValueError(f"Unknown role: {role}")
@@ -320,10 +391,7 @@ class LLMModelAdapter(ABC):
# TODO join all system messages may not be a good idea # TODO join all system messages may not be a good idea
conv.set_system_message("".join(can_use_systems)) conv.set_system_message("".join(can_use_systems))
# Add a blank message for the assistant. return conv
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def model_adaptation( def model_adaptation(
self, self,
@@ -335,6 +403,15 @@ class LLMModelAdapter(ABC):
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
"""Params adaptation""" """Params adaptation"""
messages = params.get("messages") messages = params.get("messages")
convert_to_compatible_format = params.get("convert_to_compatible_format")
message_version = params.get("version", "v2").lower()
logger.info(f"Message version is {message_version}")
if convert_to_compatible_format is None:
# Support convert messages to compatible format when message version is v1
convert_to_compatible_format = message_version == "v1"
# Save to params
params["convert_to_compatible_format"] = convert_to_compatible_format
# Some model context to dbgpt server # Some model context to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages: if messages:
@@ -345,7 +422,9 @@ class LLMModelAdapter(ABC):
] ]
params["messages"] = messages params["messages"] = messages
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template) new_prompt = self.get_str_prompt(
params, messages, tokenizer, prompt_template, convert_to_compatible_format
)
conv_stop_str, conv_stop_token_ids = None, None conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt: if not new_prompt:
( (

View File

@@ -87,6 +87,7 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
messages: List[ModelMessage], messages: List[ModelMessage],
tokenizer: Any, tokenizer: Any,
prompt_template: str = None, prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]: ) -> Optional[str]:
from transformers import AutoTokenizer from transformers import AutoTokenizer
@@ -94,7 +95,7 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
raise ValueError("tokenizer is is None") raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer tokenizer: AutoTokenizer = tokenizer
messages = self.transform_model_messages(messages) messages = self.transform_model_messages(messages, convert_to_compatible_format)
logger.debug(f"The messages after transform: \n{messages}") logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template( str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True

View File

@@ -22,6 +22,8 @@ class PromptRequest(BaseModel):
span_id: str = None span_id: str = None
metrics: bool = False metrics: bool = False
"""Whether to return metrics of inference""" """Whether to return metrics of inference"""
version: str = "v2"
"""Message version, default to v2"""
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):

View File

@@ -1,20 +1,35 @@
from typing import AsyncIterator, List
import asyncio import asyncio
from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata from typing import AsyncIterator, List, Optional
from dbgpt.model.parameter import WorkerType
from dbgpt.core.interface.llm import (
LLMClient,
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
)
from dbgpt.model.cluster.manager_base import WorkerManager from dbgpt.model.cluster.manager_base import WorkerManager
from dbgpt.model.parameter import WorkerType
class DefaultLLMClient(LLMClient): class DefaultLLMClient(LLMClient):
def __init__(self, worker_manager: WorkerManager): def __init__(self, worker_manager: WorkerManager):
self._worker_manager = worker_manager self._worker_manager = worker_manager
async def generate(self, request: ModelRequest) -> ModelOutput: async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = await self.covert_message(request, message_converter)
return await self._worker_manager.generate(request.to_dict()) return await self._worker_manager.generate(request.to_dict())
async def generate_stream( async def generate_stream(
self, request: ModelRequest self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]: ) -> AsyncIterator[ModelOutput]:
request = await self.covert_message(request, message_converter)
async for output in self._worker_manager.generate_stream(request.to_dict()): async for output in self._worker_manager.generate_stream(request.to_dict()):
yield output yield output

View File

@@ -8,7 +8,12 @@ import traceback
from dbgpt.configs.model_config import get_device from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata from dbgpt.core import (
ModelOutput,
ModelInferenceMetrics,
ModelMetadata,
ModelExtraMedata,
)
from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters from dbgpt.model.parameter import ModelParameters
from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.cluster.worker_base import ModelWorker
@@ -196,9 +201,13 @@ class DefaultModelWorker(ModelWorker):
raise NotImplementedError raise NotImplementedError
def get_model_metadata(self, params: Dict) -> ModelMetadata: def get_model_metadata(self, params: Dict) -> ModelMetadata:
ext_metadata = ModelExtraMedata(
prompt_sep=self.llm_adapter.get_default_message_separator()
)
return ModelMetadata( return ModelMetadata(
model=self.model_name, model=self.model_name,
context_length=self.context_len, context_length=self.context_len,
ext_metadata=ext_metadata,
) )
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:

View File

@@ -122,7 +122,7 @@ class RemoteModelWorker(ModelWorker):
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
) )
return ModelMetadata(**response.json()) return ModelMetadata.from_dict(response.json())
def get_model_metadata(self, params: Dict) -> ModelMetadata: def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata""" """Get model metadata"""

View File

@@ -0,0 +1,13 @@
from dbgpt.model.operator.llm_operator import (
LLMOperator,
MixinLLMOperator,
StreamingLLMOperator,
)
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
__ALL__ = [
"MixinLLMOperator",
"LLMOperator",
"StreamingLLMOperator",
"OpenAIStreamingOutputOperator",
]

View File

@@ -0,0 +1,75 @@
import logging
from abc import ABC
from typing import Optional
from dbgpt.component import ComponentType
from dbgpt.core import LLMClient
from dbgpt.core.awel import BaseOperator
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
from dbgpt.model.cluster import WorkerManagerFactory
logger = logging.getLogger(__name__)
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
"""Mixin class for LLM operator.
This class extends BaseOperator by adding LLM capabilities.
"""
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
from dbgpt.model.cluster.client import DefaultLLMClient
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
)
self._llm_client = self._default_llm_client
return self._llm_client
class LLMOperator(MixinLLMOperator, BaseLLMOperator):
"""Default LLM operator.
Args:
llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None.
If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT,
and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
BaseLLMOperator.__init__(self, llm_client, **kwargs)
class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
"""Default streaming LLM operator.
Args:
llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None.
If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT,
and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)

View File

@@ -1,17 +1,18 @@
from typing import AsyncIterator, Dict, List, Union
import logging import logging
from typing import AsyncIterator, Dict, List, Union
from dbgpt.component import ComponentType
from dbgpt.core import ModelOutput
from dbgpt.core.awel import ( from dbgpt.core.awel import (
BranchFunc, BranchFunc,
StreamifyAbsOperator,
BranchOperator, BranchOperator,
MapOperator, MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator, TransformStreamAbsOperator,
) )
from dbgpt.component import ComponentType
from dbgpt.core.awel.operator.base import BaseOperator from dbgpt.core.awel.operator.base import BaseOperator
from dbgpt.core import ModelOutput
from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory
from dbgpt.storage.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -13,6 +13,8 @@ def bard_generate_stream(
proxy_api_key = model_params.proxy_api_key proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url proxy_server_url = model_params.proxy_server_url
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = [] history = []
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
for message in messages: for message in messages:
@@ -25,14 +27,15 @@ def bard_generate_stream(
else: else:
pass pass
last_user_input_index = None if convert_to_compatible_format:
for i in range(len(history) - 1, -1, -1): last_user_input_index = None
if history[i]["role"] == "user": for i in range(len(history) - 1, -1, -1):
last_user_input_index = i if history[i]["role"] == "user":
break last_user_input_index = i
if last_user_input_index: break
last_user_input = history.pop(last_user_input_index) if last_user_input_index:
history.append(last_user_input) last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
msgs = [] msgs = []
for msg in history: for msg in history:

View File

@@ -128,7 +128,10 @@ def _build_request(model: ProxyModel, params):
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
# history = __convert_2_gpt_messages(messages) # history = __convert_2_gpt_messages(messages)
history = ModelMessage.to_openai_messages(messages) convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = ModelMessage.to_openai_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
payloads = { payloads = {
"temperature": params.get("temperature"), "temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"), "max_tokens": params.get("max_new_tokens"),

View File

@@ -12,7 +12,6 @@ def gemini_generate_stream(
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params() model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}") print(f"Model: {model}, model_params: {model_params}")
global history
# TODO proxy model use unified config? # TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key proxy_api_key = model_params.proxy_api_key

View File

@@ -56,6 +56,9 @@ def spark_generate_stream(
del messages[index] del messages[index]
break break
# TODO: Support convert_to_compatible_format config
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = [] history = []
# Add history conversation # Add history conversation
for message in messages: for message in messages:

View File

@@ -53,8 +53,12 @@ def tongyi_generate_stream(
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = __convert_2_tongyi_messages(messages) if convert_to_compatible_format:
history = __convert_2_tongyi_messages(messages)
else:
history = ModelMessage.to_openai_messages(messages)
gen = Generation() gen = Generation()
res = gen.call( res = gen.call(
proxyllm_backend, proxyllm_backend,

View File

@@ -25,8 +25,29 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token") return res.json().get("access_token")
def _to_wenxin_messages(messages: List[ModelMessage]):
"""Convert messages to wenxin compatible format
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
wenxin_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
wenxin_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
wenxin_messages.append({"role": "assistant", "content": message.content})
else:
pass
if len(system_messages) > 1:
raise ValueError("Wenxin only support one system message")
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
def __convert_2_wenxin_messages(messages: List[ModelMessage]): def __convert_2_wenxin_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = [] wenxin_messages = []
last_usr_message = "" last_usr_message = ""
@@ -57,7 +78,8 @@ def __convert_2_wenxin_messages(messages: List[ModelMessage]):
last_message = messages[-1] last_message = messages[-1]
end_message = last_message.content end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message}) wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
def wenxin_generate_stream( def wenxin_generate_stream(
@@ -87,13 +109,14 @@ def wenxin_generate_stream(
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
history, systems = __convert_2_wenxin_messages(messages) convert_to_compatible_format = params.get("convert_to_compatible_format", False)
system = "" if convert_to_compatible_format:
if systems and len(systems) > 0: history, system_message = __convert_2_wenxin_messages(messages)
system = systems[0] else:
history, system_message = _to_wenxin_messages(messages)
payload = { payload = {
"messages": history, "messages": history,
"system": system, "system": system_message,
"temperature": params.get("temperature"), "temperature": params.get("temperature"),
"stream": True, "stream": True,
} }

View File

@@ -57,6 +57,10 @@ def zhipu_generate_stream(
zhipuai.api_key = proxy_api_key zhipuai.api_key = proxy_api_key
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
# TODO: Support convert_to_compatible_format config, zhipu not support system message
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history, systems = __convert_2_zhipu_messages(messages) history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke( res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend, model=proxyllm_backend,

View File

@@ -20,8 +20,13 @@ from typing import (
from dbgpt.component import ComponentType from dbgpt.component import ComponentType
from dbgpt.core.operator import BaseLLM from dbgpt.core.operator import BaseLLM
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
from dbgpt.core.interface.llm import ModelMetadata, LLMClient from dbgpt.core.interface.llm import (
from dbgpt.core.interface.llm import ModelOutput, ModelRequest ModelOutput,
ModelRequest,
ModelMetadata,
LLMClient,
MessageConverter,
)
from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json from dbgpt._private.pydantic import model_to_json
@@ -175,7 +180,13 @@ class OpenAILLMClient(LLMClient):
payload["max_tokens"] = request.max_new_tokens payload["max_tokens"] = request.max_new_tokens
return payload return payload
async def generate(self, request: ModelRequest) -> ModelOutput: async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = await self.covert_message(request, message_converter)
messages = request.to_openai_messages() messages = request.to_openai_messages()
payload = self._build_request(request) payload = self._build_request(request)
logger.info( logger.info(
@@ -195,8 +206,11 @@ class OpenAILLMClient(LLMClient):
) )
async def generate_stream( async def generate_stream(
self, request: ModelRequest self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]: ) -> AsyncIterator[ModelOutput]:
request = await self.covert_message(request, message_converter)
messages = request.to_openai_messages() messages = request.to_openai_messages()
payload = self._build_request(request, True) payload = self._build_request(request, True)
logger.info( logger.info(
@@ -247,7 +261,7 @@ class OpenAILLMClient(LLMClient):
return self._tokenizer.count_token(prompt, model) return self._tokenizer.count_token(prompt, model)
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
"""Transform ModelOutput to openai stream format.""" """Transform ModelOutput to openai stream format."""
async def transform_stream( async def transform_stream(
@@ -266,40 +280,6 @@ class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
yield output yield output
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
"""Mixin class for LLM operator.
This class extends BaseOperator by adding LLM capabilities.
"""
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
)
self._llm_client = self._default_llm_client
return self._llm_client
async def _to_openai_stream( async def _to_openai_stream(
output_iter: AsyncIterator[ModelOutput], output_iter: AsyncIterator[ModelOutput],
model: Optional[str] = None, model: Optional[str] = None,

View File

@@ -0,0 +1,71 @@
import logging
from typing import Any, Optional
from dbgpt.core import (
InMemoryStorage,
MessageStorageItem,
StorageConversation,
StorageInterface,
)
from dbgpt.core.operator import PreChatHistoryLoadOperator
from .serve import Serve
logger = logging.getLogger(__name__)
class ServePreChatHistoryLoadOperator(PreChatHistoryLoadOperator):
"""Pre-chat history load operator for DB-GPT serve component
Args:
storage (Optional[StorageInterface[StorageConversation, Any]], optional):
The conversation storage, store the conversation items. Defaults to None.
message_storage (Optional[StorageInterface[MessageStorageItem, Any]], optional):
The message storage, store the messages of one conversation. Defaults to None.
If the storage or message_storage is not None, the storage or message_storage will be used first.
Otherwise, we will try get current serve component from system app,
and use the storage or message_storage of the serve component.
If we can't get the storage, we will use the InMemoryStorage as the storage or message_storage.
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage, message_storage, **kwargs)
@property
def storage(self):
if self._storage:
return self._storage
storage = Serve.call_on_current_serve(
self.system_app, lambda serve: serve.conv_storage
)
if not storage:
logger.warning(
"Can't get the conversation storage from current serve component, "
"use the InMemoryStorage as the conversation storage."
)
self._storage = InMemoryStorage()
return self._storage
return storage
@property
def message_storage(self):
if self._message_storage:
return self._message_storage
storage = Serve.call_on_current_serve(
self.system_app,
lambda serve: serve.message_storage,
)
if not storage:
logger.warning(
"Can't get the message storage from current serve component, "
"use the InMemoryStorage as the message storage."
)
self._message_storage = InMemoryStorage()
return self._message_storage
return storage

View File

@@ -1,6 +1,6 @@
import logging import logging
from abc import ABC from abc import ABC
from typing import List, Optional, Union from typing import Any, Callable, List, Optional, Union
from sqlalchemy import URL from sqlalchemy import URL
@@ -60,3 +60,44 @@ class BaseServe(BaseComponent, ABC):
finally: finally:
self._not_create_table = False self._not_create_table = False
return init_db return init_db
@classmethod
def get_current_serve(cls, system_app: SystemApp) -> Optional["BaseServe"]:
"""Get the current serve component.
None if the serve component is not exist.
Args:
system_app (SystemApp): The system app
Returns:
Optional[BaseServe]: The current serve component.
"""
return system_app.get_component(cls.name, cls, default_component=None)
@classmethod
def call_on_current_serve(
cls,
system_app: SystemApp,
func: Callable[["BaseServe"], Optional[Any]],
default_value: Optional[Any] = None,
) -> Optional[Any]:
"""Call the function on the current serve component.
Return default_value if the serve component is not exist or the function return None.
Args:
system_app (SystemApp): The system app
func (Callable[[BaseServe], Any]): The function to call
default_value (Optional[Any], optional): The default value. Defaults to None.
Returns:
Optional[Any]: The result of the function
"""
serve = cls.get_current_serve(system_app)
if not serve:
return default_value
result = func(serve)
if not result:
result = default_value
return result

View File

@@ -0,0 +1,87 @@
from typing import Any, get_type_hints, get_origin, get_args
from functools import wraps
import inspect
import asyncio
def _is_instance_of_generic_type(obj, generic_type):
"""Check if an object is an instance of a generic type."""
if generic_type is Any:
return True # Any type is compatible with any object
origin = get_origin(generic_type)
if origin is None:
return isinstance(obj, generic_type) # Handle non-generic types
args = get_args(generic_type)
if not args:
return isinstance(obj, origin)
# Check if object matches the generic origin (like list, dict)
if not isinstance(obj, origin):
return False
# For each item in the object, check if it matches the corresponding type argument
for sub_obj, arg in zip(obj, args):
# Skip check if the type argument is Any
if arg is not Any and not isinstance(sub_obj, arg):
return False
return True
def _sort_args(func, args, kwargs):
sig = inspect.signature(func)
type_hints = get_type_hints(func)
arg_types = [
type_hints[param_name]
for param_name in sig.parameters
if param_name != "return" and param_name != "self"
]
if "self" in sig.parameters:
self_arg = [args[0]]
other_args = args[1:]
else:
self_arg = []
other_args = args
sorted_args = sorted(
other_args,
key=lambda x: next(
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
),
)
return (*self_arg, *sorted_args), kwargs
def rearrange_args_by_type(func):
"""Decorator to rearrange the arguments of a function by type.
Examples:
.. code-block:: python
from dbgpt.util.function_utils import rearrange_args_by_type
@rearrange_args_by_type
def sync_regular_function(a: int, b: str, c: float):
return a, b, c
assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0)
assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0)
"""
@wraps(func)
def sync_wrapper(*args, **kwargs):
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
return func(*sorted_args, **sorted_kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
return await func(*sorted_args, **sorted_kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper

View File

@@ -10,11 +10,12 @@ needed), or truncating them so that they fit in a single LLM call.
import logging import logging
from string import Formatter from string import Formatter
from typing import Callable, List, Optional, Sequence from typing import Callable, List, Optional, Sequence, Set
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
from dbgpt.util.global_helper import globals_helper from dbgpt.util.global_helper import globals_helper
from dbgpt.core.interface.prompt import get_template_vars
from dbgpt._private.llm_metadata import LLMMetadata from dbgpt._private.llm_metadata import LLMMetadata
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
@@ -230,15 +231,3 @@ def get_empty_prompt_txt(template: str) -> str:
all_kwargs = {**partial_kargs, **empty_kwargs} all_kwargs = {**partial_kargs, **empty_kwargs}
prompt = template.format(**all_kwargs) prompt = template.format(**all_kwargs)
return prompt return prompt
def get_template_vars(template_str: str) -> List[str]:
"""Get template variables from a template string."""
variables = []
formatter = Formatter()
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.append(variable_name)
return variables

View File

@@ -0,0 +1,120 @@
from typing import List, Dict, Any
import pytest
from dbgpt.util.function_utils import rearrange_args_by_type
class ChatPromptTemplate:
pass
class BaseMessage:
pass
class ModelMessage:
pass
class DummyClass:
@rearrange_args_by_type
async def class_method(self, a: int, b: str, c: float):
return a, b, c
@rearrange_args_by_type
async def merge_history(
self,
prompt: ChatPromptTemplate,
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
return [type(prompt), type(history), type(prompt_dict)]
@rearrange_args_by_type
def sync_class_method(self, a: int, b: str, c: float):
return a, b, c
@rearrange_args_by_type
def sync_regular_function(a: int, b: str, c: float):
return a, b, c
@rearrange_args_by_type
async def regular_function(a: int, b: str, c: float):
return a, b, c
@pytest.mark.asyncio
async def test_class_method_correct_order():
instance = DummyClass()
result = await instance.class_method(1, "b", 3.0)
assert result == (1, "b", 3.0), "Class method failed with correct order"
@pytest.mark.asyncio
async def test_class_method_incorrect_order():
instance = DummyClass()
result = await instance.class_method("b", 3.0, 1)
assert result == (1, "b", 3.0), "Class method failed with incorrect order"
@pytest.mark.asyncio
async def test_regular_function_correct_order():
result = await regular_function(1, "b", 3.0)
assert result == (1, "b", 3.0), "Regular function failed with correct order"
@pytest.mark.asyncio
async def test_regular_function_incorrect_order():
result = await regular_function("b", 3.0, 1)
assert result == (1, "b", 3.0), "Regular function failed with incorrect order"
@pytest.mark.asyncio
async def test_merge_history_correct_order():
instance = DummyClass()
result = await instance.merge_history(
ChatPromptTemplate(), [BaseMessage()], {"key": "value"}
)
assert result == [ChatPromptTemplate, list, dict], "Failed with correct order"
@pytest.mark.asyncio
async def test_merge_history_incorrect_order_1():
instance = DummyClass()
result = await instance.merge_history(
[BaseMessage()], ChatPromptTemplate(), {"key": "value"}
)
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 1"
@pytest.mark.asyncio
async def test_merge_history_incorrect_order_2():
instance = DummyClass()
result = await instance.merge_history(
{"key": "value"}, [BaseMessage()], ChatPromptTemplate()
)
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 2"
def test_sync_class_method_correct_order():
instance = DummyClass()
result = instance.sync_class_method(1, "b", 3.0)
assert result == (1, "b", 3.0), "Sync class method failed with correct order"
def test_sync_class_method_incorrect_order():
instance = DummyClass()
result = instance.sync_class_method("b", 3.0, 1)
assert result == (1, "b", 3.0), "Sync class method failed with incorrect order"
def test_sync_regular_function_correct_order():
result = sync_regular_function(1, "b", 3.0)
assert result == (1, "b", 3.0), "Sync regular function failed with correct order"
def test_sync_regular_function_incorrect_order():
result = sync_regular_function("b", 3.0, 1)
assert result == (1, "b", 3.0), "Sync regular function failed with incorrect order"

View File

@@ -26,7 +26,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"command": "dbgpt_awel_data_analyst_code_fix", "command": "dbgpt_awel_data_analyst_code_fix",
"model": "gpt-3.5-turbo", "model": "'"$MODEL"'",
"stream": false, "stream": false,
"context": { "context": {
"conv_uid": "uuid_conv_copilot_1234", "conv_uid": "uuid_conv_copilot_1234",
@@ -37,43 +37,55 @@
""" """
import logging import logging
import os
from functools import cache from functools import cache
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import ( from dbgpt.core import (
InMemoryStorage, ChatPromptTemplate,
LLMClient, HumanPromptTemplate,
MessageStorageItem, MessagesPlaceholder,
ModelMessage, ModelMessage,
ModelMessageRoleType, ModelRequest,
ModelRequestContext,
PromptManager, PromptManager,
PromptTemplate, PromptTemplate,
StorageConversation, SystemPromptTemplate,
StorageInterface,
) )
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import ( from dbgpt.core.operator import (
BufferedConversationMapperOperator, BufferedConversationMapperOperator,
HistoryDynamicPromptBuilderOperator,
LLMBranchOperator, LLMBranchOperator,
RequestBuilderOperator,
)
from dbgpt.model.operator import (
LLMOperator, LLMOperator,
PostConversationOperator, OpenAIStreamingOutputOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
StreamingLLMOperator, StreamingLLMOperator,
) )
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator from dbgpt.serve.conversation.operator import ServePreChatHistoryLoadOperator
from dbgpt.util.utils import colored
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROMPT_LANG_ZH = "zh"
PROMPT_LANG_EN = "en"
CODE_DEFAULT = "dbgpt_awel_data_analyst_code_default"
CODE_FIX = "dbgpt_awel_data_analyst_code_fix" CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
CODE_PERF = "dbgpt_awel_data_analyst_code_perf" CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain" CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment" CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate" CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
CODE_DEFAULT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师。
你可以根据最佳实践来优化代码, 也可以对代码进行修复, 解释, 添加注释, 以及将代码翻译成其他语言。"""
CODE_DEFAULT_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst.
You can optimize the code according to best practices, or fix, explain, add comments to the code,
and you can also translate the code into other languages.
"""
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师, CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。""" 这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst, CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
@@ -126,7 +138,9 @@ class ReqContext(BaseModel):
class TriggerReqBody(BaseModel): class TriggerReqBody(BaseModel):
messages: str = Field(..., description="User input messages") messages: str = Field(..., description="User input messages")
command: Optional[str] = Field(default="fix", description="Command name") command: Optional[str] = Field(
default=None, description="Command name, None if common chat"
)
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name") model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream") stream: Optional[bool] = Field(default=False, description="Whether return stream")
language: Optional[str] = Field(default="hive", description="Language") language: Optional[str] = Field(default="hive", description="Language")
@@ -140,109 +154,89 @@ class TriggerReqBody(BaseModel):
@cache @cache
def load_or_save_prompt_template(pm: PromptManager): def load_or_save_prompt_template(pm: PromptManager):
ext_params = { zh_ext_params = {
"chat_scene": "chat_with_code", "chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst", "sub_chat_scene": "data_analyst",
"prompt_type": "common", "prompt_type": "common",
"prompt_language": PROMPT_LANG_ZH,
} }
en_ext_params = {
"chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst",
"prompt_type": "common",
"prompt_language": PROMPT_LANG_EN,
}
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_ZH),
input_variables=["language"], prompt_name=CODE_DEFAULT,
template=CODE_FIX_TEMPLATE_ZH, **zh_ext_params,
), )
pm.query_or_save(
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_EN),
prompt_name=CODE_DEFAULT,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate.from_template(CODE_FIX_TEMPLATE_ZH),
prompt_name=CODE_FIX, prompt_name=CODE_FIX,
prompt_language="zh", **zh_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_FIX_TEMPLATE_EN),
input_variables=["language"],
template=CODE_FIX_TEMPLATE_EN,
),
prompt_name=CODE_FIX, prompt_name=CODE_FIX,
prompt_language="en", **en_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_PERF_TEMPLATE_ZH),
input_variables=["language"],
template=CODE_PERF_TEMPLATE_ZH,
),
prompt_name=CODE_PERF, prompt_name=CODE_PERF,
prompt_language="zh", **zh_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_PERF_TEMPLATE_EN),
input_variables=["language"],
template=CODE_PERF_TEMPLATE_EN,
),
prompt_name=CODE_PERF, prompt_name=CODE_PERF,
prompt_language="en", **en_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_ZH),
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_ZH,
),
prompt_name=CODE_EXPLAIN, prompt_name=CODE_EXPLAIN,
prompt_language="zh", **zh_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_EN),
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_EN,
),
prompt_name=CODE_EXPLAIN, prompt_name=CODE_EXPLAIN,
prompt_language="en", **en_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_ZH),
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_ZH,
),
prompt_name=CODE_COMMENT, prompt_name=CODE_COMMENT,
prompt_language="zh", **zh_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_EN),
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_EN,
),
prompt_name=CODE_COMMENT, prompt_name=CODE_COMMENT,
prompt_language="en", **en_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_ZH),
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_ZH,
),
prompt_name=CODE_TRANSLATE, prompt_name=CODE_TRANSLATE,
prompt_language="zh", **zh_ext_params,
**ext_params,
) )
pm.query_or_save( pm.query_or_save(
PromptTemplate( PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_EN),
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_EN,
),
prompt_name=CODE_TRANSLATE, prompt_name=CODE_TRANSLATE,
prompt_language="en", **en_ext_params,
**ext_params,
) )
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]): class PromptTemplateBuilderOperator(MapOperator[TriggerReqBody, ChatPromptTemplate]):
"""Build prompt template for chat with code."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._default_prompt_manager = PromptManager() self._default_prompt_manager = PromptManager()
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: async def map(self, input_value: TriggerReqBody) -> ChatPromptTemplate:
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
from dbgpt.serve.prompt.serve import Serve as PromptServe from dbgpt.serve.prompt.serve import Serve as PromptServe
@@ -256,7 +250,24 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
load_or_save_prompt_template(pm) load_or_save_prompt_template(pm)
user_language = self.system_app.config.get_current_lang(default="en") user_language = self.system_app.config.get_current_lang(default="en")
if not input_value.command:
# No command, just chat, not include system prompt.
default_prompt_list = pm.prefer_query(
CODE_DEFAULT, prefer_prompt_language=user_language
)
default_prompt_template = (
default_prompt_list[0].to_prompt_template().template
)
prompt = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(default_prompt_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{user_input}"),
]
)
return prompt
# Query prompt template from prompt manager by command name
prompt_list = pm.prefer_query( prompt_list = pm.prefer_query(
input_value.command, prefer_prompt_language=user_language input_value.command, prefer_prompt_language=user_language
) )
@@ -264,109 +275,38 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}" error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
logger.error(error_msg) logger.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
prompt = prompt_list[0].to_prompt_template() prompt_template = prompt_list[0].to_prompt_template()
if input_value.command == CODE_TRANSLATE:
format_params = {
"source_language": input_value.language,
"target_language": input_value.target_language,
}
else:
format_params = {"language": input_value.language}
system_message = prompt.format(**format_params) return ChatPromptTemplate(
messages = [ messages=[
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message), SystemPromptTemplate.from_template(prompt_template.template),
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages), MessagesPlaceholder(variable_name="chat_history"),
] HumanPromptTemplate.from_template("{user_input}"),
context = input_value.context.dict() if input_value.context else {} ]
return {
"messages": messages,
"stream": input_value.stream,
"model": input_value.model,
"context": context,
}
class MyConversationOperator(PreConversationOperator):
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage, message_storage, **kwargs)
def _get_conversion_serve(self):
from dbgpt.serve.conversation.serve import (
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
) )
from dbgpt.serve.conversation.serve import Serve as ConversationServe
conversation_serve: ConversationServe = self.system_app.get_component(
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
)
return conversation_serve
@property
def storage(self):
if self._storage:
return self._storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.conv_storage
else:
logger.info("Conversation storage not found, use InMemoryStorage default")
self._storage = InMemoryStorage()
return self._storage
@property
def message_storage(self):
if self._message_storage:
return self._message_storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.message_storage
else:
logger.info("Message storage not found, use InMemoryStorage default")
self._message_storage = InMemoryStorage()
return self._message_storage
class MyLLMOperator(MixinLLMOperator, LLMOperator): def parse_prompt_args(req: TriggerReqBody) -> Dict[str, Any]:
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): prompt_args = {"user_input": req.messages}
super().__init__(llm_client) if not req.command:
LLMOperator.__init__(self, llm_client, **kwargs) return prompt_args
if req.command == CODE_TRANSLATE:
prompt_args["source_language"] = req.language
prompt_args["target_language"] = req.target_language
else:
prompt_args["language"] = req.language
return prompt_args
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): async def build_model_request(
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): messages: List[ModelMessage], req_body: TriggerReqBody
super().__init__(llm_client) ) -> ModelRequest:
StreamingLLMOperator.__init__(self, llm_client, **kwargs) return ModelRequest.build_request(
model=req_body.model,
messages=messages,
def history_message_mapper( context=req_body.context,
messages_by_round: List[List[ModelMessage]], stream=req_body.stream,
) -> List[ModelMessage]: )
"""Mapper for history conversation.
If there are multi system messages, just keep the first system message.
"""
has_system_message = False
mapper_messages = []
for messages in messages_by_round:
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
if has_system_message:
continue
else:
mapper_messages.append(message)
has_system_message = True
else:
mapper_messages.append(message)
print("history_message_mapper start:" + "=" * 70)
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
print("history_message_mapper end:" + "=" * 72)
return mapper_messages
with DAG("dbgpt_awel_data_analyst_assistant") as dag: with DAG("dbgpt_awel_data_analyst_assistant") as dag:
@@ -377,57 +317,59 @@ with DAG("dbgpt_awel_data_analyst_assistant") as dag:
streaming_predict_func=lambda x: x.stream, streaming_predict_func=lambda x: x.stream,
) )
copilot_task = CopilotOperator() prompt_template_load_task = PromptTemplateBuilderOperator()
request_handle_task = RequestBuildOperator() request_handle_task = RequestBuilderOperator()
# Pre-process conversation # Load and store chat history
pre_conversation_task = MyConversationOperator() chat_history_load_task = ServePreChatHistoryLoadOperator()
# Keep last k round conversation. last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5))
history_conversation_task = BufferedConversationMapperOperator( # History transform task, here we keep last k round messages
last_k_round=5, message_mapper=history_message_mapper history_transform_task = BufferedConversationMapperOperator(
last_k_round=last_k_round
)
history_prompt_build_task = HistoryDynamicPromptBuilderOperator(
history_key="chat_history"
) )
# Save conversation to storage. model_request_build_task = JoinOperator(build_model_request)
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
# Use LLMOperator to generate response. # Use BaseLLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task") llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator( branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
) )
model_parse_task = MapOperator(lambda out: out.to_dict()) model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator() openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator( result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
) )
trigger >> prompt_template_load_task >> history_prompt_build_task
( (
trigger trigger
>> copilot_task >> MapOperator(
>> request_handle_task lambda req: ModelRequestContext(
>> pre_conversation_task conv_uid=req.context.conv_uid,
>> history_conversation_task stream=req.stream,
>> branch_task chat_mode=req.context.chat_mode,
)
)
>> chat_history_load_task
>> history_transform_task
>> history_prompt_build_task
) )
trigger >> MapOperator(parse_prompt_args) >> history_prompt_build_task
history_prompt_build_task >> model_request_build_task
trigger >> model_request_build_task
model_request_build_task >> branch_task
# The branch of no streaming response. # The branch of no streaming response.
( (branch_task >> llm_task >> model_parse_task >> result_join_task)
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
# The branch of streaming response. # The branch of streaming response.
( (branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task)
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
if __name__ == "__main__": if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode: if dag.leaf_nodes[0].dev_mode:

View File

@@ -12,7 +12,7 @@
# Fist round # Fist round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo", "model": "'"$MODEL"'",
"context": { "context": {
"conv_uid": "uuid_conv_1234" "conv_uid": "uuid_conv_1234"
}, },
@@ -22,7 +22,7 @@
# Second round # Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo", "model": "'"$MODEL"'",
"context": { "context": {
"conv_uid": "uuid_conv_1234" "conv_uid": "uuid_conv_1234"
}, },
@@ -34,7 +34,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo", "model": "'"$MODEL"'",
"context": { "context": {
"conv_uid": "uuid_conv_stream_1234" "conv_uid": "uuid_conv_stream_1234"
}, },
@@ -45,7 +45,7 @@
# Second round # Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo", "model": "'"$MODEL"'",
"context": { "context": {
"conv_uid": "uuid_conv_stream_1234" "conv_uid": "uuid_conv_stream_1234"
}, },
@@ -59,19 +59,27 @@ import logging
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import InMemoryStorage, LLMClient from dbgpt.core import (
ChatPromptTemplate,
HumanPromptTemplate,
InMemoryStorage,
MessagesPlaceholder,
ModelMessage,
ModelRequest,
ModelRequestContext,
SystemPromptTemplate,
)
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import ( from dbgpt.core.operator import (
BufferedConversationMapperOperator, ChatComposerInput,
ChatHistoryPromptComposerOperator,
LLMBranchOperator, LLMBranchOperator,
)
from dbgpt.model.operator import (
LLMOperator, LLMOperator,
PostConversationOperator, OpenAIStreamingOutputOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
StreamingLLMOperator, StreamingLLMOperator,
) )
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,16 +108,15 @@ class TriggerReqBody(BaseModel):
) )
class MyLLMOperator(MixinLLMOperator, LLMOperator): async def build_model_request(
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): messages: List[ModelMessage], req_body: TriggerReqBody
super().__init__(llm_client) ) -> ModelRequest:
LLMOperator.__init__(self, llm_client, **kwargs) return ModelRequest.build_request(
model=req_body.model,
messages=messages,
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): context=req_body.context,
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): stream=req_body.stream,
super().__init__(llm_client) )
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag: with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
@@ -120,56 +127,53 @@ with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
request_body=TriggerReqBody, request_body=TriggerReqBody,
streaming_predict_func=lambda req: req.stream, streaming_predict_func=lambda req: req.stream,
) )
# Transform request body to model request. prompt = ChatPromptTemplate(
request_handle_task = RequestBuildOperator() messages=[
# Pre-process conversation, use InMemoryStorage to store conversation. SystemPromptTemplate.from_template("You are a helpful chatbot."),
pre_conversation_task = PreConversationOperator( MessagesPlaceholder(variable_name="chat_history"),
storage=InMemoryStorage(), message_storage=InMemoryStorage() HumanPromptTemplate.from_template("{user_input}"),
]
) )
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
# Save conversation to storage. composer_operator = ChatHistoryPromptComposerOperator(
post_conversation_task = PostConversationOperator() prompt_template=prompt,
# Save streaming conversation to storage. last_k_round=5,
post_streaming_conversation_task = PostStreamingConversationOperator() storage=InMemoryStorage(),
message_storage=InMemoryStorage(),
)
# Use LLMOperator to generate response. # Use BaseLLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task") llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator( branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
) )
model_parse_task = MapOperator(lambda out: out.to_dict()) model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator() openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator( result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
) )
( req_handle_task = MapOperator(
trigger lambda req: ChatComposerInput(
>> request_handle_task context=ModelRequestContext(
>> pre_conversation_task conv_uid=req.context.conv_uid, stream=req.stream
>> history_conversation_task ),
>> branch_task prompt_dict={"user_input": req.messages},
model_dict={
"model": req.model,
"context": req.context,
"stream": req.stream,
},
)
) )
trigger >> req_handle_task >> composer_operator >> branch_task
# The branch of no streaming response. # The branch of no streaming response.
( branch_task >> llm_task >> model_parse_task >> result_join_task
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
# The branch of streaming response. # The branch of streaming response.
( branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
if __name__ == "__main__": if __name__ == "__main__":
if multi_round_dag.leaf_nodes[0].dev_mode: if multi_round_dag.leaf_nodes[0].dev_mode:

View File

@@ -31,3 +31,11 @@ with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody) trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator() map_node = RequestHandleOperator()
trigger >> map_node trigger >> map_node
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag])
else:
pass

View File

@@ -8,9 +8,10 @@
.. code-block:: shell .. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5555" DBGPT_SERVER="http://127.0.0.1:5555"
MODEL="gpt-3.5-turbo"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "proxyllm", "model": "'"$MODEL"'",
"messages": "hello" "messages": "hello"
}' }'
@@ -19,7 +20,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "proxyllm", "model": "'"$MODEL"'",
"messages": "hello", "messages": "hello",
"stream": true "stream": true
}' }'
@@ -29,7 +30,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \ curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
-H "Content-Type: application/json" -d '{ -H "Content-Type: application/json" -d '{
"model": "proxyllm", "model": "'"$MODEL"'",
"messages": "hello" "messages": "hello"
}' }'
@@ -40,13 +41,13 @@ from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import LLMClient from dbgpt.core import LLMClient
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import ( from dbgpt.core.operator import LLMBranchOperator, RequestBuilderOperator
LLMBranchOperator, from dbgpt.model.operator import (
LLMOperator, LLMOperator,
RequestBuildOperator, MixinLLMOperator,
OpenAIStreamingOutputOperator,
StreamingLLMOperator, StreamingLLMOperator,
) )
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,18 +60,6 @@ class TriggerReqBody(BaseModel):
stream: Optional[bool] = Field(default=False, description="Whether return stream") stream: Optional[bool] = Field(default=False, description="Whether return stream")
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
class MyModelToolOperator( class MyModelToolOperator(
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]] MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
): ):
@@ -97,14 +86,14 @@ with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
request_body=TriggerReqBody, request_body=TriggerReqBody,
streaming_predict_func=lambda req: req.stream, streaming_predict_func=lambda req: req.stream,
) )
request_handle_task = RequestBuildOperator() request_handle_task = RequestBuilderOperator()
llm_task = MyLLMOperator(task_name="llm_task") llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator( branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
) )
model_parse_task = MapOperator(lambda out: out.to_dict()) model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator() openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator( result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
) )

View File

@@ -1,16 +1,20 @@
import asyncio import asyncio
from dbgpt.core import BaseOutputParser, PromptTemplate from dbgpt.core import BaseOutputParser
from dbgpt.core.awel import DAG from dbgpt.core.awel import DAG
from dbgpt.core.operator import LLMOperator, RequestBuildOperator from dbgpt.core.operator import (
BaseLLMOperator,
PromptBuilderOperator,
RequestBuilderOperator,
)
from dbgpt.model import OpenAILLMClient from dbgpt.model import OpenAILLMClient
with DAG("simple_sdk_llm_example_dag") as dag: with DAG("simple_sdk_llm_example_dag") as dag:
prompt_task = PromptTemplate.from_template( prompt_task = PromptBuilderOperator(
"Write a SQL of {dialect} to query all data of {table_name}." "Write a SQL of {dialect} to query all data of {table_name}."
) )
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo")
llm_task = LLMOperator(OpenAILLMClient()) llm_task = BaseLLMOperator(OpenAILLMClient())
out_parse_task = BaseOutputParser() out_parse_task = BaseOutputParser()
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task

View File

@@ -2,7 +2,7 @@ import asyncio
import json import json
from typing import Dict, List from typing import Dict, List
from dbgpt.core import PromptTemplate, SQLOutputParser from dbgpt.core import SQLOutputParser
from dbgpt.core.awel import ( from dbgpt.core.awel import (
DAG, DAG,
InputOperator, InputOperator,
@@ -10,7 +10,11 @@ from dbgpt.core.awel import (
MapOperator, MapOperator,
SimpleCallDataInputSource, SimpleCallDataInputSource,
) )
from dbgpt.core.operator import LLMOperator, RequestBuildOperator from dbgpt.core.operator import (
BaseLLMOperator,
PromptBuilderOperator,
RequestBuilderOperator,
)
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.model import OpenAILLMClient from dbgpt.model import OpenAILLMClient
@@ -116,9 +120,9 @@ with DAG("simple_sdk_llm_sql_example") as dag:
retriever_task = DatasourceRetrieverOperator(connection=db_connection) retriever_task = DatasourceRetrieverOperator(connection=db_connection)
# Merge the input data and the table structure information. # Merge the input data and the table structure information.
prompt_input_task = JoinOperator(combine_function=_join_func) prompt_input_task = JoinOperator(combine_function=_join_func)
prompt_task = PromptTemplate.from_template(_sql_prompt()) prompt_task = PromptBuilderOperator(_sql_prompt())
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo")
llm_task = LLMOperator(OpenAILLMClient()) llm_task = BaseLLMOperator(OpenAILLMClient())
out_parse_task = SQLOutputParser() out_parse_task = SQLOutputParser()
sql_parse_task = MapOperator(map_function=lambda x: x["sql"]) sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
db_query_task = DatasourceOperator(connection=db_connection) db_query_task = DatasourceOperator(connection=db_connection)

View File

@@ -411,6 +411,8 @@ def core_requires():
"aiofiles", "aiofiles",
# for agent # for agent
"GitPython", "GitPython",
# For AWEL dag visualization, graphviz is a small package, also we can move it to default.
"graphviz",
] ]