openai: developer message just openai

This commit is contained in:
Erick Friis 2024-12-18 11:54:29 -05:00
parent f723a8456e
commit 1cc21e60a9
2 changed files with 58 additions and 5 deletions

View File

@ -139,7 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
elif role == "system" or role == "developer":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function":
return FunctionMessage(
@ -180,7 +180,9 @@ def _format_message_content(content: Any) -> Any:
return formatted_content
def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(
message: BaseMessage, *, system_message_role: str = "system"
) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
@ -233,7 +235,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
)
message_dict["audio"] = audio
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
message_dict["role"] = system_message_role
elif isinstance(message, FunctionMessage):
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
@ -482,6 +484,7 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during
invocation.
"""
system_message_role: str = "system"
model_config = ConfigDict(populate_by_name=True)
@ -493,6 +496,20 @@ class BaseChatOpenAI(BaseChatModel):
values = _build_model_kwargs(values, all_required_field_names)
return values
@model_validator(mode="before")
@classmethod
def validate_system_message_role(cls, values: Dict[str, Any]) -> Any:
"""Ensure that the system message role is correctly set for the model."""
if "system_message_role" in values:
return values
model = values.get("model_name") or values.get("model") or ""
if model.startswith("o1"):
values["system_message_role"] = "developer"
# otherwise default is "system"
return values
@model_validator(mode="before")
@classmethod
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
@ -701,7 +718,12 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
"messages": [
_convert_message_to_dict(
m, system_message_role=self.system_message_role
)
for m in messages
],
**self._default_params,
**kwargs,
}
@ -936,7 +958,10 @@ class BaseChatOpenAI(BaseChatModel):
" for information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [_convert_message_to_dict(m) for m in messages]
messages_dict = [
_convert_message_to_dict(m, system_message_role=self.system_message_role)
for m in messages
]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():

View File

@ -1097,3 +1097,31 @@ def test_o1_max_tokens() -> None:
"how are you"
)
assert isinstance(response, AIMessage)
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"gpt-4o-mini",
"o1",
# "o1-mini", neither supported
"gpt-3.5-turbo",
],
)
@pytest.mark.parametrize("role", ["system", "developer", None])
def test_system_message_roles(model: str, role: Optional[str]) -> None:
init_kwargs = {"model": model}
if role is not None:
init_kwargs["system_message_role"] = role
llm = ChatOpenAI(**init_kwargs) # type: ignore[arg-type]
if role is None:
if model.startswith("o1"):
assert llm.system_message_role == "developer"
else:
assert llm.system_message_role == "system"
history = [SystemMessage("You talk like a pirate"), HumanMessage("Hello there")]
out = llm.invoke(history)
assert isinstance(out, AIMessage)
assert out.content