From ce33c4fa407c5e1537858c7d2aadd9dde07f0602 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:45:21 -0700 Subject: [PATCH] openai[patch]: default temp=1 for o1 (#27206) --- .../partners/openai/langchain_openai/chat_models/base.py | 9 +++++++++ .../openai/tests/unit_tests/chat_models/test_base.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index baaa74f637b..233fa8700a4 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -480,6 +480,15 @@ class BaseChatOpenAI(BaseChatModel): values = _build_model_kwargs(values, all_required_field_names) return values + @model_validator(mode="before") + @classmethod + def validate_temperature(cls, values: Dict[str, Any]) -> Any: + """Currently o1 models only allow temperature=1.""" + model = values.get("model_name") or values.get("model") or "" + if model.startswith("o1") and "temperature" not in values: + values["temperature"] = 1 + return values + @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index cc03698e5ef..c6065a137e5 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -35,6 +35,13 @@ def test_openai_model_param() -> None: assert llm.model_name == "foo" +def test_openai_o1_temperature() -> None: + llm = ChatOpenAI(model="o1-preview") + assert llm.temperature == 1 + llm = ChatOpenAI(model_name="o1-mini") # type: ignore[call-arg] + assert llm.temperature == 1 + + def test_function_message_dict_to_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function"