feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -8,38 +8,36 @@ The stability of this API cannot be guaranteed at present.
"""
from typing import List, Optional
from dbgpt.component import SystemApp
from .dag.base import DAGContext, DAG
from .dag.base import DAG, DAGContext
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchFunc,
BranchOperator,
InputOperator,
BranchFunc,
JoinOperator,
MapOperator,
ReduceStreamOperator,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
UnstreamifyAbsOperator,
)
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
DefaultTaskContext,
SimpleCallDataInputSource,
SimpleInputSource,
SimpleStreamTaskOutput,
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
"initialize_awel",
@@ -73,16 +71,16 @@ __all__ = [
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.dag_manager import DAGManager
from .operator.base import initialize_runner
from .trigger.trigger_manager import DefaultTriggerManager
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
dag_manager = DAGManager(system_app, dag_dirs)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
@@ -90,7 +88,11 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
dags: List[DAG],
host: Optional[str] = "0.0.0.0",
port: Optional[int] = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
) -> None:
"""Setup a development environment for AWEL.
@@ -98,9 +100,16 @@ def setup_dev_environment(
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from dbgpt.util.utils import setup_logging
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
if not logger_filename:
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)
app = FastAPI()
system_app = SystemApp(app)

View File

@@ -1,15 +1,16 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import contextvars
import logging
import threading
import uuid
from abc import ABC, abstractmethod
from collections import deque
from functools import cache
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext, TaskOutput
@@ -502,6 +503,9 @@ class DAG:
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def __repr__(self):
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()

View File

@@ -1,8 +1,10 @@
from typing import Dict, Optional
import logging
from typing import Dict, List
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
from .loader import LocalFileDAGLoader
logger = logging.getLogger(__name__)
@@ -10,9 +12,9 @@ logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}

View File

@@ -1,10 +1,10 @@
import hashlib
import logging
import os
import sys
import traceback
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
@@ -18,17 +18,19 @@ class DAGLoader(ABC):
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def __init__(self, dag_dirs: List[str]) -> None:
self._dag_dirs = dag_dirs
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
dags = []
for filepath in self._dag_dirs:
if not os.path.exists(filepath):
continue
if os.path.isdir(filepath):
dags += _process_directory(filepath)
else:
dags += _process_file(filepath)
return dags
def _process_directory(directory: str) -> List[DAG]:

View File

@@ -1,6 +1,8 @@
import pytest
import threading
import asyncio
import threading
import pytest
from ..base import DAG, DAGVar

View File

@@ -1,32 +1,32 @@
from abc import ABC, abstractmethod, ABCMeta
import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
from inspect import signature
from types import FunctionType
from typing import (
List,
Generic,
TypeVar,
AsyncIterator,
Iterator,
Union,
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
cast,
)
import functools
from inspect import signature
import asyncio
from dbgpt.component import SystemApp, ComponentType
from dbgpt.component import ComponentType, SystemApp
from dbgpt.util.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
AsyncToSyncIterator,
BlockingFunction,
DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import TaskOutput, OUT, T
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import OUT, T, TaskOutput
F = TypeVar("F", bound=FunctionType)

View File

@@ -1,27 +1,19 @@
import asyncio
import logging
from typing import (
Generic,
Dict,
List,
Union,
Callable,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Union,
)
import asyncio
import logging
from ..dag.base import DAGContext
from ..task.base import (
TaskContext,
TaskOutput,
IN,
OUT,
InputContext,
InputSource,
)
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
from .base import BaseOperator
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Generic, AsyncIterator
from ..task.base import OUT, IN, TaskOutput, TaskContext
from typing import AsyncIterator, Generic
from ..dag.base import DAGContext
from ..task.base import IN, OUT, TaskContext, TaskOutput
from .base import BaseOperator

View File

