mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
move everything into experimental
This commit is contained in:
@@ -0,0 +1,457 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def setup() -> tuple:
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
|
||||
llm = FakeListChatModel(responses=["hey"])
|
||||
return llm, PROMPT
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_multiple_ToSelectFrom_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
another_action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_missing_basedOn_from_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = {"actions": ["0", "1", "2"]}
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
with pytest.raises(RuntimeError):
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_force() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
chain.update_with_delayed_score(
|
||||
chain_response=response, score=100, force_score=True
|
||||
)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_user_defined_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
|
||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||
def score_response(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
llm_response: str,
|
||||
event: pick_best_chain.PickBestEvent,
|
||||
) -> float:
|
||||
score = 200
|
||||
return score
|
||||
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=CustomSelectionScorer(),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 200.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_everything_embedded() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
||||
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_auto_embedder_is_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn(ctx_str_1),
|
||||
action=pick_best_chain.base.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_on() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
ctx_str_1 = "context1"
|
||||
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, str2]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, rl_chain.Embed(str2)]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
|
||||
User2=rl_chain.BasedOn(ctx_str_2),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_no_scorer_specified() -> None:
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=["hey", "100"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=chain_llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_explicitly_no_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=["300"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_activate_and_deactivate_scorer() -> None:
|
||||
_, PROMPT = setup()
|
||||
llm = FakeListChatModel(responses=["hey1", "hey2", "hey3"])
|
||||
scorer_llm = FakeListChatModel(responses=["300", "400"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey1"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
chain.deactivate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey2"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
chain.activate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey3"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 400.0
|
@@ -0,0 +1,370 @@
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||
context = {
|
||||
"context1": rl_chain.Embed(ctx_str_1),
|
||||
"context2": rl_chain.Embed(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||
None
|
||||
):
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||
}
|
||||
context = {
|
||||
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.Embed(str1)},
|
||||
str2,
|
||||
rl_chain.Embed(str3),
|
||||
]
|
||||
}
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
|
||||
str2,
|
||||
rl_chain.EmbedAndKeep(str3),
|
||||
]
|
||||
}
|
||||
context = {
|
||||
"context1": ctx_str_1,
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_raw_features_underscored() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "this is a long string"
|
||||
str1_underscored = str1.replace(" ", "_")
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
|
||||
ctx_str = "this is a long context"
|
||||
ctx_str_underscored = ctx_str.replace(" ", "_")
|
||||
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))
|
||||
|
||||
# No embeddings
|
||||
named_actions = {"action": [str1]}
|
||||
context = {"context": ctx_str}
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_no_embed
|
||||
|
||||
# Just embeddings
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed
|
||||
|
||||
# Embeddings and raw features
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed_and_keep
|
@@ -0,0 +1,422 @@
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as base
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_no_emb() -> None:
|
||||
expected = [{"a_namespace": "test"}]
|
||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
|
||||
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_nested_emb() -> None:
|
||||
# nested embeddings, innermost wins
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
|
||||
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
|
||||
== expected2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_no_emb() -> None:
|
||||
expected = [{"test_namespace": "test"}]
|
||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb2() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1,
|
||||
"test_namespace2": str2 + " " + encoded_str2,
|
||||
}
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
||||
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": encoded_str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1 + " " + encoded_str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
||||
MockEncoder(),
|
||||
"a_namespace",
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1)},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1 + " " + encoded_str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1)},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb2() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace1": encoded_str1},
|
||||
{"test_namespace2": encoded_str2},
|
||||
{"test_namespace3": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.Embed(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace1": str1 + " " + encoded_str1},
|
||||
{"test_namespace2": str2 + " " + encoded_str2},
|
||||
{"test_namespace3": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.EmbedAndKeep(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
||||
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
||||
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1 + " " + encoded_str1,
|
||||
"test_namespace2": str1,
|
||||
},
|
||||
{
|
||||
"test_namespace": str2 + " " + encoded_str2,
|
||||
"test_namespace2": str2,
|
||||
},
|
||||
{
|
||||
"test_namespace": str3 + " " + encoded_str3,
|
||||
"test_namespace2": str3,
|
||||
},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
expected = [{"test_namespace": [str1, str2]}]
|
||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": [str1, encoded_str2]}]
|
||||
assert (
|
||||
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_list_features_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_dict_in_list_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_dict_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_list_of_tuples_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
class MockEncoder:
|
||||
def encode(self, to_encode: str) -> str:
|
||||
return "[encoded]" + to_encode
|
||||
|
||||
|
||||
class MockEncoderReturnsList:
|
||||
def encode(self, to_encode: Any) -> List:
|
||||
if isinstance(to_encode, str):
|
||||
return [1.0, 2.0]
|
||||
elif isinstance(to_encode, List):
|
||||
return [[1.0, 2.0] for _ in range(len(to_encode))]
|
||||
raise ValueError("Invalid input type for unit test")
|
Reference in New Issue
Block a user