mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-27 21:00:36 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageItem,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class ModelMessage(BaseModel):
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
round_index: Optional[int] = 0
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
@@ -443,6 +444,7 @@ class OnceConversation:
|
||||
self.tokens = conversation.tokens
|
||||
self.user_name = conversation.user_name
|
||||
self.sys_code = conversation.sys_code
|
||||
self._message_index = conversation._message_index
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
@@ -470,6 +472,7 @@ class OnceConversation:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
conversation = OnceConversation()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the first round")
|
||||
@@ -485,11 +488,17 @@ class OnceConversation:
|
||||
conversation.end_current_round()
|
||||
|
||||
assert len(conversation.get_messages_with_round(1)) == 2
|
||||
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(1)[0].content
|
||||
== "hello, this is the third round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(1)[1].content == "hi"
|
||||
|
||||
assert len(conversation.get_messages_with_round(2)) == 4
|
||||
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
|
||||
assert (
|
||||
conversation.get_messages_with_round(2)[0].content
|
||||
== "hello, this is the second round"
|
||||
)
|
||||
assert conversation.get_messages_with_round(2)[1].content == "hi"
|
||||
|
||||
Args:
|
||||
@@ -517,6 +526,7 @@ class OnceConversation:
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
for message in self.get_latest_round():
|
||||
@@ -528,6 +538,7 @@ class OnceConversation:
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
latest_round_index = self.chat_order
|
||||
@@ -537,7 +548,9 @@ class OnceConversation:
|
||||
for message in self.get_messages_by_round(round_index):
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type, content=message.content
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -548,7 +561,11 @@ class OnceConversation:
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
ModelMessage(
|
||||
role=message.type,
|
||||
content=message.content,
|
||||
round_index=message.round_index,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -780,6 +797,9 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
conversation.messages = messages
|
||||
# This index is used to save the message to the storage(Has not been saved)
|
||||
# The new message append to the messages, so the index is len(messages)
|
||||
conversation._message_index = len(messages)
|
||||
self._message_ids = message_ids
|
||||
self._has_stored_message_index = len(messages) - 1
|
||||
self.from_conversation(conversation)
|
||||
|
||||
Reference in New Issue
Block a user