mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat(core): Support multi round conversation operator (#986)
This commit is contained in:
@@ -8,38 +8,36 @@ The stability of this API cannot be guaranteed at present.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
|
||||
from .dag.base import DAG, DAGContext
|
||||
from .operator.base import BaseOperator, WorkflowRunner
|
||||
from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
ReduceStreamOperator,
|
||||
MapOperator,
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
InputOperator,
|
||||
BranchFunc,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
ReduceStreamOperator,
|
||||
)
|
||||
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
UnstreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
UnstreamifyAbsOperator,
|
||||
)
|
||||
|
||||
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
|
||||
from .task.task_impl import (
|
||||
SimpleInputSource,
|
||||
SimpleCallDataInputSource,
|
||||
DefaultTaskContext,
|
||||
DefaultInputContext,
|
||||
SimpleTaskOutput,
|
||||
DefaultTaskContext,
|
||||
SimpleCallDataInputSource,
|
||||
SimpleInputSource,
|
||||
SimpleStreamTaskOutput,
|
||||
SimpleTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .trigger.http_trigger import HttpTrigger
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
__all__ = [
|
||||
"initialize_awel",
|
||||
@@ -73,16 +71,16 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
from .dag.dag_manager import DAGManager
|
||||
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
from .dag.base import DAGVar
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .operator.base import initialize_runner
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
|
||||
system_app.register(DefaultTriggerManager)
|
||||
dag_manager = DAGManager(system_app, dag_filepath)
|
||||
dag_manager = DAGManager(system_app, dag_dirs)
|
||||
system_app.register_instance(dag_manager)
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
@@ -90,7 +88,11 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
|
||||
dags: List[DAG],
|
||||
host: Optional[str] = "0.0.0.0",
|
||||
port: Optional[int] = 5555,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
|
||||
@@ -98,9 +100,16 @@ def setup_dev_environment(
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
from .dag.base import DAGVar
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
if not logger_filename:
|
||||
logger_filename = "dbgpt_awel_dev.log"
|
||||
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
|
@@ -1,15 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any, Set
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from functools import cache
|
||||
from concurrent.futures import Executor
|
||||
from functools import cache
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext, TaskOutput
|
||||
|
||||
@@ -502,6 +503,9 @@ class DAG:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
||||
|
||||
def __repr__(self):
|
||||
return f"DAG(dag_id={self.dag_id})"
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from .loader import DAGLoader, LocalFileDAGLoader
|
||||
|
||||
from .base import DAG
|
||||
from .loader import LocalFileDAGLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,9 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
class DAGManager(BaseComponent):
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_filepath: str):
|
||||
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_filepath)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import os
|
||||
import hashlib
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .base import DAG
|
||||
|
||||
@@ -18,17 +18,19 @@ class DAGLoader(ABC):
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
def __init__(self, filepath: str) -> None:
|
||||
super().__init__()
|
||||
self._filepath = filepath
|
||||
def __init__(self, dag_dirs: List[str]) -> None:
|
||||
self._dag_dirs = dag_dirs
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
if not os.path.exists(self._filepath):
|
||||
return []
|
||||
if os.path.isdir(self._filepath):
|
||||
return _process_directory(self._filepath)
|
||||
else:
|
||||
return _process_file(self._filepath)
|
||||
dags = []
|
||||
for filepath in self._dag_dirs:
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
if os.path.isdir(filepath):
|
||||
dags += _process_directory(filepath)
|
||||
else:
|
||||
dags += _process_file(filepath)
|
||||
return dags
|
||||
|
||||
|
||||
def _process_directory(directory: str) -> List[DAG]:
|
||||
|
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from ..base import DAG, DAGVar
|
||||
|
||||
|
||||
|
@@ -1,32 +1,32 @@
|
||||
from abc import ABC, abstractmethod, ABCMeta
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from inspect import signature
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
Union,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
import functools
|
||||
from inspect import signature
|
||||
import asyncio
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.util.executor_utils import (
|
||||
ExecutorFactory,
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
BlockingFunction,
|
||||
AsyncToSyncIterator,
|
||||
BlockingFunction,
|
||||
DefaultExecutorFactory,
|
||||
ExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
)
|
||||
|
||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||
from ..task.base import TaskOutput, OUT, T
|
||||
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
|
||||
from ..task.base import OUT, T, TaskOutput
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
|
@@ -1,27 +1,19 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
Generic,
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
|
||||
from .base import BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, AsyncIterator
|
||||
from ..task.base import OUT, IN, TaskOutput, TaskContext
|
||||
from typing import AsyncIterator, Generic
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import IN, OUT, TaskContext, TaskOutput
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
from ..operator.base import CALL_DATA, BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -1,9 +1,10 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
|
@@ -1,15 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Optional,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
IN = TypeVar("IN")
|
||||
|
@@ -1,22 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Tuple,
|
||||
Dict,
|
||||
Union,
|
||||
Optional,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing import AsyncIterator, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
|
||||
from .. import (
|
||||
WorkflowRunner,
|
||||
InputOperator,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
DefaultWorkflowRunner,
|
||||
InputOperator,
|
||||
SimpleInputSource,
|
||||
TaskState,
|
||||
WorkflowRunner,
|
||||
)
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
@@ -1,24 +1,26 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
DAGContext,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
TaskState,
|
||||
WorkflowRunner,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
_is_async_iterator,
|
||||
input_node,
|
||||
input_nodes,
|
||||
runner,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,24 +1,26 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
DAGContext,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
TaskState,
|
||||
WorkflowRunner,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
_is_async_iterator,
|
||||
input_node,
|
||||
input_nodes,
|
||||
runner,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from .base import Trigger
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from ..dag.base import DAG
|
||||
from ..operator.base import BaseOperator
|
||||
from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TYPE_CHECKING, Optional
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.component import SystemApp, BaseComponent, ComponentType
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -114,6 +114,9 @@ class ModelRequestContext:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
chat_mode: Optional[str] = None
|
||||
"""The chat mode of the model inference."""
|
||||
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
@@ -195,7 +198,13 @@ class ModelRequest:
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v}
|
||||
|
||||
def _get_messages(self) -> List[ModelMessage]:
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
|
||||
Returns:
|
||||
List[ModelMessage]: The messages.
|
||||
"""
|
||||
return list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
@@ -209,7 +218,7 @@ class ModelRequest:
|
||||
Returns:
|
||||
Optional[ModelMessage]: The single user message.
|
||||
"""
|
||||
messages = self._get_messages()
|
||||
messages = self.get_messages()
|
||||
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@@ -176,6 +176,22 @@ class ModelMessage(BaseModel):
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
@staticmethod
|
||||
def get_printable_message(messages: List["ModelMessage"]) -> str:
|
||||
"""Get the printable message"""
|
||||
str_msg = ""
|
||||
for message in messages:
|
||||
curr_message = (
|
||||
f"(Round {message.round_index}) {message.role}: {message.content} "
|
||||
)
|
||||
str_msg += curr_message.rstrip() + "\n"
|
||||
|
||||
return str_msg
|
||||
|
||||
|
||||
_SingleRoundMessage = List[ModelMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||
return message.to_dict()
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, List, Optional
|
||||
from dbgpt.core import (
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
@@ -12,6 +13,7 @@ from dbgpt.core import (
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
||||
from dbgpt.core.interface.message import _MultiRoundMessageMapper
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
@@ -24,7 +26,7 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
@@ -88,7 +90,7 @@ class PreConversationOperator(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
@@ -109,7 +111,7 @@ class PreConversationOperator(
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
|
||||
chat_mode = input_value.context.extra.get("chat_mode")
|
||||
chat_mode = input_value.context.chat_mode
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
@@ -121,11 +123,8 @@ class PreConversationOperator(
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
# The input message must be a single user message
|
||||
single_human_message: ModelMessage = input_value.get_single_user_message()
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(single_human_message.content)
|
||||
|
||||
input_messages = input_value.get_messages()
|
||||
await self.save_to_storage(storage_conv, input_messages)
|
||||
# Get all messages from current storage conversation, and overwrite the input value
|
||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
||||
input_value.messages = messages
|
||||
@@ -139,6 +138,42 @@ class PreConversationOperator(
|
||||
)
|
||||
return input_value
|
||||
|
||||
async def save_to_storage(
|
||||
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
|
||||
) -> None:
|
||||
"""Save the messages to storage.
|
||||
|
||||
Args:
|
||||
storage_conv (StorageConversation): The storage conversation.
|
||||
input_messages (List[ModelMessage]): The input messages.
|
||||
"""
|
||||
# check first
|
||||
self.check_messages(input_messages)
|
||||
storage_conv.start_new_round()
|
||||
for message in input_messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
storage_conv.add_user_message(message.content)
|
||||
else:
|
||||
storage_conv.add_system_message(message.content)
|
||||
|
||||
def check_messages(self, messages: List[ModelMessage]) -> None:
|
||||
"""Check the messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages is empty.
|
||||
"""
|
||||
if not messages:
|
||||
raise ValueError("Input messages is empty")
|
||||
for message in messages:
|
||||
if message.role not in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
raise ValueError(f"Message role {message.role} is not supported")
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
@@ -198,8 +233,9 @@ class PostStreamingConversationOperator(
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._message_mapper = message_mapper
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
@@ -211,12 +247,12 @@ class ConversationMapperOperator(
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
input_value = input_value.copy()
|
||||
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
|
||||
messages: List[ModelMessage] = self.map_messages(input_value.messages)
|
||||
# Overwrite the input value
|
||||
input_value.messages = messages
|
||||
return input_value
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
@@ -225,7 +261,73 @@ class ConversationMapperOperator(
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
return messages
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
"""Map multi round messages to a list of ModelMessage
|
||||
|
||||
By default, just merge all multi round messages to a list of ModelMessage according origin order.
|
||||
And you can overwrite this method to implement your own logic.
|
||||
|
||||
Examples:
|
||||
|
||||
Merge multi round messages to a list of ModelMessage according origin order.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.operator import ConversationMapperOperator
|
||||
messages_by_round = [
|
||||
[
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
],
|
||||
]
|
||||
operator = ConversationMapperOperator()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
|
||||
Map multi round messages to a list of ModelMessage just keep the last one round.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyMapper(ConversationMapperOperator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
|
||||
return messages_by_round[-1]
|
||||
operator = MyMapper()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
|
||||
Args:
|
||||
"""
|
||||
# Just merge and return
|
||||
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[ModelMessage]
|
||||
@@ -236,7 +338,7 @@ class ConversationMapperOperator(
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The splitted messages.
|
||||
List[List[ModelMessage]]: The split messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = []
|
||||
last_round_index = 0
|
||||
@@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
@@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
# Just keep the last one round, so the first round messages will be removed
|
||||
# Note: The round index 3 is not a complete round
|
||||
assert messages == [
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
@@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
last_k_round: Optional[int] = 2,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._last_k_round = last_k_round
|
||||
if message_mapper:
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
# Apply keep k round messages first, then apply the custom message mapper
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
else:
|
||||
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
super().__init__(new_message_mapper, **kwargs)
|
||||
|
||||
def _keep_last_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[List[ModelMessage]]:
|
||||
"""Keep the last k round messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
messages_by_round (List[List[ModelMessage]]): The messages by round.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
List[List[ModelMessage]]: The latest round messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
# Get the last k round messages
|
||||
index = self._last_k_round + 1
|
||||
messages_by_round = messages_by_round[-index:]
|
||||
messages: List[ModelMessage] = sum(messages_by_round, [])
|
||||
return messages
|
||||
return messages_by_round[-index:]
|
||||
|
@@ -169,9 +169,7 @@ class StoragePromptTemplate(StorageItem):
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
None
|
||||
if not self.input_variables
|
||||
else self.input_variables.strip().split(",")
|
||||
[] if not self.input_variables else self.input_variables.strip().split(",")
|
||||
)
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
@@ -458,6 +456,33 @@ class PromptManager:
|
||||
)
|
||||
self.storage.save(storage_prompt_template)
|
||||
|
||||
def query_or_save(
|
||||
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
|
||||
) -> StoragePromptTemplate:
|
||||
"""Query a prompt template from storage, if not found, save it.
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to save.
|
||||
prompt_name (str): The name of the prompt template.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
|
||||
|
||||
Returns:
|
||||
StoragePromptTemplate: The storage prompt template.
|
||||
"""
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template, prompt_name, **kwargs
|
||||
)
|
||||
exist_prompt_template = self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
if exist_prompt_template:
|
||||
return exist_prompt_template
|
||||
self.save(prompt_template, prompt_name, **kwargs)
|
||||
return self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
|
Reference in New Issue
Block a user