diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index e312ed249c8..fdcccdd8511 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -139,7 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - elif role == "system": + elif role == "system" or role == "developer": return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) elif role == "function": return FunctionMessage( @@ -180,7 +180,9 @@ def _format_message_content(content: Any) -> Any: return formatted_content -def _convert_message_to_dict(message: BaseMessage) -> dict: +def _convert_message_to_dict( + message: BaseMessage, *, system_message_role: str = "system" +) -> dict: """Convert a LangChain message to a dictionary. Args: @@ -233,7 +235,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: ) message_dict["audio"] = audio elif isinstance(message, SystemMessage): - message_dict["role"] = "system" + message_dict["role"] = system_message_role elif isinstance(message, FunctionMessage): message_dict["role"] = "function" elif isinstance(message, ToolMessage): @@ -482,6 +484,7 @@ class BaseChatOpenAI(BaseChatModel): However this does not prevent a user from directly passed in the parameter during invocation. """ + system_message_role: str = "system" model_config = ConfigDict(populate_by_name=True) @@ -493,6 +496,20 @@ class BaseChatOpenAI(BaseChatModel): values = _build_model_kwargs(values, all_required_field_names) return values + @model_validator(mode="before") + @classmethod + def validate_system_message_role(cls, values: Dict[str, Any]) -> Any: + """Ensure that the system message role is correctly set for the model.""" + if "system_message_role" in values: + return values + + model = values.get("model_name") or values.get("model") or "" + if model.startswith("o1"): + values["system_message_role"] = "developer" + # otherwise default is "system" + + return values + @model_validator(mode="before") @classmethod def validate_temperature(cls, values: Dict[str, Any]) -> Any: @@ -701,7 +718,12 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stop"] = stop return { - "messages": [_convert_message_to_dict(m) for m in messages], + "messages": [ + _convert_message_to_dict( + m, system_message_role=self.system_message_role + ) + for m in messages + ], **self._default_params, **kwargs, } @@ -936,7 +958,10 @@ class BaseChatOpenAI(BaseChatModel): " for information on how messages are converted to tokens." ) num_tokens = 0 - messages_dict = [_convert_message_to_dict(m) for m in messages] + messages_dict = [ + _convert_message_to_dict(m, system_message_role=self.system_message_role) + for m in messages + ] for message in messages_dict: num_tokens += tokens_per_message for key, value in message.items(): diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 1204ccef87c..44ada843884 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1097,3 +1097,31 @@ def test_o1_max_tokens() -> None: "how are you" ) assert isinstance(response, AIMessage) + + +@pytest.mark.parametrize( + "model", + [ + "gpt-4o", + "gpt-4o-mini", + "o1", + # "o1-mini", neither supported + "gpt-3.5-turbo", + ], +) +@pytest.mark.parametrize("role", ["system", "developer", None]) +def test_system_message_roles(model: str, role: Optional[str]) -> None: + init_kwargs = {"model": model} + if role is not None: + init_kwargs["system_message_role"] = role + llm = ChatOpenAI(**init_kwargs) # type: ignore[arg-type] + if role is None: + if model.startswith("o1"): + assert llm.system_message_role == "developer" + else: + assert llm.system_message_role == "system" + history = [SystemMessage("You talk like a pirate"), HumanMessage("Hello there")] + + out = llm.invoke(history) + assert isinstance(out, AIMessage) + assert out.content