mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
from dbgpt.core.interface.llm import (
|
||||
ModelInferenceMetrics,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
ModelOutput,
|
||||
LLMClient,
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
RequestBuildOperator,
|
||||
ModelMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.message import (
|
||||
@@ -17,7 +15,11 @@ from dbgpt.core.interface.message import (
|
||||
ConversationIdentifier,
|
||||
MessageIdentifier,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptTemplate,
|
||||
PromptManager,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import (
|
||||
@@ -38,17 +40,15 @@ from dbgpt.core.interface.storage import (
|
||||
StorageError,
|
||||
)
|
||||
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
"ModelRequest",
|
||||
"ModelRequestContext",
|
||||
"ModelOutput",
|
||||
"Operator",
|
||||
"RequestBuildOperator",
|
||||
"ModelMetadata",
|
||||
"ModelMessage",
|
||||
"LLMClient",
|
||||
"LLMOperator",
|
||||
"StreamingLLMOperator",
|
||||
"ModelMessageRoleType",
|
||||
"OnceConversation",
|
||||
"StorageConversation",
|
||||
@@ -56,7 +56,8 @@ __ALL__ = [
|
||||
"ConversationIdentifier",
|
||||
"MessageIdentifier",
|
||||
"PromptTemplate",
|
||||
"PromptTemplateOperator",
|
||||
"PromptManager",
|
||||
"StoragePromptTemplate",
|
||||
"BaseOutputParser",
|
||||
"SQLOutputParser",
|
||||
"Serializable",
|
||||
|
@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
@@ -68,6 +69,7 @@ __all__ = [
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
"setup_dev_environment",
|
||||
]
|
||||
|
||||
|
||||
@@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
||||
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .dag.base import DAGVar
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
trigger_manager = DefaultTriggerManager()
|
||||
system_app.register_instance(trigger_manager)
|
||||
|
||||
for dag in dags:
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
@@ -11,7 +11,7 @@ from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
from ..task.base import TaskContext, TaskOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -168,7 +168,19 @@ class DAGVar:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
pass
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
) -> None:
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
Dict[str, TaskContext]: The task outputs of current DAG
|
||||
"""
|
||||
return self._node_to_outputs
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
@@ -292,12 +347,69 @@ class DAGContext:
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
return self.get_from_share_data(_build_task_key(task_name, key))
|
||||
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
@@ -305,11 +417,20 @@ class DAG:
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
return
|
||||
if node.node_name:
|
||||
if node.node_name in self.node_name_to_node:
|
||||
raise ValueError(
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
@@ -336,22 +457,44 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
@@ -146,6 +146,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
In production mode, the default runner is not None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
|
@@ -1,4 +1,14 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
from typing import (
|
||||
Generic,
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Optional,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
@@ -203,7 +215,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
@@ -229,7 +241,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
@@ -18,18 +19,20 @@ class DAGInstance:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
class JobManager(DAGLifecycle):
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
@@ -38,11 +41,31 @@ class JobManager:
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
node_name_to_ids = {}
|
||||
for node in nodes:
|
||||
if node.node_name is not None:
|
||||
node_name_to_ids[node.node_name] = node.node_id
|
||||
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
@@ -66,6 +89,7 @@ def _save_call_data(
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from dbgpt.component import SystemApp
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
|
||||
)
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
|
||||
return dag_ctx
|
||||
|
||||
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager,
|
||||
upstream_node,
|
||||
dag_ctx,
|
||||
node_outputs,
|
||||
skip_node_ids,
|
||||
system_app,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
@@ -13,7 +13,8 @@ from ..operator.base import BaseOperator
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], str]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +26,7 @@ class HttpTrigger(Trigger):
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
@@ -39,6 +41,7 @@ class HttpTrigger(Trigger):
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._streaming_predict_func = streaming_predict_func
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
@@ -59,10 +62,13 @@ class HttpTrigger(Trigger):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
streaming_response = self._streaming_response
|
||||
if self._streaming_predict_func:
|
||||
streaming_response = self._streaming_predict_func(body)
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
@@ -112,6 +118,7 @@ async def _trigger_dag(
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
@@ -131,8 +138,11 @@ async def _trigger_dag(
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(end_node.dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background_tasks,
|
||||
)
|
||||
|
@@ -1,8 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Any, TypeVar, Generic, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
|
@@ -1,14 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Any, Union, AsyncIterator
|
||||
import time
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import copy
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,6 +96,28 @@ class ModelInferenceMetrics:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequestContext:
|
||||
stream: Optional[bool] = False
|
||||
"""Whether to return a stream of responses."""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""The user name of the model request."""
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""The system code of the model request."""
|
||||
|
||||
conv_uid: Optional[str] = None
|
||||
"""The conversation id of the model inference."""
|
||||
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelOutput:
|
||||
@@ -145,6 +166,27 @@ class ModelRequest:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
context: Optional[ModelRequestContext] = field(
|
||||
default_factory=lambda: ModelRequestContext()
|
||||
)
|
||||
"""The context of the model inference."""
|
||||
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
"""Whether to return a stream of responses."""
|
||||
return self.context and self.context.stream
|
||||
|
||||
def copy(self):
|
||||
new_request = copy.deepcopy(self)
|
||||
# Transform messages to List[ModelMessage]
|
||||
new_request.messages = list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
new_request.messages,
|
||||
)
|
||||
)
|
||||
return new_request
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
@@ -161,6 +203,17 @@ class ModelRequest:
|
||||
)
|
||||
)
|
||||
|
||||
def get_single_user_message(self) -> Optional[ModelMessage]:
|
||||
"""Get the single user message.
|
||||
|
||||
Returns:
|
||||
Optional[ModelMessage]: The single user message.
|
||||
"""
|
||||
messages = self._get_messages()
|
||||
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
||||
@staticmethod
|
||||
def _build(model: str, prompt: str, **kwargs):
|
||||
return ModelRequest(
|
||||
@@ -178,11 +231,22 @@ class ModelRequest:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
|
||||
),
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content="Who are your"
|
||||
),
|
||||
]
|
||||
openai_messages = ModelRequest.to_openai_messages(messages)
|
||||
assert openai_messages == [
|
||||
@@ -272,63 +336,3 @@ class LLMClient(ABC):
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: str) -> ModelRequest:
|
||||
return ModelRequest._build(self._model, input_value)
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageItem,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class ModelMessage(BaseModel):
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
round_index: Optional[int] = 0
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
@@ -443,6 +444,7 @@ class OnceConversation:
|
||||
self.tokens = conversation.tokens
|
||||
self.user_name = conversation.user_name
|
||||
self.sys_code = conversation.sys_code
|
||||
self._message_index = conversation._message_index
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
@@ -470,6 +472,7 @@ class OnceConversation:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
conversation = OnceConversation()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the first round")
|
||||
@@ -485,11 +488,17 @@ class OnceConversation:
|
||||
conversation.end_current_round()
|
||||
|
||||
assert len(conversation.get_messages_with_round(1)) == 2
|
||||
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(1)[0].content
|
||||
== "hello, this is the third round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(1)[1].content == "hi"
|
||||
|
||||
assert len(conversation.get_messages_with_round(2)) == 4
|
||||
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(2)[0].content
|
||||
== "hello, this is the second round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(2)[1].content == "hi"
|
||||
|
||||
Args:
|
||||
@@ -517,6 +526,7 @@ class OnceConversation:
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
for message in self.get_latest_round():
|
||||
@@ -528,6 +538,7 @@ class OnceConversation:
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
latest_round_index = self.chat_order
|
||||
@@ -537,7 +548,9 @@ class OnceConversation:
|
||||
for message in self.get_messages_by_round(round_index):
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type, content=message.content
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -548,7 +561,11 @@ class OnceConversation:
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type,
|
||||
content=message.content,
|
||||
round_index=message.round_index,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -780,6 +797,9 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
conversation.messages = messages
|
||||
# This index is used to save the message to the storage(Has not been saved)
|
||||
# The new message append to the messages, so the index is len(messages)
|
||||
conversation._message_index = len(messages)
|
||||
self._message_ids = message_ids
|
||||
self._has_stored_message_index = len(messages) - 1
|
||||
self.from_conversation(conversation)
|
||||
|
0
dbgpt/core/interface/operator/__init__.py
Normal file
0
dbgpt/core/interface/operator/__init__.py
Normal file
166
dbgpt/core/interface/operator/llm_operator.py
Normal file
166
dbgpt/core/interface/operator/llm_operator.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
LLMClient,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
|
||||
RequestInput = Union[
|
||||
ModelRequest,
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseModel,
|
||||
]
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: RequestInput) -> ModelRequest:
|
||||
req_dict = {}
|
||||
if isinstance(input_value, str):
|
||||
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
|
||||
elif isinstance(input_value, dict):
|
||||
req_dict = input_value
|
||||
elif dataclasses.is_dataclass(input_value):
|
||||
req_dict = dataclasses.asdict(input_value)
|
||||
elif isinstance(input_value, BaseModel):
|
||||
req_dict = input_value.dict()
|
||||
elif isinstance(input_value, ModelRequest):
|
||||
if not input_value.model:
|
||||
input_value.model = self._model
|
||||
return input_value
|
||||
if "messages" not in req_dict:
|
||||
raise ValueError("messages is not set")
|
||||
messages = req_dict["messages"]
|
||||
if isinstance(messages, str):
|
||||
# Single message, transform to a list including one human message
|
||||
req_dict["messages"] = [ModelMessage.build_human_message(messages)]
|
||||
if "model" not in req_dict:
|
||||
req_dict["model"] = self._model
|
||||
if not req_dict["model"]:
|
||||
raise ValueError("model is not set")
|
||||
stream = False
|
||||
has_stream = False
|
||||
if "stream" in req_dict:
|
||||
has_stream = True
|
||||
stream = req_dict["stream"]
|
||||
del req_dict["stream"]
|
||||
if "context" not in req_dict:
|
||||
req_dict["context"] = ModelRequestContext(stream=stream)
|
||||
else:
|
||||
context_dict = req_dict["context"]
|
||||
if not isinstance(context_dict, dict):
|
||||
raise ValueError("context is not a dict")
|
||||
if has_stream:
|
||||
context_dict["stream"] = stream
|
||||
req_dict["context"] = ModelRequestContext(**context_dict)
|
||||
return ModelRequest(**req_dict)
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
if not self._llm_client:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
return await self.llm_client.generate(request)
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client=llm_client)
|
||||
StreamifyAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
yield output
|
||||
|
||||
|
||||
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""Branch operator for LLM.
|
||||
|
||||
This operator will branch the workflow based on the stream flag of the request.
|
||||
"""
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not stream_task_name:
|
||||
raise ValueError("stream_task_name is not set")
|
||||
if not no_stream_task_name:
|
||||
raise ValueError("no_stream_task_name is not set")
|
||||
self._stream_task_name = stream_task_name
|
||||
self._no_stream_task_name = no_stream_task_name
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[ModelRequest], str]:
|
||||
"""
|
||||
Return a dict of branch function and task name.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name.
|
||||
the key is a predicate function, the value is the task name. If the predicate function returns True,
|
||||
we will run the corresponding task.
|
||||
"""
|
||||
|
||||
async def check_stream_true(r: ModelRequest) -> bool:
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task.
|
||||
return r.stream
|
||||
|
||||
return {
|
||||
check_stream_true: self._stream_task_name,
|
||||
lambda x: not x.stream: self._no_stream_task_name,
|
||||
}
|
321
dbgpt/core/interface/operator/message_operator.py
Normal file
321
dbgpt/core/interface/operator/message_operator.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, List, Optional
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
"""Base class for conversation operators."""
|
||||
|
||||
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
|
||||
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StorageConversation, Any]:
|
||||
"""Return the LLM client."""
|
||||
if not self._storage:
|
||||
raise ValueError("Storage is not set")
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
|
||||
"""Return the LLM client."""
|
||||
if not self._message_storage:
|
||||
raise ValueError("Message storage is not set")
|
||||
return self._message_storage
|
||||
|
||||
async def get_storage_conversation(self) -> StorageConversation:
|
||||
"""Get the storage conversation from share data.
|
||||
|
||||
Returns:
|
||||
StorageConversation: The storage conversation.
|
||||
"""
|
||||
storage_conv: StorageConversation = (
|
||||
await self.current_dag_context.get_from_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION
|
||||
)
|
||||
)
|
||||
if not storage_conv:
|
||||
raise ValueError("Storage conversation is not set")
|
||||
return storage_conv
|
||||
|
||||
async def get_model_request(self) -> ModelRequest:
|
||||
"""Get the model request from share data.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The model request.
|
||||
"""
|
||||
model_request: ModelRequest = (
|
||||
await self.current_dag_context.get_from_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST
|
||||
)
|
||||
)
|
||||
if not model_request:
|
||||
raise ValueError("Model request is not set")
|
||||
return model_request
|
||||
|
||||
|
||||
class PreConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The operator to prepare the storage conversation.
|
||||
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||
and they can store in different storage(for high performance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
if input_value.context is None:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.conv_uid:
|
||||
input_value.context.conv_uid = str(uuid.uuid4())
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
|
||||
chat_mode = input_value.context.extra.get("chat_mode")
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
StorageConversation,
|
||||
conv_uid=input_value.context.conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
user_name=input_value.context.user_name,
|
||||
sys_code=input_value.context.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
# The input message must be a single user message
|
||||
single_human_message: ModelMessage = input_value.get_single_user_message()
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(single_human_message.content)
|
||||
|
||||
# Get all messages from current storage conversation, and overwrite the input value
|
||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
||||
input_value.messages = messages
|
||||
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
|
||||
)
|
||||
return input_value
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
# TODO dont save if the conversation has some internal error
|
||||
storage_conv.end_current_round()
|
||||
|
||||
|
||||
class PostConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Map the input value to a ModelOutput.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The mapped ModelOutput.
|
||||
"""
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(input_value.text)
|
||||
return input_value
|
||||
|
||||
|
||||
class PostStreamingConversationOperator(
|
||||
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
TransformStreamAbsOperator.__init__(self, **kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> ModelOutput:
|
||||
"""Transform the input value to a ModelOutput.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The transformed ModelOutput.
|
||||
"""
|
||||
full_text = ""
|
||||
async for model_output in input_value:
|
||||
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
|
||||
full_text = model_output.text
|
||||
yield model_output
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(full_text)
|
||||
|
||||
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
input_value = input_value.copy()
|
||||
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
|
||||
# Overwrite the input value
|
||||
input_value.messages = messages
|
||||
return input_value
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[List[ModelMessage]]:
|
||||
"""Split the messages by round index.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The splitted messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = []
|
||||
last_round_index = 0
|
||||
for message in messages:
|
||||
if not message.round_index:
|
||||
# Round index must bigger than 0
|
||||
raise ValueError("Message round_index is not set")
|
||||
if message.round_index > last_round_index:
|
||||
last_round_index = message.round_index
|
||||
messages_by_round.append([])
|
||||
messages_by_round[-1].append(message)
|
||||
return messages_by_round
|
||||
|
||||
|
||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The buffered conversation mapper operator.
|
||||
|
||||
This Operator must be used after the PreConversationOperator,
|
||||
and it will map the messages in the storage conversation.
|
||||
|
||||
Examples:
|
||||
|
||||
Transform no history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# With history
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
# Just keep the last one round, so the first round messages will be removed
|
||||
# Note: The round index 3 is not a complete round
|
||||
assert messages == [
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._last_k_round = last_k_round
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
# Get the last k round messages
|
||||
index = self._last_k_round + 1
|
||||
messages_by_round = messages_by_round[-index:]
|
||||
messages: List[ModelMessage] = sum(messages_by_round, [])
|
||||
return messages
|
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
import logging
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseTye = Union[str, bytes, ModelOutput]
|
||||
|
@@ -1,12 +1,20 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
|
||||
|
||||
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
@@ -86,6 +94,434 @@ class PromptTemplate(BaseModel, ABC):
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
|
||||
prompt_name: str
|
||||
prompt_language: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
|
||||
if any(
|
||||
self.identifier_split in key
|
||||
for key in [
|
||||
self.prompt_name,
|
||||
self.prompt_language,
|
||||
self.sys_code,
|
||||
self.model,
|
||||
]
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return self.identifier_split.join(
|
||||
key
|
||||
for key in [
|
||||
self.prompt_name,
|
||||
self.prompt_language,
|
||||
self.sys_code,
|
||||
self.model,
|
||||
]
|
||||
if key is not None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"prompt_language": self.prompt_language,
|
||||
"sys_code": self.sys_code,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StoragePromptTemplate(StorageItem):
|
||||
prompt_name: str
|
||||
content: Optional[str] = None
|
||||
prompt_language: Optional[str] = None
|
||||
prompt_format: Optional[str] = None
|
||||
input_variables: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
chat_scene: Optional[str] = None
|
||||
sub_chat_scene: Optional[str] = None
|
||||
prompt_type: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
sys_code: Optional[str] = None
|
||||
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._identifier = PromptTemplateIdentifier(
|
||||
prompt_name=self.prompt_name,
|
||||
prompt_language=self.prompt_language,
|
||||
sys_code=self.sys_code,
|
||||
model=self.model,
|
||||
)
|
||||
self._check() # Assuming _check() is a method you need to call after initialization
|
||||
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
None
|
||||
if not self.input_variables
|
||||
else self.input_variables.strip().split(",")
|
||||
)
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=self.content,
|
||||
template_scene=self.chat_scene,
|
||||
prompt_name=self.prompt_name,
|
||||
template_format=self.prompt_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prompt_template(
|
||||
prompt_template: PromptTemplate,
|
||||
prompt_name: str,
|
||||
prompt_language: Optional[str] = None,
|
||||
prompt_type: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
sub_chat_scene: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> "StoragePromptTemplate":
|
||||
"""Convert a prompt template to a storage prompt template.
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to convert from.
|
||||
prompt_name (str): The name of the prompt.
|
||||
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt. Defaults to None.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
"""
|
||||
input_variables = prompt_template.input_variables or kwargs.get(
|
||||
"input_variables"
|
||||
)
|
||||
if input_variables and isinstance(input_variables, list):
|
||||
input_variables = ",".join(input_variables)
|
||||
return StoragePromptTemplate(
|
||||
prompt_name=prompt_name,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
input_variables=input_variables,
|
||||
model=model,
|
||||
content=prompt_template.template or kwargs.get("content"),
|
||||
prompt_language=prompt_language,
|
||||
prompt_format=prompt_template.template_format
|
||||
or kwargs.get("prompt_format"),
|
||||
chat_scene=prompt_template.template_scene or kwargs.get("chat_scene"),
|
||||
sub_chat_scene=sub_chat_scene,
|
||||
prompt_type=prompt_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def identifier(self) -> PromptTemplateIdentifier:
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other item into the current item.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other item to merge
|
||||
"""
|
||||
if not isinstance(other, StoragePromptTemplate):
|
||||
raise ValueError(
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
|
||||
)
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"content": self.content,
|
||||
"prompt_language": self.prompt_language,
|
||||
"prompt_format": self.prompt_format,
|
||||
"input_variables": self.input_variables,
|
||||
"model": self.model,
|
||||
"chat_scene": self.chat_scene,
|
||||
"sub_chat_scene": self.sub_chat_scene,
|
||||
"prompt_type": self.prompt_type,
|
||||
"user_name": self.user_name,
|
||||
"sys_code": self.sys_code,
|
||||
}
|
||||
|
||||
def _check(self):
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
if self.content is None:
|
||||
raise ValueError("content cannot be None")
|
||||
|
||||
def from_object(self, template: "StoragePromptTemplate") -> None:
|
||||
"""Load the prompt template from an existing prompt template object.
|
||||
|
||||
Args:
|
||||
template (PromptTemplate): The prompt template to load from.
|
||||
"""
|
||||
self.content = template.content
|
||||
self.prompt_format = template.prompt_format
|
||||
self.input_variables = template.input_variables
|
||||
self.model = template.model
|
||||
self.chat_scene = template.chat_scene
|
||||
self.sub_chat_scene = template.sub_chat_scene
|
||||
self.prompt_type = template.prompt_type
|
||||
self.user_name = template.user_name
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""The manager class for prompt templates.
|
||||
|
||||
Simple wrapper for the storage interface.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Default use InMemoryStorage
|
||||
prompt_manager = PromptManager()
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(prompt_template, prompt_name="hello")
|
||||
prompt_template_list = prompt_manager.list()
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
With a custom storage interface.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.storage import InMemoryStorage
|
||||
|
||||
prompt_manager = PromptManager(InMemoryStorage())
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(prompt_template, prompt_name="hello")
|
||||
prompt_template_list = prompt_manager.list()
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
|
||||
):
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
|
||||
"""The storage interface for prompt templates."""
|
||||
return self._storage
|
||||
|
||||
def prefer_query(
|
||||
self,
|
||||
prompt_name: str,
|
||||
sys_code: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[StoragePromptTemplate]:
|
||||
"""Query prompt templates from storage with prefer params.
|
||||
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
|
||||
|
||||
Examples:
|
||||
|
||||
Query a prompt template.
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template_list = prompt_manager.prefer_query("hello")
|
||||
|
||||
Query with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Query with prefer prompt language.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_prompt_language="zh-cn"
|
||||
)
|
||||
|
||||
Query with prefer model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_model="vicuna-13b-v1.5"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
|
||||
"""
|
||||
query_spec = QuerySpec(
|
||||
conditions={
|
||||
"prompt_name": prompt_name,
|
||||
"sys_code": sys_code,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
queries: List[StoragePromptTemplate] = self.storage.query(
|
||||
query_spec, StoragePromptTemplate
|
||||
)
|
||||
if not queries:
|
||||
return []
|
||||
if prefer_prompt_language:
|
||||
prefer_prompt_language = prefer_prompt_language.lower()
|
||||
temp_queries = [
|
||||
query
|
||||
for query in queries
|
||||
if query.prompt_language
|
||||
and query.prompt_language.lower() == prefer_prompt_language
|
||||
]
|
||||
if temp_queries:
|
||||
queries = temp_queries
|
||||
if prefer_model:
|
||||
prefer_model = prefer_model.lower()
|
||||
temp_queries = [
|
||||
query
|
||||
for query in queries
|
||||
if query.model and query.model.lower() == prefer_model
|
||||
]
|
||||
if temp_queries:
|
||||
queries = temp_queries
|
||||
return queries
|
||||
|
||||
def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> None:
|
||||
"""Save a prompt template to storage.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
prompt_name="hello",
|
||||
)
|
||||
prompt_manager.save(prompt_template)
|
||||
|
||||
Save with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
prompt_name="hello",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template, sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to save.
|
||||
prompt_name (str): The name of the prompt template.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
|
||||
"""
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template, prompt_name, **kwargs
|
||||
)
|
||||
self.storage.save(storage_prompt_template)
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
List all prompt templates.
|
||||
.. code-block:: python
|
||||
|
||||
all_prompt_templates = prompt_manager.list()
|
||||
|
||||
List with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
templates = prompt_manager.list(
|
||||
sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
kwargs (Dict): Other query params.
|
||||
"""
|
||||
query_spec = QuerySpec(conditions=kwargs)
|
||||
return self.storage.query(query_spec, StoragePromptTemplate)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
prompt_name: str,
|
||||
prompt_language: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Delete a prompt template from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
Delete a prompt template.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_manager.delete("hello")
|
||||
|
||||
Delete with sys_code and username.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_manager.delete(
|
||||
"hello", sys_code="sys_code", user_name="user_name"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
"""
|
||||
identifier = PromptTemplateIdentifier(
|
||||
prompt_name=prompt_name,
|
||||
prompt_language=prompt_language,
|
||||
sys_code=sys_code,
|
||||
model=model,
|
||||
)
|
||||
self.storage.delete(identifier)
|
||||
|
||||
|
||||
class PromptTemplateOperator(MapOperator[Dict, str]):
|
||||
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, Dict
|
||||
from typing import Dict, Type
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
from dbgpt.core.interface.message import *
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
320
dbgpt/core/interface/tests/test_prompt.py
Normal file
320
dbgpt/core/interface/tests/test_prompt.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.prompt import (
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
StoragePromptTemplate,
|
||||
)
|
||||
from dbgpt.core.interface.storage import QuerySpec
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_storage_prompt_template():
|
||||
return StoragePromptTemplate(
|
||||
prompt_name="test_prompt",
|
||||
content="Sample content, {var1}, {var2}",
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="var1,var2",
|
||||
model="model1",
|
||||
chat_scene="scene1",
|
||||
sub_chat_scene="subscene1",
|
||||
prompt_type="type1",
|
||||
user_name="user1",
|
||||
sys_code="code1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_storage_prompt_template():
|
||||
content = """Database name: {db_name} Table structure definition: {table_info} User Question:{user_input}"""
|
||||
return StoragePromptTemplate(
|
||||
prompt_name="chat_data_auto_execute_prompt",
|
||||
content=content,
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="db_name,table_info,user_input",
|
||||
model="vicuna-13b-v1.5",
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="subscene1",
|
||||
prompt_type="common",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(in_memory_storage):
|
||||
return PromptManager(storage=in_memory_storage)
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
@pytest.mark.parametrize(
|
||||
"template_str, input_vars, expected_output",
|
||||
[
|
||||
("Hello {name}", {"name": "World"}, "Hello World"),
|
||||
("{greeting}, {name}", {"greeting": "Hi", "name": "Alice"}, "Hi, Alice"),
|
||||
],
|
||||
)
|
||||
def test_format_f_string(self, template_str, input_vars, expected_output):
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=list(input_vars.keys()),
|
||||
template_format="f-string",
|
||||
)
|
||||
formatted_output = prompt.format(**input_vars)
|
||||
assert formatted_output == expected_output
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template_str, input_vars, expected_output",
|
||||
[
|
||||
("Hello {{ name }}", {"name": "World"}, "Hello World"),
|
||||
(
|
||||
"{{ greeting }}, {{ name }}",
|
||||
{"greeting": "Hi", "name": "Alice"},
|
||||
"Hi, Alice",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_format_jinja2(self, template_str, input_vars, expected_output):
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=list(input_vars.keys()),
|
||||
template_format="jinja2",
|
||||
)
|
||||
formatted_output = prompt.format(**input_vars)
|
||||
assert formatted_output == expected_output
|
||||
|
||||
def test_format_with_response_format(self):
|
||||
template_str = "Response: {response}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=["response"],
|
||||
template_format="f-string",
|
||||
response_format=json.dumps({"message": "hello"}),
|
||||
)
|
||||
formatted_output = prompt.format(response="hello")
|
||||
assert "Response: " in formatted_output
|
||||
|
||||
def test_from_template(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate.from_template(template_str)
|
||||
assert prompt._prompt_template.template == template_str
|
||||
assert prompt._prompt_template.input_variables == []
|
||||
|
||||
def test_format_missing_variable(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str, input_variables=["name"], template_format="f-string"
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
prompt.format()
|
||||
|
||||
def test_format_extra_variable(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate(
|
||||
template=template_str,
|
||||
input_variables=["name"],
|
||||
template_format="f-string",
|
||||
template_is_strict=False,
|
||||
)
|
||||
formatted_output = prompt.format(name="World", extra="unused")
|
||||
assert formatted_output == "Hello World"
|
||||
|
||||
def test_format_complex(self, complex_storage_prompt_template):
|
||||
prompt = complex_storage_prompt_template.to_prompt_template()
|
||||
formatted_output = prompt.format(
|
||||
db_name="db1",
|
||||
table_info="create table users(id int, name varchar(20))",
|
||||
user_input="find all users whose name is 'Alice'",
|
||||
)
|
||||
assert (
|
||||
formatted_output
|
||||
== "Database name: db1 Table structure definition: create table users(id int, name varchar(20)) "
|
||||
"User Question:find all users whose name is 'Alice'"
|
||||
)
|
||||
|
||||
|
||||
class TestStoragePromptTemplate:
|
||||
def test_constructor_and_properties(self):
|
||||
storage_item = StoragePromptTemplate(
|
||||
prompt_name="test",
|
||||
content="Hello {name}",
|
||||
prompt_language="en",
|
||||
prompt_format="f-string",
|
||||
input_variables="name",
|
||||
model="model1",
|
||||
chat_scene="chat",
|
||||
sub_chat_scene="sub_chat",
|
||||
prompt_type="type",
|
||||
user_name="user",
|
||||
sys_code="sys",
|
||||
)
|
||||
assert storage_item.prompt_name == "test"
|
||||
assert storage_item.content == "Hello {name}"
|
||||
assert storage_item.prompt_language == "en"
|
||||
assert storage_item.prompt_format == "f-string"
|
||||
assert storage_item.input_variables == "name"
|
||||
assert storage_item.model == "model1"
|
||||
|
||||
def test_constructor_exceptions(self):
|
||||
with pytest.raises(ValueError):
|
||||
StoragePromptTemplate(prompt_name=None, content="Hello")
|
||||
|
||||
def test_to_prompt_template(self, sample_storage_prompt_template):
|
||||
prompt_template = sample_storage_prompt_template.to_prompt_template()
|
||||
assert isinstance(prompt_template, PromptTemplate)
|
||||
assert prompt_template.template == "Sample content, {var1}, {var2}"
|
||||
assert prompt_template.input_variables == ["var1", "var2"]
|
||||
|
||||
def test_from_prompt_template(self):
|
||||
prompt_template = PromptTemplate(
|
||||
template="Sample content, {var1}, {var2}",
|
||||
input_variables=["var1", "var2"],
|
||||
template_format="f-string",
|
||||
)
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template=prompt_template, prompt_name="test_prompt"
|
||||
)
|
||||
assert storage_prompt_template.prompt_name == "test_prompt"
|
||||
assert storage_prompt_template.content == "Sample content, {var1}, {var2}"
|
||||
assert storage_prompt_template.input_variables == "var1,var2"
|
||||
|
||||
def test_merge(self, sample_storage_prompt_template):
|
||||
other = StoragePromptTemplate(
|
||||
prompt_name="other_prompt",
|
||||
content="Other content",
|
||||
)
|
||||
sample_storage_prompt_template.merge(other)
|
||||
assert sample_storage_prompt_template.content == "Other content"
|
||||
|
||||
def test_to_dict(self, sample_storage_prompt_template):
|
||||
result = sample_storage_prompt_template.to_dict()
|
||||
assert result == {
|
||||
"prompt_name": "test_prompt",
|
||||
"content": "Sample content, {var1}, {var2}",
|
||||
"prompt_language": "en",
|
||||
"prompt_format": "f-string",
|
||||
"input_variables": "var1,var2",
|
||||
"model": "model1",
|
||||
"chat_scene": "scene1",
|
||||
"sub_chat_scene": "subscene1",
|
||||
"prompt_type": "type1",
|
||||
"user_name": "user1",
|
||||
"sys_code": "code1",
|
||||
}
|
||||
|
||||
def test_save_and_load_storage(
|
||||
self, sample_storage_prompt_template, in_memory_storage
|
||||
):
|
||||
in_memory_storage.save(sample_storage_prompt_template)
|
||||
loaded_item = in_memory_storage.load(
|
||||
sample_storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
assert loaded_item.content == "Sample content, {var1}, {var2}"
|
||||
|
||||
def test_check_exceptions(self):
|
||||
with pytest.raises(ValueError):
|
||||
StoragePromptTemplate(prompt_name=None, content="Hello")
|
||||
|
||||
def test_from_object(self, sample_storage_prompt_template):
|
||||
other = StoragePromptTemplate(prompt_name="other", content="Other content")
|
||||
sample_storage_prompt_template.from_object(other)
|
||||
assert sample_storage_prompt_template.content == "Other content"
|
||||
assert sample_storage_prompt_template.input_variables != "var1,var2"
|
||||
# Prompt name should not be changed
|
||||
assert sample_storage_prompt_template.prompt_name == "test_prompt"
|
||||
assert sample_storage_prompt_template.sys_code == "code1"
|
||||
|
||||
|
||||
class TestPromptManager:
|
||||
def test_save(self, prompt_manager, in_memory_storage):
|
||||
prompt_template = PromptTemplate(
|
||||
template="hello {input}",
|
||||
input_variables=["input"],
|
||||
template_scene="chat_normal",
|
||||
)
|
||||
prompt_manager.save(
|
||||
prompt_template,
|
||||
prompt_name="hello",
|
||||
)
|
||||
result = in_memory_storage.query(
|
||||
QuerySpec(conditions={"prompt_name": "hello"}), StoragePromptTemplate
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "hello {input}"
|
||||
|
||||
def test_prefer_query_simple(self, prompt_manager, in_memory_storage):
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(prompt_name="test_prompt", content="test")
|
||||
)
|
||||
result = prompt_manager.prefer_query("test_prompt")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
|
||||
def test_prefer_query_language(self, prompt_manager, in_memory_storage):
|
||||
for language in ["en", "zh"]:
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(
|
||||
prompt_name="test_prompt",
|
||||
content="test",
|
||||
prompt_language=language,
|
||||
)
|
||||
)
|
||||
# Prefer zh, and zh exists, will return zh prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].prompt_language == "zh"
|
||||
# Prefer language not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query(
|
||||
"test_prompt", prefer_prompt_language="not_exist"
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_prefer_query_model(self, prompt_manager, in_memory_storage):
|
||||
for model in ["model1", "model2"]:
|
||||
in_memory_storage.save(
|
||||
StoragePromptTemplate(
|
||||
prompt_name="test_prompt", content="test", model=model
|
||||
)
|
||||
)
|
||||
# Prefer model1, and model1 exists, will return model1 prompt template
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "test"
|
||||
assert result[0].model == "model1"
|
||||
# Prefer model not exists, will return all prompt templates of this name
|
||||
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
|
||||
assert len(result) == 2
|
||||
|
||||
def test_list(self, prompt_manager, in_memory_storage):
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="name1",
|
||||
)
|
||||
prompt_manager.save(
|
||||
PromptTemplate(
|
||||
template="Write a SQL of {dialect} to query all data of {table_name}.",
|
||||
input_variables=["dialect", "table_name"],
|
||||
),
|
||||
prompt_name="sql_template",
|
||||
)
|
||||
all_templates = prompt_manager.list()
|
||||
assert len(all_templates) == 2
|
||||
assert len(prompt_manager.list(prompt_name="name1")) == 1
|
||||
assert len(prompt_manager.list(prompt_name="not exist")) == 0
|
||||
|
||||
def test_delete(self, prompt_manager, in_memory_storage):
|
||||
prompt_manager.save(
|
||||
PromptTemplate(template="Hello {name}", input_variables=["name"]),
|
||||
prompt_name="to_delete",
|
||||
)
|
||||
prompt_manager.delete("to_delete")
|
||||
result = in_memory_storage.query(
|
||||
QuerySpec(conditions={"prompt_name": "to_delete"}), StoragePromptTemplate
|
||||
)
|
||||
assert len(result) == 0
|
@@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
QuerySpec,
|
||||
InMemoryStorage,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
31
dbgpt/core/operator/__init__.py
Normal file
31
dbgpt/core/operator/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from dbgpt.core.interface.operator.llm_operator import (
|
||||
BaseLLM,
|
||||
LLMBranchOperator,
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.message_operator import (
|
||||
BaseConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
ConversationMapperOperator,
|
||||
PostConversationOperator,
|
||||
PostStreamingConversationOperator,
|
||||
PreConversationOperator,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplateOperator
|
||||
|
||||
__ALL__ = [
|
||||
"BaseLLM",
|
||||
"LLMBranchOperator",
|
||||
"LLMOperator",
|
||||
"RequestBuildOperator",
|
||||
"StreamingLLMOperator",
|
||||
"BaseConversationOperator",
|
||||
"BufferedConversationMapperOperator",
|
||||
"ConversationMapperOperator",
|
||||
"PostConversationOperator",
|
||||
"PostStreamingConversationOperator",
|
||||
"PreConversationOperator",
|
||||
"PromptTemplateOperator",
|
||||
]
|
Reference in New Issue
Block a user