feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -169,14 +169,18 @@ CREATE TABLE IF NOT EXISTS `prompt_manage`
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';

View File

@@ -1,11 +1,9 @@
from dbgpt.core.interface.llm import (
ModelInferenceMetrics,
ModelRequest,
ModelRequestContext,
ModelOutput,
LLMClient,
LLMOperator,
StreamingLLMOperator,
RequestBuildOperator,
ModelMetadata,
)
from dbgpt.core.interface.message import (
@@ -17,7 +15,11 @@ from dbgpt.core.interface.message import (
ConversationIdentifier,
MessageIdentifier,
)
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.prompt import (
PromptTemplate,
PromptManager,
StoragePromptTemplate,
)
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.cache import (
@@ -38,17 +40,15 @@ from dbgpt.core.interface.storage import (
StorageError,
)
__ALL__ = [
"ModelInferenceMetrics",
"ModelRequest",
"ModelRequestContext",
"ModelOutput",
"Operator",
"RequestBuildOperator",
"ModelMetadata",
"ModelMessage",
"LLMClient",
"LLMOperator",
"StreamingLLMOperator",
"ModelMessageRoleType",
"OnceConversation",
"StorageConversation",
@@ -56,7 +56,8 @@ __ALL__ = [
"ConversationIdentifier",
"MessageIdentifier",
"PromptTemplate",
"PromptTemplateOperator",
"PromptManager",
"StoragePromptTemplate",
"BaseOutputParser",
"SQLOutputParser",
"Serializable",

View File

@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
"""
from typing import List, Optional
from dbgpt.component import SystemApp
from .dag.base import DAGContext, DAG
@@ -68,6 +69,7 @@ __all__ = [
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
"setup_dev_environment",
]
@@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()
def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
) -> None:
"""Setup a development environment for AWEL.
Just using in development environment, not production environment.
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.base import DAGVar
app = FastAPI()
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
system_app.register_instance(trigger_manager)
for dag in dags:
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
uvicorn.run(app, host=host, port=port)

View File

@@ -11,7 +11,7 @@ from concurrent.futures import Executor
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
from ..task.base import TaskContext, TaskOutput
logger = logging.getLogger(__name__)
@@ -168,7 +168,19 @@ class DAGVar:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
class DAGLifecycle:
"""The lifecycle of DAG"""
async def before_dag_run(self):
"""The callback before DAG run"""
pass
async def after_dag_end(self):
"""The callback after DAG end"""
pass
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
**kwargs,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str:
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
Args:
system_app (SystemApp): The system app
"""
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
node._upstream.append(self)
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
"""The context of current DAG, created when the DAG is running
Every DAG has been triggered will create a new DAGContext.
"""
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
) -> None:
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
Just use for internal for now.
Returns:
Dict[str, TaskContext]: The task outputs of current DAG
"""
return self._node_to_outputs
@property
def current_task_context(self) -> TaskContext:
@@ -292,12 +347,69 @@ class DAGContext:
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
Args:
task_name (str): The task name
Returns:
TaskOutput: The task output
"""
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
async def get_from_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
async def save_to_share_data(
self, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
Returns:
Any: The share data
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
return self.get_from_share_data(_build_task_key(task_name, key))
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
"""Save share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
class DAG:
def __init__(
@@ -305,11 +417,20 @@ class DAG:
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
return
if node.node_name:
if node.node_name in self.node_name_to_node:
raise ValueError(
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
@@ -336,22 +457,44 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
"""
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
"""
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
"""
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def __enter__(self):
DAGVar.enter_dag(self)
return self

View File

@@ -146,6 +146,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def current_dag_context(self) -> DAGContext:
return self._dag_ctx
@property
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
Returns:
bool: Whether the operator is in dev mode. True if the default runner is None.
"""
return default_runner is None
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")

View File

@@ -1,4 +1,14 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
from typing import (
Generic,
Dict,
List,
Union,
Callable,
Any,
AsyncIterator,
Awaitable,
Optional,
)
import asyncio
import logging
@@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""
def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
self,
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
**kwargs,
):
"""
Initializes a BranchDAGNode with a branching function.
@@ -203,7 +215,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches = self._branches
if not branches:
branches = await self.branchs()
branches = await self.branches()
branch_func_tasks = []
branch_nodes: List[str] = []
@@ -229,7 +241,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..dag.base import DAG, DAGLifecycle
from ..operator.base import BaseOperator, CALL_DATA
@@ -18,18 +19,20 @@ class DAGInstance:
self._dag = dag
class JobManager:
class JobManager(DAGLifecycle):
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
node_name_to_ids: Dict[str, str],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
self._node_name_to_ids = node_name_to_ids
@staticmethod
def build_from_end_node(
@@ -38,11 +41,31 @@ class JobManager:
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
node_name_to_ids = {}
for node in nodes:
if node.node_name is not None:
node_name_to_ids[node.node_name] = node.node_id
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
@@ -66,6 +89,7 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id

View File

@@ -1,7 +1,8 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext(streaming_call=streaming_call)
# Save node output
# dag = node.dag
node_outputs: Dict[str, TaskContext] = {}
job_manager = JobManager.build_from_end_node(node, call_data)
# Create DAG context
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
)
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
return dag_ctx
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
):
# Skip run node
if node.node_id in node_outputs:
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
job_manager,
upstream_node,
dag_ctx,
node_outputs,
skip_node_ids,
system_app,
)
inputs = [
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
@@ -13,7 +13,8 @@ from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
RequestBody = Union[Type[Request], Type[BaseModel], str]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
logger = logging.getLogger(__name__)
@@ -25,6 +26,7 @@ class HttpTrigger(Trigger):
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
@@ -39,6 +41,7 @@ class HttpTrigger(Trigger):
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._streaming_predict_func = streaming_predict_func
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
@@ -59,10 +62,13 @@ class HttpTrigger(Trigger):
return await _parse_request_body(request, self._req_body)
async def route_function(body=Depends(_request_body_dependency)):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
streaming_response,
self._response_headers,
self._response_media_type,
)
@@ -112,6 +118,7 @@ async def _trigger_dag(
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
@@ -131,8 +138,11 @@ async def _trigger_dag(
"Transfer-Encoding": "chunked",
}
generator = await end_node.call_stream(call_data={"data": body})
background_tasks = BackgroundTasks()
background_tasks.add_task(end_node.dag._after_dag_end)
return StreamingResponse(
generator,
headers=headers,
media_type=media_type,
background=background_tasks,
)

View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, TypeVar, Generic, Optional
from dataclasses import dataclass
from enum import Enum
from typing import Any, Generic, Optional, TypeVar
from dbgpt.core.interface.serialization import Serializable

View File

@@ -1,14 +1,13 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Any, Union, AsyncIterator
import time
from dataclasses import dataclass, asdict, field
import copy
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core.awel import MapOperator, StreamifyAbsOperator
@dataclass
@@ -97,6 +96,28 @@ class ModelInferenceMetrics:
return asdict(self)
@dataclass
@PublicAPI(stability="beta")
class ModelRequestContext:
stream: Optional[bool] = False
"""Whether to return a stream of responses."""
user_name: Optional[str] = None
"""The user name of the model request."""
sys_code: Optional[str] = None
"""The system code of the model request."""
conv_uid: Optional[str] = None
"""The conversation id of the model inference."""
span_id: Optional[str] = None
"""The span id of the model inference."""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
@dataclass
@PublicAPI(stability="beta")
class ModelOutput:
@@ -145,6 +166,27 @@ class ModelRequest:
span_id: Optional[str] = None
"""The span id of the model inference."""
context: Optional[ModelRequestContext] = field(
default_factory=lambda: ModelRequestContext()
)
"""The context of the model inference."""
@property
def stream(self) -> bool:
"""Whether to return a stream of responses."""
return self.context and self.context.stream
def copy(self):
new_request = copy.deepcopy(self)
# Transform messages to List[ModelMessage]
new_request.messages = list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
new_request.messages,
)
)
return new_request
def to_dict(self) -> Dict[str, Any]:
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
@@ -161,6 +203,17 @@ class ModelRequest:
)
)
def get_single_user_message(self) -> Optional[ModelMessage]:
"""Get the single user message.
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self._get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]
@staticmethod
def _build(model: str, prompt: str, **kwargs):
return ModelRequest(
@@ -178,11 +231,22 @@ class ModelRequest:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Examples:
.. code-block:: python
from dbgpt.core.interface.message import (
ModelMessage,
ModelMessageRoleType,
)
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.")
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"),
ModelMessage(
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
),
ModelMessage(
role=ModelMessageRoleType.HUMAN, content="Who are your"
),
]
openai_messages = ModelRequest.to_openai_messages(messages)
assert openai_messages == [
@@ -272,63 +336,3 @@ class LLMClient(ABC):
Returns:
int: The number of tokens.
"""
class RequestBuildOperator(MapOperator[str, ModelRequest], ABC):
def __init__(self, model: str, **kwargs):
self._model = model
super().__init__(**kwargs)
async def map(self, input_value: str) -> ModelRequest:
return ModelRequest._build(self._model, input_value)
class BaseLLM:
"""The abstract operator for a LLM."""
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
if not self._llm_client:
raise ValueError("llm_client is not set")
return self._llm_client
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate a no streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
return await self.llm_client.generate(request)
class StreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
async for output in self.llm_client.generate_stream(request):
yield output

View File

@@ -1,16 +1,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union, Optional
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageInterface,
InMemoryStorage,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
@@ -112,6 +112,7 @@ class ModelMessage(BaseModel):
"""Similar to openai's message format"""
role: str
content: str
round_index: Optional[int] = 0
@staticmethod
def from_openai_messages(
@@ -443,6 +444,7 @@ class OnceConversation:
self.tokens = conversation.tokens
self.user_name = conversation.user_name
self.sys_code = conversation.sys_code
self._message_index = conversation._message_index
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
@@ -470,6 +472,7 @@ class OnceConversation:
Example:
.. code-block:: python
conversation = OnceConversation()
conversation.start_new_round()
conversation.add_user_message("hello, this is the first round")
@@ -485,11 +488,17 @@ class OnceConversation:
conversation.end_current_round()
assert len(conversation.get_messages_with_round(1)) == 2
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
assert (
conversation.get_messages_with_round(1)[0].content
== "hello, this is the third round"
)
assert conversation.get_messages_with_round(1)[1].content == "hi"
assert len(conversation.get_messages_with_round(2)) == 4
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
assert (
conversation.get_messages_with_round(2)[0].content
== "hello, this is the second round"
)
assert conversation.get_messages_with_round(2)[1].content == "hi"
Args:
@@ -517,6 +526,7 @@ class OnceConversation:
Examples:
If you not need the history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
for message in self.get_latest_round():
@@ -528,6 +538,7 @@ class OnceConversation:
If you want to add the one round history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
latest_round_index = self.chat_order
@@ -537,7 +548,9 @@ class OnceConversation:
for message in self.get_messages_by_round(round_index):
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
ModelMessage(
role=message.type, content=message.content
)
)
return messages
@@ -548,7 +561,11 @@ class OnceConversation:
for message in self.messages:
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
ModelMessage(
role=message.type,
content=message.content,
round_index=message.round_index,
)
)
return messages
@@ -780,6 +797,9 @@ class StorageConversation(OnceConversation, StorageItem):
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
# This index is used to save the message to the storage(Has not been saved)
# The new message append to the messages, so the index is len(messages)
conversation._message_index = len(messages)
self._message_ids = message_ids
self._has_stored_message_index = len(messages) - 1
self.from_conversation(conversation)

View File

@@ -0,0 +1,166 @@
import dataclasses
from abc import ABC
from typing import Any, AsyncIterator, Dict, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import (
BranchFunc,
BranchOperator,
MapOperator,
StreamifyAbsOperator,
)
from dbgpt.core.interface.llm import (
LLMClient,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.interface.message import ModelMessage
RequestInput = Union[
ModelRequest,
str,
Dict[str, Any],
BaseModel,
]
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
def __init__(self, model: Optional[str] = None, **kwargs):
self._model = model
super().__init__(**kwargs)
async def map(self, input_value: RequestInput) -> ModelRequest:
req_dict = {}
if isinstance(input_value, str):
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
elif isinstance(input_value, dict):
req_dict = input_value
elif dataclasses.is_dataclass(input_value):
req_dict = dataclasses.asdict(input_value)
elif isinstance(input_value, BaseModel):
req_dict = input_value.dict()
elif isinstance(input_value, ModelRequest):
if not input_value.model:
input_value.model = self._model
return input_value
if "messages" not in req_dict:
raise ValueError("messages is not set")
messages = req_dict["messages"]
if isinstance(messages, str):
# Single message, transform to a list including one human message
req_dict["messages"] = [ModelMessage.build_human_message(messages)]
if "model" not in req_dict:
req_dict["model"] = self._model
if not req_dict["model"]:
raise ValueError("model is not set")
stream = False
has_stream = False
if "stream" in req_dict:
has_stream = True
stream = req_dict["stream"]
del req_dict["stream"]
if "context" not in req_dict:
req_dict["context"] = ModelRequestContext(stream=stream)
else:
context_dict = req_dict["context"]
if not isinstance(context_dict, dict):
raise ValueError("context is not a dict")
if has_stream:
context_dict["stream"] = stream
req_dict["context"] = ModelRequestContext(**context_dict)
return ModelRequest(**req_dict)
class BaseLLM:
"""The abstract operator for a LLM."""
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
if not self._llm_client:
raise ValueError("llm_client is not set")
return self._llm_client
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate a no streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
return await self.llm_client.generate(request)
class StreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
This operator will generate streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client=llm_client)
StreamifyAbsOperator.__init__(self, **kwargs)
async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]:
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
async for output in self.llm_client.generate_stream(request):
yield output
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
"""Branch operator for LLM.
This operator will branch the workflow based on the stream flag of the request.
"""
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
super().__init__(**kwargs)
if not stream_task_name:
raise ValueError("stream_task_name is not set")
if not no_stream_task_name:
raise ValueError("no_stream_task_name is not set")
self._stream_task_name = stream_task_name
self._no_stream_task_name = no_stream_task_name
async def branches(self) -> Dict[BranchFunc[ModelRequest], str]:
"""
Return a dict of branch function and task name.
Returns:
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name.
the key is a predicate function, the value is the task name. If the predicate function returns True,
we will run the corresponding task.
"""
async def check_stream_true(r: ModelRequest) -> bool:
# If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task.
return r.stream
return {
check_stream_true: self._stream_task_name,
lambda x: not x.stream: self._no_stream_task_name,
}

View File

@@ -0,0 +1,321 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, List, Optional
from dbgpt.core import (
MessageStorageItem,
ModelMessage,
ModelOutput,
ModelRequest,
ModelRequestContext,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
class BaseConversationOperator(BaseOperator, ABC):
"""Base class for conversation operators."""
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
):
super().__init__(**kwargs)
self._storage = storage
self._message_storage = message_storage
@property
def storage(self) -> StorageInterface[StorageConversation, Any]:
"""Return the LLM client."""
if not self._storage:
raise ValueError("Storage is not set")
return self._storage
@property
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
"""Return the LLM client."""
if not self._message_storage:
raise ValueError("Message storage is not set")
return self._message_storage
async def get_storage_conversation(self) -> StorageConversation:
"""Get the storage conversation from share data.
Returns:
StorageConversation: The storage conversation.
"""
storage_conv: StorageConversation = (
await self.current_dag_context.get_from_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION
)
)
if not storage_conv:
raise ValueError("Storage conversation is not set")
return storage_conv
async def get_model_request(self) -> ModelRequest:
"""Get the model request from share data.
Returns:
ModelRequest: The model request.
"""
model_request: ModelRequest = (
await self.current_dag_context.get_from_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST
)
)
if not model_request:
raise ValueError("Model request is not set")
return model_request
class PreConversationOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
"""The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
and they can store in different storage(for high performance).
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
Args:
input_value (ModelRequest): The input value.
Returns:
ModelRequest: The mapped ModelRequest.
"""
if input_value.context is None:
input_value.context = ModelRequestContext()
if not input_value.context.conv_uid:
input_value.context.conv_uid = str(uuid.uuid4())
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.extra.get("chat_mode")
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
StorageConversation,
conv_uid=input_value.context.conv_uid,
chat_mode=chat_mode,
user_name=input_value.context.user_name,
sys_code=input_value.context.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
)
# The input message must be a single user message
single_human_message: ModelMessage = input_value.get_single_user_message()
storage_conv.start_new_round()
storage_conv.add_user_message(single_human_message.content)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
)
return input_value
async def after_dag_end(self):
"""The callback after DAG end"""
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
# TODO dont save if the conversation has some internal error
storage_conv.end_current_round()
class PostConversationOperator(
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
):
def __init__(self, **kwargs):
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Map the input value to a ModelOutput.
Args:
input_value (ModelOutput): The input value.
Returns:
ModelOutput: The mapped ModelOutput.
"""
# Get the storage conversation from share data
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv.add_ai_message(input_value.text)
return input_value
class PostStreamingConversationOperator(
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
def __init__(self, **kwargs):
TransformStreamAbsOperator.__init__(self, **kwargs)
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> ModelOutput:
"""Transform the input value to a ModelOutput.
Args:
input_value (ModelOutput): The input value.
Returns:
ModelOutput: The transformed ModelOutput.
"""
full_text = ""
async for model_output in input_value:
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
full_text = model_output.text
yield model_output
# Get the storage conversation from share data
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv.add_ai_message(full_text)
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
def __init__(self, **kwargs):
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
Args:
input_value (ModelRequest): The input value.
Returns:
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
messages (List[ModelMessage]): The input messages.
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
return messages
def _split_messages_by_round(
self, messages: List[ModelMessage]
) -> List[List[ModelMessage]]:
"""Split the messages by round index.
Args:
messages (List[ModelMessage]): The input messages.
Returns:
List[List[ModelMessage]]: The splitted messages.
"""
messages_by_round: List[List[ModelMessage]] = []
last_round_index = 0
for message in messages:
if not message.round_index:
# Round index must bigger than 0
raise ValueError("Message round_index is not set")
if message.round_index > last_round_index:
last_round_index = message.round_index
messages_by_round.append([])
messages_by_round[-1].append(message)
return messages_by_round
class BufferedConversationMapperOperator(ConversationMapperOperator):
"""The buffered conversation mapper operator.
This Operator must be used after the PreConversationOperator,
and it will map the messages in the storage conversation.
Examples:
Transform no history messages
.. code-block:: python
import asyncio
from dbgpt.core import ModelMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
# No history
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
Transform with history messages
.. code-block:: python
# With history
messages = [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
# Just keep the last one round, so the first round messages will be removed
# Note: The round index 3 is not a complete round
assert messages == [
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
"""
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
super().__init__(**kwargs)
self._last_k_round = last_k_round
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
messages (List[ModelMessage]): The input messages.
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
# Get the last k round messages
index = self._last_k_round + 1
messages_by_round = messages_by_round[-index:]
messages: List[ModelMessage] = sum(messages_by_round, [])
return messages

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import json
from abc import ABC
import logging
from abc import ABC
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from dbgpt.core.awel import MapOperator
from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]

View File

@@ -1,12 +1,20 @@
import dataclasses
import json
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel
from dbgpt.util.formatting import formatter, no_strict_formatter
from dbgpt._private.pydantic import BaseModel
from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from dbgpt.util.formatting import formatter, no_strict_formatter
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
@@ -86,6 +94,434 @@ class PromptTemplate(BaseModel, ABC):
)
@dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier):
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
prompt_name: str
prompt_language: Optional[str] = None
sys_code: Optional[str] = None
model: Optional[str] = None
def __post_init__(self):
if self.prompt_name is None:
raise ValueError("prompt_name cannot be None")
if any(
self.identifier_split in key
for key in [
self.prompt_name,
self.prompt_language,
self.sys_code,
self.model,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
)
@property
def str_identifier(self) -> str:
return self.identifier_split.join(
key
for key in [
self.prompt_name,
self.prompt_language,
self.sys_code,
self.model,
]
if key is not None
)
def to_dict(self) -> Dict:
return {
"prompt_name": self.prompt_name,
"prompt_language": self.prompt_language,
"sys_code": self.sys_code,
"model": self.model,
}
@dataclasses.dataclass
class StoragePromptTemplate(StorageItem):
prompt_name: str
content: Optional[str] = None
prompt_language: Optional[str] = None
prompt_format: Optional[str] = None
input_variables: Optional[str] = None
model: Optional[str] = None
chat_scene: Optional[str] = None
sub_chat_scene: Optional[str] = None
prompt_type: Optional[str] = None
user_name: Optional[str] = None
sys_code: Optional[str] = None
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
def __post_init__(self):
self._identifier = PromptTemplateIdentifier(
prompt_name=self.prompt_name,
prompt_language=self.prompt_language,
sys_code=self.sys_code,
model=self.model,
)
self._check() # Assuming _check() is a method you need to call after initialization
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
)
return PromptTemplate(
input_variables=input_variables,
template=self.content,
template_scene=self.chat_scene,
prompt_name=self.prompt_name,
template_format=self.prompt_format,
)
@staticmethod
def from_prompt_template(
prompt_template: PromptTemplate,
prompt_name: str,
prompt_language: Optional[str] = None,
prompt_type: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
sub_chat_scene: Optional[str] = None,
model: Optional[str] = None,
**kwargs,
) -> "StoragePromptTemplate":
"""Convert a prompt template to a storage prompt template.
Args:
prompt_template (PromptTemplate): The prompt template to convert from.
prompt_name (str): The name of the prompt.
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
model (Optional[str], optional): The model name of the prompt. Defaults to None.
kwargs (Dict): Other params to build the storage prompt template.
"""
input_variables = prompt_template.input_variables or kwargs.get(
"input_variables"
)
if input_variables and isinstance(input_variables, list):
input_variables = ",".join(input_variables)
return StoragePromptTemplate(
prompt_name=prompt_name,
sys_code=sys_code,
user_name=user_name,
input_variables=input_variables,
model=model,
content=prompt_template.template or kwargs.get("content"),
prompt_language=prompt_language,
prompt_format=prompt_template.template_format
or kwargs.get("prompt_format"),
chat_scene=prompt_template.template_scene or kwargs.get("chat_scene"),
sub_chat_scene=sub_chat_scene,
prompt_type=prompt_type,
)
@property
def identifier(self) -> PromptTemplateIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, StoragePromptTemplate):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
)
self.from_object(other)
def to_dict(self) -> Dict:
return {
"prompt_name": self.prompt_name,
"content": self.content,
"prompt_language": self.prompt_language,
"prompt_format": self.prompt_format,
"input_variables": self.input_variables,
"model": self.model,
"chat_scene": self.chat_scene,
"sub_chat_scene": self.sub_chat_scene,
"prompt_type": self.prompt_type,
"user_name": self.user_name,
"sys_code": self.sys_code,
}
def _check(self):
if self.prompt_name is None:
raise ValueError("prompt_name cannot be None")
if self.content is None:
raise ValueError("content cannot be None")
def from_object(self, template: "StoragePromptTemplate") -> None:
"""Load the prompt template from an existing prompt template object.
Args:
template (PromptTemplate): The prompt template to load from.
"""
self.content = template.content
self.prompt_format = template.prompt_format
self.input_variables = template.input_variables
self.model = template.model
self.chat_scene = template.chat_scene
self.sub_chat_scene = template.sub_chat_scene
self.prompt_type = template.prompt_type
self.user_name = template.user_name
class PromptManager:
"""The manager class for prompt templates.
Simple wrapper for the storage interface.
Examples:
.. code-block:: python
# Default use InMemoryStorage
prompt_manager = PromptManager()
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(prompt_template, prompt_name="hello")
prompt_template_list = prompt_manager.list()
prompt_template_list = prompt_manager.prefer_query("hello")
With a custom storage interface.
.. code-block:: python
from dbgpt.core.interface.storage import InMemoryStorage
prompt_manager = PromptManager(InMemoryStorage())
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(prompt_template, prompt_name="hello")
prompt_template_list = prompt_manager.list()
prompt_template_list = prompt_manager.prefer_query("hello")
"""
def __init__(
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
):
if storage is None:
storage = InMemoryStorage()
self._storage = storage
@property
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
"""The storage interface for prompt templates."""
return self._storage
def prefer_query(
self,
prompt_name: str,
sys_code: Optional[str] = None,
prefer_prompt_language: Optional[str] = None,
prefer_model: Optional[str] = None,
**kwargs,
) -> List[StoragePromptTemplate]:
"""Query prompt templates from storage with prefer params.
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
Examples:
Query a prompt template.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query("hello")
Query with sys_code and username.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query(
"hello", sys_code="sys_code", user_name="user_name"
)
Query with prefer prompt language.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_prompt_language="zh-cn"
)
Query with prefer model.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_model="vicuna-13b-v1.5"
)
Args:
prompt_name (str): The name of the prompt template.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
"""
query_spec = QuerySpec(
conditions={
"prompt_name": prompt_name,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[StoragePromptTemplate] = self.storage.query(
query_spec, StoragePromptTemplate
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries
def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> None:
"""Save a prompt template to storage.
Examples:
.. code-block:: python
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
prompt_name="hello",
)
prompt_manager.save(prompt_template)
Save with sys_code and username.
.. code-block:: python
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
prompt_name="hello",
)
prompt_manager.save(
prompt_template, sys_code="sys_code", user_name="user_name"
)
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
self.storage.save(storage_prompt_template)
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.
Examples:
List all prompt templates.
.. code-block:: python
all_prompt_templates = prompt_manager.list()
List with sys_code and username.
.. code-block:: python
templates = prompt_manager.list(
sys_code="sys_code", user_name="user_name"
)
Args:
kwargs (Dict): Other query params.
"""
query_spec = QuerySpec(conditions=kwargs)
return self.storage.query(query_spec, StoragePromptTemplate)
def delete(
self,
prompt_name: str,
prompt_language: Optional[str] = None,
sys_code: Optional[str] = None,
model: Optional[str] = None,
) -> None:
"""Delete a prompt template from storage.
Examples:
Delete a prompt template.
.. code-block:: python
prompt_manager.delete("hello")
Delete with sys_code and username.
.. code-block:: python
prompt_manager.delete(
"hello", sys_code="sys_code", user_name="user_name"
)
Args:
prompt_name (str): The name of the prompt template.
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
model (Optional[str], optional): The model of the prompt template. Defaults to None.
"""
identifier = PromptTemplateIdentifier(
prompt_name=prompt_name,
prompt_language=prompt_language,
sys_code=sys_code,
model=model,
)
self.storage.delete(identifier)
class PromptTemplateOperator(MapOperator[Dict, str]):
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
super().__init__(**kwargs)

View File

@@ -1,4 +1,5 @@
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, Dict
from typing import Dict, Type
class Serializable(ABC):

View File

@@ -1,10 +1,10 @@
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.serialization.json_serialization import JsonSerializer
@PublicAPI(stability="beta")

View File

@@ -1,7 +1,7 @@
import pytest
from dbgpt.core.interface.tests.conftest import in_memory_storage
from dbgpt.core.interface.message import *
from dbgpt.core.interface.tests.conftest import in_memory_storage
@pytest.fixture

View File

@@ -0,0 +1,320 @@
import json
import pytest
from dbgpt.core.interface.prompt import (
PromptManager,
PromptTemplate,
StoragePromptTemplate,
)
from dbgpt.core.interface.storage import QuerySpec
from dbgpt.core.interface.tests.conftest import in_memory_storage
@pytest.fixture
def sample_storage_prompt_template():
return StoragePromptTemplate(
prompt_name="test_prompt",
content="Sample content, {var1}, {var2}",
prompt_language="en",
prompt_format="f-string",
input_variables="var1,var2",
model="model1",
chat_scene="scene1",
sub_chat_scene="subscene1",
prompt_type="type1",
user_name="user1",
sys_code="code1",
)
@pytest.fixture
def complex_storage_prompt_template():
content = """Database name: {db_name} Table structure definition: {table_info} User Question:{user_input}"""
return StoragePromptTemplate(
prompt_name="chat_data_auto_execute_prompt",
content=content,
prompt_language="en",
prompt_format="f-string",
input_variables="db_name,table_info,user_input",
model="vicuna-13b-v1.5",
chat_scene="chat_data",
sub_chat_scene="subscene1",
prompt_type="common",
user_name="zhangsan",
sys_code="dbgpt",
)
@pytest.fixture
def prompt_manager(in_memory_storage):
return PromptManager(storage=in_memory_storage)
class TestPromptTemplate:
@pytest.mark.parametrize(
"template_str, input_vars, expected_output",
[
("Hello {name}", {"name": "World"}, "Hello World"),
("{greeting}, {name}", {"greeting": "Hi", "name": "Alice"}, "Hi, Alice"),
],
)
def test_format_f_string(self, template_str, input_vars, expected_output):
prompt = PromptTemplate(
template=template_str,
input_variables=list(input_vars.keys()),
template_format="f-string",
)
formatted_output = prompt.format(**input_vars)
assert formatted_output == expected_output
@pytest.mark.parametrize(
"template_str, input_vars, expected_output",
[
("Hello {{ name }}", {"name": "World"}, "Hello World"),
(
"{{ greeting }}, {{ name }}",
{"greeting": "Hi", "name": "Alice"},
"Hi, Alice",
),
],
)
def test_format_jinja2(self, template_str, input_vars, expected_output):
prompt = PromptTemplate(
template=template_str,
input_variables=list(input_vars.keys()),
template_format="jinja2",
)
formatted_output = prompt.format(**input_vars)
assert formatted_output == expected_output
def test_format_with_response_format(self):
template_str = "Response: {response}"
prompt = PromptTemplate(
template=template_str,
input_variables=["response"],
template_format="f-string",
response_format=json.dumps({"message": "hello"}),
)
formatted_output = prompt.format(response="hello")
assert "Response: " in formatted_output
def test_from_template(self):
template_str = "Hello {name}"
prompt = PromptTemplate.from_template(template_str)
assert prompt._prompt_template.template == template_str
assert prompt._prompt_template.input_variables == []
def test_format_missing_variable(self):
template_str = "Hello {name}"
prompt = PromptTemplate(
template=template_str, input_variables=["name"], template_format="f-string"
)
with pytest.raises(KeyError):
prompt.format()
def test_format_extra_variable(self):
template_str = "Hello {name}"
prompt = PromptTemplate(
template=template_str,
input_variables=["name"],
template_format="f-string",
template_is_strict=False,
)
formatted_output = prompt.format(name="World", extra="unused")
assert formatted_output == "Hello World"
def test_format_complex(self, complex_storage_prompt_template):
prompt = complex_storage_prompt_template.to_prompt_template()
formatted_output = prompt.format(
db_name="db1",
table_info="create table users(id int, name varchar(20))",
user_input="find all users whose name is 'Alice'",
)
assert (
formatted_output
== "Database name: db1 Table structure definition: create table users(id int, name varchar(20)) "
"User Question:find all users whose name is 'Alice'"
)
class TestStoragePromptTemplate:
def test_constructor_and_properties(self):
storage_item = StoragePromptTemplate(
prompt_name="test",
content="Hello {name}",
prompt_language="en",
prompt_format="f-string",
input_variables="name",
model="model1",
chat_scene="chat",
sub_chat_scene="sub_chat",
prompt_type="type",
user_name="user",
sys_code="sys",
)
assert storage_item.prompt_name == "test"
assert storage_item.content == "Hello {name}"
assert storage_item.prompt_language == "en"
assert storage_item.prompt_format == "f-string"
assert storage_item.input_variables == "name"
assert storage_item.model == "model1"
def test_constructor_exceptions(self):
with pytest.raises(ValueError):
StoragePromptTemplate(prompt_name=None, content="Hello")
def test_to_prompt_template(self, sample_storage_prompt_template):
prompt_template = sample_storage_prompt_template.to_prompt_template()
assert isinstance(prompt_template, PromptTemplate)
assert prompt_template.template == "Sample content, {var1}, {var2}"
assert prompt_template.input_variables == ["var1", "var2"]
def test_from_prompt_template(self):
prompt_template = PromptTemplate(
template="Sample content, {var1}, {var2}",
input_variables=["var1", "var2"],
template_format="f-string",
)
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template=prompt_template, prompt_name="test_prompt"
)
assert storage_prompt_template.prompt_name == "test_prompt"
assert storage_prompt_template.content == "Sample content, {var1}, {var2}"
assert storage_prompt_template.input_variables == "var1,var2"
def test_merge(self, sample_storage_prompt_template):
other = StoragePromptTemplate(
prompt_name="other_prompt",
content="Other content",
)
sample_storage_prompt_template.merge(other)
assert sample_storage_prompt_template.content == "Other content"
def test_to_dict(self, sample_storage_prompt_template):
result = sample_storage_prompt_template.to_dict()
assert result == {
"prompt_name": "test_prompt",
"content": "Sample content, {var1}, {var2}",
"prompt_language": "en",
"prompt_format": "f-string",
"input_variables": "var1,var2",
"model": "model1",
"chat_scene": "scene1",
"sub_chat_scene": "subscene1",
"prompt_type": "type1",
"user_name": "user1",
"sys_code": "code1",
}
def test_save_and_load_storage(
self, sample_storage_prompt_template, in_memory_storage
):
in_memory_storage.save(sample_storage_prompt_template)
loaded_item = in_memory_storage.load(
sample_storage_prompt_template.identifier, StoragePromptTemplate
)
assert loaded_item.content == "Sample content, {var1}, {var2}"
def test_check_exceptions(self):
with pytest.raises(ValueError):
StoragePromptTemplate(prompt_name=None, content="Hello")
def test_from_object(self, sample_storage_prompt_template):
other = StoragePromptTemplate(prompt_name="other", content="Other content")
sample_storage_prompt_template.from_object(other)
assert sample_storage_prompt_template.content == "Other content"
assert sample_storage_prompt_template.input_variables != "var1,var2"
# Prompt name should not be changed
assert sample_storage_prompt_template.prompt_name == "test_prompt"
assert sample_storage_prompt_template.sys_code == "code1"
class TestPromptManager:
def test_save(self, prompt_manager, in_memory_storage):
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="hello",
)
result = in_memory_storage.query(
QuerySpec(conditions={"prompt_name": "hello"}), StoragePromptTemplate
)
assert len(result) == 1
assert result[0].content == "hello {input}"
def test_prefer_query_simple(self, prompt_manager, in_memory_storage):
in_memory_storage.save(
StoragePromptTemplate(prompt_name="test_prompt", content="test")
)
result = prompt_manager.prefer_query("test_prompt")
assert len(result) == 1
assert result[0].content == "test"
def test_prefer_query_language(self, prompt_manager, in_memory_storage):
for language in ["en", "zh"]:
in_memory_storage.save(
StoragePromptTemplate(
prompt_name="test_prompt",
content="test",
prompt_language=language,
)
)
# Prefer zh, and zh exists, will return zh prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].prompt_language == "zh"
# Prefer language not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query(
"test_prompt", prefer_prompt_language="not_exist"
)
assert len(result) == 2
def test_prefer_query_model(self, prompt_manager, in_memory_storage):
for model in ["model1", "model2"]:
in_memory_storage.save(
StoragePromptTemplate(
prompt_name="test_prompt", content="test", model=model
)
)
# Prefer model1, and model1 exists, will return model1 prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].model == "model1"
# Prefer model not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
assert len(result) == 2
def test_list(self, prompt_manager, in_memory_storage):
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="name1",
)
prompt_manager.save(
PromptTemplate(
template="Write a SQL of {dialect} to query all data of {table_name}.",
input_variables=["dialect", "table_name"],
),
prompt_name="sql_template",
)
all_templates = prompt_manager.list()
assert len(all_templates) == 2
assert len(prompt_manager.list(prompt_name="name1")) == 1
assert len(prompt_manager.list(prompt_name="not exist")) == 0
def test_delete(self, prompt_manager, in_memory_storage):
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="to_delete",
)
prompt_manager.delete("to_delete")
result = in_memory_storage.query(
QuerySpec(conditions={"prompt_name": "to_delete"}), StoragePromptTemplate
)
assert len(result) == 0

View File

@@ -1,10 +1,12 @@
import pytest
from typing import Dict, Type, Union
import pytest
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageError,
QuerySpec,
InMemoryStorage,
StorageItem,
)
from dbgpt.util.serialization.json_serialization import JsonSerializer

View File

@@ -0,0 +1,31 @@
from dbgpt.core.interface.operator.llm_operator import (
BaseLLM,
LLMBranchOperator,
LLMOperator,
RequestBuildOperator,
StreamingLLMOperator,
)
from dbgpt.core.interface.operator.message_operator import (
BaseConversationOperator,
BufferedConversationMapperOperator,
ConversationMapperOperator,
PostConversationOperator,
PostStreamingConversationOperator,
PreConversationOperator,
)
from dbgpt.core.interface.prompt import PromptTemplateOperator
__ALL__ = [
"BaseLLM",
"LLMBranchOperator",
"LLMOperator",
"RequestBuildOperator",
"StreamingLLMOperator",
"BaseConversationOperator",
"BufferedConversationMapperOperator",
"ConversationMapperOperator",
"PostConversationOperator",
"PostStreamingConversationOperator",
"PreConversationOperator",
"PromptTemplateOperator",
]

View File

@@ -1,4 +1,13 @@
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
from dbgpt.model.utils.chatgpt_utils import (
OpenAILLMClient,
OpenAIStreamingOperator,
MixinLLMOperator,
)
__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"]
__ALL__ = [
"DefaultLLMClient",
"OpenAILLMClient",
"OpenAIStreamingOperator",
"MixinLLMOperator",
]

View File

@@ -171,7 +171,7 @@ class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
Returns:
@@ -233,7 +233,7 @@ class ModelStreamSaveCacheOperator(
outputs = []
async for out in input_value:
if not llm_cache_key:
llm_cache_key = await self.current_dag_context.get_share_data(
llm_cache_key = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
outputs.append(out)
@@ -265,7 +265,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)

View File

@@ -3,11 +3,27 @@ from __future__ import annotations
import os
import logging
from dataclasses import dataclass
from abc import ABC
import importlib.metadata as metadata
from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator
from typing import (
List,
Dict,
Any,
Optional,
TYPE_CHECKING,
Union,
AsyncIterator,
Callable,
Awaitable,
)
from dbgpt.component import ComponentType
from dbgpt.core.operator import BaseLLM
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
if TYPE_CHECKING:
import httpx
@@ -176,13 +192,13 @@ class OpenAILLMClient(LLMClient):
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
messages = request.to_openai_messages()
payload = self._build_request(request)
payload = self._build_request(request, True)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
for r in chat_completion:
async for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
@@ -221,17 +237,74 @@ class OpenAILLMClient(LLMClient):
raise NotImplementedError()
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
"""Transform ModelOutput to openai stream format."""
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async def model_caller() -> str:
"""Read model name from share data.
In streaming mode, this transform_stream function will be executed
before parent operator(Streaming Operator is trigger by downstream Operator).
"""
return await self.current_dag_context.get_from_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_NAME
)
async for output in _to_openai_stream(input_value, None, model_caller):
yield output
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
"""Mixin class for LLM operator.
This class extends BaseOperator by adding LLM capabilities.
"""
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
)
self._llm_client = self._default_llm_client
return self._llm_client
async def _to_openai_stream(
model: str, output_iter: AsyncIterator[ModelOutput]
output_iter: AsyncIterator[ModelOutput],
model: Optional[str] = None,
model_caller: Callable[[], Union[Awaitable[str], str]] = None,
) -> AsyncIterator[str]:
"""Convert the output_iter to openai stream format.
Args:
model (str): The model name.
output_iter (AsyncIterator[ModelOutput]): The output iterator.
model (Optional[str], optional): The model name. Defaults to None.
model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None.
"""
import json
import shortuuid
import asyncio
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
@@ -245,12 +318,19 @@ async def _to_openai_stream(
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model or ""
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
finish_stream_events = []
async for model_output in output_iter:
if model_caller is not None:
if asyncio.iscoroutinefunction(model_caller):
model = await model_caller()
else:
model = model_caller()
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"

View File

@@ -1,15 +1,16 @@
from typing import Optional, List
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from .schemas import ServeRequest, ServerResponse
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
from .schemas import ServeRequest, ServerResponse
router = APIRouter()

View File

@@ -1,6 +1,8 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP

View File

@@ -1,9 +1,8 @@
from typing import Optional
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "prompt"
SERVE_APP_NAME = "dbgpt_serve_prompt"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt"

View File

@@ -1,33 +1,64 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from typing import Union, Any, Dict
from datetime import datetime
from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint
from dbgpt.storage.metadata import Model, BaseDao, db
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = "prompt_manage"
__table_args__ = (
UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"),
UniqueConstraint(
"prompt_name",
"sys_code",
"prompt_language",
"model",
name="uk_prompt_name_sys_code",
),
)
id = Column(Integer, primary_key=True, comment="Auto increment id")
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
chat_scene = Column(String(100), comment="Chat scene")
sub_chat_scene = Column(String(100), comment="Sub chat scene")
prompt_type = Column(String(100), comment="Prompt type(eg: common, private)")
prompt_name = Column(String(256), comment="Prompt name")
content = Column(Text, comment="Prompt content")
input_variables = Column(
String(1024), nullable=True, comment="Prompt input variables(split by comma))"
)
model = Column(
String(128),
nullable=True,
comment="Prompt model name(we can use different models for different prompt",
)
prompt_language = Column(
String(32), index=True, nullable=True, comment="Prompt language(eg:en, zh-cn)"
)
prompt_format = Column(
String(32),
index=True,
nullable=True,
default="f-string",
comment="Prompt format(eg: f-string, jinja2)",
)
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return (
f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', "
f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',"
f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):

View File

@@ -0,0 +1,56 @@
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.prompt import PromptTemplateIdentifier, StoragePromptTemplate
from dbgpt.core.interface.storage import StorageItemAdapter
from .models import ServeEntity
class PromptTemplateAdapter(StorageItemAdapter[StoragePromptTemplate, ServeEntity]):
def to_storage_format(self, item: StoragePromptTemplate) -> ServeEntity:
return ServeEntity(
chat_scene=item.chat_scene,
sub_chat_scene=item.sub_chat_scene,
prompt_type=item.prompt_type,
prompt_name=item.prompt_name,
content=item.content,
input_variables=item.input_variables,
model=item.model,
prompt_language=item.prompt_language,
prompt_format=item.prompt_format,
user_name=item.user_name,
sys_code=item.sys_code,
)
def from_storage_format(self, model: ServeEntity) -> StoragePromptTemplate:
return StoragePromptTemplate(
chat_scene=model.chat_scene,
sub_chat_scene=model.sub_chat_scene,
prompt_type=model.prompt_type,
prompt_name=model.prompt_name,
content=model.content,
input_variables=model.input_variables,
model=model.model,
prompt_language=model.prompt_language,
prompt_format=model.prompt_format,
user_name=model.user_name,
sys_code=model.sys_code,
)
def get_query_for_identifier(
self,
storage_format: Type[ServeEntity],
resource_id: PromptTemplateIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(ServeEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(ServeEntity, key) == value)
return query_obj

View File

@@ -1,17 +1,80 @@
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
import logging
from typing import List, Optional, Union
from .api.endpoints import router, init_endpoints
from sqlalchemy import URL
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import PromptManager
from ...storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
from .models.prompt_template_adapter import PromptTemplateAdapter
logger = logging.getLogger(__name__)
class Serve(BaseComponent):
"""Serve component
Examples:
Register the serve component to the system app
.. code-block:: python
from fastapi import FastAPI
from dbgpt import SystemApp
from dbgpt.core import PromptTemplate
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt")
# Run before start hook
system_app.before_start()
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
# Get the prompt manager
prompt_manager = prompt_serve.prompt_manager
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="prompt_name",
)
With your database url
.. code-block:: python
from fastapi import FastAPI
from dbgpt import SystemApp
from dbgpt.core import PromptTemplate
from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
# Run before start hook
system_app.before_start()
prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve)
# Get the prompt manager
prompt_manager = prompt_serve.prompt_manager
prompt_manager.save(
PromptTemplate(template="Hello {name}", input_variables=["name"]),
prompt_name="prompt_name",
)
"""
name = SERVE_APP_NAME
def __init__(
@@ -19,12 +82,17 @@ class Serve(BaseComponent):
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
self._prompt_manager = None
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
def init_app(self, system_app: SystemApp):
self._system_app = system_app
@@ -33,10 +101,37 @@ class Serve(BaseComponent):
)
init_endpoints(self._system_app)
@property
def prompt_manager(self) -> PromptManager:
"""Get the prompt manager of the serve app with db storage"""
return self._prompt_manager
def before_start(self):
"""Called before the start of the application.
You can do some initialization here.
"""
# import your own module here to ensure the module is loaded before the application starts
from dbgpt.core.interface.prompt import PromptManager
from dbgpt.storage.metadata import Model, db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
storage_adapter = PromptTemplateAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
init_db,
ServeEntity,
storage_adapter,
serializer,
)
self._prompt_manager = PromptManager(storage)

View File

@@ -1,11 +1,13 @@
from typing import Optional, List
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.serve.core import BaseService
from ..models.models import ServeDao, ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):

View File

@@ -1,15 +1,15 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)

View File

@@ -1,9 +1,12 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
from ..config import ServeConfig
from ..models.models import ServeDao, ServeEntity
@pytest.fixture(autouse=True)
@@ -34,6 +37,8 @@ def default_entity_dict():
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
@@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict):
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with pytest.raises(Exception):
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
ServeEntity.create(
**{
"prompt_name": "my_prompt_1",
"sys_code": "dbgpt",
"prompt_language": "zh",
"model": "vicuna-13b-v1.5",
}
)
def test_entity_get(default_entity_dict):

View File

@@ -0,0 +1,144 @@
import pytest
from dbgpt.core.interface.prompt import PromptManager, PromptTemplate
from dbgpt.storage.metadata import db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def db_manager(db_url):
db.init_db(db_url)
db.create_all()
return db
@pytest.fixture
def storage_adapter():
return PromptTemplateAdapter()
@pytest.fixture
def storage(db_manager, serializer, storage_adapter):
storage = SQLAlchemyStorage(
db_manager,
ServeEntity,
storage_adapter,
serializer,
)
return storage
@pytest.fixture
def prompt_manager(storage):
return PromptManager(storage)
def test_save(prompt_manager: PromptManager):
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="hello",
)
with db.session() as session:
# Query from database
result = (
session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all()
)
assert len(result) == 1
assert result[0].prompt_name == "hello"
assert result[0].content == "hello {input}"
assert result[0].input_variables == "input"
with db.session() as session:
assert session.query(ServeEntity).count() == 1
assert (
session.query(ServeEntity)
.filter(ServeEntity.prompt_name == "not exist prompt name")
.count()
== 0
)
def test_prefer_query_language(prompt_manager: PromptManager):
for language in ["en", "zh"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
prompt_language=language,
)
# Prefer zh, and zh exists, will return zh prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].prompt_language == "zh"
# Prefer language not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query(
"test_prompt", prefer_prompt_language="not_exist"
)
assert len(result) == 2
def test_prefer_query_model(prompt_manager: PromptManager):
for model in ["model1", "model2"]:
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name="test_prompt",
model=model,
)
# Prefer model1, and model1 exists, will return model1 prompt template
result = prompt_manager.prefer_query("test_prompt", prefer_model="model1")
assert len(result) == 1
assert result[0].content == "test"
assert result[0].model == "model1"
# Prefer model not exists, will return all prompt templates of this name
result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist")
assert len(result) == 2
def test_list(prompt_manager: PromptManager):
for i in range(10):
prompt_template = PromptTemplate(
template="test",
input_variables=[],
template_scene="chat_normal",
)
prompt_manager.save(
prompt_template,
prompt_name=f"test_prompt_{i}",
sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt",
)
# Test list all
result = prompt_manager.list()
assert len(result) == 10
for i in range(10):
assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1
assert len(prompt_manager.list(sys_code="dbgpt")) == 5

View File

@@ -1,11 +1,13 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service

View File

@@ -236,6 +236,7 @@ class DatabaseManager:
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
):
"""Initialize the database manager.
@@ -244,15 +245,16 @@ class DatabaseManager:
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
"""
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query"):
if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self)
if not getattr(base, "query_class", None):
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
@@ -299,6 +301,59 @@ class DatabaseManager:
def create_all(self):
self.Model.metadata.create_all(self._engine)
@staticmethod
def build_from(
db_url_or_db: Union[str, URL, DatabaseManager],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
) -> DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
from sqlalchemy import Column, Integer, String
db = DatabaseManager.build_from("sqlite:///:memory:")
class User(db.Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
db.create_all()
with db.session() as session:
session.add(User(name="test", fullname="test"))
session.commit()
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
)
return db_manager
elif isinstance(db_url_or_db, DatabaseManager):
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
db = DatabaseManager()
"""The global database manager.
@@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db_manager(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return db_manager._session().get(cls, ident)
return cls._db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = db_manager._session()
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
@@ -390,7 +452,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = db_manager._session()
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()

View File

@@ -34,15 +34,8 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
query_class=BaseQuery,
):
super().__init__(serializer=serializer, adapter=adapter)
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
self.db_manager = db_manager
elif isinstance(db_url_or_db, DatabaseManager):
self.db_manager = db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
self.db_manager = DatabaseManager.build_from(
db_url_or_db, engine_args, base, query_class
)
self._model_class = model_class

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import pytest
import tempfile
from typing import Type
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
@@ -103,7 +104,6 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
db.create_all()
# 添加数据
with db.session() as session:
for i in range(30):
user = User(name=f"User {i}")
@@ -127,3 +127,29 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
User.query.paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
assert db.metadata.tables == {}
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
with tempfile.NamedTemporaryFile(delete=True) as db_file:
filename = db_file.name
new_db = DatabaseManager.build_from(
f"sqlite:///{filename}", base=Model, override_query_class=True
)
Model.set_db_manager(new_db)
new_db.create_all()
db.create_all()
assert list(new_db.metadata.tables.keys())[0] == "user"
User.create(**{"name": "John Doe"})
with new_db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is not None
with db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is None
assert len(User.query.all()) == 1
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"

