mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
openai: developer message just openai
This commit is contained in:
parent
f723a8456e
commit
1cc21e60a9
@ -139,7 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
elif role == "system" or role == "developer":
|
||||
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
|
||||
elif role == "function":
|
||||
return FunctionMessage(
|
||||
@ -180,7 +180,9 @@ def _format_message_content(content: Any) -> Any:
|
||||
return formatted_content
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
def _convert_message_to_dict(
|
||||
message: BaseMessage, *, system_message_role: str = "system"
|
||||
) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
@ -233,7 +235,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
)
|
||||
message_dict["audio"] = audio
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict["role"] = "system"
|
||||
message_dict["role"] = system_message_role
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict["role"] = "function"
|
||||
elif isinstance(message, ToolMessage):
|
||||
@ -482,6 +484,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
"""
|
||||
system_message_role: str = "system"
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@ -493,6 +496,20 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_system_message_role(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Ensure that the system message role is correctly set for the model."""
|
||||
if "system_message_role" in values:
|
||||
return values
|
||||
|
||||
model = values.get("model_name") or values.get("model") or ""
|
||||
if model.startswith("o1"):
|
||||
values["system_message_role"] = "developer"
|
||||
# otherwise default is "system"
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
|
||||
@ -701,7 +718,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
kwargs["stop"] = stop
|
||||
|
||||
return {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"messages": [
|
||||
_convert_message_to_dict(
|
||||
m, system_message_role=self.system_message_role
|
||||
)
|
||||
for m in messages
|
||||
],
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
@ -936,7 +958,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
" for information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [
|
||||
_convert_message_to_dict(m, system_message_role=self.system_message_role)
|
||||
for m in messages
|
||||
]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
@ -1097,3 +1097,31 @@ def test_o1_max_tokens() -> None:
|
||||
"how are you"
|
||||
)
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1",
|
||||
# "o1-mini", neither supported
|
||||
"gpt-3.5-turbo",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("role", ["system", "developer", None])
|
||||
def test_system_message_roles(model: str, role: Optional[str]) -> None:
|
||||
init_kwargs = {"model": model}
|
||||
if role is not None:
|
||||
init_kwargs["system_message_role"] = role
|
||||
llm = ChatOpenAI(**init_kwargs) # type: ignore[arg-type]
|
||||
if role is None:
|
||||
if model.startswith("o1"):
|
||||
assert llm.system_message_role == "developer"
|
||||
else:
|
||||
assert llm.system_message_role == "system"
|
||||
history = [SystemMessage("You talk like a pirate"), HumanMessage("Hello there")]
|
||||
|
||||
out = llm.invoke(history)
|
||||
assert isinstance(out, AIMessage)
|
||||
assert out.content
|
||||
|
Loading…
Reference in New Issue
Block a user