mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 05:19:44 +00:00
chore(build): Fix typo and new pre-commit config (#987)
This commit is contained in:
@@ -283,6 +283,7 @@ class ConversationMapperOperator(
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.operator import ConversationMapperOperator
|
||||
|
||||
messages_by_round = [
|
||||
[
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
@@ -290,7 +291,9 @@ class ConversationMapperOperator(
|
||||
],
|
||||
[
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
],
|
||||
[
|
||||
@@ -303,7 +306,9 @@ class ConversationMapperOperator(
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
@@ -315,8 +320,13 @@ class ConversationMapperOperator(
|
||||
class MyMapper(ConversationMapperOperator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
return messages_by_round[-1]
|
||||
|
||||
|
||||
operator = MyMapper()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
@@ -371,7 +381,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="human", content="Hello", round_index=1)
|
||||
]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
|
Reference in New Issue
Block a user