From e195b78e1dbf0498a2475673ef1671d11c0730c4 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 14 Sep 2023 14:43:42 -0700 Subject: [PATCH] Fix replicate model kwargs (#10599) --- libs/langchain/langchain/llms/replicate.py | 21 +++++++++++++------ .../integration_tests/llms/test_replicate.py | 18 ++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 04e3863ef24..5d407c40b41 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -23,7 +23,7 @@ class Replicate(LLM): You can find your token here: https://replicate.com/account The model param is required, but any other model parameters can also - be passed in with the format input={model_param: value, ...} + be passed in with the format model_kwargs={model_param: value, ...} Example: .. code-block:: python @@ -35,13 +35,12 @@ class Replicate(LLM): "stability-ai/stable-diffusion: " "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", ), - input={"image_dimensions": "512x512"} + model_kwargs={"image_dimensions": "512x512"} ) """ model: str - input: Dict[str, Any] = Field(default_factory=dict) - model_kwargs: Dict[str, Any] = Field(default_factory=dict) + model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input") replicate_api_token: Optional[str] = None prompt_key: Optional[str] = None version_obj: Any = Field(default=None, exclude=True) @@ -59,6 +58,7 @@ class Replicate(LLM): class Config: """Configuration for this pydantic config.""" + allow_population_by_field_name = True extra = Extra.forbid @property @@ -74,7 +74,12 @@ class Replicate(LLM): """Build extra kwargs from additional params that were passed in.""" all_required_field_names = {field.alias for field in cls.__fields__.values()} - extra = values.get("model_kwargs", {}) + input = values.pop("input", {}) + if input: + logger.warning( + "Init param `input` is deprecated, please use `model_kwargs` instead." + ) + extra = {**values.get("model_kwargs", {}), **input} for field_name in list(values): if field_name not in all_required_field_names: if field_name in extra: @@ -202,7 +207,11 @@ class Replicate(LLM): self.prompt_key = input_properties[0][0] - input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs} + input_: Dict = { + self.prompt_key: prompt, + **self.model_kwargs, + **kwargs, + } return replicate_python.predictions.create( version=self.version_obj, input=input_ ) diff --git a/libs/langchain/tests/integration_tests/llms/test_replicate.py b/libs/langchain/tests/integration_tests/llms/test_replicate.py index 7e2666cbeab..9bc183bb8b0 100644 --- a/libs/langchain/tests/integration_tests/llms/test_replicate.py +++ b/libs/langchain/tests/integration_tests/llms/test_replicate.py @@ -24,3 +24,21 @@ def test_replicate_streaming_call() -> None: output = llm("What is LangChain") assert output assert isinstance(output, str) + + +def test_replicate_model_kwargs() -> None: + """Test simple non-streaming call to Replicate.""" + llm = Replicate( + model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01} + ) + long_output = llm("What is LangChain") + llm = Replicate( + model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01} + ) + short_output = llm("What is LangChain") + assert len(short_output) < len(long_output) + + +def test_replicate_input() -> None: + llm = Replicate(model=TEST_MODEL, input={"max_length": 10}) + assert llm.model_kwargs == {"max_length": 10}