feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,16 +1,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union, Optional
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageInterface,
InMemoryStorage,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
@@ -112,6 +112,7 @@ class ModelMessage(BaseModel):
"""Similar to openai's message format"""
role: str
content: str
round_index: Optional[int] = 0
@staticmethod
def from_openai_messages(
@@ -443,6 +444,7 @@ class OnceConversation:
self.tokens = conversation.tokens
self.user_name = conversation.user_name
self.sys_code = conversation.sys_code
self._message_index = conversation._message_index
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
@@ -470,6 +472,7 @@ class OnceConversation:
Example:
.. code-block:: python
conversation = OnceConversation()
conversation.start_new_round()
conversation.add_user_message("hello, this is the first round")
@@ -485,11 +488,17 @@ class OnceConversation:
conversation.end_current_round()
assert len(conversation.get_messages_with_round(1)) == 2
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
assert (
conversation.get_messages_with_round(1)[0].content
== "hello, this is the third round"
)
assert conversation.get_messages_with_round(1)[1].content == "hi"
assert len(conversation.get_messages_with_round(2)) == 4
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
assert (
conversation.get_messages_with_round(2)[0].content
== "hello, this is the second round"
)
assert conversation.get_messages_with_round(2)[1].content == "hi"
Args:
@@ -517,6 +526,7 @@ class OnceConversation:
Examples:
If you not need the history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
for message in self.get_latest_round():
@@ -528,6 +538,7 @@ class OnceConversation:
If you want to add the one round history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
latest_round_index = self.chat_order
@@ -537,7 +548,9 @@ class OnceConversation:
for message in self.get_messages_by_round(round_index):
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
ModelMessage(
role=message.type, content=message.content
)
)
return messages
@@ -548,7 +561,11 @@ class OnceConversation:
for message in self.messages:
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
ModelMessage(
role=message.type,
content=message.content,
round_index=message.round_index,
)
)
return messages
@@ -780,6 +797,9 @@ class StorageConversation(OnceConversation, StorageItem):
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
# This index is used to save the message to the storage(Has not been saved)
# The new message append to the messages, so the index is len(messages)
conversation._message_index = len(messages)
self._message_ids = message_ids
self._has_stored_message_index = len(messages) - 1
self.from_conversation(conversation)