mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -169,14 +169,18 @@ CREATE TABLE IF NOT EXISTS `prompt_manage`
|
||||
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
|
||||
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
|
||||
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
|
||||
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
||||
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
|
||||
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
|
||||
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
|
||||
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
|
||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
||||
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
|
||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
|
||||
KEY `gmt_created_idx` (`gmt_created`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
|
||||
|
||||
|
@@ -1,11 +1,9 @@
|
||||
from dbgpt.core.interface.llm import (
|
||||
ModelInferenceMetrics,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
ModelOutput,
|
||||
LLMClient,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
RequestBuildOperator,
|
||||
ModelMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
@@ -17,7 +15,11 @@ from dbgpt.core.interface.message import (
|
||||
ConversationIdentifier,
|
||||
MessageIdentifier,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptTemplate,
|
||||
PromptManager,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import (
|
||||
@@ -38,17 +40,15 @@ from dbgpt.core.interface.storage import (
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
"ModelRequest",
|
||||
"ModelRequestContext",
|
||||
"ModelOutput",
|
||||
"Operator",
|
||||
"RequestBuildOperator",
|
||||
"ModelMetadata",
|
||||
"ModelMessage",
|
||||
"LLMClient",
|
||||
"LLMOperator",
|
||||
"StreamingLLMOperator",
|
||||
"ModelMessageRoleType",
|
||||
"OnceConversation",
|
||||
"StorageConversation",
|
||||
@@ -56,7 +56,8 @@ __ALL__ = [
|
||||
"ConversationIdentifier",
|
||||
"MessageIdentifier",
|
||||
"PromptTemplate",
|
||||
"PromptTemplateOperator",
|
||||
"PromptManager",
|
||||
"StoragePromptTemplate",
|
||||
"BaseOutputParser",
|
||||
"SQLOutputParser",
|
||||
"Serializable",
|
||||
|
@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
@@ -68,6 +69,7 @@ __all__ = [
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
"setup_dev_environment",
|
||||
]
|
||||
|
||||
|
||||
@@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
||||
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .dag.base import DAGVar
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
trigger_manager = DefaultTriggerManager()
|
||||
system_app.register_instance(trigger_manager)
|
||||
|
||||
for dag in dags:
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
@@ -11,7 +11,7 @@ from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
from ..task.base import TaskContext, TaskOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -168,7 +168,19 @@ class DAGVar:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
pass
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
) -> None:
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
Dict[str, TaskContext]: The task outputs of current DAG
|
||||
"""
|
||||
return self._node_to_outputs
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
@@ -292,12 +347,69 @@ class DAGContext:
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
return self.get_from_share_data(_build_task_key(task_name, key))
|
||||
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
@@ -305,11 +417,20 @@ class DAG:
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
return
|
||||
if node.node_name:
|
||||
if node.node_name in self.node_name_to_node:
|
||||
raise ValueError(
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
@@ -336,22 +457,44 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
@@ -146,6 +146,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
In production mode, the default runner is not None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
|
@@ -1,4 +1,14 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
from typing import (
|
||||
Generic,
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Optional,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
@@ -203,7 +215,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
@@ -229,7 +241,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
@@ -18,18 +19,20 @@ class DAGInstance:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
class JobManager(DAGLifecycle):
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
@@ -38,11 +41,31 @@ class JobManager:
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
node_name_to_ids = {}
|
||||
for node in nodes:
|
||||
if node.node_name is not None:
|
||||
node_name_to_ids[node.node_name] = node.node_id
|
||||
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
@@ -66,6 +89,7 @@ def _save_call_data(
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from dbgpt.component import SystemApp
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
|
||||
)
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
|
||||
return dag_ctx
|
||||
|
||||
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager,
|
||||
upstream_node,
|
||||
dag_ctx,
|
||||
node_outputs,
|
||||
skip_node_ids,
|
||||
system_app,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
@@ -13,7 +13,8 @@ from ..operator.base import BaseOperator
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], str]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +26,7 @@ class HttpTrigger(Trigger):
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
@@ -39,6 +41,7 @@ class HttpTrigger(Trigger):
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._streaming_predict_func = streaming_predict_func
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
@@ -59,10 +62,13 @@ class HttpTrigger(Trigger):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
streaming_response = self._streaming_response
|
||||
if self._streaming_predict_func:
|
||||
streaming_response = self._streaming_predict_func(body)
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
@@ -112,6 +118,7 @@ async def _trigger_dag(
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
@@ -131,8 +138,11 @@ async def _trigger_dag(
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(end_node.dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background_tasks,
|
||||
)
|
||||
|
@@ -1,8 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Any, TypeVar, Generic, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
|
@@ -1,14 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Any, Union, AsyncIterator
|
||||
import time
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import copy
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,6 +96,28 @@ class ModelInferenceMetrics:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequestContext:
|
||||
stream: Optional[bool] = False
|
||||
"""Whether to return a stream of responses."""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""The user name of the model request."""
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""The system code of the model request."""
|
||||
|
||||
conv_uid: Optional[str] = None
|
||||
"""The conversation id of the model inference."""
|
||||
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelOutput:
|
||||
@@ -145,6 +166,27 @@ class ModelRequest:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
context: Optional[ModelRequestContext] = field(
|
||||
default_factory=lambda: ModelRequestContext()
|
||||
)
|
||||
"""The context of the model inference."""
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
"""Whether to return a stream of responses."""
|
||||
return self.context and self.context.stream
|
||||
|
||||
def copy(self):
|
||||
new_request = copy.deepcopy(self)
|
||||
# Transform messages to List[ModelMessage]
|
||||
new_request.messages = list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
new_request.messages,
|
||||
)
|
||||
)
|
||||
return new_request
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
@@ -161,6 +203,17 @@ class ModelRequest:
|
||||
)
|
||||
)
|
||||
|
||||
def get_single_user_message(self) -> Optional[ModelMessage]:
|
||||
"""Get the single user message.
|
||||
|
||||
Returns:
|
||||
Optional[ModelMessage]: The single user message.
|
||||
"""
|
||||
messages = self._get_messages()
|
||||
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
||||
@staticmethod
|
||||
def _build(model: str, prompt: str, **kwargs):
|
||||
return ModelRequest(
|
||||
@@ -178,11 +231,22 @@ class ModelRequest:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
|
||||
),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content="Who are your"
|
||||
),
|
||||
]
|
||||
openai_messages = ModelRequest.to_openai_messages(messages)
|
||||
assert openai_messages == [
|
||||
@@ -272,63 +336,3 @@ class LLMClient(ABC):
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: str) -> ModelRequest:
|
||||
return ModelRequest._build(self._model, input_value)
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageItem,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class ModelMessage(BaseModel):
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
round_index: Optional[int] = 0
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
@@ -443,6 +444,7 @@ class OnceConversation:
|
||||
self.tokens = conversation.tokens
|
||||
self.user_name = conversation.user_name
|
||||
self.sys_code = conversation.sys_code
|
||||
self._message_index = conversation._message_index
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
@@ -470,6 +472,7 @@ class OnceConversation:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
conversation = OnceConversation()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the first round")
|
||||
@@ -485,11 +488,17 @@ class OnceConversation:
|
||||
conversation.end_current_round()
|
||||
|
||||
assert len(conversation.get_messages_with_round(1)) == 2
|
||||
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(1)[0].content
|
||||
== "hello, this is the third round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(1)[1].content == "hi"
|
||||
|
||||
assert len(conversation.get_messages_with_round(2)) == 4
|
||||
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(2)[0].content
|
||||
== "hello, this is the second round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(2)[1].content == "hi"
|
||||
|
||||
Args:
|
||||
@@ -517,6 +526,7 @@ class OnceConversation:
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
for message in self.get_latest_round():
|
||||
@@ -528,6 +538,7 @@ class OnceConversation:
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
latest_round_index = self.chat_order
|
||||
@@ -537,7 +548,9 @@ class OnceConversation:
|
||||
for message in self.get_messages_by_round(round_index):
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type, content=message.content
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -548,7 +561,11 @@ class OnceConversation:
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type,
|
||||
content=message.content,
|
||||
round_index=message.round_index,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -780,6 +797,9 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
conversation.messages = messages
|
||||
# This index is used to save the message to the storage(Has not been saved)
|
||||
# The new message append to the messages, so the index is len(messages)
|
||||
conversation._message_index = len(messages)
|
||||
self._message_ids = message_ids
|
||||
self._has_stored_message_index = len(messages) - 1
|
||||
self.from_conversation(conversation)
|
||||
|
0
dbgpt/core/interface/operator/__init__.py
Normal file
0
dbgpt/core/interface/operator/__init__.py
Normal file
166
dbgpt/core/interface/operator/llm_operator.py
Normal file
166
dbgpt/core/interface/operator/llm_operator.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
LLMClient,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
|
||||
RequestInput = Union[
|
||||
ModelRequest,
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseModel,
|
||||
]
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: RequestInput) -> ModelRequest:
|
||||
req_dict = {}
|
||||
if isinstance(input_value, str):
|
||||
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
|
||||
elif isinstance(input_value, dict):
|
||||
req_dict = input_value
|
||||
elif dataclasses.is_dataclass(input_value):
|
||||
req_dict = dataclasses.asdict(input_value)
|
||||
elif isinstance(input_value, BaseModel):
|
||||
req_dict = input_value.dict()
|
||||
elif isinstance(input_value, ModelRequest):
|
||||
if not input_value.model:
|
||||
input_value.model = self._model
|
||||
return input_value
|
||||
if "messages" not in req_dict:
|
||||
raise ValueError("messages is not set")
|
||||
messages = req_dict["messages"]
|
||||
if isinstance(messages, str):
|
||||
# Single message, transform to a list including one human message
|
||||
req_dict["messages"] = [ModelMessage.build_human_message(messages)]
|
||||
if "model" not in req_dict:
|
||||
req_dict["model"] = self._model
|
||||
if not req_dict["model"]:
|
||||
raise ValueError("model is not set")
|
||||
stream = False
|
||||
has_stream = False
|
||||
if "stream" in req_dict:
|
||||
has_stream = True
|
||||
stream = req_dict["stream"]
|
||||
del req_dict["stream"]
|
||||
if "context" not in req_dict:
|
||||
req_dict["context"] = ModelRequestContext(stream=stream)
|
||||
else:
|
||||
context_dict = req_dict["context"]
|
||||
if not isinstance(context_dict, dict):
|
||||
raise ValueError("context is not a dict")
|
||||
if has_stream:
|
||||
context_dict["stream"] = stream
|
||||
req_dict["context"] = ModelRequestContext(**context_dict)
|
||||
return ModelRequest(**req_dict)
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
||||
|
||||
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""Branch operator for LLM.
|
||||
|
||||
This operator will branch the workflow based on the stream flag of the request.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not stream_task_name:
|
||||
raise ValueError("stream_task_name is not set")
|
||||
if not no_stream_task_name:
|
||||
raise ValueError("no_stream_task_name is not set")
|
||||
self._stream_task_name = stream_task_name
|
||||
self._no_stream_task_name = no_stream_task_name
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[ModelRequest], str]:
|
||||
"""
|
||||
Return a dict of branch function and task name.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name.
|
||||
the key is a predicate function, the value is the task name. If the predicate function returns True,
|
||||
we will run the corresponding task.
|
||||
"""
|
||||
|
||||
async def check_stream_true(r: ModelRequest) -> bool:
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task.
|
||||
return r.stream
|
||||
|
||||
return {
|
||||
check_stream_true: self._stream_task_name,
|
||||
lambda x: not x.stream: self._no_stream_task_name,
|
||||
}
|
321
dbgpt/core/interface/operator/message_operator.py
Normal file
321
dbgpt/core/interface/operator/message_operator.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, List, Optional
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
"""Base class for conversation operators."""
|
||||
|
||||
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
|
||||
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StorageConversation, Any]:
|
||||
"""Return the LLM client."""
|
||||
if not self._storage:
|
||||
raise ValueError("Storage is not set")
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
|
||||
"""Return the LLM client."""
|
||||
if not self._message_storage:
|
||||
raise ValueError("Message storage is not set")
|
||||
return self._message_storage
|
||||
|
||||
async def get_storage_conversation(self) -> StorageConversation:
|
||||
"""Get the storage conversation from share data.
|
||||
|
||||
Returns:
|
||||
StorageConversation: The storage conversation.
|
||||
"""
|
||||
storage_conv: StorageConversation = (
|
||||
await self.current_dag_context.get_from_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION
|
||||
)
|
||||
)
|
||||
if not storage_conv:
|
||||
raise ValueError("Storage conversation is not set")
|
||||
return storage_conv
|
||||
|
||||
async def get_model_request(self) -> ModelRequest:
|
||||
"""Get the model request from share data.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The model request.
|
||||
"""
|
||||
model_request: ModelRequest = (
|
||||
await self.current_dag_context.get_from_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST
|
||||
)
|
||||
)
|
||||
if not model_request:
|
||||
raise ValueError("Model request is not set")
|
||||
return model_request
|
||||
|
||||
|
||||
class PreConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The operator to prepare the storage conversation.
|
||||
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||
and they can store in different storage(for high performance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
if input_value.context is None:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.conv_uid:
|
||||
input_value.context.conv_uid = str(uuid.uuid4())
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
|
||||
chat_mode = input_value.context.extra.get("chat_mode")
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
StorageConversation,
|
||||
conv_uid=input_value.context.conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
user_name=input_value.context.user_name,
|
||||
sys_code=input_value.context.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
# The input message must be a single user message
|
||||
single_human_message: ModelMessage = input_value.get_single_user_message()
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(single_human_message.content)
|
||||
|
||||
# Get all messages from current storage conversation, and overwrite the input value
|
||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
||||
input_value.messages = messages
|
||||
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
|
||||
)
|
||||
return input_value
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
# TODO dont save if the conversation has some internal error
|
||||
storage_conv.end_current_round()
|
||||
|
||||
|
||||
class PostConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Map the input value to a ModelOutput.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The mapped ModelOutput.
|
||||
"""
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(input_value.text)
|
||||
return input_value
|
||||
|
||||
|
||||
class PostStreamingConversationOperator(
|
||||
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
TransformStreamAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> ModelOutput:
|
||||
"""Transform the input value to a ModelOutput.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The transformed ModelOutput.
|
||||
"""
|
||||
full_text = ""
|
||||
async for model_output in input_value:
|
||||
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
|
||||
full_text = model_output.text
|
||||
yield model_output
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(full_text)
|
||||
|
||||
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
input_value = input_value.copy()
|
||||
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
|
||||
# Overwrite the input value
|
||||
input_value.messages = messages
|
||||
return input_value
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[List[ModelMessage]]:
|
||||
"""Split the messages by round index.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The splitted messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = []
|
||||
last_round_index = 0
|
||||
for message in messages:
|
||||
if not message.round_index:
|
||||
# Round index must bigger than 0
|
||||
raise ValueError("Message round_index is not set")
|
||||
if message.round_index > last_round_index:
|
||||
last_round_index = message.round_index
|
||||
messages_by_round.append([])
|
||||
messages_by_round[-1].append(message)
|
||||
return messages_by_round
|
||||
|
||||
|
||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The buffered conversation mapper operator.
|
||||
|
||||
This Operator must be used after the PreConversationOperator,
|
||||
and it will map the messages in the storage conversation.
|
||||
|
||||
Examples:
|
||||
|
||||
Transform no history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# With history
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
# Just keep the last one round, so the first round messages will be removed
|
||||
# Note: The round index 3 is not a complete round
|
||||
assert messages == [
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._last_k_round = last_k_round
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
# Get the last k round messages
|
||||
index = self._last_k_round + 1
|
||||
messages_by_round = messages_by_round[-index:]
|
||||
messages: List[ModelMessage] = sum(messages_by_round, [])
|
||||
return messages
|
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
import logging
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseTye = Union[str, bytes, ModelOutput]
|
||||
|
@@ -1,12 +1,20 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
|
||||
|
||||
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
@@ -86,6 +94,434 @@ class PromptTemplate(BaseModel, ABC):
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
|
||||
prompt_name: str
|
||||
prompt_language: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
|
||||
if any(
|
||||
self.identifier_split in key
|
||||
for key in [
|
||||
self.prompt_name,
|
||||
self.prompt_language,
|
||||
self.sys_code,
|
||||
self.model,
|
||||
]
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return self.identifier_split.join(
|
||||
key
|
||||
for key in [
|
||||
self.prompt_name,
|
||||
self.prompt_language,
|
||||
self.sys_code,
|
||||
self.model,
|
||||
]
|
||||
if key is not None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"prompt_language": self.prompt_language,
|
||||
"sys_code": self.sys_code,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StoragePromptTemplate(StorageItem):
|
||||
prompt_name: str
|
||||
content: Optional[str] = None
|
||||
prompt_language: Optional[str] = None
|
||||
prompt_format: Optional[str] = None
|
||||
input_variables: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
chat_scene: Optional[str] = None
|
||||
sub_chat_scene: Optional[str] = None
|
||||
prompt_type: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._identifier = PromptTemplateIdentifier(
|
||||
prompt_name=self.prompt_name,
|
||||
prompt_language=self.prompt_language,
|
||||
sys_code=self.sys_code,
|
||||
model=self.model,
|
||||
)
|
||||
self._check() # Assuming _check() is a method you need to call after initialization
|
||||
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
None
|
||||
if not self.input_variables
|
||||
else self.input_variables.strip().split(",")
|
||||
)
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=self.content,
|
||||
template_scene=self.chat_scene,
|
||||
prompt_name=self.prompt_name,
|
||||
template_format=self.prompt_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prompt_template(
|
||||
prompt_template: PromptTemplate,
|
||||
prompt_name: str,
|
||||
prompt_language: Optional[str] = None,
|
||||
prompt_type: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
sub_chat_scene: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> "StoragePromptTemplate":
|
||||
"""Convert a prompt template to a storage prompt template.
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to convert from.
|
||||
prompt_name (str): The name of the prompt.
|
||||
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt. Defaults to None.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
"""
|
||||
input_variables = prompt_template.input_variables or kwargs.get(
|
||||
"input_variables"
|
||||
)
|
||||
if input_variables and isinstance(input_variables, list):
|
||||
input_variables = ",".join(input_variables)
|
||||
return StoragePromptTemplate(
|
||||
prompt_name=prompt_name,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
input_variables=input_variables,
|
||||
model=model,
|
||||
content=prompt_template.template or kwargs.get("content"),
|
||||
prompt_language=prompt_language,
|
||||
prompt_format=prompt_template.template_format
|
||||
or kwargs.get("prompt_format"),
|
||||
chat_scene=prompt_template.template_scene or kwargs.get("chat_scene"),
|
||||
sub_chat_scene=sub_chat_scene,
|
||||
prompt_type=prompt_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def identifier(self) -> PromptTemplateIdentifier:
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other item into the current item.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other item to merge
|
||||
"""
|
||||
if not isinstance(other, StoragePromptTemplate):
|
||||
raise ValueError(
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
|
||||
)
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"content": self.content,
|
||||
"prompt_language": self.prompt_language,
|
||||
"prompt_format": self.prompt_format,
|
||||
"input_variables": self.input_variables,
|
||||
"model": self.model,
|
||||
"chat_scene": self.chat_scene,
|
||||
"sub_chat_scene": self.sub_chat_scene,
|
||||
"prompt_type": self.prompt_type,
|
||||
"user_name": self.user_name,
|
||||
"sys_code": self.sys_code,
|
||||
}
|
||||
|
||||
def _check(self):
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
if self.content is None:
|
||||
raise ValueError("content cannot be None")
|
||||
|
||||
def from_object(self, template: "StoragePromptTemplate") -> None:
|
||||
"""Load the prompt template from an existing prompt template object.
|
||||
|
||||
Args:
|
||||
template (PromptTemplate): The prompt template to load from.
|
||||
"""
|
||||
self.content = template.content
|
||||
self.prompt_format = template.prompt_format
|
||||
self.input_variables = template.input_variables
|
||||
self.model = template.model
|
||||
self.chat_scene = template.chat_scene
|
||||
self.sub_chat_scene = template.sub_chat_scene
|
||||
self.prompt_type = template.prompt_type
|
||||
self.user_name = template.user_name
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""The manager class for prompt templates.
|
||||
|
||||
Simple wrapper for the storage interface.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Default use InMemoryStorage
|
||||
prompt_manager = PromptManager()
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(prompt_template, prompt_name="hello")
|
||||
prompt_template_list = prompt_manager.list()
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
With a custom storage interface.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.storage import InMemoryStorage
|
||||
|
||||
prompt_manager = PromptManager(InMemoryStorage())
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(prompt_template, prompt_name="hello")
|
||||
prompt_template_list = prompt_manager.list()
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
|
||||
):
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
|
||||
"""The storage interface for prompt templates."""
|
||||
return self._storage
|
||||
|
||||
def prefer_query(
|
||||
self,
|
||||
prompt_name: str,
|
||||
sys_code: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[StoragePromptTemplate]:
|
||||
"""Query prompt templates from storage with prefer params.
|
||||
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
|
||||
|
||||
Examples:
|
||||
|
||||
Query a prompt template.
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
Query with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Query with prefer prompt language.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_prompt_language="zh-cn"
|
||||
)
|
||||
|
||||
Query with prefer model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_model="vicuna-13b-v1.5"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
|
||||
"""
|
||||
query_spec = QuerySpec(
|
||||
conditions={
|
||||
"prompt_name": prompt_name,
|
||||
"sys_code": sys_code,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
queries: List[StoragePromptTemplate] = self.storage.query(
|
||||
query_spec, StoragePromptTemplate
|
||||
)
|
||||
if not queries:
|
||||
return []
|
||||
if prefer_prompt_language:
|
||||
prefer_prompt_language = prefer_prompt_language.lower()
|
||||
temp_queries = [
|
||||
query
|
||||
for query in queries
|
||||
if query.prompt_language
|
||||
and query.prompt_language.lower() == prefer_prompt_language
|
||||
]
|
||||
if temp_queries:
|
||||
queries = temp_queries
|
||||
if prefer_model:
|
||||
prefer_model = prefer_model.lower()
|
||||
temp_queries = [
|
||||
query
|
||||
for query in queries
|
||||
if query.model and query.model.lower() == prefer_model
|
||||
]
|
||||
if temp_queries:
|
||||
queries = temp_queries
|
||||
return queries
|
||||
|
||||
def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> None:
|
||||
"""Save a prompt template to storage.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
prompt_name="hello",
|
||||
)
|
||||
prompt_manager.save(prompt_template)
|
||||
|
||||
Save with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
prompt_name="hello",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template, sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to save.
|
||||
prompt_name (str): The name of the prompt template.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
|
||||
"""
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template, prompt_name, **kwargs
|
||||
)
|
||||
self.storage.save(storage_prompt_template)
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
List all prompt templates.
|
||||
.. code-block:: python
|
||||
|
||||
all_prompt_templates = prompt_manager.list()
|
||||
|
||||
List with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
templates = prompt_manager.list(
|
||||
sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
kwargs (Dict): Other query params.
|
||||
"""
|
||||
query_spec = QuerySpec(conditions=kwargs)
|
||||
return self.storage.query(query_spec, StoragePromptTemplate)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
prompt_name: str,
|
||||
prompt_language: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Delete a prompt template from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
Delete a prompt template.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_manager.delete("hello")
|
||||
|
||||
Delete with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_manager.delete(
|
||||
"hello", sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
"""
|
||||
identifier = PromptTemplateIdentifier(
|
||||
prompt_name=prompt_name,
|
||||
prompt_language=prompt_language,
|
||||
sys_code=sys_code,
|
||||
model=model,
|
||||
)
|
||||
self.storage.delete(identifier)
|
||||
|
||||
|
||||
class PromptTemplateOperator(MapOperator[Dict, str]):
|
||||
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, Dict
|
||||
from typing import Dict, Type
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
from dbgpt.core.interface.message import *
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
320
dbgpt/core/interface/tests/test_prompt.py
Normal file
320
dbgpt/core/interface/tests/test_prompt.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.storage import QuerySpec
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_storage_prompt_template():
|
||||
return StoragePromptTemplate(
|
||||
prompt_name="test_prompt",
|
||||
content="Sample content, {var1}, {var2}",
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="var1,var2",
|
||||
model="model1",
|
||||
chat_scene="scene1",
|
||||
sub_chat_scene="subscene1",
|
||||
prompt_type="type1",
|
||||
user_name="user1",
|
||||
sys_code="code1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_storage_prompt_template():
|
||||
content = """Database name: {db_name} Table structure definition: {table_info} User Question:{user_input}"""
|
||||
return StoragePromptTemplate(
|
||||
prompt_name="chat_data_auto_execute_prompt",
|
||||
content=content,
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="db_name,table_info,user_input",
|
||||
model="vicuna-13b-v1.5",
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="subscene1",
|
||||
prompt_type="common",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(in_memory_storage):
|
||||
return PromptManager(storage=in_memory_storage)
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
@pytest.mark.parametrize(
|
||||
"template_str, input_vars, expected_output",
|
||||
[
|
||||
("Hello {name}", {"name": "World"}, "Hello World"),
|
||||
("{greeting}, {name}", {"greeting": "Hi", "name": "Alice"}, "Hi, Alice"),
|
||||
],
|
||||
)
|
||||
def test_format_f_string(self, template_str, input_vars, expected_output):
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=list(input_vars.keys()),
|
||||
template_format="f-string",
|
||||
)
|
||||
formatted_output = prompt.format(**input_vars)
|
||||
assert formatted_output == expected_output
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template_str, input_vars, expected_output",
|
||||
[
|
||||
("Hello {{ name }}", {"name": "World"}, "Hello World"),
|
||||
(
|
||||
"{{ greeting }}, {{ name }}",
|
||||
{"greeting": "Hi", "name": "Alice"},
|
||||
"Hi, Alice",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_format_jinja2(self, template_str, input_vars, expected_output):
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=list(input_vars.keys()),
|
||||
template_format="jinja2",
|
||||
)
|
||||
formatted_output = prompt.format(**input_vars)
|
||||
assert formatted_output == expected_output
|
||||
|
||||
def test_format_with_response_format(self):
|
||||
template_str = "Response: {response}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=["response"],
|
||||
template_format="f-string",
|
||||
response_format=json.dumps({"message": "hello"}),
|
||||
)
|
||||
formatted_output = prompt.format(response="hello")
|
||||
assert "Response: " in formatted_output
|
||||
|
||||
def test_from_template(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate.from_template(template_str)
|
||||
assert prompt._prompt_template.template == template_str
|
||||
assert prompt._prompt_template.input_variables == []
|
||||
|
||||
def test_format_missing_variable(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str, input_variables=["name"], template_format="f-string"
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
prompt.format()
|
||||
|
||||
def test_format_extra_variable(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=["name"],
|
||||
template_format="f-string",
|
||||
template_is_strict=False,
|
||||
)
|
||||
formatted_output = prompt.format(name="World", extra="unused")
|
||||
assert formatted_output == "Hello World"
|
||||
|
||||
def test_format_complex(self, complex_storage_prompt_template):
|
||||
prompt = complex_storage_prompt_template.to_prompt_template()
|
||||
formatted_output = prompt.format(
|
||||
db_name="db1",
|
||||
table_info="create table users(id int, name varchar(20))",
|
||||
user_input="find all users whose name is 'Alice'",
|
||||
)
|
||||
assert (
|
||||
formatted_output
|
||||
== "Database name: db1 Table structure definition: create table users(id int, name varchar(20)) "
|
||||
"User Question:find all users whose name is 'Alice'"
|
||||
)
|
||||
|
||||
|
||||
class TestStoragePromptTemplate:
|
||||
def test_constructor_and_properties(self):
|
||||
storage_item = StoragePromptTemplate(
|
||||
prompt_name="test",
|
||||
content="Hello {name}",
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="name",
|
||||
model="model1",
|
||||
chat_scene="chat",
|
||||
sub_chat_scene="sub_chat",
|
||||
prompt_type="type",
|
||||
user_name="user",
|
||||
sys_code="sys",
|
||||
)
|
||||
assert storage_item.prompt_name == "test"
|
||||
assert storage_item.content == "Hello {name}"
|
||||
assert storage_item.prompt_language == "en"
|
||||
assert storage_item.prompt_format == "f-string"
|
||||
assert storage_item.input_variables == "name"
|
||||
assert storage_item.model == "model1"
|
||||
|
||||
def test_constructor_exceptions(self):
|
||||
with pytest.raises(ValueError):
|
||||
StoragePromptTemplate(prompt_name=None, content="Hello")
|
||||
|
||||
def test_to_prompt_template(self, sample_storage_prompt_template):
|
||||
prompt_template = sample_storage_prompt_template.to_prompt_template()
|
||||
assert isinstance(prompt_template, PromptTemplate)
|
||||
assert prompt_template.template == "Sample content, {var1}, {var2}"
|
||||
assert prompt_template.input_variables == ["var1", "var2"]
|
||||
|
||||
def test_from_prompt_template(self):
|
||||
prompt_template = PromptTemplate(
|
||||
template="Sample content, {var1}, {var2}",
|
||||
input_variables=["var1", "var2"],
|
||||
template_format="f-string",
|
||||
)
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template=prompt_template, prompt_name="test_prompt"
|
||||
)
|
||||
assert storage_prompt_template.prompt_name == "test_prompt"
|
||||
assert storage_prompt_template.content == "Sample content, {var1}, {var2}"
|
||||
assert storage_prompt_template.input_variables == "var1,var2"
|
||||
|
||||
def test_merge(self, sample_storage_prompt_template):
|
||||
other = StoragePromptTemplate(
|
||||
prompt_name="other_prompt",
|
||||
content="Other content",
|
||||
)
|
||||
sample_storage_prompt_template.merge(other)
|
||||
assert sample_storage_prompt_template.content == "Other content"
|
||||
|
||||
def test_to_dict(self, sample_storage_prompt_template):
|
||||
result = sample_storage_prompt_template.to_dict()
|
||||
assert result == {
|
||||
"prompt_name": "test_prompt",
|
||||
"content": "Sample content, {var1}, {var2}",
|
||||
"prompt_language": "en",
|
||||
"prompt_format": "f-string",
|
||||
"input_variables": "var1,var2",
|
||||
"model": "model1",
|
||||
"chat_scene": "scene1",
|
||||
"sub_chat_scene": "subscene1",
|
||||
"prompt_type": "type1",
|
||||
"user_name": "user1",
|
||||
"sys_code": "code1",
|
||||
}
|
||||
|
||||
def test_save_and_load_storage(
|
||||
self, sample_storage_prompt_template, in_memory_storage
|
||||
):
|
||||
in_memory_storage.save(sample_storage_prompt_template)
|
||||
loaded_item = in_memory_storage.load(
|
||||
sample_storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
assert loaded_item.content == "Sample content, {var1}, {var2}"
|
||||
|
||||
def test_check_exceptions(self):
|
||||
with pytest.raises(ValueError):
|
||||
StoragePromptTemplate(prompt_name=None, content="Hello")
|
||||
|
||||
def test_from_object(self, sample_storage_prompt_template):
|
||||
other = StoragePromptTemplate(prompt_name="other", content="Other content")
|
||||
sample_storage_prompt_template.from_object(other)
|
||||
assert sample_storage_prompt_template.content == "Other content"
|
||||
assert sample_storage_prompt_template.input_variables != "var1,var2"
|
||||
# Prompt name should not be changed
|
||||
assert sample_storage_prompt_template.prompt_name == "test_prompt"
|
||||
assert sample_storage_prompt_template.sys_code == "code1"
|
||||
|
||||
|
||||
class TestPromptManager:
|
||||
def test_save(self, prompt_manager, in_memory_storage):
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="hello",
|
||||
)
|
||||
result = in_memory_storage.query(
|
||||
QuerySpec(conditions={"prompt_name": "hello"}), StoragePromptTemplate
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "hello {input}"
|
||||
|
||||
def test_prefer_query_simple(self, prompt_manager, in_memory_storage):
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(prompt_name="test_prompt", content="test")
|
||||
)
|
||||
result = prompt_manager.prefer_query("test_prompt")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
|
||||
def test_prefer_query_language(self, prompt_manager, in_memory_storage):
|
||||
for language in ["en", "zh"]:
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(
|
||||
prompt_name="test_prompt",
|
||||
content="test",
|
||||
prompt_language=language,
|
||||
)
|
||||
)
|
||||
# Prefer zh, and zh exists, will return zh prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].prompt_language == "zh"
|
||||
# Prefer language not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query(
|
||||
"test_prompt", prefer_prompt_language="not_exist"
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_prefer_query_model(self, prompt_manager, in_memory_storage):
|
||||
for model in ["model1", "model2"]:
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(
|
||||
prompt_name="test_prompt", content="test", model=model
|
||||
)
|
||||
)
|
||||
# Prefer model1, and model1 exists, will return model1 prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].model == "model1"
|
||||
# Prefer model not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
|
||||
assert len(result) == 2
|
||||
|
||||
def test_list(self, prompt_manager, in_memory_storage):
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="name1",
|
||||
)
|
||||
prompt_manager.save(
|
||||
PromptTemplate(
|
||||
template="Write a SQL of {dialect} to query all data of {table_name}.",
|
||||
input_variables=["dialect", "table_name"],
|
||||
),
|
||||
prompt_name="sql_template",
|
||||
)
|
||||
all_templates = prompt_manager.list()
|
||||
assert len(all_templates) == 2
|
||||
assert len(prompt_manager.list(prompt_name="name1")) == 1
|
||||
assert len(prompt_manager.list(prompt_name="not exist")) == 0
|
||||
|
||||
def test_delete(self, prompt_manager, in_memory_storage):
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="to_delete",
|
||||
)
|
||||
prompt_manager.delete("to_delete")
|
||||
result = in_memory_storage.query(
|
||||
QuerySpec(conditions={"prompt_name": "to_delete"}), StoragePromptTemplate
|
||||
)
|
||||
assert len(result) == 0
|
@@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
QuerySpec,
|
||||
InMemoryStorage,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
31
dbgpt/core/operator/__init__.py
Normal file
31
dbgpt/core/operator/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from dbgpt.core.interface.operator.llm_operator import (
|
||||
BaseLLM,
|
||||
LLMBranchOperator,
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.message_operator import (
|
||||
BaseConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
ConversationMapperOperator,
|
||||
PostConversationOperator,
|
||||
PostStreamingConversationOperator,
|
||||
PreConversationOperator,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplateOperator
|
||||
|
||||
__ALL__ = [
|
||||
"BaseLLM",
|
||||
"LLMBranchOperator",
|
||||
"LLMOperator",
|
||||
"RequestBuildOperator",
|
||||
"StreamingLLMOperator",
|
||||
"BaseConversationOperator",
|
||||
"BufferedConversationMapperOperator",
|
||||
"ConversationMapperOperator",
|
||||
"PostConversationOperator",
|
||||
"PostStreamingConversationOperator",
|
||||
"PreConversationOperator",
|
||||
"PromptTemplateOperator",
|
||||
]
|
@@ -1,4 +1,13 @@
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||
from dbgpt.model.utils.chatgpt_utils import (
|
||||
OpenAILLMClient,
|
||||
OpenAIStreamingOperator,
|
||||
MixinLLMOperator,
|
||||
)
|
||||
|
||||
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]
|
||||
__ALL__ = [
|
||||
"DefaultLLMClient",
|
||||
"OpenAILLMClient",
|
||||
"OpenAIStreamingOperator",
|
||||
"MixinLLMOperator",
|
||||
]
|
||||
|
@@ -171,7 +171,7 @@ class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
|
||||
self._model_task_name = model_task_name
|
||||
self._cache_task_name = cache_task_name
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
@@ -233,7 +233,7 @@ class ModelStreamSaveCacheOperator(
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
llm_cache_key = await self.current_dag_context.get_share_data(
|
||||
llm_cache_key = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
outputs.append(out)
|
||||
@@ -265,7 +265,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
Returns:
|
||||
ModelOutput: The same input model output.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||
|
@@ -3,11 +3,27 @@ from __future__ import annotations
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC
|
||||
import importlib.metadata as metadata
|
||||
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
|
||||
from typing import (
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Awaitable,
|
||||
)
|
||||
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core.operator import BaseLLM
|
||||
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
|
||||
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
|
||||
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
@@ -176,13 +192,13 @@ class OpenAILLMClient(LLMClient):
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
messages = request.to_openai_messages()
|
||||
payload = self._build_request(request)
|
||||
payload = self._build_request(request, True)
|
||||
try:
|
||||
chat_completion = await self.client.chat.completions.create(
|
||||
messages=messages, **payload
|
||||
)
|
||||
text = ""
|
||||
for r in chat_completion:
|
||||
async for r in chat_completion:
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
if r.choices[0].delta.content is not None:
|
||||
@@ -221,17 +237,74 @@ class OpenAILLMClient(LLMClient):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
"""Transform ModelOutput to openai stream format."""
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
async def model_caller() -> str:
|
||||
"""Read model name from share data.
|
||||
In streaming mode, this transform_stream function will be executed
|
||||
before parent operator(Streaming Operator is trigger by downstream Operator).
|
||||
"""
|
||||
return await self.current_dag_context.get_from_share_data(
|
||||
BaseLLM.SHARE_DATA_KEY_MODEL_NAME
|
||||
)
|
||||
|
||||
async for output in _to_openai_stream(input_value, None, model_caller):
|
||||
yield output
|
||||
|
||||
|
||||
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
"""Mixin class for LLM operator.
|
||||
|
||||
This class extends BaseOperator by adding LLM capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
self._default_llm_client = default_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
if worker_manager_factory:
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
else:
|
||||
if self._default_llm_client is None:
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
|
||||
self._default_llm_client = OpenAILLMClient()
|
||||
logger.info(
|
||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
||||
)
|
||||
self._llm_client = self._default_llm_client
|
||||
return self._llm_client
|
||||
|
||||
|
||||
async def _to_openai_stream(
|
||||
model: str, output_iter: AsyncIterator[ModelOutput]
|
||||
output_iter: AsyncIterator[ModelOutput],
|
||||
model: Optional[str] = None,
|
||||
model_caller: Callable[[], Union[Awaitable[str], str]] = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Convert the output_iter to openai stream format.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
output_iter (AsyncIterator[ModelOutput]): The output iterator.
|
||||
model (Optional[str], optional): The model name. Defaults to None.
|
||||
model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None.
|
||||
"""
|
||||
import json
|
||||
import shortuuid
|
||||
import asyncio
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
@@ -245,12 +318,19 @@ async def _to_openai_stream(
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model or ""
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
finish_stream_events = []
|
||||
async for model_output in output_iter:
|
||||
if model_caller is not None:
|
||||
if asyncio.iscoroutinefunction(model_caller):
|
||||
model = await model_caller()
|
||||
else:
|
||||
model = model_caller()
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
|
@@ -1,15 +1,16 @@
|
||||
from typing import Optional, List
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.util import PaginationResult
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
|
@@ -1,9 +1,8 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
|
||||
APP_NAME = "prompt"
|
||||
SERVE_APP_NAME = "dbgpt_serve_prompt"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt"
|
||||
|
@@ -1,33 +1,64 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
from typing import Union, Any, Dict
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint
|
||||
from dbgpt.storage.metadata import Model, BaseDao, db
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
__tablename__ = "prompt_manage"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"),
|
||||
UniqueConstraint(
|
||||
"prompt_name",
|
||||
"sys_code",
|
||||
"prompt_language",
|
||||
"model",
|
||||
name="uk_prompt_name_sys_code",
|
||||
),
|
||||
)
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
|
||||
chat_scene = Column(String(100))
|
||||
sub_chat_scene = Column(String(100))
|
||||
prompt_type = Column(String(100))
|
||||
prompt_name = Column(String(512))
|
||||
content = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
chat_scene = Column(String(100), comment="Chat scene")
|
||||
sub_chat_scene = Column(String(100), comment="Sub chat scene")
|
||||
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
|
||||
prompt_name = Column(String(256), comment="Prompt name")
|
||||
content = Column(Text, comment="Prompt content")
|
||||
input_variables = Column(
|
||||
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
|
||||
)
|
||||
model = Column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
comment="Prompt model name(we can use different models for different prompt",
|
||||
)
|
||||
prompt_language = Column(
|
||||
String(32), index=True, nullable=True, comment="Prompt language(eg:en, zh-cn)"
|
||||
)
|
||||
prompt_format = Column(
|
||||
String(32),
|
||||
index=True,
|
||||
nullable=True,
|
||||
default="f-string",
|
||||
comment="Prompt format(eg: f-string, jinja2)",
|
||||
)
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
def __repr__(self):
|
||||
return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
return (
|
||||
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
|
||||
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
|
||||
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
)
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
56
dbgpt/serve/prompt/models/prompt_template_adapter.py
Normal file
56
dbgpt/serve/prompt/models/prompt_template_adapter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptTemplateIdentifier, StoragePromptTemplate
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
|
||||
from .models import ServeEntity
|
||||
|
||||
|
||||
class PromptTemplateAdapter(StorageItemAdapter[StoragePromptTemplate, ServeEntity]):
|
||||
def to_storage_format(self, item: StoragePromptTemplate) -> ServeEntity:
|
||||
return ServeEntity(
|
||||
chat_scene=item.chat_scene,
|
||||
sub_chat_scene=item.sub_chat_scene,
|
||||
prompt_type=item.prompt_type,
|
||||
prompt_name=item.prompt_name,
|
||||
content=item.content,
|
||||
input_variables=item.input_variables,
|
||||
model=item.model,
|
||||
prompt_language=item.prompt_language,
|
||||
prompt_format=item.prompt_format,
|
||||
user_name=item.user_name,
|
||||
sys_code=item.sys_code,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: ServeEntity) -> StoragePromptTemplate:
|
||||
return StoragePromptTemplate(
|
||||
chat_scene=model.chat_scene,
|
||||
sub_chat_scene=model.sub_chat_scene,
|
||||
prompt_type=model.prompt_type,
|
||||
prompt_name=model.prompt_name,
|
||||
content=model.content,
|
||||
input_variables=model.input_variables,
|
||||
model=model.model,
|
||||
prompt_language=model.prompt_language,
|
||||
prompt_format=model.prompt_format,
|
||||
user_name=model.user_name,
|
||||
sys_code=model.sys_code,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[ServeEntity],
|
||||
resource_id: PromptTemplateIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(ServeEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(ServeEntity, key) == value)
|
||||
return query_obj
|
@@ -1,17 +1,80 @@
|
||||
from typing import List, Optional
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .api.endpoints import router, init_endpoints
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import PromptManager
|
||||
|
||||
from ...storage.metadata import DatabaseManager
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
APP_NAME,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
from .models.prompt_template_adapter import PromptTemplateAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
"""Serve component
|
||||
|
||||
Examples:
|
||||
|
||||
Register the serve component to the system app
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core import PromptTemplate
|
||||
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt")
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
|
||||
|
||||
# Get the prompt manager
|
||||
prompt_manager = prompt_serve.prompt_manager
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="prompt_name",
|
||||
)
|
||||
|
||||
With your database url
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core import PromptTemplate
|
||||
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
|
||||
|
||||
# Get the prompt manager
|
||||
prompt_manager = prompt_serve.prompt_manager
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="prompt_name",
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
@@ -19,12 +82,17 @@ class Serve(BaseComponent):
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
|
||||
tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if tags is None:
|
||||
tags = [SERVE_APP_NAME_HUMP]
|
||||
self._system_app = None
|
||||
self._api_prefix = api_prefix
|
||||
self._tags = tags
|
||||
self._prompt_manager = None
|
||||
self._db_url_or_db = db_url_or_db
|
||||
self._try_create_tables = try_create_tables
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self._system_app = system_app
|
||||
@@ -33,10 +101,37 @@ class Serve(BaseComponent):
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
"""Get the prompt manager of the serve app with db storage"""
|
||||
return self._prompt_manager
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application.
|
||||
|
||||
You can do some initialization here.
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from dbgpt.core.interface.prompt import PromptManager
|
||||
from dbgpt.storage.metadata import Model, db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.models import ServeEntity
|
||||
|
||||
init_db = self._db_url_or_db or db
|
||||
init_db = DatabaseManager.build_from(init_db, base=Model)
|
||||
if self._try_create_tables:
|
||||
try:
|
||||
init_db.create_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create tables: {e}")
|
||||
storage_adapter = PromptTemplateAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
init_db,
|
||||
ServeEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
self._prompt_manager = PromptManager(storage)
|
||||
|
@@ -1,11 +1,13 @@
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.serve.core import BaseService
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
@@ -1,15 +1,15 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import asystem_app, client
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
from ..config import ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -34,6 +37,8 @@ def default_entity_dict():
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict):
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
ServeEntity.create(**default_entity_dict)
|
||||
with pytest.raises(Exception):
|
||||
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
ServeEntity.create(
|
||||
**{
|
||||
"prompt_name": "my_prompt_1",
|
||||
"sys_code": "dbgpt",
|
||||
"prompt_language": "zh",
|
||||
"model": "vicuna-13b-v1.5",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
|
144
dbgpt/serve/prompt/tests/test_prompt_template_adapter.py
Normal file
144
dbgpt/serve/prompt/tests/test_prompt_template_adapter.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptManager, PromptTemplate
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_url():
|
||||
"""Use in-memory SQLite database for testing"""
|
||||
return "sqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_manager(db_url):
|
||||
db.init_db(db_url)
|
||||
db.create_all()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_adapter():
|
||||
return PromptTemplateAdapter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db_manager, serializer, storage_adapter):
|
||||
storage = SQLAlchemyStorage(
|
||||
db_manager,
|
||||
ServeEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(storage):
|
||||
return PromptManager(storage)
|
||||
|
||||
|
||||
def test_save(prompt_manager: PromptManager):
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="hello",
|
||||
)
|
||||
|
||||
with db.session() as session:
|
||||
# Query from database
|
||||
result = (
|
||||
session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all()
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].prompt_name == "hello"
|
||||
assert result[0].content == "hello {input}"
|
||||
assert result[0].input_variables == "input"
|
||||
with db.session() as session:
|
||||
assert session.query(ServeEntity).count() == 1
|
||||
assert (
|
||||
session.query(ServeEntity)
|
||||
.filter(ServeEntity.prompt_name == "not exist prompt name")
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
def test_prefer_query_language(prompt_manager: PromptManager):
|
||||
for language in ["en", "zh"]:
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="test_prompt",
|
||||
prompt_language=language,
|
||||
)
|
||||
# Prefer zh, and zh exists, will return zh prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].prompt_language == "zh"
|
||||
# Prefer language not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query(
|
||||
"test_prompt", prefer_prompt_language="not_exist"
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_prefer_query_model(prompt_manager: PromptManager):
|
||||
for model in ["model1", "model2"]:
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="test_prompt",
|
||||
model=model,
|
||||
)
|
||||
# Prefer model1, and model1 exists, will return model1 prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].model == "model1"
|
||||
# Prefer model not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_list(prompt_manager: PromptManager):
|
||||
for i in range(10):
|
||||
prompt_template = PromptTemplate(
|
||||
template="test",
|
||||
input_variables=[],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name=f"test_prompt_{i}",
|
||||
sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt",
|
||||
)
|
||||
# Test list all
|
||||
result = prompt_manager.list()
|
||||
assert len(result) == 10
|
||||
|
||||
for i in range(10):
|
||||
assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1
|
||||
assert len(prompt_manager.list(sys_code="dbgpt")) == 5
|
@@ -1,11 +1,13 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
import pytest
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
|
@@ -236,6 +236,7 @@ class DatabaseManager:
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
):
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -244,15 +245,16 @@ class DatabaseManager:
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
"""
|
||||
self._db_url = db_url
|
||||
if query_class is not None:
|
||||
self.Query = query_class
|
||||
if base is not None:
|
||||
self._base = base
|
||||
if not hasattr(base, "query"):
|
||||
if not hasattr(base, "query") or override_query_class:
|
||||
base.query = _QueryObject(self)
|
||||
if not getattr(base, "query_class", None):
|
||||
if not getattr(base, "query_class", None) or override_query_class:
|
||||
base.query_class = self.Query
|
||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||
session_factory = sessionmaker(bind=self._engine)
|
||||
@@ -299,6 +301,59 @@ class DatabaseManager:
|
||||
def create_all(self):
|
||||
self.Model.metadata.create_all(self._engine)
|
||||
|
||||
@staticmethod
|
||||
def build_from(
|
||||
db_url_or_db: Union[str, URL, DatabaseManager],
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
) -> DatabaseManager:
|
||||
"""Build the database manager from the db_url_or_db.
|
||||
|
||||
Examples:
|
||||
|
||||
Build from the database url.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
from sqlalchemy import Column, Integer, String
|
||||
db = DatabaseManager.build_from("sqlite:///:memory:")
|
||||
class User(db.Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
db.create_all()
|
||||
with db.session() as session:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
session.commit()
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
|
||||
Args:
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(
|
||||
db_url_or_db, engine_args, base, query_class, override_query_class
|
||||
)
|
||||
return db_manager
|
||||
elif isinstance(db_url_or_db, DatabaseManager):
|
||||
return db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
)
|
||||
|
||||
|
||||
db = DatabaseManager()
|
||||
"""The global database manager.
|
||||
@@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@classmethod
|
||||
def set_db_manager(cls, db_manager: DatabaseManager):
|
||||
# TODO: It is hard to replace to user DB Connection
|
||||
cls._db_manager = db_manager
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
return db_manager._session().get(cls, ident)
|
||||
return cls._db_manager._session().get(cls, ident)
|
||||
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
session = db_manager._session()
|
||||
session = self._db_manager._session()
|
||||
session.add(self)
|
||||
if commit:
|
||||
session.commit()
|
||||
@@ -390,7 +452,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
session = db_manager._session()
|
||||
session = self._db_manager._session()
|
||||
session.delete(self)
|
||||
return commit and session.commit()
|
||||
|
||||
|
@@ -34,15 +34,8 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
query_class=BaseQuery,
|
||||
):
|
||||
super().__init__(serializer=serializer, adapter=adapter)
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
|
||||
self.db_manager = db_manager
|
||||
elif isinstance(db_url_or_db, DatabaseManager):
|
||||
self.db_manager = db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
self.db_manager = DatabaseManager.build_from(
|
||||
db_url_or_db, engine_args, base, query_class
|
||||
)
|
||||
self._model_class = model_class
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import pytest
|
||||
import tempfile
|
||||
from typing import Type
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
DatabaseManager,
|
||||
@@ -103,7 +104,6 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
|
||||
db.create_all()
|
||||
|
||||
# 添加数据
|
||||
with db.session() as session:
|
||||
for i in range(30):
|
||||
user = User(name=f"User {i}")
|
||||
@@ -127,3 +127,29 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
User.query.paginate_query(page=0, per_page=10)
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=1, per_page=-1)
|
||||
|
||||
|
||||
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
assert db.metadata.tables == {}
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True) as db_file:
|
||||
filename = db_file.name
|
||||
new_db = DatabaseManager.build_from(
|
||||
f"sqlite:///{filename}", base=Model, override_query_class=True
|
||||
)
|
||||
Model.set_db_manager(new_db)
|
||||
new_db.create_all()
|
||||
db.create_all()
|
||||
assert list(new_db.metadata.tables.keys())[0] == "user"
|
||||
User.create(**{"name": "John Doe"})
|
||||
with new_db.session() as session:
|
||||
assert session.query(User).filter_by(name="John Doe").first() is not None
|
||||
with db.session() as session:
|
||||
assert session.query(User).filter_by(name="John Doe").first() is None
|
||||
assert len(User.query.all()) == 1
|
||||
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"
|
||||
|
@@ -6,7 +6,8 @@
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
|
||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_chat \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"user_input": "hello"
|
||||
@@ -52,3 +53,14 @@ with DAG("dbgpt_awel_simple_dag_example") as dag:
|
||||
# type(out) == ModelOutput
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
trigger >> request_handle_task >> model_task >> model_parse_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag], port=5555)
|
||||
else:
|
||||
# Production mode, DB-GPT will automatically load and execute the current file after startup.
|
||||
pass
|
||||
|
186
examples/awel/simple_chat_history_example.py
Normal file
186
examples/awel/simple_chat_history_example.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""AWEL: Simple chat with history example
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Examples:
|
||||
|
||||
Call with non-streaming response.
|
||||
.. code-block:: shell
|
||||
|
||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||
MODEL="gpt-3.5-turbo"
|
||||
# Fist round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_1234"
|
||||
},
|
||||
"messages": "Who is elon musk?"
|
||||
}'
|
||||
|
||||
# Second round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_1234"
|
||||
},
|
||||
"messages": "Is he rich?"
|
||||
}'
|
||||
|
||||
Call with streaming response.
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_stream_1234"
|
||||
},
|
||||
"stream": true,
|
||||
"messages": "Who is elon musk?"
|
||||
}'
|
||||
|
||||
# Second round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_stream_1234"
|
||||
},
|
||||
"stream": true,
|
||||
"messages": "Is he rich?"
|
||||
}'
|
||||
|
||||
|
||||
"""
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import logging
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
HttpTrigger,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
)
|
||||
from dbgpt.core import LLMClient, InMemoryStorage
|
||||
from dbgpt.core.operator import (
|
||||
LLMBranchOperator,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
RequestBuildOperator,
|
||||
PreConversationOperator,
|
||||
PostConversationOperator,
|
||||
PostStreamingConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
)
|
||||
from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReqContext(BaseModel):
|
||||
user_name: Optional[str] = Field(
|
||||
None, description="The user name of the model request."
|
||||
)
|
||||
|
||||
sys_code: Optional[str] = Field(
|
||||
None, description="The system code of the model request."
|
||||
)
|
||||
conv_uid: Optional[str] = Field(
|
||||
None, description="The conversation uid of the model request."
|
||||
)
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
messages: Union[str, List[Dict[str, str]]] = Field(
|
||||
..., description="User input messages"
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||
context: Optional[ReqContext] = Field(
|
||||
default=None, description="The context of the model request."
|
||||
)
|
||||
|
||||
|
||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_history/multi_round/chat/completions",
|
||||
methods="POST",
|
||||
request_body=TriggerReqBody,
|
||||
streaming_predict_func=lambda req: req.stream,
|
||||
)
|
||||
# Transform request body to model request.
|
||||
request_handle_task = RequestBuildOperator()
|
||||
# Pre-process conversation, use InMemoryStorage to store conversation.
|
||||
pre_conversation_task = PreConversationOperator(
|
||||
storage=InMemoryStorage(), message_storage=InMemoryStorage()
|
||||
)
|
||||
# Keep last k round conversation.
|
||||
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
|
||||
|
||||
# Save conversation to storage.
|
||||
post_conversation_task = PostConversationOperator()
|
||||
# Save streaming conversation to storage.
|
||||
post_streaming_conversation_task = PostStreamingConversationOperator()
|
||||
|
||||
# Use LLMOperator to generate response.
|
||||
llm_task = MyLLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
||||
branch_task = LLMBranchOperator(
|
||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||
)
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
openai_format_stream_task = OpenAIStreamingOperator()
|
||||
result_join_task = JoinOperator(
|
||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||
)
|
||||
|
||||
(
|
||||
trigger
|
||||
>> request_handle_task
|
||||
>> pre_conversation_task
|
||||
>> history_conversation_task
|
||||
>> branch_task
|
||||
)
|
||||
|
||||
# The branch of no streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> llm_task
|
||||
>> post_conversation_task
|
||||
>> model_parse_task
|
||||
>> result_join_task
|
||||
)
|
||||
# The branch of streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> streaming_llm_task
|
||||
>> post_streaming_conversation_task
|
||||
>> openai_format_stream_task
|
||||
>> result_join_task
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if multi_round_dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([multi_round_dag], port=5555)
|
||||
else:
|
||||
# Production mode, DB-GPT will automatically load and execute the current file after startup.
|
||||
pass
|
@@ -2,24 +2,31 @@
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
|
||||
Call with non-streaming response.
|
||||
.. code-block:: shell
|
||||
|
||||
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"messages": "hello"
|
||||
}'
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \
|
||||
Call with streaming response.
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"messages": "hello",
|
||||
"stream": true
|
||||
}'
|
||||
|
||||
Call model and count token.
|
||||
.. code-block:: shell
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
@@ -27,20 +34,26 @@
|
||||
}'
|
||||
|
||||
"""
|
||||
from typing import Dict, Any, AsyncIterator, Optional, Union, List
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
import logging
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
LLMClient,
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
HttpTrigger,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
)
|
||||
from dbgpt.core import LLMClient
|
||||
|
||||
from dbgpt.core.operator import (
|
||||
LLMBranchOperator,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
RequestBuildOperator,
|
||||
)
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
@@ -51,58 +64,24 @@ class TriggerReqBody(BaseModel):
|
||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> ModelRequest:
|
||||
messages = [ModelMessage.build_human_message(input_value.messages)]
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
"request_model_name", input_value.model
|
||||
)
|
||||
return ModelRequest(
|
||||
model=input_value.model,
|
||||
messages=messages,
|
||||
echo=False,
|
||||
)
|
||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class LLMMixin:
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self._llm_client = DefaultLLMClient(worker_manager)
|
||||
return self._llm_client
|
||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyLLMOperator(LLMMixin, LLMOperator):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
super().__init__(llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
super().__init__(llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
from dbgpt.model.utils.chatgpt_utils import _to_openai_stream
|
||||
|
||||
model = await self.current_dag_context.get_share_data("request_model_name")
|
||||
async for output in _to_openai_stream(model, input_value):
|
||||
yield output
|
||||
|
||||
|
||||
class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]):
|
||||
def __init__(self, llm_client: LLMClient = None, **kwargs):
|
||||
self._llm_client = llm_client
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
class MyModelToolOperator(
|
||||
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
|
||||
):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
MapOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
|
||||
prompt_tokens = await self.llm_client.count_token(
|
||||
@@ -118,25 +97,27 @@ class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]])
|
||||
with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
model_task = MyLLMOperator()
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
trigger >> request_handle_task >> model_task >> model_parse_task
|
||||
|
||||
with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
trigger = HttpTrigger(
|
||||
"/examples/simple_client/generate_stream",
|
||||
"/examples/simple_client/chat/completions",
|
||||
methods="POST",
|
||||
request_body=TriggerReqBody,
|
||||
streaming_response=True,
|
||||
streaming_predict_func=lambda req: req.stream,
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
model_task = MyStreamingLLMOperator()
|
||||
openai_format_stream_task = MyLLMStreamingOperator()
|
||||
trigger >> request_handle_task >> model_task >> openai_format_stream_task
|
||||
request_handle_task = RequestBuildOperator()
|
||||
llm_task = MyLLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
||||
branch_task = LLMBranchOperator(
|
||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||
)
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
openai_format_stream_task = OpenAIStreamingOperator()
|
||||
result_join_task = JoinOperator(
|
||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||
)
|
||||
|
||||
trigger >> request_handle_task >> branch_task
|
||||
branch_task >> llm_task >> model_parse_task >> result_join_task
|
||||
branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
|
||||
# Receive http request and trigger dag to run.
|
||||
@@ -147,3 +128,15 @@ with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
|
||||
)
|
||||
model_task = MyModelToolOperator()
|
||||
trigger >> model_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if client_generate_dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
dags = [client_generate_dag, client_count_token_dag]
|
||||
setup_dev_environment(dags, port=5555)
|
||||
else:
|
||||
# Production mode, DB-GPT will automatically load and execute the current file after startup.
|
||||
pass
|
||||
|
@@ -2,9 +2,11 @@ import asyncio
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core import (
|
||||
BaseOutputParser,
|
||||
RequestBuildOperator,
|
||||
PromptTemplate,
|
||||
)
|
||||
from dbgpt.core.operator import (
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
)
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
|
||||
|
@@ -10,9 +10,11 @@ from dbgpt.core.awel import (
|
||||
)
|
||||
from dbgpt.core import (
|
||||
SQLOutputParser,
|
||||
PromptTemplate,
|
||||
)
|
||||
from dbgpt.core.operator import (
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
PromptTemplate,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||
|
Reference in New Issue
Block a user