mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
openai: developer message just openai
This commit is contained in:
parent
f723a8456e
commit
1cc21e60a9
@ -139,7 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
elif role == "system":
|
elif role == "system" or role == "developer":
|
||||||
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
|
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
|
||||||
elif role == "function":
|
elif role == "function":
|
||||||
return FunctionMessage(
|
return FunctionMessage(
|
||||||
@ -180,7 +180,9 @@ def _format_message_content(content: Any) -> Any:
|
|||||||
return formatted_content
|
return formatted_content
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
def _convert_message_to_dict(
|
||||||
|
message: BaseMessage, *, system_message_role: str = "system"
|
||||||
|
) -> dict:
|
||||||
"""Convert a LangChain message to a dictionary.
|
"""Convert a LangChain message to a dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -233,7 +235,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
)
|
)
|
||||||
message_dict["audio"] = audio
|
message_dict["audio"] = audio
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict["role"] = "system"
|
message_dict["role"] = system_message_role
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
message_dict["role"] = "function"
|
message_dict["role"] = "function"
|
||||||
elif isinstance(message, ToolMessage):
|
elif isinstance(message, ToolMessage):
|
||||||
@ -482,6 +484,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
However this does not prevent a user from directly passed in the parameter during
|
However this does not prevent a user from directly passed in the parameter during
|
||||||
invocation.
|
invocation.
|
||||||
"""
|
"""
|
||||||
|
system_message_role: str = "system"
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
@ -493,6 +496,20 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
values = _build_model_kwargs(values, all_required_field_names)
|
values = _build_model_kwargs(values, all_required_field_names)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_system_message_role(cls, values: Dict[str, Any]) -> Any:
|
||||||
|
"""Ensure that the system message role is correctly set for the model."""
|
||||||
|
if "system_message_role" in values:
|
||||||
|
return values
|
||||||
|
|
||||||
|
model = values.get("model_name") or values.get("model") or ""
|
||||||
|
if model.startswith("o1"):
|
||||||
|
values["system_message_role"] = "developer"
|
||||||
|
# otherwise default is "system"
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
|
def validate_temperature(cls, values: Dict[str, Any]) -> Any:
|
||||||
@ -701,7 +718,12 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["stop"] = stop
|
kwargs["stop"] = stop
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
"messages": [
|
||||||
|
_convert_message_to_dict(
|
||||||
|
m, system_message_role=self.system_message_role
|
||||||
|
)
|
||||||
|
for m in messages
|
||||||
|
],
|
||||||
**self._default_params,
|
**self._default_params,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
@ -936,7 +958,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
" for information on how messages are converted to tokens."
|
" for information on how messages are converted to tokens."
|
||||||
)
|
)
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
messages_dict = [
|
||||||
|
_convert_message_to_dict(m, system_message_role=self.system_message_role)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
for message in messages_dict:
|
for message in messages_dict:
|
||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
|
@ -1097,3 +1097,31 @@ def test_o1_max_tokens() -> None:
|
|||||||
"how are you"
|
"how are you"
|
||||||
)
|
)
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"o1",
|
||||||
|
# "o1-mini", neither supported
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("role", ["system", "developer", None])
|
||||||
|
def test_system_message_roles(model: str, role: Optional[str]) -> None:
|
||||||
|
init_kwargs = {"model": model}
|
||||||
|
if role is not None:
|
||||||
|
init_kwargs["system_message_role"] = role
|
||||||
|
llm = ChatOpenAI(**init_kwargs) # type: ignore[arg-type]
|
||||||
|
if role is None:
|
||||||
|
if model.startswith("o1"):
|
||||||
|
assert llm.system_message_role == "developer"
|
||||||
|
else:
|
||||||
|
assert llm.system_message_role == "system"
|
||||||
|
history = [SystemMessage("You talk like a pirate"), HumanMessage("Hello there")]
|
||||||
|
|
||||||
|
out = llm.invoke(history)
|
||||||
|
assert isinstance(out, AIMessage)
|
||||||
|
assert out.content
|
||||||
|
Loading…
Reference in New Issue
Block a user