openai: developer message just openai

This commit is contained in:
Erick Friis 2024-12-18 11:54:29 -05:00
parent f723a8456e
commit 1cc21e60a9
2 changed files with 58 additions and 5 deletions

View File

@ -139,7 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls, tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls, invalid_tool_calls=invalid_tool_calls,
) )
elif role == "system": elif role == "system" or role == "developer":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function": elif role == "function":
return FunctionMessage( return FunctionMessage(
@ -180,7 +180,9 @@ def _format_message_content(content: Any) -> Any:
return formatted_content return formatted_content
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(
message: BaseMessage, *, system_message_role: str = "system"
) -> dict:
"""Convert a LangChain message to a dictionary. """Convert a LangChain message to a dictionary.
Args: Args:
@ -233,7 +235,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
) )
message_dict["audio"] = audio message_dict["audio"] = audio
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict["role"] = "system" message_dict["role"] = system_message_role
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
message_dict["role"] = "function" message_dict["role"] = "function"
elif isinstance(message, ToolMessage): elif isinstance(message, ToolMessage):
@ -482,6 +484,7 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during However this does not prevent a user from directly passed in the parameter during
invocation. invocation.
""" """
system_message_role: str = "system"
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@ -493,6 +496,20 @@ class BaseChatOpenAI(BaseChatModel):
values = _build_model_kwargs(values, all_required_field_names) values = _build_model_kwargs(values, all_required_field_names)
return values return values
@model_validator(mode="before")
@classmethod
def validate_system_message_role(cls, values: Dict[str, Any]) -> Any:
"""Ensure that the system message role is correctly set for the model."""
if "system_message_role" in values:
return values
model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1"):
values["system_message_role"] = "developer"
# otherwise default is "system"
return values
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_temperature(cls, values: Dict[str, Any]) -> Any: def validate_temperature(cls, values: Dict[str, Any]) -> Any:
@ -701,7 +718,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stop"] = stop kwargs["stop"] = stop
return { return {
"messages": [_convert_message_to_dict(m) for m in messages], "messages": [
_convert_message_to_dict(
m, system_message_role=self.system_message_role
)
for m in messages
],
**self._default_params, **self._default_params,
**kwargs, **kwargs,
} }
@ -936,7 +958,10 @@ class BaseChatOpenAI(BaseChatModel):
" for information on how messages are converted to tokens." " for information on how messages are converted to tokens."
) )
num_tokens = 0 num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages] messages_dict = [
_convert_message_to_dict(m, system_message_role=self.system_message_role)
for m in messages
]
for message in messages_dict: for message in messages_dict:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():

View File

@ -1097,3 +1097,31 @@ def test_o1_max_tokens() -> None:
"how are you" "how are you"
) )
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"gpt-4o-mini",
"o1",
# "o1-mini", neither supported
"gpt-3.5-turbo",
],
)
@pytest.mark.parametrize("role", ["system", "developer", None])
def test_system_message_roles(model: str, role: Optional[str]) -> None:
init_kwargs = {"model": model}
if role is not None:
init_kwargs["system_message_role"] = role
llm = ChatOpenAI(**init_kwargs) # type: ignore[arg-type]
if role is None:
if model.startswith("o1"):
assert llm.system_message_role == "developer"
else:
assert llm.system_message_role == "system"
history = [SystemMessage("You talk like a pirate"), HumanMessage("Hello there")]
out = llm.invoke(history)
assert isinstance(out, AIMessage)
assert out.content