openai[patch]: update system role to developer for o-series models (#29785)

Some o-series models will raise a 400 error for `"role": "system"`
(`o1-mini` and `o1-preview` will raise, `o1` and `o3-mini` will not).

Here we update `ChatOpenAI` to update the role to `"developer"` for all
model names matching `^o\d`.

We only make this change on the ChatOpenAI class (not BaseChatOpenAI).
This commit is contained in:
ccurme 2025-02-24 08:59:46 -05:00 committed by GitHub
parent 8b511a3a78
commit 927ec20b69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import base64
import json import json
import logging import logging
import os import os
import re
import sys import sys
import warnings import warnings
from io import BytesIO from io import BytesIO
@ -2011,6 +2012,12 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
# in September 2024 release # in September 2024 release
if "max_tokens" in payload: if "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
# Mutate system message role to "developer" for o-series models
if self.model_name and re.match(r"^o\d", self.model_name):
for message in payload.get("messages", []):
if message["role"] == "system":
message["role"] = "developer"
return payload return payload
def _should_stream_usage( def _should_stream_usage(

View File

@ -881,6 +881,20 @@ def test__get_request_payload() -> None:
payload = llm._get_request_payload(messages) payload = llm._get_request_payload(messages)
assert payload == expected assert payload == expected
# Test we coerce to developer role for o-series models
llm = ChatOpenAI(model="o3-mini")
payload = llm._get_request_payload(messages)
expected = {
"messages": [
{"role": "developer", "content": "hello"},
{"role": "developer", "content": "bye"},
{"role": "user", "content": "how are you"},
],
"model": "o3-mini",
"stream": False,
}
assert payload == expected
def test_init_o1() -> None: def test_init_o1() -> None:
with pytest.warns(None) as record: # type: ignore[call-overload] with pytest.warns(None) as record: # type: ignore[call-overload]