feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -175,6 +175,7 @@ CREATE TABLE IF NOT EXISTS `prompt_manage`
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
`prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',

View File

@@ -31,3 +31,13 @@ else:
validator,
PrivateAttr,
)
def model_to_json(model, **kwargs):
"""Convert a pydantic model to json"""
if PYDANTIC_VERSION == 1:
return model.json(**kwargs)
else:
if "ensure_ascii" in kwargs:
del kwargs["ensure_ascii"]
return model.model_dump_json(**kwargs)

View File

@@ -40,7 +40,7 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
cfg = Config()
cfg.SYSTEM_APP = system_app
# Initialize db storage first
_initialize_db_storage(param)
_initialize_db_storage(param, system_app)
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
@@ -86,12 +86,14 @@ def _create_model_start_listener(system_app: SystemApp):
return startup_event
def _initialize_db_storage(param: "WebServerParameters"):
def _initialize_db_storage(param: "WebServerParameters", system_app: SystemApp):
"""Initialize the db storage.
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
_initialize_db(
try_to_create_db=not param.disable_alembic_upgrade, system_app=system_app
)
def _migration_db_storage(param: "WebServerParameters"):
@@ -114,7 +116,9 @@ def _migration_db_storage(param: "WebServerParameters"):
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
def _initialize_db(
try_to_create_db: Optional[bool] = False, system_app: Optional[SystemApp] = None
) -> str:
"""Initialize the database
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
@@ -147,7 +151,11 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args)
db = initialize_db(db_url, db_name, engine_args)
if system_app:
from dbgpt.storage.metadata import UnifiedDBManagerFactory
system_app.register(UnifiedDBManagerFactory, db)
return default_meta_data_path
@@ -273,3 +281,9 @@ class WebServerParameters(BaseParameters):
"help": "Whether to disable alembic to initialize and upgrade database metadata",
},
)
awel_dirs: Optional[str] = field(
default=None,
metadata={
"help": "The directories to search awel files, split by `,`",
},
)

View File

@@ -46,9 +46,9 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
_initialize_awel(system_app, param)
# Register serve apps
register_serve_apps(system_app)
register_serve_apps(system_app, CFG)
def _initialize_model_cache(system_app: SystemApp):
@@ -64,8 +64,14 @@ def _initialize_model_cache(system_app: SystemApp):
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
def _initialize_awel(system_app: SystemApp):
def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
from dbgpt.core.awel import initialize_awel
from dbgpt.configs.model_config import _DAG_DEFINITION_DIR
initialize_awel(system_app, _DAG_DEFINITION_DIR)
# Add default dag definition dir
dag_dirs = [_DAG_DEFINITION_DIR]
if param.awel_dirs:
dag_dirs += param.awel_dirs.strip().split(",")
dag_dirs = [x.strip() for x in dag_dirs]
initialize_awel(system_app, dag_dirs)

View File

@@ -146,14 +146,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
mount_routers(app)
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
system_app.on_init()
# Before start, after initialize_components
# TODO: initialize_worker_manager_in_client as a component register in system_app
system_app.before_start()
# Migration db storage, so you db models must be imported before this
_migration_db_storage(param)
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
# TODO: initialize_worker_manager_in_client as a component register in system_app
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
@@ -186,6 +185,9 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
CFG.SERVER_LIGHT_MODE = True
mount_static_files(app)
# Before start, after on_init
system_app.before_start()
return param

View File

@@ -1,13 +1,28 @@
from dbgpt.component import SystemApp
from dbgpt._private.config import Config
def register_serve_apps(system_app: SystemApp):
def register_serve_apps(system_app: SystemApp, cfg: Config):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
# ################################ Prompt Serve Register Begin ######################################
from dbgpt.serve.prompt.serve import (
Serve as PromptServe,
SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX,
)
# Replace old prompt serve
# Set config
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
# Register serve app
system_app.register(PromptServe, api_prefix="/prompt")
# ################################ Prompt Serve Register End ########################################
# ################################ Conversation Serve Register Begin ######################################
from dbgpt.serve.conversation.serve import Serve as ConversationServe
# Register serve app
system_app.register(ConversationServe)
# ################################ Conversation Serve Register End ########################################

View File

@@ -217,6 +217,10 @@ async def dialogue_list(
model_name = item.get("model_name", CFG.LLM_MODEL)
user_name = item.get("user_name")
sys_code = item.get("sys_code")
if not item.get("messages"):
# Skip the empty messages
# TODO support new conversation and message mode
continue
messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"])

View File

@@ -17,10 +17,28 @@ logger = logging.getLogger(__name__)
class LifeCycle:
"""This class defines hooks for lifecycle events of a component."""
"""This class defines hooks for lifecycle events of a component.
Execution order of lifecycle hooks:
1. on_init
2. before_start(async_before_start)
3. after_start(async_after_start)
4. before_stop(async_before_stop)
"""
def on_init(self):
"""Called when the component is being initialized."""
pass
async def async_on_init(self):
"""Asynchronous version of on_init."""
pass
def before_start(self):
"""Called before the component starts."""
"""Called before the component starts.
This method is called after the component has been initialized and before it is started.
"""
pass
async def async_before_start(self):
@@ -59,6 +77,7 @@ class ComponentType(str, Enum):
RAG_GRAPH_DEFAULT = "dbgpt_rag_engine_default"
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
@PublicAPI(stability="beta")
@@ -177,6 +196,16 @@ class SystemApp(LifeCycle):
raise TypeError(f"Component {name} is not of type {component_type}")
return component
def on_init(self):
"""Invoke the on_init hooks for all registered components."""
for _, v in self.components.items():
v.on_init()
async def async_on_init(self):
"""Asynchronously invoke the on_init hooks for all registered components."""
tasks = [v.async_on_init() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.components.items():

View File

@@ -8,38 +8,36 @@ The stability of this API cannot be guaranteed at present.
"""
from typing import List, Optional
from dbgpt.component import SystemApp
from .dag.base import DAGContext, DAG
from .dag.base import DAG, DAGContext
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchFunc,
BranchOperator,
InputOperator,
BranchFunc,
JoinOperator,
MapOperator,
ReduceStreamOperator,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
UnstreamifyAbsOperator,
)
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
DefaultTaskContext,
SimpleCallDataInputSource,
SimpleInputSource,
SimpleStreamTaskOutput,
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
"initialize_awel",
@@ -73,16 +71,16 @@ __all__ = [
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.dag_manager import DAGManager
from .operator.base import initialize_runner
from .trigger.trigger_manager import DefaultTriggerManager
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
dag_manager = DAGManager(system_app, dag_dirs)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
@@ -90,7 +88,11 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
dags: List[DAG],
host: Optional[str] = "0.0.0.0",
port: Optional[int] = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
) -> None:
"""Setup a development environment for AWEL.
@@ -98,9 +100,16 @@ def setup_dev_environment(
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from dbgpt.util.utils import setup_logging
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
if not logger_filename:
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)
app = FastAPI()
system_app = SystemApp(app)

View File

@@ -1,15 +1,16 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import contextvars
import logging
import threading
import uuid
from abc import ABC, abstractmethod
from collections import deque
from functools import cache
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext, TaskOutput
@@ -502,6 +503,9 @@ class DAG:
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def __repr__(self):
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()

View File

@@ -1,8 +1,10 @@
from typing import Dict, Optional
import logging
from typing import Dict, List
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
from .loader import LocalFileDAGLoader
logger = logging.getLogger(__name__)
@@ -10,9 +12,9 @@ logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}

View File

@@ -1,10 +1,10 @@
import hashlib
import logging
import os
import sys
import traceback
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
@@ -18,17 +18,19 @@ class DAGLoader(ABC):
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def __init__(self, dag_dirs: List[str]) -> None:
self._dag_dirs = dag_dirs
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
dags = []
for filepath in self._dag_dirs:
if not os.path.exists(filepath):
continue
if os.path.isdir(filepath):
dags += _process_directory(filepath)
else:
dags += _process_file(filepath)
return dags
def _process_directory(directory: str) -> List[DAG]:

View File

@@ -1,6 +1,8 @@
import pytest
import threading
import asyncio
import threading
import pytest
from ..base import DAG, DAGVar

View File

@@ -1,32 +1,32 @@
from abc import ABC, abstractmethod, ABCMeta
import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
from inspect import signature
from types import FunctionType
from typing import (
List,
Generic,
TypeVar,
AsyncIterator,
Iterator,
Union,
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
cast,
)
import functools
from inspect import signature
import asyncio
from dbgpt.component import SystemApp, ComponentType
from dbgpt.component import ComponentType, SystemApp
from dbgpt.util.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
AsyncToSyncIterator,
BlockingFunction,
DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import TaskOutput, OUT, T
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import OUT, T, TaskOutput
F = TypeVar("F", bound=FunctionType)

View File

@@ -1,27 +1,19 @@
import asyncio
import logging
from typing import (
Generic,
Dict,
List,
Union,
Callable,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Union,
)
import asyncio
import logging
from ..dag.base import DAGContext
from ..task.base import (
TaskContext,
TaskOutput,
IN,
OUT,
InputContext,
InputSource,
)
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
from .base import BaseOperator
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Generic, AsyncIterator
from ..task.base import OUT, IN, TaskOutput, TaskContext
from typing import AsyncIterator, Generic
from ..dag.base import DAGContext
from ..task.base import IN, OUT, TaskContext, TaskOutput
from .base import BaseOperator

View File

@@ -1,10 +1,10 @@
import asyncio
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG, DAGLifecycle
import uuid
from typing import Dict, List, Optional, Set
from ..operator.base import BaseOperator, CALL_DATA
from ..dag.base import DAG, DAGLifecycle
from ..operator.base import CALL_DATA, BaseOperator
logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,10 @@
from typing import Dict, Optional, Set, List
import logging
from typing import Dict, List, Optional, Set
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput

View File

@@ -1,15 +1,15 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
AsyncIterator,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)
IN = TypeVar("IN")

View File

@@ -1,22 +1,22 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
Optional,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
logger = logging.getLogger(__name__)

View File

@@ -1,14 +1,16 @@
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from typing import AsyncIterator, List
from contextlib import contextmanager, asynccontextmanager
from .. import (
WorkflowRunner,
InputOperator,
DAGContext,
TaskState,
DefaultWorkflowRunner,
InputOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from ..task.task_impl import _is_async_iterator

View File

@@ -1,24 +1,26 @@
import pytest
from typing import List
import pytest
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
DAGContext,
InputOperator,
JoinOperator,
MapOperator,
ReduceStreamOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from .conftest import (
runner,
_is_async_iterator,
input_node,
input_nodes,
runner,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)

View File

@@ -1,24 +1,26 @@
import pytest
from typing import List
import pytest
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
DAGContext,
InputOperator,
JoinOperator,
MapOperator,
ReduceStreamOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from .conftest import (
runner,
_is_async_iterator,
input_node,
input_nodes,
runner,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
import logging
from .base import Trigger
from dbgpt._private.pydantic import BaseModel
from ..dag.base import DAG
from ..operator.base import BaseOperator
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI

View File

@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from fastapi import APIRouter
from dbgpt.component import SystemApp, BaseComponent, ComponentType
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__)

View File

@@ -114,6 +114,9 @@ class ModelRequestContext:
span_id: Optional[str] = None
"""The span id of the model inference."""
chat_mode: Optional[str] = None
"""The chat mode of the model inference."""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
@@ -195,7 +198,13 @@ class ModelRequest:
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v}
def _get_messages(self) -> List[ModelMessage]:
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
return list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
@@ -209,7 +218,7 @@ class ModelRequest:
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self._get_messages()
messages = self.get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
@@ -176,6 +176,22 @@ class ModelMessage(BaseModel):
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
@staticmethod
def get_printable_message(messages: List["ModelMessage"]) -> str:
"""Get the printable message"""
str_msg = ""
for message in messages:
curr_message = (
f"(Round {message.round_index}) {message.role}: {message.content} "
)
str_msg += curr_message.rstrip() + "\n"
return str_msg
_SingleRoundMessage = List[ModelMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
def _message_to_dict(message: BaseMessage) -> Dict:
return message.to_dict()

View File

@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, List, Optional
from dbgpt.core import (
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
@@ -12,6 +13,7 @@ from dbgpt.core import (
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
from dbgpt.core.interface.message import _MultiRoundMessageMapper
class BaseConversationOperator(BaseOperator, ABC):
@@ -24,7 +26,7 @@ class BaseConversationOperator(BaseOperator, ABC):
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self._storage = storage
@@ -88,7 +90,7 @@ class PreConversationOperator(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
@@ -109,7 +111,7 @@ class PreConversationOperator(
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.extra.get("chat_mode")
chat_mode = input_value.context.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
@@ -121,11 +123,8 @@ class PreConversationOperator(
conv_storage=self.storage,
message_storage=self.message_storage,
)
# The input message must be a single user message
single_human_message: ModelMessage = input_value.get_single_user_message()
storage_conv.start_new_round()
storage_conv.add_user_message(single_human_message.content)
input_messages = input_value.get_messages()
await self.save_to_storage(storage_conv, input_messages)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
@@ -139,6 +138,42 @@ class PreConversationOperator(
)
return input_value
async def save_to_storage(
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
) -> None:
"""Save the messages to storage.
Args:
storage_conv (StorageConversation): The storage conversation.
input_messages (List[ModelMessage]): The input messages.
"""
# check first
self.check_messages(input_messages)
storage_conv.start_new_round()
for message in input_messages:
if message.role == ModelMessageRoleType.HUMAN:
storage_conv.add_user_message(message.content)
else:
storage_conv.add_system_message(message.content)
def check_messages(self, messages: List[ModelMessage]) -> None:
"""Check the messages.
Args:
messages (List[ModelMessage]): The messages.
Raises:
ValueError: If the messages is empty.
"""
if not messages:
raise ValueError("Input messages is empty")
for message in messages:
if message.role not in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.SYSTEM,
]:
raise ValueError(f"Message role {message.role} is not supported")
async def after_dag_end(self):
"""The callback after DAG end"""
# Save the storage conversation to storage after the whole DAG finished
@@ -198,8 +233,9 @@ class PostStreamingConversationOperator(
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
def __init__(self, **kwargs):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
@@ -211,12 +247,12 @@ class ConversationMapperOperator(
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
messages: List[ModelMessage] = self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
@@ -225,7 +261,73 @@ class ConversationMapperOperator(
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
return messages
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
"""Map multi round messages to a list of ModelMessage
By default, just merge all multi round messages to a list of ModelMessage according origin order.
And you can overwrite this method to implement your own logic.
Examples:
Merge multi round messages to a list of ModelMessage according origin order.
.. code-block:: python
import asyncio
from dbgpt.core.operator import ConversationMapperOperator
messages_by_round = [
[
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
],
[
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
],
[
ModelMessage(role="human", content="Funny!", round_index=3),
],
]
operator = ConversationMapperOperator()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
Map multi round messages to a list of ModelMessage just keep the last one round.
.. code-block:: python
class MyMapper(ConversationMapperOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
return messages_by_round[-1]
operator = MyMapper()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Funny!", round_index=3),
]
Args:
"""
# Just merge and return
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
return sum(messages_by_round, [])
def _split_messages_by_round(
self, messages: List[ModelMessage]
@@ -236,7 +338,7 @@ class ConversationMapperOperator(
messages (List[ModelMessage]): The input messages.
Returns:
List[List[ModelMessage]]: The splitted messages.
List[List[ModelMessage]]: The split messages.
"""
messages_by_round: List[List[ModelMessage]] = []
last_round_index = 0
@@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
.. code-block:: python
import asyncio
from dbgpt.core import ModelMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
# No history
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
Transform with history messages
@@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
# Just keep the last one round, so the first round messages will be removed
# Note: The round index 3 is not a complete round
assert messages == [
assert operator.map_messages(messages) == [
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
@@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
]
"""
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
last_k_round: Optional[int] = 2,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
self._last_k_round = last_k_round
if message_mapper:
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
# Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round)
return message_mapper(messages_by_round)
else:
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, [])
super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[List[ModelMessage]]:
"""Keep the last k round messages.
Args:
messages (List[ModelMessage]): The input messages.
messages_by_round (List[List[ModelMessage]]): The messages by round.
Returns:
List[ModelMessage]: The mapped ModelMessage.
List[List[ModelMessage]]: The latest round messages.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
# Get the last k round messages
index = self._last_k_round + 1
messages_by_round = messages_by_round[-index:]
messages: List[ModelMessage] = sum(messages_by_round, [])
return messages
return messages_by_round[-index:]

View File

@@ -169,9 +169,7 @@ class StoragePromptTemplate(StorageItem):
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
[] if not self.input_variables else self.input_variables.strip().split(",")
)
return PromptTemplate(
input_variables=input_variables,
@@ -458,6 +456,33 @@ class PromptManager:
)
self.storage.save(storage_prompt_template)
def query_or_save(
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
) -> StoragePromptTemplate:
"""Query a prompt template from storage, if not found, save it.
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
Returns:
StoragePromptTemplate: The storage prompt template.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
exist_prompt_template = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if exist_prompt_template:
return exist_prompt_template
self.save(prompt_template, prompt_name, **kwargs)
return self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.

View File

@@ -16,7 +16,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseSettings
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponse,
@@ -42,6 +41,7 @@ from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
from dbgpt.model.cluster import ModelRegistry
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from dbgpt.util.utils import setup_logging
from dbgpt._private.pydantic import BaseModel
logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class APIServerException(Exception):
self.message = message
class APISettings(BaseSettings):
class APISettings(BaseModel):
api_keys: Optional[List[str]] = None

View File

@@ -24,6 +24,7 @@ from dbgpt.core.interface.llm import ModelMetadata, LLMClient
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt._private.pydantic import model_to_json
if TYPE_CHECKING:
import httpx
@@ -175,6 +176,9 @@ class OpenAILLMClient(LLMClient):
async def generate(self, request: ModelRequest) -> ModelOutput:
messages = request.to_openai_messages()
payload = self._build_request(request)
logger.info(
f"Send request to openai, payload: {payload}\n\n messages:\n{messages}"
)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
@@ -193,6 +197,9 @@ class OpenAILLMClient(LLMClient):
) -> AsyncIterator[ModelOutput]:
messages = request.to_openai_messages()
payload = self._build_request(request, True)
logger.info(
f"Send request to openai, payload: {payload}\n\n messages:\n{messages}"
)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
@@ -321,7 +328,7 @@ async def _to_openai_stream(
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model or ""
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield f"data: {model_to_json(chunk, exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
finish_stream_events = []
@@ -356,7 +363,7 @@ async def _to_openai_stream(
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield f"data: {model_to_json(chunk, exclude_unset=True, ensure_ascii=False)}\n\n"
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield f"data: {model_to_json(finish_chunk, exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve conversation`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve conversation`

View File

@@ -0,0 +1,176 @@
from typing import Optional, List
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from .schemas import ServeRequest, ServerResponse
from ..service.service import Service
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new Conversation entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a Conversation entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query Conversation entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Conversation entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,20 @@
# Define your Pydantic schemas here
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""Conversation request model"""
# TODO define your own fields here
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
class ServerResponse(BaseModel):
"""Conversation response model"""
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@@ -0,0 +1,23 @@
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "conversation"
SERVE_APP_NAME = "dbgpt_serve_conversation"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Conversation"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.conversation."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_conversation"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve conversation`

