fix: Fixed multi-turn dialogue bug (#1259)

This commit is contained in:
Fangyin Cheng 2024-03-06 22:17:47 +08:00 committed by GitHub
parent 74ec8e52cd
commit 872b5745d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 199 additions and 35 deletions

View File

@ -96,6 +96,13 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924
# proxy_http_openapi_proxy_backend=text2vec
#*******************************************************************#
#** DB-GPT METADATA DATABASE SETTINGS **#

View File

@ -172,6 +172,8 @@ EMBEDDING_MODEL_CONFIG = {
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
}

View File

@ -370,9 +370,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
):
"""Create a new BufferedConversationMapperOperator."""
# Validate the input parameters
if keep_start_rounds is not None and keep_start_rounds < 0:
if keep_start_rounds is None:
keep_start_rounds = 0
if keep_end_rounds is None:
keep_end_rounds = 0
if keep_start_rounds < 0:
raise ValueError("keep_start_rounds must be non-negative")
if keep_end_rounds is not None and keep_end_rounds < 0:
if keep_end_rounds < 0:
raise ValueError("keep_end_rounds must be non-negative")
self._keep_start_rounds = keep_start_rounds
@ -420,7 +424,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping only the first 2 rounds
>>> # Test keeping only the first 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
@ -433,7 +437,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping only the last 2 rounds
>>> # Test keeping only the last 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
@ -446,7 +450,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test keeping the first 2 and last 1 rounds
>>> # Test keeping the first 2 and last 1 rounds
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=2, keep_end_rounds=1
... )
@ -465,24 +469,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
... ],
... ]
# Test without specifying start or end rounds (keep all rounds)
>>> # Test without specifying start or end rounds (keep 0 rounds)
>>> operator = BufferedConversationMapperOperator()
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
>>> assert operator._filter_round_messages(messages) == []
# Test end rounds is zero
>>> # Test end rounds is zero
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=1, keep_end_rounds=0
... )
@ -503,12 +494,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
"""
total_rounds = len(messages_by_round)
if (
self._keep_start_rounds is not None
and self._keep_end_rounds is not None
and self._keep_start_rounds > 0
and self._keep_end_rounds > 0
):
if self._keep_start_rounds > 0 and self._keep_end_rounds > 0:
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total
# rounds
@ -517,12 +503,12 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
messages_by_round[: self._keep_start_rounds]
+ messages_by_round[-self._keep_end_rounds :]
)
elif self._keep_start_rounds is not None:
elif self._keep_start_rounds:
return messages_by_round[: self._keep_start_rounds]
elif self._keep_end_rounds is not None:
elif self._keep_end_rounds:
return messages_by_round[-self._keep_end_rounds :]
else:
return messages_by_round
return []
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]

View File

@ -0,0 +1,155 @@
from typing import List
import pytest
from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage
from dbgpt.core.operators import BufferedConversationMapperOperator
@pytest.fixture
def messages() -> List[BaseMessage]:
return [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]
# Test keep start 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == []
# Test keep start 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=None,
)
assert await operator.map_messages(messages) == messages
# Test keep start -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=None,
)
await operator.map_messages(messages)
@pytest.mark.asyncio
async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]):
# Test keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []
# Test keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages
# Test keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=None,
keep_end_rounds=-1,
)
await operator.map_messages(messages)
@pytest.mark.asyncio
async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]):
# Test keep_start_rounds and keep_end_rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=1,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep start 0 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == []
# Test keep start 0 rounds and keep end 1 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=0,
keep_end_rounds=1,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# Test keep start 2 rounds and keep end 0 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=0,
)
assert await operator.map_messages(messages) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
]
# Test keep start 100 rounds and keep end 100 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=100,
keep_end_rounds=100,
)
assert await operator.map_messages(messages) == messages
# Test keep start 2 round and keep end 2 rounds
operator = BufferedConversationMapperOperator(
keep_start_rounds=2,
keep_end_rounds=2,
)
assert await operator.map_messages(messages) == messages
# Test keep start -1 rounds and keep end -1 rounds
with pytest.raises(ValueError):
operator = BufferedConversationMapperOperator(
keep_start_rounds=-1,
keep_end_rounds=-1,
)
await operator.map_messages(messages)

View File

@ -1,14 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, cast
from dbgpt.model.parameter import BaseEmbeddingModelParameters
from dbgpt.model.parameter import BaseEmbeddingModelParameters, ProxyEmbeddingParameters
from dbgpt.util.parameter_utils import _get_dict_from_obj
from dbgpt.util.system_utils import get_system_info
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
from dbgpt.rag.embedding import Embeddings
class EmbeddingLoader:
@ -17,7 +19,7 @@ class EmbeddingLoader:
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
) -> "Union[LangChainEmbeddings, Embeddings]":
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
@ -32,6 +34,18 @@ class EmbeddingLoader:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
elif model_name in ["proxy_http_openapi"]:
from dbgpt.rag.embedding import OpenAPIEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
else:
from langchain.embeddings import HuggingFaceEmbeddings

View File

@ -552,7 +552,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}