@@ -1,10 +1,10 @@
import asyncio
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG, DAGLifecycle
import uuid
from typing import Dict, List, Optional, Set
from ..operator.base import BaseOperator, CALL_DATA
from ..dag.base import DAG, DAGLifecycle
from ..operator.base import CALL_DATA, BaseOperator
logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,10 @@
from typing import Dict, Optional, Set, List
import logging
from typing import Dict, List, Optional, Set
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput

View File

@@ -1,15 +1,15 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
AsyncIterator,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)
IN = TypeVar("IN")

View File

@@ -1,22 +1,22 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
Optional,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
logger = logging.getLogger(__name__)

View File

@@ -1,14 +1,16 @@
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, List
import pytest
import pytest_asyncio
from typing import AsyncIterator, List
from contextlib import contextmanager, asynccontextmanager
from .. import (
WorkflowRunner,
InputOperator,
DAGContext,
TaskState,
DefaultWorkflowRunner,
InputOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from ..task.task_impl import _is_async_iterator

View File

@@ -1,24 +1,26 @@
import pytest
from typing import List
import pytest
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
DAGContext,
InputOperator,
JoinOperator,
MapOperator,
ReduceStreamOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from .conftest import (
runner,
_is_async_iterator,
input_node,
input_nodes,
runner,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)

View File

@@ -1,24 +1,26 @@
import pytest
from typing import List
import pytest
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
DAGContext,
InputOperator,
JoinOperator,
MapOperator,
ReduceStreamOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
)
from .conftest import (
runner,
_is_async_iterator,
input_node,
input_nodes,
runner,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
import logging
from .base import Trigger
from dbgpt._private.pydantic import BaseModel
from ..dag.base import DAG
from ..operator.base import BaseOperator
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI

View File

@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from fastapi import APIRouter
from dbgpt.component import SystemApp, BaseComponent, ComponentType
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__)

View File

