mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
fix: Fixed multi-turn dialogue bug (#1259)
This commit is contained in:
parent
74ec8e52cd
commit
872b5745d3
@ -96,6 +96,13 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
||||
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||
|
||||
## Common HTTP embedding model
|
||||
# EMBEDDING_MODEL=proxy_http_openapi
|
||||
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
|
||||
# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924
|
||||
# proxy_http_openapi_proxy_backend=text2vec
|
||||
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** DB-GPT METADATA DATABASE SETTINGS **#
|
||||
|
@ -172,6 +172,8 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxy_openai": "proxy_openai",
|
||||
"proxy_azure": "proxy_azure",
|
||||
# Common HTTP embedding model
|
||||
"proxy_http_openapi": "proxy_http_openapi",
|
||||
}
|
||||
|
||||
|
||||
|
@ -370,9 +370,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
):
|
||||
"""Create a new BufferedConversationMapperOperator."""
|
||||
# Validate the input parameters
|
||||
if keep_start_rounds is not None and keep_start_rounds < 0:
|
||||
if keep_start_rounds is None:
|
||||
keep_start_rounds = 0
|
||||
if keep_end_rounds is None:
|
||||
keep_end_rounds = 0
|
||||
if keep_start_rounds < 0:
|
||||
raise ValueError("keep_start_rounds must be non-negative")
|
||||
if keep_end_rounds is not None and keep_end_rounds < 0:
|
||||
if keep_end_rounds < 0:
|
||||
raise ValueError("keep_end_rounds must be non-negative")
|
||||
|
||||
self._keep_start_rounds = keep_start_rounds
|
||||
@ -420,7 +424,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the first 2 rounds
|
||||
>>> # Test keeping only the first 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
@ -433,7 +437,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the last 2 rounds
|
||||
>>> # Test keeping only the last 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
@ -446,7 +450,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping the first 2 and last 1 rounds
|
||||
>>> # Test keeping the first 2 and last 1 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(
|
||||
... keep_start_rounds=2, keep_end_rounds=1
|
||||
... )
|
||||
@ -465,24 +469,11 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test without specifying start or end rounds (keep all rounds)
|
||||
>>> # Test without specifying start or end rounds (keep 0 rounds)
|
||||
>>> operator = BufferedConversationMapperOperator()
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
>>> assert operator._filter_round_messages(messages) == []
|
||||
|
||||
# Test end rounds is zero
|
||||
>>> # Test end rounds is zero
|
||||
>>> operator = BufferedConversationMapperOperator(
|
||||
... keep_start_rounds=1, keep_end_rounds=0
|
||||
... )
|
||||
@ -503,12 +494,7 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
|
||||
"""
|
||||
total_rounds = len(messages_by_round)
|
||||
if (
|
||||
self._keep_start_rounds is not None
|
||||
and self._keep_end_rounds is not None
|
||||
and self._keep_start_rounds > 0
|
||||
and self._keep_end_rounds > 0
|
||||
):
|
||||
if self._keep_start_rounds > 0 and self._keep_end_rounds > 0:
|
||||
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
|
||||
# Avoid overlapping when the sum of start and end rounds exceeds total
|
||||
# rounds
|
||||
@ -517,12 +503,12 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
messages_by_round[: self._keep_start_rounds]
|
||||
+ messages_by_round[-self._keep_end_rounds :]
|
||||
)
|
||||
elif self._keep_start_rounds is not None:
|
||||
elif self._keep_start_rounds:
|
||||
return messages_by_round[: self._keep_start_rounds]
|
||||
elif self._keep_end_rounds is not None:
|
||||
elif self._keep_end_rounds:
|
||||
return messages_by_round[-self._keep_end_rounds :]
|
||||
else:
|
||||
return messages_by_round
|
||||
return []
|
||||
|
||||
|
||||
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
|
||||
|
0
dbgpt/core/interface/operators/tests/__init__.py
Normal file
0
dbgpt/core/interface/operators/tests/__init__.py
Normal file
155
dbgpt/core/interface/operators/tests/test_message_operator.py
Normal file
155
dbgpt/core/interface/operators/tests/test_message_operator.py
Normal file
@ -0,0 +1,155 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.message import AIMessage, BaseMessage, HumanMessage
|
||||
from dbgpt.core.operators import BufferedConversationMapperOperator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> List[BaseMessage]:
|
||||
return [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_start_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_start_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
]
|
||||
# Test keep start 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep start 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=100,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=-1,
|
||||
keep_end_rounds=None,
|
||||
)
|
||||
await operator.map_messages(messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_end_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_end_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=2,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# Test keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep end 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=100,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep end -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=None,
|
||||
keep_end_rounds=-1,
|
||||
)
|
||||
await operator.map_messages(messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_conversation_keep_start_end_rounds(messages: List[BaseMessage]):
|
||||
# Test keep_start_rounds and keep_end_rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=1,
|
||||
keep_end_rounds=1,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# Test keep start 0 rounds and keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == []
|
||||
|
||||
# Test keep start 0 rounds and keep end 1 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=0,
|
||||
keep_end_rounds=1,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
|
||||
# Test keep start 2 rounds and keep end 0 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=0,
|
||||
)
|
||||
assert await operator.map_messages(messages) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
]
|
||||
|
||||
# Test keep start 100 rounds and keep end 100 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=100,
|
||||
keep_end_rounds=100,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start 2 round and keep end 2 rounds
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=2,
|
||||
keep_end_rounds=2,
|
||||
)
|
||||
assert await operator.map_messages(messages) == messages
|
||||
|
||||
# Test keep start -1 rounds and keep end -1 rounds
|
||||
with pytest.raises(ValueError):
|
||||
operator = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=-1,
|
||||
keep_end_rounds=-1,
|
||||
)
|
||||
await operator.map_messages(messages)
|
@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union, cast
|
||||
|
||||
from dbgpt.model.parameter import BaseEmbeddingModelParameters
|
||||
from dbgpt.model.parameter import BaseEmbeddingModelParameters, ProxyEmbeddingParameters
|
||||
from dbgpt.util.parameter_utils import _get_dict_from_obj
|
||||
from dbgpt.util.system_utils import get_system_info
|
||||
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
from dbgpt.rag.embedding import Embeddings
|
||||
|
||||
|
||||
class EmbeddingLoader:
|
||||
@ -17,7 +19,7 @@ class EmbeddingLoader:
|
||||
|
||||
def load(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> "Embeddings":
|
||||
) -> "Union[LangChainEmbeddings, Embeddings]":
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
@ -32,6 +34,18 @@ class EmbeddingLoader:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(**param.build_kwargs())
|
||||
elif model_name in ["proxy_http_openapi"]:
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
openapi_param = {}
|
||||
if proxy_param.proxy_server_url:
|
||||
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||
if proxy_param.proxy_api_key:
|
||||
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIEmbeddings(**openapi_param)
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
|
@ -552,7 +552,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi",
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
Loading…
Reference in New Issue
Block a user