mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
revert: accept integer temperature values in _get_ls_params (#35319)
This commit is contained in:
@@ -812,11 +812,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], (int, float)):
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(
|
||||
self.temperature, (int, float)
|
||||
):
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
|
||||
@@ -351,11 +351,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], (int, float)):
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(
|
||||
self.temperature, (int, float)
|
||||
):
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
|
||||
@@ -1213,40 +1213,6 @@ def test_get_ls_params() -> None:
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
|
||||
def test_get_ls_params_int_temperature() -> None:
|
||||
class IntTempModel(BaseChatModel):
|
||||
model: str = "foo"
|
||||
temperature: int = 0
|
||||
max_tokens: int = 1024
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
llm = IntTempModel()
|
||||
|
||||
# Integer temperature from self attribute
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_temperature"] == 0
|
||||
|
||||
# Integer temperature from kwargs
|
||||
ls_params = llm._get_ls_params(temperature=1)
|
||||
assert ls_params["ls_temperature"] == 1
|
||||
|
||||
# Float temperature from kwargs still works
|
||||
ls_params = llm._get_ls_params(temperature=0.5)
|
||||
assert ls_params["ls_temperature"] == 0.5
|
||||
|
||||
|
||||
def test_model_profiles() -> None:
|
||||
model = GenericFakeChatModel(messages=iter([]))
|
||||
assert model.profile is None
|
||||
|
||||
@@ -277,38 +277,3 @@ def test_get_ls_params() -> None:
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
|
||||
def test_get_ls_params_int_temperature() -> None:
|
||||
class IntTempModel(BaseLLM):
|
||||
model: str = "foo"
|
||||
temperature: int = 0
|
||||
max_tokens: int = 1024
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-model"
|
||||
|
||||
llm = IntTempModel()
|
||||
|
||||
# Integer temperature from self attribute
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_temperature"] == 0
|
||||
|
||||
# Integer temperature from kwargs
|
||||
ls_params = llm._get_ls_params(temperature=1)
|
||||
assert ls_params["ls_temperature"] == 1
|
||||
|
||||
# Float temperature from kwargs still works
|
||||
ls_params = llm._get_ls_params(temperature=0.5)
|
||||
assert ls_params["ls_temperature"] == 0.5
|
||||
|
||||
Reference in New Issue
Block a user