fix mypy errors in tests

This commit is contained in:
olgavrou 2023-08-29 05:28:43 -04:00
parent 0b8691c6e5
commit b3c0728de2
3 changed files with 58 additions and 52 deletions

View File

@ -1,3 +1,5 @@
from typing import Any, Dict
import pytest
from test_utils import MockEncoder
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup():
def setup() -> tuple:
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -19,7 +21,7 @@ def setup():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws():
def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws():
def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws():
def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = {"actions": ["0", "1", "2"]}
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws():
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force():
def test_update_with_delayed_score_force() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score():
def test_update_with_delayed_score() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer():
def test_user_defined_scorer() -> None:
llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs, llm_response: str) -> float:
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
score = 200
return score
@ -139,7 +141,7 @@ def test_user_defined_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings():
def test_default_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -173,7 +175,7 @@ def test_default_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off():
def test_default_embeddings_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -199,7 +201,7 @@ def test_default_embeddings_off():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings():
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified():
def test_default_no_scorer_specified() -> None:
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer():
def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm():
def test_auto_scorer_with_user_defined_llm() -> None:
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws():
def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError):

View File

@ -8,7 +8,7 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws():
def test_pickbest_textembedder_missing_context_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBestEvent(
@ -19,7 +19,7 @@ def test_pickbest_textembedder_missing_context_throws():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws():
def test_pickbest_textembedder_missing_actions_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"}
@ -29,7 +29,7 @@ def test_pickbest_textembedder_missing_actions_throws():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb():
def test_pickbest_textembedder_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
@ -41,7 +41,7 @@ def test_pickbest_textembedder_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb():
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
@ -57,7 +57,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb():
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = (
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb():
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -123,7 +123,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
@ -136,7 +136,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
@ -150,7 +150,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -262,7 +264,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -299,7 +303,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored():
def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")

View File

@ -7,13 +7,13 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb():
def test_simple_context_str_no_emb() -> None:
expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb():
def test_simple_context_str_w_emb() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"a_namespace": encoded_text + encoded_str1}]
@ -28,7 +28,7 @@ def test_simple_context_str_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb():
def test_simple_context_str_w_nested_emb() -> None:
# nested embeddings, innermost wins
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
@ -46,13 +46,13 @@ def test_simple_context_str_w_nested_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb():
def test_context_w_namespace_no_emb() -> None:
expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb():
def test_context_w_namespace_w_emb() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -67,7 +67,7 @@ def test_context_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2():
def test_context_w_namespace_w_emb2() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -82,7 +82,7 @@ def test_context_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb():
def test_context_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
@ -111,7 +111,7 @@ def test_context_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb():
def test_simple_action_strlist_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -120,7 +120,7 @@ def test_simple_action_strlist_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb():
def test_simple_action_strlist_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -148,7 +148,7 @@ def test_simple_action_strlist_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb():
def test_simple_action_strlist_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -181,7 +181,7 @@ def test_simple_action_strlist_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb():
def test_action_w_namespace_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -204,7 +204,7 @@ def test_action_w_namespace_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb():
def test_action_w_namespace_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -246,7 +246,7 @@ def test_action_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2():
def test_action_w_namespace_w_emb2() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -292,7 +292,7 @@ def test_action_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb():
def test_action_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -333,7 +333,7 @@ def test_action_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -384,7 +384,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb():
def test_one_namespace_w_list_of_features_no_emb() -> None:
str1 = "test1"
str2 = "test2"
expected = [{"test_namespace": [str1, str2]}]
@ -392,7 +392,7 @@ def test_one_namespace_w_list_of_features_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb():
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
@ -404,24 +404,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws():
def test_nested_list_features_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws():
def test_dict_in_list_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws():
def test_nested_dict_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws():
def test_list_of_tuples_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())