From 3b956b3a97a5fde589fafe976e7f7a64cdb4ca3f Mon Sep 17 00:00:00 2001 From: fayvor <682308+fayvor@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:07:08 -0400 Subject: [PATCH] community: Update Replicate LLM and fix tests (#27655) **Description:** - Fix bug in Replicate LLM class, where it was looking for parameter names in a place where they no longer exist in pydantic 2, resulting in the "Field required" validation error described in the issue. - Fix Replicate LLM integration tests to: - Use active models on Replicate. - Use the correct model parameter `max_new_tokens` as shown in the [Replicate docs](https://replicate.com/docs/guides/language-models/how-to-use#minimum-and-maximum-new-tokens). - Use callbacks instead of deprecated callback_manager. **Issue:** #26937 **Dependencies:** n/a **Twitter handle:** n/a --------- Signed-off-by: Fayvor Love Co-authored-by: Chester Curme --- .../langchain_community/llms/replicate.py | 2 +- .../integration_tests/llms/test_replicate.py | 25 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index ad4dceaf8f3..f6c4e15ba62 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -78,7 +78,7 @@ class Replicate(LLM): @classmethod def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in get_fields(cls).values()} + all_required_field_names = {field for field in get_fields(cls).keys()} input = values.pop("input", {}) if input: diff --git a/libs/community/tests/integration_tests/llms/test_replicate.py b/libs/community/tests/integration_tests/llms/test_replicate.py index fa58bd19c2d..dbf4ba2ff11 100644 --- a/libs/community/tests/integration_tests/llms/test_replicate.py +++ b/libs/community/tests/integration_tests/llms/test_replicate.py @@ -1,16 +1,18 @@ """Test Replicate API wrapper.""" -from langchain_core.callbacks import CallbackManager - from langchain_community.llms.replicate import Replicate from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -TEST_MODEL = "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5" # noqa: E501 +TEST_MODEL_HELLO = ( + "replicate/hello-world:" + + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" +) +TEST_MODEL_LANG = "meta/meta-llama-3-8b-instruct" def test_replicate_call() -> None: """Test simple non-streaming call to Replicate.""" - llm = Replicate(model=TEST_MODEL) + llm = Replicate(model=TEST_MODEL_HELLO) output = llm.invoke("What is LangChain") assert output assert isinstance(output, str) @@ -19,9 +21,10 @@ def test_replicate_call() -> None: def test_replicate_streaming_call() -> None: """Test streaming call to Replicate.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL) + llm = Replicate( + streaming=True, callbacks=[callback_handler], model=TEST_MODEL_HELLO + ) output = llm.invoke("What is LangChain") assert output assert isinstance(output, str) @@ -30,17 +33,17 @@ def test_replicate_streaming_call() -> None: def test_replicate_model_kwargs() -> None: """Test simple non-streaming call to Replicate.""" llm = Replicate( # type: ignore[call-arg] - model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01} + model=TEST_MODEL_LANG, model_kwargs={"max_new_tokens": 10, "temperature": 0.01} ) long_output = llm.invoke("What is LangChain") llm = Replicate( # type: ignore[call-arg] - model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01} + model=TEST_MODEL_LANG, model_kwargs={"max_new_tokens": 5, "temperature": 0.01} ) short_output = llm.invoke("What is LangChain") assert len(short_output) < len(long_output) - assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01} + assert llm.model_kwargs == {"max_new_tokens": 5, "temperature": 0.01} def test_replicate_input() -> None: - llm = Replicate(model=TEST_MODEL, input={"max_length": 10}) - assert llm.model_kwargs == {"max_length": 10} + llm = Replicate(model=TEST_MODEL_LANG, input={"max_new_tokens": 10}) + assert llm.model_kwargs == {"max_new_tokens": 10}