@@ -114,6 +114,9 @@ class ModelRequestContext:
span_id: Optional[str] = None
"""The span id of the model inference."""
chat_mode: Optional[str] = None
"""The chat mode of the model inference."""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
@@ -195,7 +198,13 @@ class ModelRequest:
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v}
def _get_messages(self) -> List[ModelMessage]:
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
return list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
@@ -209,7 +218,7 @@ class ModelRequest:
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self._get_messages()
messages = self.get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
@@ -176,6 +176,22 @@ class ModelMessage(BaseModel):
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
@staticmethod
def get_printable_message(messages: List["ModelMessage"]) -> str:
"""Get the printable message"""
str_msg = ""
for message in messages:
curr_message = (
f"(Round {message.round_index}) {message.role}: {message.content} "
)
str_msg += curr_message.rstrip() + "\n"
return str_msg
_SingleRoundMessage = List[ModelMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
def _message_to_dict(message: BaseMessage) -> Dict:
return message.to_dict()

View File

@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, List, Optional
from dbgpt.core import (
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
@@ -12,6 +13,7 @@ from dbgpt.core import (
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
from dbgpt.core.interface.message import _MultiRoundMessageMapper
class BaseConversationOperator(BaseOperator, ABC):
@@ -24,7 +26,7 @@ class BaseConversationOperator(BaseOperator, ABC):
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self._storage = storage
@@ -88,7 +90,7 @@ class PreConversationOperator(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
@@ -109,7 +111,7 @@ class PreConversationOperator(
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.extra.get("chat_mode")
chat_mode = input_value.context.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
@@ -121,11 +123,8 @@ class PreConversationOperator(
conv_storage=self.storage,
message_storage=self.message_storage,
)
# The input message must be a single user message
single_human_message: ModelMessage = input_value.get_single_user_message()
storage_conv.start_new_round()
storage_conv.add_user_message(single_human_message.content)
input_messages = input_value.get_messages()
await self.save_to_storage(storage_conv, input_messages)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
@@ -139,6 +138,42 @@ class PreConversationOperator(
)
return input_value
async def save_to_storage(
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
) -> None:
"""Save the messages to storage.
Args:
storage_conv (StorageConversation): The storage conversation.
input_messages (List[ModelMessage]): The input messages.
"""
# check first
self.check_messages(input_messages)
storage_conv.start_new_round()
for message in input_messages:
if message.role == ModelMessageRoleType.HUMAN:
storage_conv.add_user_message(message.content)
else:
storage_conv.add_system_message(message.content)
def check_messages(self, messages: List[ModelMessage]) -> None:
"""Check the messages.
Args:
messages (List[ModelMessage]): The messages.
Raises:
ValueError: If the messages is empty.
"""
if not messages:
raise ValueError("Input messages is empty")
for message in messages:
if message.role not in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.SYSTEM,
]:
raise ValueError(f"Message role {message.role} is not supported")
async def after_dag_end(self):
"""The callback after DAG end"""
# Save the storage conversation to storage after the whole DAG finished
@@ -198,8 +233,9 @@ class PostStreamingConversationOperator(
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
def __init__(self, **kwargs):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
@@ -211,12 +247,12 @@ class ConversationMapperOperator(
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
messages: List[ModelMessage] = self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
@@ -225,7 +261,73 @@ class ConversationMapperOperator(
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
return messages
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
"""Map multi round messages to a list of ModelMessage
By default, just merge all multi round messages to a list of ModelMessage according origin order.
And you can overwrite this method to implement your own logic.
Examples:
Merge multi round messages to a list of ModelMessage according origin order.
.. code-block:: python
import asyncio
from dbgpt.core.operator import ConversationMapperOperator
messages_by_round = [
[
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
],
[
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
],
[
ModelMessage(role="human", content="Funny!", round_index=3),
],
]
operator = ConversationMapperOperator()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
Map multi round messages to a list of ModelMessage just keep the last one round.
.. code-block:: python
class MyMapper(ConversationMapperOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
return messages_by_round[-1]
operator = MyMapper()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Funny!", round_index=3),
]
Args:
"""
# Just merge and return
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
return sum(messages_by_round, [])
def _split_messages_by_round(
self, messages: List[ModelMessage]
@@ -236,7 +338,7 @@ class ConversationMapperOperator(
messages (List[ModelMessage]): The input messages.
Returns:
List[List[ModelMessage]]: The splitted messages.
List[List[ModelMessage]]: The split messages.
"""
messages_by_round: List[List[ModelMessage]] = []
last_round_index = 0
@@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
.. code-block:: python
import asyncio
from dbgpt.core import ModelMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
# No history
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
Transform with history messages
@@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
# Just keep the last one round, so the first round messages will be removed
# Note: The round index 3 is not a complete round
assert messages == [
assert operator.map_messages(messages) == [
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
@@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
]
"""
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
last_k_round: Optional[int] = 2,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
self._last_k_round = last_k_round
if message_mapper:
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
# Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round)
return message_mapper(messages_by_round)
else:
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, [])
super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[List[ModelMessage]]:
"""Keep the last k round messages.
Args:
messages (List[ModelMessage]): The input messages.
messages_by_round (List[List[ModelMessage]]): The messages by round.
Returns:
List[ModelMessage]: The mapped ModelMessage.
List[List[ModelMessage]]: The latest round messages.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
# Get the last k round messages
index = self._last_k_round + 1
messages_by_round = messages_by_round[-index:]
messages: List[ModelMessage] = sum(messages_by_round, [])
return messages
return messages_by_round[-index:]

View File

@@ -169,9 +169,7 @@ class StoragePromptTemplate(StorageItem):
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
[] if not self.input_variables else self.input_variables.strip().split(",")
)
return PromptTemplate(
input_variables=input_variables,
@@ -458,6 +456,33 @@ class PromptManager:
)
self.storage.save(storage_prompt_template)
def query_or_save(
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
) -> StoragePromptTemplate:
"""Query a prompt template from storage, if not found, save it.
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
Returns:
StoragePromptTemplate: The storage prompt template.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
exist_prompt_template = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if exist_prompt_template:
return exist_prompt_template
self.save(prompt_template, prompt_name, **kwargs)
return self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.