View File

@@ -6,7 +6,8 @@
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
DBGPT_SERVER="http://127.0.0.1:5000"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_chat \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"user_input": "hello"
@@ -52,3 +53,14 @@ with DAG("dbgpt_awel_simple_dag_example") as dag:
# type(out) == ModelOutput
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
# Production mode, DB-GPT will automatically load and execute the current file after startup.
pass

View File

@@ -0,0 +1,186 @@
"""AWEL: Simple chat with history example
DB-GPT will automatically load and execute the current file after startup.
Examples:
Call with non-streaming response.
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
MODEL="gpt-3.5-turbo"
# Fist round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"context": {
"conv_uid": "uuid_conv_1234"
},
"messages": "Who is elon musk?"
}'
# Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"context": {
"conv_uid": "uuid_conv_1234"
},
"messages": "Is he rich?"
}'
Call with streaming response.
.. code-block:: shell
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"context": {
"conv_uid": "uuid_conv_stream_1234"
},
"stream": true,
"messages": "Who is elon musk?"
}'
# Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"context": {
"conv_uid": "uuid_conv_stream_1234"
},
"stream": true,
"messages": "Is he rich?"
}'
"""
from typing import Dict, Any, Optional, Union, List
import logging
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import (
DAG,
HttpTrigger,
MapOperator,
JoinOperator,
)
from dbgpt.core import LLMClient, InMemoryStorage
from dbgpt.core.operator import (
LLMBranchOperator,
LLMOperator,
StreamingLLMOperator,
RequestBuildOperator,
PreConversationOperator,
PostConversationOperator,
PostStreamingConversationOperator,
BufferedConversationMapperOperator,
)
from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator
logger = logging.getLogger(__name__)
class ReqContext(BaseModel):
user_name: Optional[str] = Field(
None, description="The user name of the model request."
)
sys_code: Optional[str] = Field(
None, description="The system code of the model request."
)
conv_uid: Optional[str] = Field(
None, description="The conversation uid of the model request."
)
class TriggerReqBody(BaseModel):
messages: Union[str, List[Dict[str, str]]] = Field(
..., description="User input messages"
)
model: str = Field(..., description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream")
context: Optional[ReqContext] = Field(
default=None, description="The context of the model request."
)
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_history/multi_round/chat/completions",
methods="POST",
request_body=TriggerReqBody,
streaming_predict_func=lambda req: req.stream,
)
# Transform request body to model request.
request_handle_task = RequestBuildOperator()
# Pre-process conversation, use InMemoryStorage to store conversation.
pre_conversation_task = PreConversationOperator(
storage=InMemoryStorage(), message_storage=InMemoryStorage()
)
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
# Save conversation to storage.
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
# Use LLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
(
trigger
>> request_handle_task
>> pre_conversation_task
>> history_conversation_task
>> branch_task
)
# The branch of no streaming response.
(
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
# The branch of streaming response.
(
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
if __name__ == "__main__":
if multi_round_dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([multi_round_dag], port=5555)
else:
# Production mode, DB-GPT will automatically load and execute the current file after startup.
pass

View File

@@ -2,24 +2,31 @@
DB-GPT will automatically load and execute the current file after startup.
Example:
Examples:
Call with non-streaming response.
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"messages": "hello"
}'
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \
Call with streaming response.
.. code-block:: shell
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"messages": "hello",
"stream": true
}'
Call model and count token.
.. code-block:: shell
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
@@ -27,20 +34,26 @@
}'
"""
from typing import Dict, Any, AsyncIterator, Optional, Union, List
from typing import Dict, Any, Optional, Union, List
import logging
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.component import ComponentType
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator
from dbgpt.core import (
ModelMessage,
LLMClient,
from dbgpt.core.awel import (
DAG,
HttpTrigger,
MapOperator,
JoinOperator,
)
from dbgpt.core import LLMClient
from dbgpt.core.operator import (
LLMBranchOperator,
LLMOperator,
StreamingLLMOperator,
ModelOutput,
ModelRequest,
RequestBuildOperator,
)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator
logger = logging.getLogger(__name__)
class TriggerReqBody(BaseModel):
@@ -51,58 +64,24 @@ class TriggerReqBody(BaseModel):
stream: Optional[bool] = Field(default=False, description="Whether return stream")
class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> ModelRequest:
messages = [ModelMessage.build_human_message(input_value.messages)]
await self.current_dag_context.save_to_share_data(
"request_model_name", input_value.model
)
return ModelRequest(
model=input_value.model,
messages=messages,
echo=False,
)
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class LLMMixin:
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self._llm_client = DefaultLLMClient(worker_manager)
return self._llm_client
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
class MyLLMOperator(LLMMixin, LLMOperator):
def __init__(self, llm_client: LLMClient = None, **kwargs):
super().__init__(llm_client, **kwargs)
class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator):
def __init__(self, llm_client: LLMClient = None, **kwargs):
super().__init__(llm_client, **kwargs)
class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
from dbgpt.model.utils.chatgpt_utils import _to_openai_stream
model = await self.current_dag_context.get_share_data("request_model_name")
async for output in _to_openai_stream(model, input_value):
yield output
class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]):
def __init__(self, llm_client: LLMClient = None, **kwargs):
self._llm_client = llm_client
MapOperator.__init__(self, **kwargs)
class MyModelToolOperator(
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
MapOperator.__init__(self, llm_client, **kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
prompt_tokens = await self.llm_client.count_token(
@@ -118,25 +97,27 @@ class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]])
with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = MyLLMOperator()
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task
with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_client/generate_stream",
"/examples/simple_client/chat/completions",
methods="POST",
request_body=TriggerReqBody,
streaming_response=True,
streaming_predict_func=lambda req: req.stream,
)
request_handle_task = RequestHandleOperator()
model_task = MyStreamingLLMOperator()
openai_format_stream_task = MyLLMStreamingOperator()
trigger >> request_handle_task >> model_task >> openai_format_stream_task
request_handle_task = RequestBuildOperator()
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
trigger >> request_handle_task >> branch_task
branch_task >> llm_task >> model_parse_task >> result_join_task
branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
# Receive http request and trigger dag to run.
@@ -147,3 +128,15 @@ with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag:
)
model_task = MyModelToolOperator()
trigger >> model_task
if __name__ == "__main__":
if client_generate_dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
dags = [client_generate_dag, client_count_token_dag]
setup_dev_environment(dags, port=5555)
else:
# Production mode, DB-GPT will automatically load and execute the current file after startup.
pass

View File

@@ -2,9 +2,11 @@ import asyncio
from dbgpt.core.awel import DAG
from dbgpt.core import (
BaseOutputParser,
RequestBuildOperator,
PromptTemplate,
)
from dbgpt.core.operator import (
LLMOperator,
RequestBuildOperator,
)
from dbgpt.model import OpenAILLMClient

View File

@@ -10,9 +10,11 @@ from dbgpt.core.awel import (
)
from dbgpt.core import (
SQLOutputParser,
PromptTemplate,
)
from dbgpt.core.operator import (
LLMOperator,
RequestBuildOperator,
PromptTemplate,
)
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator