diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 7e11c491d6f..32f198532e6 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -812,9 +812,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): ls_params["ls_model_name"] = self.model_name # temperature - if "temperature" in kwargs and isinstance(kwargs["temperature"], float): + if "temperature" in kwargs and isinstance(kwargs["temperature"], (int, float)): ls_params["ls_temperature"] = kwargs["temperature"] - elif hasattr(self, "temperature") and isinstance(self.temperature, float): + elif hasattr(self, "temperature") and isinstance( + self.temperature, (int, float) + ): ls_params["ls_temperature"] = self.temperature # max_tokens diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 5aa287ada8e..fa034675d89 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -351,9 +351,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): ls_params["ls_model_name"] = self.model_name # temperature - if "temperature" in kwargs and isinstance(kwargs["temperature"], float): + if "temperature" in kwargs and isinstance(kwargs["temperature"], (int, float)): ls_params["ls_temperature"] = kwargs["temperature"] - elif hasattr(self, "temperature") and isinstance(self.temperature, float): + elif hasattr(self, "temperature") and isinstance( + self.temperature, (int, float) + ): ls_params["ls_temperature"] = self.temperature # max_tokens diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index eedb5d393a8..51fbecb8ffe 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1213,6 +1213,40 @@ def test_get_ls_params() -> None: assert ls_params["ls_stop"] == ["stop"] +def test_get_ls_params_int_temperature() -> None: + class IntTempModel(BaseChatModel): + model: str = "foo" + temperature: int = 0 + max_tokens: int = 1024 + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + raise NotImplementedError + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + llm = IntTempModel() + + # Integer temperature from self attribute + ls_params = llm._get_ls_params() + assert ls_params["ls_temperature"] == 0 + + # Integer temperature from kwargs + ls_params = llm._get_ls_params(temperature=1) + assert ls_params["ls_temperature"] == 1 + + # Float temperature from kwargs still works + ls_params = llm._get_ls_params(temperature=0.5) + assert ls_params["ls_temperature"] == 0.5 + + def test_model_profiles() -> None: model = GenericFakeChatModel(messages=iter([])) assert model.profile is None diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index e5547a617a7..fb2ae8dc321 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -277,3 +277,38 @@ def test_get_ls_params() -> None: ls_params = llm._get_ls_params(stop=["stop"]) assert ls_params["ls_stop"] == ["stop"] + + +def test_get_ls_params_int_temperature() -> None: + class IntTempModel(BaseLLM): + model: str = "foo" + temperature: int = 0 + max_tokens: int = 1024 + + @override + def _generate( + self, + prompts: list[str], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> LLMResult: + raise NotImplementedError + + @property + def _llm_type(self) -> str: + return "fake-model" + + llm = IntTempModel() + + # Integer temperature from self attribute + ls_params = llm._get_ls_params() + assert ls_params["ls_temperature"] == 0 + + # Integer temperature from kwargs + ls_params = llm._get_ls_params(temperature=1) + assert ls_params["ls_temperature"] == 1 + + # Float temperature from kwargs still works + ls_params = llm._get_ls_params(temperature=0.5) + assert ls_params["ls_temperature"] == 0.5