mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 21:23:32 +00:00
Replicate params fix (#10603)
This commit is contained in:
parent
50bb704da5
commit
ecbb1ed8cb
@ -79,7 +79,7 @@ class Replicate(LLM):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Init param `input` is deprecated, please use `model_kwargs` instead."
|
"Init param `input` is deprecated, please use `model_kwargs` instead."
|
||||||
)
|
)
|
||||||
extra = {**values.get("model_kwargs", {}), **input}
|
extra = {**values.pop("model_kwargs", {}), **input}
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
@ -96,7 +96,7 @@ class Replicate(LLM):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
replicate_api_token = get_from_dict_or_env(
|
replicate_api_token = get_from_dict_or_env(
|
||||||
values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN"
|
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||||
)
|
)
|
||||||
values["replicate_api_token"] = replicate_api_token
|
values["replicate_api_token"] = replicate_api_token
|
||||||
return values
|
return values
|
||||||
|
@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None:
|
|||||||
)
|
)
|
||||||
short_output = llm("What is LangChain")
|
short_output = llm("What is LangChain")
|
||||||
assert len(short_output) < len(long_output)
|
assert len(short_output) < len(long_output)
|
||||||
|
assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01}
|
||||||
|
|
||||||
|
|
||||||
def test_replicate_input() -> None:
|
def test_replicate_input() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user