core[patch], openai[patch]: Handle OpenAI developer msg (#28794)

- Convert developer openai messages to SystemMessage
- store additional_kwargs={"__openai_role__": "developer"} so that the
correct role can be reconstructed if needed
- update ChatOpenAI to read in openai_role

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur
2024-12-18 13:54:07 -08:00
committed by GitHub
parent 43b0736a51
commit 4a531437bb
5 changed files with 96 additions and 10 deletions

View File

@@ -139,8 +139,17 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role in ("system", "developer"):
if role == "developer":
additional_kwargs = {"__openai_role__": role}
else:
additional_kwargs = {}
return SystemMessage(
content=_dict.get("content", ""),
name=name,
id=id_,
additional_kwargs=additional_kwargs,
)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
@@ -233,7 +242,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
)
message_dict["audio"] = audio
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
message_dict["role"] = message.additional_kwargs.get(
"__openai_role__", "system"
)
elif isinstance(message, FunctionMessage):
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
@@ -284,8 +295,14 @@ def _convert_delta_to_message_chunk(
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content, id=id_)
elif role in ("system", "developer") or default_class == SystemMessageChunk:
if role == "developer":
additional_kwargs = {"__openai_role__": "developer"}
else:
additional_kwargs = {}
return SystemMessageChunk(
content=content, id=id_, additional_kwargs=additional_kwargs
)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:

View File

@@ -1097,3 +1097,16 @@ def test_o1_max_tokens() -> None:
"how are you"
)
assert isinstance(response, AIMessage)
def test_developer_message() -> None:
llm = ChatOpenAI(model="o1", max_tokens=10) # type: ignore[call-arg]
response = llm.invoke(
[
{"role": "developer", "content": "respond in all caps"},
{"role": "user", "content": "HOW ARE YOU"},
]
)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content.upper() == response.content

View File

@@ -100,6 +100,16 @@ def test__convert_dict_to_message_system() -> None:
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_developer() -> None:
message = {"role": "developer", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(
content="foo", additional_kwargs={"__openai_role__": "developer"}
)
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
@@ -850,3 +860,25 @@ def test_nested_structured_output_strict() -> None:
self_evaluation: SelfEvaluation
llm.with_structured_output(JokeWithEvaluation, method="json_schema")
def test__get_request_payload() -> None:
llm = ChatOpenAI(model="gpt-4o-2024-08-06")
messages: list = [
SystemMessage("hello"),
SystemMessage("bye", additional_kwargs={"__openai_role__": "developer"}),
{"role": "human", "content": "how are you"},
]
expected = {
"messages": [
{"role": "system", "content": "hello"},
{"role": "developer", "content": "bye"},
{"role": "user", "content": "how are you"},
],
"model": "gpt-4o-2024-08-06",
"stream": False,
"n": 1,
"temperature": 0.7,
}
payload = llm._get_request_payload(messages)
assert payload == expected