mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 18:40:10 +00:00
fix: Fixed multi-turn dialogue bug (#1259)
This commit is contained in:
@@ -370,9 +370,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
):
|
||||
"""Create a new BufferedConversationMapperOperator."""
|
||||
# Validate the input parameters
|
||||
if keep_start_rounds is not None and keep_start_rounds < 0:
|
||||
if keep_start_rounds is None:
|
||||
keep_start_rounds = 0
|
||||
if keep_end_rounds is None:
|
||||
keep_end_rounds = 0
|
||||
if keep_start_rounds < 0:
|
||||
raise ValueError("keep_start_rounds must be non-negative")
|
||||
if keep_end_rounds is not None and keep_end_rounds < 0:
|
||||
if keep_end_rounds < 0:
|
||||
raise ValueError("keep_end_rounds must be non-negative")
|
||||
|
||||
self._keep_start_rounds = keep_start_rounds
|
||||
@@ -420,7 +424,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the first 2 rounds
|
||||
>>> # Test keeping only the first 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
@@ -433,7 +437,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the last 2 rounds
|
||||
>>> # Test keeping only the last 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
@@ -446,7 +450,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping the first 2 and last 1 rounds
|
||||
>>> # Test keeping the first 2 and last 1 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(
|
||||
... keep_start_rounds=2, keep_end_rounds=1
|
||||
... )
|
||||
@@ -465,24 +469,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test without specifying start or end rounds (keep all rounds)
|
||||
>>> # Test without specifying start or end rounds (keep 0 rounds)
|
||||
>>> operator = BufferedConversationMapperOperator()
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
>>> assert operator._filter_round_messages(messages) == []
|
||||
|
||||
# Test end rounds is zero
|
||||
>>> # Test end rounds is zero
|
||||
>>> operator = BufferedConversationMapperOperator(
|
||||
... keep_start_rounds=1, keep_end_rounds=0
|
||||
... )
|
||||
@@ -503,12 +494,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
|
||||
"""
|
||||
total_rounds = len(messages_by_round)
|
||||
if (
|
||||
self._keep_start_rounds is not None
|
||||
and self._keep_end_rounds is not None
|
||||
and self._keep_start_rounds > 0
|
||||
and self._keep_end_rounds > 0
|
||||
):
|
||||
if self._keep_start_rounds > 0 and self._keep_end_rounds > 0:
|
||||
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
|
||||
# Avoid overlapping when the sum of start and end rounds exceeds total
|
||||
# rounds
|
||||
@@ -517,12 +503,12 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
messages_by_round[: self._keep_start_rounds]
|
||||
+ messages_by_round[-self._keep_end_rounds :]
|
||||
)
|
||||
elif self._keep_start_rounds is not None:
|
||||
elif self._keep_start_rounds:
|
||||
return messages_by_round[: self._keep_start_rounds]
|
||||
elif self._keep_end_rounds is not None:
|
||||
elif self._keep_end_rounds:
|
||||
return messages_by_round[-self._keep_end_rounds :]
|
||||
else:
|
||||
return messages_by_round
|
||||
return []
|
||||
|
||||
|
||||
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
|
||||
|
0
dbgpt/core/interface/operators/tests/__init__.py
Normal file
0
dbgpt/core/interface/operators/tests/__init__.py
Normal file
155
dbgpt/core/interface/operators/tests/test_message_operator.py
Normal file
155
dbgpt/core/interface/operators/tests/test_message_operator.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage
|
||||
from dbgpt.core.operators import BufferedConversationMapperOperator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> List[BaseMessage]:
|
||||
return [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_start_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
]
|
||||
# Test keep start 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep start 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=100,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=-1,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
await operator.map_messages(messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_end_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=2,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# Test keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep end 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=100,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep end -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=-1,
|
||||
)
|
||||
await operator.map_messages(messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_start_rounds and keep_end_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=1,
|
||||
keep_end_rounds=1,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# Test keep start 0 rounds and keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep start 0 rounds and keep end 1 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=1,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
|
||||
# Test keep start 2 rounds and keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
]
|
||||
|
||||
# Test keep start 100 rounds and keep end 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=100,
|
||||
keep_end_rounds=100,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start 2 round and keep end 2 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=2,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start -1 rounds and keep end -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=-1,
|
||||
keep_end_rounds=-1,
|
||||
)
|
||||
await operator.map_messages(messages)
|
Reference in New Issue
Block a user