View File

@@ -0,0 +1,68 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from typing import Union, Any, Dict
from datetime import datetime
from sqlalchemy import Column, Integer, String, Index, Text, DateTime
from dbgpt.storage.metadata import Model, BaseDao, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
# TODO: define your own fields here
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Conversation"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = request.dict() if isinstance(request, ServeRequest) else request
entity = ServeEntity(**request_dict)
# TODO implement your own logic here, transfer the request_dict to an entity
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
# TODO implement your own logic here, transfer the entity to a request
return ServeRequest()
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
return ServerResponse()

View File

@@ -0,0 +1,99 @@
from typing import List, Optional, Union
import logging
from dbgpt.component import SystemApp
from sqlalchemy import URL
from dbgpt.core import StorageInterface
from dbgpt.storage.metadata import DatabaseManager
from dbgpt.serve.core import BaseServe
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT
Message DB-GPT conversation history and provide API for other components to access.
TODO: Move some Http API in app to this component.
"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
self._conv_storage = None
self._message_storage = None
@property
def conv_storage(self) -> StorageInterface:
return self._conv_storage
@property
def message_storage(self) -> StorageInterface:
return self._message_storage
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# Load DB Model
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
from dbgpt.storage.chat_history.storage_adapter import (
DBStorageConversationItemAdapter,
DBMessageStorageItemAdapter,
)
self._db_manager = self.create_or_get_db_manager()
self._conv_storage = SQLAlchemyStorage(
self._db_manager,
ChatHistoryEntity,
DBStorageConversationItemAdapter(),
JsonSerializer(),
)
self._message_storage = SQLAlchemyStorage(
self._db_manager,
ChatHistoryMessageEntity,
DBMessageStorageItemAdapter(),
JsonSerializer(),
)

