Replicate params fix (#10603)

This commit is contained in:
Bagatur 2023-09-14 15:04:42 -07:00 committed by GitHub
parent 50bb704da5
commit ecbb1ed8cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -79,7 +79,7 @@ class Replicate(LLM):
logger.warning( logger.warning(
"Init param `input` is deprecated, please use `model_kwargs` instead." "Init param `input` is deprecated, please use `model_kwargs` instead."
) )
extra = {**values.get("model_kwargs", {}), **input} extra = {**values.pop("model_kwargs", {}), **input}
for field_name in list(values): for field_name in list(values):
if field_name not in all_required_field_names: if field_name not in all_required_field_names:
if field_name in extra: if field_name in extra:
@ -96,7 +96,7 @@ class Replicate(LLM):
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env( replicate_api_token = get_from_dict_or_env(
values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN" values, "replicate_api_token", "REPLICATE_API_TOKEN"
) )
values["replicate_api_token"] = replicate_api_token values["replicate_api_token"] = replicate_api_token
return values return values

View File

@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None:
) )
short_output = llm("What is LangChain") short_output = llm("What is LangChain")
assert len(short_output) < len(long_output) assert len(short_output) < len(long_output)
assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01}
def test_replicate_input() -> None: def test_replicate_input() -> None: