mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
fix mypy errors in tests
This commit is contained in:
parent
0b8691c6e5
commit
b3c0728de2
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def setup():
|
def setup() -> tuple:
|
||||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||||
|
|
||||||
@ -19,7 +21,7 @@ def setup():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_multiple_ToSelectFrom_throws():
|
def test_multiple_ToSelectFrom_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_missing_basedOn_from_throws():
|
def test_missing_basedOn_from_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_ToSelectFrom_not_a_list_throws():
|
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = {"actions": ["0", "1", "2"]}
|
actions = {"actions": ["0", "1", "2"]}
|
||||||
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||||
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score_force():
|
def test_update_with_delayed_score_force() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||||
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score():
|
def test_update_with_delayed_score() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||||
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_user_defined_scorer():
|
def test_user_defined_scorer() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
|
|
||||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||||
def score_response(self, inputs, llm_response: str) -> float:
|
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||||
score = 200
|
score = 200
|
||||||
return score
|
return score
|
||||||
|
|
||||||
@ -139,7 +141,7 @@ def test_user_defined_scorer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings():
|
def test_default_embeddings() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -173,7 +175,7 @@ def test_default_embeddings():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_off():
|
def test_default_embeddings_off() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -199,7 +201,7 @@ def test_default_embeddings_off():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_no_scorer_specified():
|
def test_default_no_scorer_specified() -> None:
|
||||||
_, PROMPT = setup()
|
_, PROMPT = setup()
|
||||||
chain_llm = FakeListChatModel(responses=[100])
|
chain_llm = FakeListChatModel(responses=[100])
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
||||||
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_explicitly_no_scorer():
|
def test_explicitly_no_scorer() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||||
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_auto_scorer_with_user_defined_llm():
|
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
scorer_llm = FakeListChatModel(responses=[300])
|
scorer_llm = FakeListChatModel(responses=[300])
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_calling_chain_w_reserved_inputs_throws():
|
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -8,7 +8,7 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_context_throws():
|
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_action = {"action": ["0", "1", "2"]}
|
named_action = {"action": ["0", "1", "2"]}
|
||||||
event = pick_best_chain.PickBestEvent(
|
event = pick_best_chain.PickBestEvent(
|
||||||
@ -19,7 +19,7 @@ def test_pickbest_textembedder_missing_context_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_actions_throws():
|
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
event = pick_best_chain.PickBestEvent(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||||
@ -29,7 +29,7 @@ def test_pickbest_textembedder_missing_actions_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_no_label_no_emb():
|
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
@ -41,7 +41,7 @@ def test_pickbest_textembedder_no_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
@ -57,7 +57,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = (
|
expected = (
|
||||||
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -123,7 +123,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
@ -136,7 +136,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
@ -150,7 +150,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -262,7 +264,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -299,7 +303,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_raw_features_underscored():
|
def test_raw_features_underscored() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "this is a long string"
|
str1 = "this is a long string"
|
||||||
str1_underscored = str1.replace(" ", "_")
|
str1_underscored = str1.replace(" ", "_")
|
||||||
|
@ -7,13 +7,13 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_no_emb():
|
def test_simple_context_str_no_emb() -> None:
|
||||||
expected = [{"a_namespace": "test"}]
|
expected = [{"a_namespace": "test"}]
|
||||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_emb():
|
def test_simple_context_str_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"a_namespace": encoded_text + encoded_str1}]
|
expected = [{"a_namespace": encoded_text + encoded_str1}]
|
||||||
@ -28,7 +28,7 @@ def test_simple_context_str_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_nested_emb():
|
def test_simple_context_str_w_nested_emb() -> None:
|
||||||
# nested embeddings, innermost wins
|
# nested embeddings, innermost wins
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
@ -46,13 +46,13 @@ def test_simple_context_str_w_nested_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_no_emb():
|
def test_context_w_namespace_no_emb() -> None:
|
||||||
expected = [{"test_namespace": "test"}]
|
expected = [{"test_namespace": "test"}]
|
||||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb():
|
def test_context_w_namespace_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||||
@ -67,7 +67,7 @@ def test_context_w_namespace_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb2():
|
def test_context_w_namespace_w_emb2() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||||
@ -82,7 +82,7 @@ def test_context_w_namespace_w_emb2():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_some_emb():
|
def test_context_w_namespace_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = " ".join(char for char in str2)
|
encoded_str2 = " ".join(char for char in str2)
|
||||||
@ -111,7 +111,7 @@ def test_context_w_namespace_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_no_emb():
|
def test_simple_action_strlist_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -120,7 +120,7 @@ def test_simple_action_strlist_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_emb():
|
def test_simple_action_strlist_w_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -148,7 +148,7 @@ def test_simple_action_strlist_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_some_emb():
|
def test_simple_action_strlist_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -181,7 +181,7 @@ def test_simple_action_strlist_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_no_emb():
|
def test_action_w_namespace_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -204,7 +204,7 @@ def test_action_w_namespace_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb():
|
def test_action_w_namespace_w_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -246,7 +246,7 @@ def test_action_w_namespace_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb2():
|
def test_action_w_namespace_w_emb2() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -292,7 +292,7 @@ def test_action_w_namespace_w_emb2():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_some_emb():
|
def test_action_w_namespace_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -333,7 +333,7 @@ def test_action_w_namespace_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -384,7 +384,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_no_emb():
|
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
expected = [{"test_namespace": [str1, str2]}]
|
expected = [{"test_namespace": [str1, str2]}]
|
||||||
@ -392,7 +392,7 @@ def test_one_namespace_w_list_of_features_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_w_some_emb():
|
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = " ".join(char for char in str2)
|
encoded_str2 = " ".join(char for char in str2)
|
||||||
@ -404,24 +404,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_list_features_throws():
|
def test_nested_list_features_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_dict_in_list_throws():
|
def test_dict_in_list_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_dict_throws():
|
def test_nested_dict_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_list_of_tuples_throws():
|
def test_list_of_tuples_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
||||||
|
Loading…
Reference in New Issue
Block a user