Compare commits

...

5 Commits

Author SHA1 Message Date
Eugene Yurtsev
a5304aeb9b x 2024-12-18 11:19:51 -05:00
Eugene Yurtsev
6143127200 x 2024-12-18 11:17:07 -05:00
Eugene Yurtsev
e956663b94 x 2024-12-18 11:11:42 -05:00
Eugene Yurtsev
d1f49fe44a update 2024-12-18 11:11:35 -05:00
Eugene Yurtsev
45fe5878bf x 2024-12-18 11:07:39 -05:00
2 changed files with 64 additions and 1 deletions

View File

@@ -141,6 +141,13 @@ def _message_from_dict(message: dict) -> BaseMessage:
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "developer":
# The `developer` role is a new role that OpenAI has introduced to replace
# the `system` role.
# As of the time of writing, the developer role is mostly a drop-in replacement
# for the system role, so for now we will treat it as a system message.
# https://cdn.openai.com/spec/model-spec-2024-05-08.html
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
@@ -263,6 +270,13 @@ def _create_message_from_message_type(
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "developer":
# The `developer` role is a new role that OpenAI has introduced to replace
# the `system` role.
# As of the time of writing, the developer role is mostly a drop-in replacement
# for the system role, so for now we will treat it as a system message.
# https://cdn.openai.com/spec/model-spec-2024-05-08.html
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
@@ -273,7 +287,7 @@ def _create_message_from_message_type(
else:
msg = (
f"Unexpected message type: '{message_type}'. Use one of 'human',"
f" 'user', 'ai', 'assistant', 'function', 'tool', or 'system'."
f" 'user', 'ai', 'assistant', 'function', 'tool', 'system' or 'developer'."
)
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
raise ValueError(msg)

View File

@@ -404,6 +404,53 @@ def test_multiple_msg_with_name() -> None:
assert messages_from_dict(messages_to_dict(msgs)) == msgs
@pytest.mark.parametrize(
"input_message, expected_message",
[
(
{"type": "human", "data": {"content": "Hello!", "id": "human1"}},
HumanMessage(content="Hello!", id="human1"),
),
(
{"type": "ai", "data": {"content": "Hi!", "id": "ai1"}},
AIMessage(content="Hi!", id="ai1"),
),
(
{"type": "system", "data": {"content": "You are a helpful assistant."}},
SystemMessage(content="You are a helpful assistant."),
),
(
{"type": "developer", "data": {"content": "System-level control."}},
SystemMessage(content="System-level control."),
),
(
{"type": "function", "data": {"name": "greet", "content": "Hello!"}},
FunctionMessage(name="greet", content="Hello!"),
),
(
{
"type": "tool",
"data": {"tool_call_id": "tool1", "content": "Tool output"},
},
ToolMessage(tool_call_id="tool1", content="Tool output"),
),
(
{"type": "remove", "data": {"id": "remove1", "content": ""}},
RemoveMessage(id="remove1"),
),
(
{"type": "AIMessageChunk", "data": {"content": "AI chunk"}},
AIMessageChunk(content="AI chunk"),
),
],
)
def test_message_from_dict(input_message: dict, expected_message: BaseMessage):
"""Test that messages can be created from dictionaries."""
base_messages = messages_from_dict([input_message])
assert len(base_messages) == 1
assert base_messages[0] == expected_message
def test_message_chunk_to_message() -> None:
assert message_chunk_to_message(
AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"})
@@ -699,6 +746,7 @@ def test_convert_to_messages() -> None:
actual = convert_to_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "developer", "content": "You are a helpful assistant 2."},
{"role": "user", "content": "Hello!"},
{"role": "ai", "content": "Hi!", "id": "ai1"},
{"type": "human", "content": "Hello!", "name": "Jane", "id": "human1"},
@@ -733,6 +781,7 @@ def test_convert_to_messages() -> None:
)
expected = [
SystemMessage(content="You are a helpful assistant."),
SystemMessage(content="You are a helpful assistant 2."),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!", id="ai1"),
HumanMessage(content="Hello!", name="Jane", id="human1"),