mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
fix(core): Move the last user's information to the end (#960)
This commit is contained in:
13
dbgpt/core/interface/message.py
Normal file → Executable file
13
dbgpt/core/interface/message.py
Normal file → Executable file
@@ -157,14 +157,13 @@ class ModelMessage(BaseModel):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
# Move the last user's information to the end
|
# Move the last user's information to the end
|
||||||
temp_his = history[::-1]
|
last_user_input_index = None
|
||||||
last_user_input = None
|
for i in range(len(history) - 1, -1, -1):
|
||||||
for m in temp_his:
|
if history[i]["role"] == "user":
|
||||||
if m["role"] == "user":
|
last_user_input_index = i
|
||||||
last_user_input = m
|
|
||||||
break
|
break
|
||||||
if last_user_input:
|
if last_user_input_index:
|
||||||
history.remove(last_user_input)
|
last_user_input = history.pop(last_user_input_index)
|
||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
57
dbgpt/core/interface/tests/test_message.py
Normal file → Executable file
57
dbgpt/core/interface/tests/test_message.py
Normal file → Executable file
@@ -67,6 +67,23 @@ def conversation_with_messages():
|
|||||||
return conv
|
return conv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def human_model_message():
|
||||||
|
return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ai_model_message():
|
||||||
|
return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def system_model_message():
|
||||||
|
return ModelMessage(
|
||||||
|
role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_init(basic_conversation):
|
def test_init(basic_conversation):
|
||||||
assert basic_conversation.chat_mode == "chat_normal"
|
assert basic_conversation.chat_mode == "chat_normal"
|
||||||
assert basic_conversation.user_name == "user1"
|
assert basic_conversation.user_name == "user1"
|
||||||
@@ -305,3 +322,43 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
|
|||||||
assert new_conversation.messages[1].content == "AI response"
|
assert new_conversation.messages[1].content == "AI response"
|
||||||
assert isinstance(new_conversation.messages[0], HumanMessage)
|
assert isinstance(new_conversation.messages[0], HumanMessage)
|
||||||
assert isinstance(new_conversation.messages[1], AIMessage)
|
assert isinstance(new_conversation.messages[1], AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages(
|
||||||
|
human_model_message, ai_model_message, system_model_message
|
||||||
|
):
|
||||||
|
none_messages = ModelMessage.to_openai_messages([])
|
||||||
|
assert none_messages == []
|
||||||
|
|
||||||
|
single_messages = ModelMessage.to_openai_messages([human_model_message])
|
||||||
|
assert single_messages == [{"role": "user", "content": human_model_message.content}]
|
||||||
|
|
||||||
|
normal_messages = ModelMessage.to_openai_messages(
|
||||||
|
[
|
||||||
|
system_model_message,
|
||||||
|
human_model_message,
|
||||||
|
ai_model_message,
|
||||||
|
human_model_message,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert normal_messages == [
|
||||||
|
{"role": "system", "content": system_model_message.content},
|
||||||
|
{"role": "user", "content": human_model_message.content},
|
||||||
|
{"role": "assistant", "content": ai_model_message.content},
|
||||||
|
{"role": "user", "content": human_model_message.content},
|
||||||
|
]
|
||||||
|
|
||||||
|
shuffle_messages = ModelMessage.to_openai_messages(
|
||||||
|
[
|
||||||
|
system_model_message,
|
||||||
|
human_model_message,
|
||||||
|
human_model_message,
|
||||||
|
ai_model_message,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert shuffle_messages == [
|
||||||
|
{"role": "system", "content": system_model_message.content},
|
||||||
|
{"role": "user", "content": human_model_message.content},
|
||||||
|
{"role": "assistant", "content": ai_model_message.content},
|
||||||
|
{"role": "user", "content": human_model_message.content},
|
||||||
|
]
|
||||||
|
13
dbgpt/model/proxy/llms/bard.py
Normal file → Executable file
13
dbgpt/model/proxy/llms/bard.py
Normal file → Executable file
@@ -25,14 +25,13 @@ def bard_generate_stream(
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
temp_his = history[::-1]
|
last_user_input_index = None
|
||||||
last_user_input = None
|
for i in range(len(history) - 1, -1, -1):
|
||||||
for m in temp_his:
|
if history[i]["role"] == "user":
|
||||||
if m["role"] == "user":
|
last_user_input_index = i
|
||||||
last_user_input = m
|
|
||||||
break
|
break
|
||||||
if last_user_input:
|
if last_user_input_index:
|
||||||
history.remove(last_user_input)
|
last_user_input = history.pop(last_user_input_index)
|
||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
|
||||||
msgs = []
|
msgs = []
|
||||||
|
13
dbgpt/model/proxy/llms/chatgpt.py
Normal file → Executable file
13
dbgpt/model/proxy/llms/chatgpt.py
Normal file → Executable file
@@ -110,14 +110,13 @@ def _build_request(model: ProxyModel, params):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Move the last user's information to the end
|
# Move the last user's information to the end
|
||||||
temp_his = history[::-1]
|
last_user_input_index = None
|
||||||
last_user_input = None
|
for i in range(len(history) - 1, -1, -1):
|
||||||
for m in temp_his:
|
if history[i]["role"] == "user":
|
||||||
if m["role"] == "user":
|
last_user_input_index = i
|
||||||
last_user_input = m
|
|
||||||
break
|
break
|
||||||
if last_user_input:
|
if last_user_input_index:
|
||||||
history.remove(last_user_input)
|
last_user_input = history.pop(last_user_input_index)
|
||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
|
||||||
payloads = {
|
payloads = {
|
||||||
|
Reference in New Issue
Block a user