fix: Fixed multi-turn dialogue bug (#1259)

This commit is contained in:
Fangyin Cheng
2024-03-06 22:17:47 +08:00
committed by GitHub
parent 74ec8e52cd
commit 872b5745d3
7 changed files with 199 additions and 35 deletions

View File

@@ -370,9 +370,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
):
"""Create a new BufferedConversationMapperOperator."""
# Validate the input parameters
if keep_start_rounds is not None and keep_start_rounds < 0:
if keep_start_rounds is None:
keep_start_rounds = 0
if keep_end_rounds is None:
keep_end_rounds = 0
if keep_start_rounds < 0:
raise ValueError("keep_start_rounds must be non-negative")
if keep_end_rounds is not None and keep_end_rounds < 0:
if keep_end_rounds < 0:
raise ValueError("keep_end_rounds must be non-negative")
self._keep_start_rounds = keep_start_rounds
@@ -420,7 +424,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping only the first 2 rounds
>>> # Test keeping only the first 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
@@ -433,7 +437,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping only the last 2 rounds
>>> # Test keeping only the last 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
@@ -446,7 +450,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping the first 2 and last 1 rounds
>>> # Test keeping the first 2 and last 1 rounds
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=2, keep_end_rounds=1
... )
@@ -465,24 +469,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test without specifying start or end rounds (keep all rounds)
>>> # Test without specifying start or end rounds (keep 0 rounds)
>>> operator = BufferedConversationMapperOperator()
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
>>> assert operator._filter_round_messages(messages) == []
# Test end rounds is zero
>>> # Test end rounds is zero
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=1, keep_end_rounds=0
... )
@@ -503,12 +494,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
"""
total_rounds = len(messages_by_round)
if (
self._keep_start_rounds is not None
and self._keep_end_rounds is not None
and self._keep_start_rounds > 0
and self._keep_end_rounds > 0
):
if self._keep_start_rounds > 0 and self._keep_end_rounds > 0:
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total
# rounds
@@ -517,12 +503,12 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
messages_by_round[: self._keep_start_rounds]
+ messages_by_round[-self._keep_end_rounds :]
)
elif self._keep_start_rounds is not None:
elif self._keep_start_rounds:
return messages_by_round[: self._keep_start_rounds]
elif self._keep_end_rounds is not None:
elif self._keep_end_rounds:
return messages_by_round[-self._keep_end_rounds :]
else:
return messages_by_round
return []
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]

View File

@@ -0,0 +1,155 @@
from typing import List
import pytest
from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage
from dbgpt.core.operators import BufferedConversationMapperOperator
@pytest.fixture
def messages() -> List[BaseMessage]:
return [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]
# Test keep start 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == []
# Test keep start 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == messages
# Test keep start -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=None,
)
await operator.map_messages(messages)
@pytest.mark.asyncio
async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]):
# Test keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []
# Test keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages
# Test keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=-1,
)
await operator.map_messages(messages)
@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds and keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=1,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep start 0 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []
# Test keep start 0 rounds and keep end 1 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep start 2 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]
# Test keep start 100 rounds and keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages
# Test keep start 2 round and keep end 2 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == messages
# Test keep start -1 rounds and keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=-1,
)
await operator.map_messages(messages)