View File

@@ -0,0 +1,116 @@
from typing import Optional, List
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.serve.core import BaseService
from ..models.models import ServeDao, ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Conversation"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Conversation entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a Conversation entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a Conversation entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Conversation entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of Conversation entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)

View File

@@ -0,0 +1,124 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,109 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,76 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -1,5 +1,6 @@
from dbgpt.serve.core.schemas import Result
from dbgpt.serve.core.config import BaseServeConfig
from dbgpt.serve.core.service import BaseService
from dbgpt.serve.core.serve import BaseServe
__ALL__ = ["Result", "BaseServeConfig", "BaseService"]
__ALL__ = ["Result", "BaseServeConfig", "BaseService", "BaseServe"]

60
dbgpt/serve/core/serve.py Normal file
View File

@@ -0,0 +1,60 @@
from abc import ABC
from typing import Optional, Union, List
import logging
from dbgpt.component import BaseComponent, SystemApp, ComponentType
from sqlalchemy import URL
from dbgpt.storage.metadata import DatabaseManager
logger = logging.getLogger(__name__)
class BaseServe(BaseComponent, ABC):
"""Base serve component for DB-GPT"""
name = "dbgpt_serve_base"
def __init__(
self,
system_app: SystemApp,
api_prefix: str,
api_tags: List[str],
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
self._system_app = system_app
self._api_prefix = api_prefix
self._api_tags = api_tags
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
self._not_create_table = True
self._app_has_initiated = False
def create_or_get_db_manager(self) -> DatabaseManager:
"""Create or get the database manager.
This method must be called after the application is initialized
Returns:
DatabaseManager: The database manager
"""
from dbgpt.storage.metadata import Model, db, UnifiedDBManagerFactory
# If you need to use the database, you can get the database manager here
db_manager_factory: UnifiedDBManagerFactory = self._system_app.get_component(
ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY,
UnifiedDBManagerFactory,
default_component=None,
)
if db_manager_factory is not None and db_manager_factory.create():
init_db = db_manager_factory.create()
else:
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables and self._not_create_table:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
finally:
self._not_create_table = False
return init_db

View File

@@ -8,15 +8,6 @@ from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.util import AppConfig
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
def create_system_app(param: Dict) -> SystemApp:
app_config = param.get("app_config", {})
@@ -24,7 +15,17 @@ def create_system_app(param: Dict) -> SystemApp:
app_config = AppConfig(configs=app_config)
elif not isinstance(app_config, AppConfig):
raise RuntimeError("app_config must be AppConfig or dict")
return SystemApp(app, app_config)
test_app = FastAPI()
test_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
return SystemApp(test_app, app_config)
@pytest_asyncio.fixture
@@ -51,9 +52,12 @@ async def client(request, asystem_app: SystemApp):
del param["api_keys"]
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
test_app = asystem_app.app
async with AsyncClient(app=test_app, base_url=base_url, headers=headers) as client:
for router in routers:
app.include_router(router)
test_app.include_router(router)
if app_caller:
app_caller(app, asystem_app)
app_caller(test_app, asystem_app)
yield client

View File

@@ -43,6 +43,15 @@ class ServeRequest(BaseModel):
"You are a data analysis expert.",
],
)
prompt_desc: Optional[str] = Field(
None,
description="The prompt description.",
examples=[
"This is a prompt for code assistant.",
"This is a prompt for joker.",
"This is a prompt for data analysis expert.",
],
)
user_name: Optional[str] = Field(
None,

View File

@@ -48,6 +48,7 @@ class ServeEntity(Model):
default="f-string",
comment="Prompt format(eg: f-string, jinja2)",
)
prompt_desc = Column(String(512), nullable=True, comment="Prompt description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -96,6 +97,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -119,6 +121,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,

View File

@@ -3,10 +3,11 @@ from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.component import SystemApp
from dbgpt.core import PromptManager
from ...storage.metadata import DatabaseManager
from dbgpt.storage.metadata import DatabaseManager
from dbgpt.serve.core import BaseServe
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
@@ -20,7 +21,7 @@ from .models.prompt_template_adapter import PromptTemplateAdapter
logger = logging.getLogger(__name__)
class Serve(BaseComponent):
class Serve(BaseServe):
"""Serve component
Examples:
@@ -37,6 +38,7 @@ class Serve(BaseComponent):
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt")
system_app.on_init()
# Run before start hook
system_app.before_start()
@@ -61,6 +63,7 @@ class Serve(BaseComponent):
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
system_app.on_init()
# Run before start hook
system_app.before_start()
@@ -81,31 +84,41 @@ class Serve(BaseComponent):
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._prompt_manager = None
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._tags
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
@property
def prompt_manager(self) -> PromptManager:
"""Get the prompt manager of the serve app with db storage"""
return self._prompt_manager
def on_init(self):
"""Called before the start of the application.
You can do some initialization here.
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application.
@@ -113,23 +126,16 @@ class Serve(BaseComponent):
"""
# import your own module here to ensure the module is loaded before the application starts
from dbgpt.core.interface.prompt import PromptManager
from dbgpt.storage.metadata import Model, db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
self._db_manager = self.create_or_get_db_manager()
storage_adapter = PromptTemplateAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
init_db,
self._db_manager,
ServeEntity,
storage_adapter,
serializer,

View File

@@ -1,6 +1,9 @@
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from typing import List, Optional, Union
import logging
from dbgpt.component import SystemApp
from sqlalchemy import URL
from dbgpt.storage.metadata import DatabaseManager
from dbgpt.serve.core import BaseServe
from .api.endpoints import router, init_endpoints
from .config import (
SERVE_APP_NAME,
@@ -10,8 +13,10 @@ from .config import (
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseComponent):
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
@@ -20,25 +25,36 @@ class Serve(BaseComponent):
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._tags
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def before_start(self):
"""Called before the start of the application.
def on_init(self):
"""Called when init the application.
You can do some initialization here.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -5,6 +5,7 @@ from dbgpt.storage.metadata.db_manager import (
create_model,
BaseModel,
)
from dbgpt.storage.metadata.db_factory import UnifiedDBManagerFactory
from dbgpt.storage.metadata._base_dao import BaseDao
__ALL__ = [
@@ -14,4 +15,5 @@ __ALL__ = [
"create_model",
"BaseModel",
"BaseDao",
"UnifiedDBManagerFactory",
]

View File

@@ -0,0 +1,21 @@
from dbgpt.component import SystemApp, BaseComponent, ComponentType
from .db_manager import DatabaseManager
class UnifiedDBManagerFactory(BaseComponent):
name = ComponentType.UNIFIED_METADATA_DB_MANAGER_FACTORY
def __init__(self, system_app: SystemApp, db_manager: DatabaseManager):
super().__init__(system_app)
self._db_manager = db_manager
def init_app(self, system_app: SystemApp):
pass
def create(self) -> DatabaseManager:
if not self._db_manager:
raise RuntimeError("db_manager is not initialized")
if not self._db_manager.is_initialized:
raise RuntimeError("db_manager is not initialized")
return self._db_manager

View File

@@ -183,6 +183,11 @@ class DatabaseManager:
"""Get the engine.""" ""
return self._engine
@property
def is_initialized(self) -> bool:
"""Whether the database manager is initialized.""" ""
return self._engine is not None and self._session is not None
@contextmanager
def session(self) -> Session:
"""Get the session with context manager.
@@ -200,7 +205,7 @@ class DatabaseManager:
RuntimeError: The database manager is not initialized.
Exception: Any exception.
"""
if not self._session:
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
try:

View File

@@ -1,3 +1,4 @@
import os
from functools import cache
from typing import Any, Dict, Optional
@@ -33,3 +34,19 @@ class AppConfig:
prefix (str): The prefix of config
"""
return {k: v for k, v in self.configs.items() if k.startswith(prefix)}
def get_current_lang(self, default: Optional[str] = None) -> str:
"""Get current language
Args:
default (Optional[str], optional): The default language if not found. Defaults to None.
Returns:
str: The language of user running environment
"""
env_lang = (
"zh"
if os.getenv("LANG") and os.getenv("LANG").startswith("zh")
else default
)
return self.get("dbgpt.app.global.language", env_lang)

View File

@@ -44,11 +44,20 @@ async def _do_chat_completion(
decoded_line = line.split("data: ", 1)[1]
if decoded_line.lower().strip() != "[DONE]".lower():
obj = json.loads(decoded_line)
if obj["choices"][0]["delta"].get("content") is not None:
text = obj["choices"][0]["delta"].get("content")
if "error_code" in obj and obj["error_code"] != 0:
if caller:
await caller(text)
yield text
await caller(obj.get("text"))
yield obj.get("text")
else:
if (
"choices" in obj
and obj["choices"][0]["delta"].get("content")
is not None
):
text = obj["choices"][0]["delta"].get("content")
if caller:
await caller(text)
yield text
await asyncio.sleep(0.02)

View File

@@ -139,7 +139,7 @@ class FileSpanStorage(SpanStorage):
def _write_to_file(self, spans: List[Span]):
self._roll_over_if_needed()
with open(self.filename, "a") as file:
with open(self.filename, "a", encoding="utf8") as file:
for span in spans:
span_data = span.to_dict()
try:

View File

@@ -10,6 +10,14 @@ import asyncio
from dbgpt.configs.model_config import LOGDIR
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)

View File

@@ -0,0 +1,438 @@
"""AWEL: Data analyst assistant.
DB-GPT will automatically load and execute the current file after startup.
Examples:
.. code-block:: shell
# Run this file in your terminal with dev mode.
# First terminal
export OPENAI_API_KEY=xxx
export OPENAI_API_BASE=https://api.openai.com/v1
python examples/awel/simple_chat_history_example.py
Code fix command, return no streaming response
.. code-block:: shell
# Open a new terminal
# Second terminal
DBGPT_SERVER="http://127.0.0.1:5555"
MODEL="gpt-3.5-turbo"
# Fist round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
-H "Content-Type: application/json" -d '{
"command": "dbgpt_awel_data_analyst_code_fix",
"model": "gpt-3.5-turbo",
"stream": false,
"context": {
"conv_uid": "uuid_conv_copilot_1234",
"chat_mode": "chat_with_code"
},
"messages": "SELECT * FRM orders WHERE order_amount > 500;"
}'
"""
import logging
from functools import cache
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
InMemoryStorage,
LLMClient,
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
PromptManager,
PromptTemplate,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import (
BufferedConversationMapperOperator,
LLMBranchOperator,
LLMOperator,
PostConversationOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
StreamingLLMOperator,
)
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
from dbgpt.util.utils import colored
logger = logging.getLogger(__name__)
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
here is a snippet of code of {language}. Please review the code following best practices to identify and fix all errors.
Provide the corrected code and include a line-by-line explanation of all the fixes you've made, please use the same language as the user."""
CODE_PERF_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,这里有一段 {language} 代码。
请你按照最佳实践来优化这段代码。请在代码中加入注释点明所做的更改,并解释每项优化的原因,以便提高代码的维护性和性能,请使用和用户相同的语言进行回答。"""
CODE_PERF_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you are provided with a snippet of code of {language}. Please optimize the code according to best practices.
Include comments to highlight the changes made and explain the reasons for each optimization for better maintenance and performance,
please use the same language as the user."""
CODE_EXPLAIN_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
现在给你的是一份 {language} 代码。请你逐行解释代码的含义,请使用和用户相同的语言进行回答。"""
CODE_EXPLAIN_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you are provided with a snippet of code of {language}. Please explain the meaning of the code line by line,
please use the same language as the user."""
CODE_COMMENT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在给你的是一份 {language} 代码。
请你为每一行代码添加注释,解释每个部分的作用,请使用和用户相同的语言进行回答。"""
CODE_COMMENT_TEMPLATE_EN = """As an experienced Data Warehouse Developer and Data Analyst.
Below is a snippet of code written in {language}.
Please provide line-by-line comments explaining what each section of the code does, please use the same language as the user."""
CODE_TRANSLATE_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在手头有一份用{source_language}语言编写的代码片段。
请你将这段代码准确无误地翻译成{target_language}语言,确保语法和功能在翻译后的代码中得到正确体现,请使用和用户相同的语言进行回答。"""
CODE_TRANSLATE_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you're presented with a snippet of code written in {source_language}.
Please translate this code into {target_language} ensuring that the syntax and functionalities are accurately reflected in the translated code,
please use the same language as the user."""
class ReqContext(BaseModel):
user_name: Optional[str] = Field(
None, description="The user name of the model request."
)
sys_code: Optional[str] = Field(
None, description="The system code of the model request."
)
conv_uid: Optional[str] = Field(
None, description="The conversation uid of the model request."
)
chat_mode: Optional[str] = Field(
"chat_with_code", description="The chat mode of the model request."
)
class TriggerReqBody(BaseModel):
messages: str = Field(..., description="User input messages")
command: Optional[str] = Field(default="fix", description="Command name")
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream")
language: Optional[str] = Field(default="hive", description="Language")
target_language: Optional[str] = Field(
default="hive", description="Target language, use in translate"
)
context: Optional[ReqContext] = Field(
default=None, description="The context of the model request."
)
@cache
def load_or_save_prompt_template(pm: PromptManager):
ext_params = {
"chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst",
"prompt_type": "common",
}
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_ZH,
),
prompt_name=CODE_FIX,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_EN,
),
prompt_name=CODE_FIX,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_ZH,
),
prompt_name=CODE_PERF,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_EN,
),
prompt_name=CODE_PERF,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_ZH,
),
prompt_name=CODE_EXPLAIN,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_EN,
),
prompt_name=CODE_EXPLAIN,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_ZH,
),
prompt_name=CODE_COMMENT,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_EN,
),
prompt_name=CODE_COMMENT,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_ZH,
),
prompt_name=CODE_TRANSLATE,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_EN,
),
prompt_name=CODE_TRANSLATE,
prompt_language="en",
**ext_params,
)
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._default_prompt_manager = PromptManager()
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
from dbgpt.serve.prompt.serve import Serve as PromptServe
prompt_serve = self.system_app.get_component(
PROMPT_SERVE_APP_NAME, PromptServe, default_component=None
)
if prompt_serve:
pm = prompt_serve.prompt_manager
else:
pm = self._default_prompt_manager
load_or_save_prompt_template(pm)
user_language = self.system_app.config.get_current_lang(default="en")
prompt_list = pm.prefer_query(
input_value.command, prefer_prompt_language=user_language
)
if not prompt_list:
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
logger.error(error_msg)
raise ValueError(error_msg)
prompt = prompt_list[0].to_prompt_template()
if input_value.command == CODE_TRANSLATE:
format_params = {
"source_language": input_value.language,
"target_language": input_value.target_language,
}
else:
format_params = {"language": input_value.language}
system_message = prompt.format(**format_params)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message),
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages),
]
context = input_value.context.dict() if input_value.context else {}
return {
"messages": messages,
"stream": input_value.stream,
"model": input_value.model,
"context": context,
}
class MyConversationOperator(PreConversationOperator):
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage, message_storage, **kwargs)
def _get_conversion_serve(self):
from dbgpt.serve.conversation.serve import (
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
)
from dbgpt.serve.conversation.serve import Serve as ConversationServe
conversation_serve: ConversationServe = self.system_app.get_component(
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
)
return conversation_serve
@property
def storage(self):
if self._storage:
return self._storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.conv_storage
else:
logger.info("Conversation storage not found, use InMemoryStorage default")
self._storage = InMemoryStorage()
return self._storage
@property
def message_storage(self):
if self._message_storage:
return self._message_storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.message_storage
else:
logger.info("Message storage not found, use InMemoryStorage default")
self._message_storage = InMemoryStorage()
return self._message_storage
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
def history_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
"""Mapper for history conversation.
If there are multi system messages, just keep the first system message.
"""
has_system_message = False
mapper_messages = []
for messages in messages_by_round:
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
if has_system_message:
continue
else:
mapper_messages.append(message)
has_system_message = True
else:
mapper_messages.append(message)
print("history_message_mapper start:" + "=" * 70)
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
print("history_message_mapper end:" + "=" * 72)
return mapper_messages
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
trigger = HttpTrigger(
"/examples/data_analyst/copilot",
request_body=TriggerReqBody,
methods="POST",
streaming_predict_func=lambda x: x.stream,
)
copilot_task = CopilotOperator()
request_handle_task = RequestBuildOperator()
# Pre-process conversation
pre_conversation_task = MyConversationOperator()
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(
last_k_round=5, message_mapper=history_message_mapper
)
# Save conversation to storage.
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
# Use LLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
(
trigger
>> copilot_task
>> request_handle_task
>> pre_conversation_task
>> history_conversation_task
>> branch_task
)
# The branch of no streaming response.
(
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
# The branch of streaming response.
(
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag])
else:
pass

View File

@@ -33,12 +33,15 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": null,
"metadata": {
"is_executing": true
},
"outputs": [],
"source": [
"os.environ[\"OPENAI_API_KEY\"] = \"sk-ElhG3036tcvECTOYO9nHqAkUIWtqT55JXs1cTjODz1bnUQDz\"\n",
"os.environ[\"OPENAI_API_BASE\"] = \"https://api.chatanywhere.tech/v1\""
"# Set your api key and api base url\n",
"# os.environ[\"OPENAI_API_KEY\"] = \"Your API\"\n",
"# os.environ[\"OPENAI_API_BASE\"] = \"https://api.openai.com/v1\""
]
},
{
@@ -50,12 +53,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mUser\u001b[0m (to plan_manager)-[]:\n",
"\u001B[33mUser\u001B[0m (to plan_manager)-[]:\n",
"\n",
"\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mplan_manager\u001b[0m (to Planner)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to Planner)-[]:\n",
"\n",
"\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n",
"\n",
@@ -72,15 +75,15 @@
" \"content\": \"Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\",\n",
" \"rely\": \"1\"\n",
"}]\n",
"\u001b[33mPlanner\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[33mPlanner\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"[{\\n \\\"serial_number\\\": \\\"1\\\",\\n \\\"agent\\\": \\\"CodeEngineer\\\",\\n \\\"content\\\": \\\"Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\\\",\\n \\\"rely\\\": \\\"\\\"\\n},\\n{\\n \\\"serial_number\\\": \\\"2\\\",\\n \\\"agent\\\": \\\"CodeEngineer\\\",\\n \\\"content\\\": \\\"Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\\\",\\n \\\"rely\\\": \\\"1\\\"\\n}]\"\n",
"\u001b[32m>>>>>>>>Planner Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>Planner Action report: \n",
"\u001B[32m>>>>>>>>Planner Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>Planner Action report: \n",
"execution succeeded,\n",
"1,Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\n",
"2,Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\u001b[0m\n",
"2,Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -96,7 +99,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n",
"\n",
"\"Write a Python script to use the GitHub API to retrieve issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. The script should extract the issue's creation date and status.\"\n",
"\n",
@@ -141,14 +144,14 @@
"```\n",
"\n",
"When you run this script, it will print out the creation date and status of issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. Make sure to have internet access while running the script as it needs to make requests to the GitHub API.\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n",
"\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nWhen you run this script, it will print out the creation date and status of issues from the 'eosphoros-ai/DB-GPT' repository that were created in the past three days. Make sure to have internet access while running the script as it needs to make requests to the GitHub API.\"\n",
"\u001b[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>CodeEngineer Action report: \n",
"\u001B[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>CodeEngineer Action report: \n",
"execution failed,\n",
"exitcode: 1 (execution failed)\n",
" \n",
@@ -156,7 +159,7 @@
" File \"github_issues.py\", line 2, in <module>\n",
" import requests\n",
"ModuleNotFoundError: No module named 'requests'\n",
"\u001b[0m\n",
"\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -172,7 +175,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n",
"\n",
"\"exitcode: 1 (execution failed)\\n \\nTraceback (most recent call last):\\n File \\\"github_issues.py\\\", line 2, in <module>\\n import requests\\nModuleNotFoundError: No module named 'requests'\\n\"\n",
"\n",
@@ -225,17 +228,17 @@
"```\n",
"\n",
"Please run the above Python script after you have installed the `requests` module.\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001b[0m\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)...\u001b[0m\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001B[0m\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 1 (inferred language is python)...\u001B[0m\n",
"un_stream ai response: False\n",
"\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"It appears that the `requests` module is not installed on your system. You can install it using `pip`, the Python package manager. Here's the command to install the `requests` module:\\n\\n```sh\\npip install requests\\n```\\n\\nAfter installing the `requests` module, you can try running the script again. Here is the complete script including the installation of the `requests` module:\\n\\n```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nPlease run the above Python script after you have installed the `requests` module.\"\n",
"\u001b[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>CodeEngineer Action report: \n",
"\u001B[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>CodeEngineer Action report: \n",
"execution succeeded,\n",
"\n",
"Collecting requests\n",
@@ -302,7 +305,7 @@
"Issue created at: 2023-10-26T01:45:37Z, Status: open\n",
"Issue created at: 2023-08-30T06:02:32Z, Status: closed\n",
"Issue created at: 2023-06-18T01:11:22Z, Status: open\n",
"\u001b[0m\n",
"\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -318,7 +321,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n",
"\n",
"\"The execution result of the code you wrote is judged as not answering the task question. Please re-understand and complete the task.\"\n",
"\n",
@@ -376,15 +379,15 @@
"```\n",
"\n",
"When you run this script, it will check for the `requests` module and install it if necessary. Then it will retrieve the issues from the 'eosphoros-ai/DB-GPT' repository created in the past three days and print their creation date and status.\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n",
"un_stream ai response: False\n",
"\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"I apologize for the inconvenience. Let me provide you with a Python script that includes a check to ensure the `requests` module is installed. If it's not installed, the script will attempt to install it using `pip`. Please run the following Python script:\\n\\n```python\\n# filename: github_issues.py\\nimport subprocess\\nimport pkg_resources\\nfrom datetime import datetime, timedelta\\n\\n# Ensure the 'requests' package is installed\\nrequired_packages = {'requests'}\\ninstalled_packages = {pkg.key for pkg in pkg_resources.working_set}\\nmissing_packages = required_packages - installed_packages\\n\\nif missing_packages:\\n subprocess.check_call(['pip', 'install', *missing_packages])\\n\\nimport requests\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n \\n # Parameters to get issues created in the past three days\\n params = {\\n 'since': since,\\n 'state': 'all', # Get both open and closed issues\\n }\\n \\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n \\n # Parse the JSON response\\n issues = response.json()\\n \\n # Extract and print the creation date and status of each issue\\n for issue in issues:\\n created_at = issue['created_at']\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nWhen you run this script, it will check for the `requests` module and install it if necessary. Then it will retrieve the issues from the 'eosphoros-ai/DB-GPT' repository created in the past three days and print their creation date and status.\"\n",
"\u001b[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>CodeEngineer Action report: \n",
"\u001B[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>CodeEngineer Action report: \n",
"execution succeeded,\n",
"\n",
"/workspace/github_issues.py:3: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n",
@@ -421,7 +424,7 @@
"Issue created at: 2023-10-26T01:45:37Z, Status: open\n",
"Issue created at: 2023-08-30T06:02:32Z, Status: closed\n",
"Issue created at: 2023-06-18T01:11:22Z, Status: open\n",
"\u001b[0m\n",
"\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -437,7 +440,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n",
"\n",
"\"The execution result of the code you wrote is judged as not answering the task question. Please re-understand and complete the task.\"\n",
"\n",
@@ -484,15 +487,15 @@
"```\n",
"\n",
"This script retrieves all issues from the specified repository and then filters them based on the creation date to find the ones that were created in the past three days. It then prints the creation date and status of each issue that meets the criteria. Please execute this updated script.\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n",
"un_stream ai response: True\n",
"\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"I apologize for the oversight. Let's try a different approach to ensure we get the issues created in the past three days from the 'eosphoros-ai/DB-GPT' repository. I will provide you with an updated Python script that correctly retrieves and filters the issues based on their creation date.\\n\\n```python\\n# filename: github_issues.py\\nimport requests\\nfrom datetime import datetime, timedelta\\n\\n# Function to get issues created in the past three days from a GitHub repository\\ndef get_recent_issues(owner, repo):\\n # Calculate the date for three days ago\\n three_days_ago = datetime.utcnow() - timedelta(days=3)\\n # Format the date in ISO 8601 format\\n since = three_days_ago.isoformat()\\n\\n # GitHub API URL for issues of the given repository\\n issues_url = f\\\"https://api.github.com/repos/{owner}/{repo}/issues\\\"\\n\\n # Parameters to get all issues, GitHub API does not allow filtering by creation date directly\\n params = {\\n 'state': 'all', # Get both open and closed issues\\n }\\n\\n # Send a request to the GitHub API\\n response = requests.get(issues_url, params=params)\\n # Raise an exception if the request was unsuccessful\\n response.raise_for_status()\\n\\n # Parse the JSON response\\n issues = response.json()\\n\\n # Filter and print the creation date and status of each issue created in the past three days\\n for issue in issues:\\n created_at = datetime.strptime(issue['created_at'], \\\"%Y-%m-%dT%H:%M:%SZ\\\")\\n if created_at > three_days_ago:\\n state = issue['state']\\n print(f\\\"Issue created at: {created_at}, Status: {state}\\\")\\n\\n# Replace 'owner' and 'repo' with the actual owner and repository name\\nget_recent_issues('eosphoros-ai', 'DB-GPT')\\n```\\n\\nThis script retrieves all issues from the specified repository and then filters them based on the creation date to find the ones that were created in the past three days. It then prints the creation date and status of each issue that meets the criteria. Please execute this updated script.\"\n",
"\u001b[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>CodeEngineer Action report: \n",
"\u001B[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>CodeEngineer Action report: \n",
"execution succeeded,\n",
"\n",
"/workspace/github_issues.py:8: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n",
@@ -510,7 +513,7 @@
"Issue created at: 2023-12-25 08:35:48, Status: open\n",
"Issue created at: 2023-12-25 08:34:13, Status: closed\n",
"Issue created at: 2023-12-25 06:33:37, Status: closed\n",
"\u001b[0m\n",
"\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -526,7 +529,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mplan_manager\u001b[0m (to CodeEngineer)-[]:\n",
"\u001B[33mplan_manager\u001B[0m (to CodeEngineer)-[]:\n",
"\n",
"\"Read the result data of the dependent steps in the above historical message to complete the current goal:Write a Python script to process the retrieved issues data to group them by creation date and status. Then, format the grouped data into a Markdown table.\"\n",
"\n",
@@ -573,15 +576,15 @@
"```\n",
"\n",
"When you run this script, it will print out a Markdown table with the number of open and closed issues grouped by their creation dates.\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001B[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001B[0m\n",
"un_stream ai response: True\n",
"\u001b[33mCodeEngineer\u001b[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\u001B[33mCodeEngineer\u001B[0m (to plan_manager)-[gpt-4-vision-preview]:\n",
"\n",
"\"```python\\n# Let's process the retrieved issues data to group them by creation date and status.\\n# Then, we will format the grouped data into a Markdown table.\\n\\n# The data from the previous step\\nissues_data = [\\n {\\\"created_at\\\": \\\"2023-12-27 03:20:12\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 16:13:42\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 10:23:43\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 07:25:37\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 05:01:26\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 03:54:15\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-26 02:19:48\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 12:26:40\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 12:00:52\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:36:56\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:35:48\\\", \\\"status\\\": \\\"open\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 08:34:13\\\", \\\"status\\\": \\\"closed\\\"},\\n {\\\"created_at\\\": \\\"2023-12-25 06:33:37\\\", \\\"status\\\": \\\"closed\\\"},\\n]\\n\\n# Grouping the issues by creation date and status\\nfrom collections import defaultdict\\ngrouped_issues = defaultdict(lambda: defaultdict(list))\\n\\nfor issue in issues_data:\\n date = issue[\\\"created_at\\\"].split(' ')[0] # Extract just the date part\\n grouped_issues[date][issue[\\\"status\\\"]].append(issue)\\n\\n# Format the grouped data into a Markdown table\\nmarkdown_table = \\\"Date | Open Issues | Closed Issues\\\\n\\\"\\nmarkdown_table += \\\"--- | --- | ---\\\\n\\\"\\n\\nfor date, statuses in sorted(grouped_issues.items()):\\n open_issues = len(statuses[\\\"open\\\"]) if \\\"open\\\" in statuses else 0\\n closed_issues = len(statuses[\\\"closed\\\"]) if \\\"closed\\\" in statuses else 0\\n markdown_table += f\\\"{date} | {open_issues} | {closed_issues}\\\\n\\\"\\n\\nprint(markdown_table)\\n```\\n\\nWhen you run this script, it will print out a Markdown table with the number of open and closed issues grouped by their creation dates.\"\n",
"\u001b[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>CodeEngineer Action report: \n",
"\u001B[32m>>>>>>>>CodeEngineer Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>CodeEngineer Action report: \n",
"execution succeeded,\n",
"\n",
"Date | Open Issues | Closed Issues\n",
@@ -590,17 +593,17 @@
"2023-12-26 | 4 | 2\n",
"2023-12-27 | 1 | 0\n",
"\n",
"\u001b[0m\n",
"\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mplan_manager\u001b[0m (to User)-[None]:\n",
"\u001B[33mplan_manager\u001B[0m (to User)-[None]:\n",
"\n",
"\"Obtain simple information about issues in the repository 'eosphoros-ai/DB-GPT' in the past three days and analyze the data. Create a Markdown table grouped by day and status.\"\n",
"\u001b[32m>>>>>>>>plan_manager Review info: \n",
" Pass.None\u001b[0m\n",
"\u001b[34m>>>>>>>>plan_manager Action report: \n",
"\u001B[32m>>>>>>>>plan_manager Review info: \n",
" Pass.None\u001B[0m\n",
"\u001B[34m>>>>>>>>plan_manager Action report: \n",
"execution succeeded,\n",
"TERMINATE\u001b[0m\n",
"TERMINATE\u001B[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"```agent-messages\n",

View File

@@ -369,6 +369,7 @@ def core_requires():
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
"pydantic<2,>=1",
"httpx",
"jinja2",
"fastapi==0.98.0",
"shortuuid",
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2