diff --git a/libs/experimental/langchain_experimental/rl_chain/base.py b/libs/experimental/langchain_experimental/rl_chain/base.py index 9b3d7e018af..facf977450f 100644 --- a/libs/experimental/langchain_experimental/rl_chain/base.py +++ b/libs/experimental/langchain_experimental/rl_chain/base.py @@ -19,19 +19,20 @@ from typing import ( from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain_experimental.rl_chain.metrics import ( - MetricsTrackerAverage, - MetricsTrackerRollingWindow, -) -from langchain_experimental.rl_chain.model_repository import ModelRepository -from langchain_experimental.rl_chain.vw_logger import VwLogger from langchain.prompts import ( BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) + from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_experimental.rl_chain.metrics import ( + MetricsTrackerAverage, + MetricsTrackerRollingWindow, +) +from langchain_experimental.rl_chain.model_repository import ModelRepository +from langchain_experimental.rl_chain.vw_logger import VwLogger if TYPE_CHECKING: import vowpal_wabbit_next as vw diff --git a/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py b/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py index 090db9d8633..c17a5f8bc22 100644 --- a/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py +++ b/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py @@ -3,12 +3,13 @@ from __future__ import annotations import logging from typing import Any, Dict, List, Optional, Tuple, Type, Union -import langchain_experimental.rl_chain.base as base from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate +import langchain_experimental.rl_chain.base as base + logger = logging.getLogger(__name__) # sentinel object used to distinguish between diff --git a/libs/experimental/tests/integration_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/experimental/tests/integration_tests/chains/rl_chain/test_pick_best_chain_call.py index add69a9c9e5..765e52e05eb 100644 --- a/libs/experimental/tests/integration_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/experimental/tests/integration_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -1,12 +1,12 @@ from typing import Any, Dict import pytest +from langchain.chat_models import FakeListChatModel +from langchain.prompts.prompt import PromptTemplate from test_utils import MockEncoder, MockEncoderReturnsList import langchain_experimental.rl_chain.base as rl_chain import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain -from langchain.chat_models import FakeListChatModel -from langchain.prompts.prompt import PromptTemplate encoded_keyword = "[encoded]" @@ -90,11 +90,13 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None: User=rl_chain.BasedOn("Context"), action=rl_chain.ToSelectFrom(actions), ) - assert response["response"] == "hey" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score == 3.0 # type: ignore + assert response["response"] == "hey" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score == 3.0 # type: ignore with pytest.raises(RuntimeError): - chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore + chain.update_with_delayed_score( + chain_response=response, score=100 # type: ignore + ) @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -209,7 +211,7 @@ def test_everything_embedded() -> None: action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)), ) selection_metadata = response["selection_metadata"] # type: ignore - vw_str = feature_embedder.format(selection_metadata) # type: ignore + vw_str = feature_embedder.format(selection_metadata) # type: ignore assert vw_str == expected @@ -237,7 +239,7 @@ def test_default_auto_embedder_is_off() -> None: action=pick_best_chain.base.ToSelectFrom(actions), ) selection_metadata = response["selection_metadata"] # type: ignore - vw_str = feature_embedder.format(selection_metadata) # type: ignore + vw_str = feature_embedder.format(selection_metadata) # type: ignore assert vw_str == expected @@ -264,8 +266,8 @@ def test_default_w_embeddings_off() -> None: User=rl_chain.BasedOn(ctx_str_1), action=rl_chain.ToSelectFrom(actions), ) - selection_metadata = response["selection_metadata"] # type: ignore - vw_str = feature_embedder.format(selection_metadata) # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + vw_str = feature_embedder.format(selection_metadata) # type: ignore assert vw_str == expected @@ -292,8 +294,8 @@ def test_default_w_embeddings_on() -> None: User=rl_chain.BasedOn(ctx_str_1), action=rl_chain.ToSelectFrom(actions), ) - selection_metadata = response["selection_metadata"] # type: ignore - vw_str = feature_embedder.format(selection_metadata) # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + vw_str = feature_embedder.format(selection_metadata) # type: ignore assert vw_str == expected @@ -324,8 +326,8 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: User2=rl_chain.BasedOn(ctx_str_2), action=rl_chain.ToSelectFrom(actions), ) - selection_metadata = response["selection_metadata"] # type: ignore - vw_str = feature_embedder.format(selection_metadata) # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + vw_str = feature_embedder.format(selection_metadata) # type: ignore assert vw_str == expected @@ -345,9 +347,9 @@ def test_default_no_scorer_specified() -> None: action=rl_chain.ToSelectFrom(["0", "1", "2"]), ) # chain llm used for both basic prompt and for scoring - assert response["response"] == "hey" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score == 100.0 # type: ignore + assert response["response"] == "hey" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score == 100.0 # type: ignore @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -366,9 +368,9 @@ def test_explicitly_no_scorer() -> None: action=rl_chain.ToSelectFrom(["0", "1", "2"]), ) # chain llm used for both basic prompt and for scoring - assert response["response"] == "hey" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score is None # type: ignore + assert response["response"] == "hey" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score is None # type: ignore @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -388,9 +390,9 @@ def test_auto_scorer_with_user_defined_llm() -> None: action=rl_chain.ToSelectFrom(["0", "1", "2"]), ) # chain llm used for both basic prompt and for scoring - assert response["response"] == "hey" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score == 300.0 # type: ignore + assert response["response"] == "hey" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score == 300.0 # type: ignore @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -434,24 +436,24 @@ def test_activate_and_deactivate_scorer() -> None: action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), ) # chain llm used for both basic prompt and for scoring - assert response["response"] == "hey1" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score == 300.0 # type: ignore + assert response["response"] == "hey1" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score == 300.0 # type: ignore chain.deactivate_selection_scorer() response = chain.run( User=pick_best_chain.base.BasedOn("Context"), action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), ) - assert response["response"] == "hey2" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score is None # type: ignore + assert response["response"] == "hey2" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score is None # type: ignore chain.activate_selection_scorer() response = chain.run( User=pick_best_chain.base.BasedOn("Context"), action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), ) - assert response["response"] == "hey3" # type: ignore - selection_metadata = response["selection_metadata"] # type: ignore - assert selection_metadata.selected.score == 400.0 # type: ignore + assert response["response"] == "hey3" # type: ignore + selection_metadata = response["selection_metadata"] # type: ignore + assert selection_metadata.selected.score == 400.0 # type: ignore