diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index ad4dceaf8f3..f6c4e15ba62 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -78,7 +78,7 @@ class Replicate(LLM): @classmethod def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in get_fields(cls).values()} + all_required_field_names = {field for field in get_fields(cls).keys()} input = values.pop("input", {}) if input: diff --git a/libs/community/tests/integration_tests/llms/test_replicate.py b/libs/community/tests/integration_tests/llms/test_replicate.py index fa58bd19c2d..dbf4ba2ff11 100644 --- a/libs/community/tests/integration_tests/llms/test_replicate.py +++ b/libs/community/tests/integration_tests/llms/test_replicate.py @@ -1,16 +1,18 @@ """Test Replicate API wrapper.""" -from langchain_core.callbacks import CallbackManager - from langchain_community.llms.replicate import Replicate from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -TEST_MODEL = "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5" # noqa: E501 +TEST_MODEL_HELLO = ( + "replicate/hello-world:" + + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" +) +TEST_MODEL_LANG = "meta/meta-llama-3-8b-instruct" def test_replicate_call() -> None: """Test simple non-streaming call to Replicate.""" - llm = Replicate(model=TEST_MODEL) + llm = Replicate(model=TEST_MODEL_HELLO) output = llm.invoke("What is LangChain") assert output assert isinstance(output, str) @@ -19,9 +21,10 @@ def test_replicate_call() -> None: def test_replicate_streaming_call() -> None: """Test streaming call to Replicate.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL) + llm = Replicate( + streaming=True, callbacks=[callback_handler], model=TEST_MODEL_HELLO + ) output = llm.invoke("What is LangChain") assert output assert isinstance(output, str) @@ -30,17 +33,17 @@ def test_replicate_streaming_call() -> None: def test_replicate_model_kwargs() -> None: """Test simple non-streaming call to Replicate.""" llm = Replicate( # type: ignore[call-arg] - model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01} + model=TEST_MODEL_LANG, model_kwargs={"max_new_tokens": 10, "temperature": 0.01} ) long_output = llm.invoke("What is LangChain") llm = Replicate( # type: ignore[call-arg] - model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01} + model=TEST_MODEL_LANG, model_kwargs={"max_new_tokens": 5, "temperature": 0.01} ) short_output = llm.invoke("What is LangChain") assert len(short_output) < len(long_output) - assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01} + assert llm.model_kwargs == {"max_new_tokens": 5, "temperature": 0.01} def test_replicate_input() -> None: - llm = Replicate(model=TEST_MODEL, input={"max_length": 10}) - assert llm.model_kwargs == {"max_length": 10} + llm = Replicate(model=TEST_MODEL_LANG, input={"max_new_tokens": 10}) + assert llm.model_kwargs == {"max_new_tokens": 10}