mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 00:58:32 +00:00
RL Chain with VowpalWabbit (#10242)
- Description: This PR adds a new chain `rl_chain.PickBest` for learned prompt variable injection, detailed description and usage can be found in the example notebook added. It essentially adds a [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) layer before the llm call in order to learn or personalize prompt variable selections. Most of the code is to make the API simple and provide lots of defaults and data wrangling that is needed to use Vowpal Wabbit, so that the user of the chain doesn't have to worry about it. - Dependencies: [vowpal-wabbit-next](https://pypi.org/project/vowpal-wabbit-next/), - sentence-transformers (already a dep) - numpy (already a dep) - tagging @ataymano who contributed to this chain - Tag maintainer: @baskaryan - Twitter handle: @olgavrou Added example notebook and unit tests
This commit is contained in:
@@ -0,0 +1,459 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def setup() -> tuple:
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
|
||||
llm = FakeListChatModel(responses=["hey"])
|
||||
return llm, PROMPT
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_multiple_ToSelectFrom_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
another_action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_missing_basedOn_from_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = {"actions": ["0", "1", "2"]}
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 3.0 # type: ignore
|
||||
with pytest.raises(RuntimeError):
|
||||
chain.update_with_delayed_score(
|
||||
chain_response=response, score=100 # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_force() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 3.0 # type: ignore
|
||||
chain.update_with_delayed_score(
|
||||
chain_response=response, score=100, force_score=True # type: ignore
|
||||
)
|
||||
assert selection_metadata.selected.score == 100.0 # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score is None # type: ignore
|
||||
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
|
||||
assert selection_metadata.selected.score == 100.0 # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_user_defined_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
|
||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||
def score_response(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
llm_response: str,
|
||||
event: pick_best_chain.PickBestEvent,
|
||||
) -> float:
|
||||
score = 200
|
||||
return score
|
||||
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=CustomSelectionScorer(),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 200.0 # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_everything_embedded() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
||||
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_auto_embedder_is_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn(ctx_str_1),
|
||||
action=pick_best_chain.base.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_on() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
ctx_str_1 = "context1"
|
||||
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, str2]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, rl_chain.Embed(str2)]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
|
||||
User2=rl_chain.BasedOn(ctx_str_2),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_no_scorer_specified() -> None:
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=["hey", "100"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=chain_llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 100.0 # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_explicitly_no_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score is None # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=["300"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 300.0 # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_activate_and_deactivate_scorer() -> None:
|
||||
_, PROMPT = setup()
|
||||
llm = FakeListChatModel(responses=["hey1", "hey2", "hey3"])
|
||||
scorer_llm = FakeListChatModel(responses=["300", "400"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey1" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 300.0 # type: ignore
|
||||
|
||||
chain.deactivate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey2" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score is None # type: ignore
|
||||
|
||||
chain.activate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey3" # type: ignore
|
||||
selection_metadata = response["selection_metadata"] # type: ignore
|
||||
assert selection_metadata.selected.score == 400.0 # type: ignore
|
@@ -0,0 +1,370 @@
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||
context = {
|
||||
"context1": rl_chain.Embed(ctx_str_1),
|
||||
"context2": rl_chain.Embed(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||
None
|
||||
):
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||
}
|
||||
context = {
|
||||
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.Embed(str1)},
|
||||
str2,
|
||||
rl_chain.Embed(str3),
|
||||
]
|
||||
}
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
|
||||
str2,
|
||||
rl_chain.EmbedAndKeep(str3),
|
||||
]
|
||||
}
|
||||
context = {
|
||||
"context1": ctx_str_1,
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_raw_features_underscored() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "this is a long string"
|
||||
str1_underscored = str1.replace(" ", "_")
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
|
||||
ctx_str = "this is a long context"
|
||||
ctx_str_underscored = ctx_str.replace(" ", "_")
|
||||
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))
|
||||
|
||||
# No embeddings
|
||||
named_actions = {"action": [str1]}
|
||||
context = {"context": ctx_str}
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_no_embed
|
||||
|
||||
# Just embeddings
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed
|
||||
|
||||
# Embeddings and raw features
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed_and_keep
|
@@ -0,0 +1,422 @@
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as base
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_no_emb() -> None:
|
||||
expected = [{"a_namespace": "test"}]
|
||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
|
||||
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_nested_emb() -> None:
|
||||
# nested embeddings, innermost wins
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
|
||||
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
|
||||
== expected2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_no_emb() -> None:
|
||||
expected = [{"test_namespace": "test"}]
|
||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb2() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1,
|
||||
"test_namespace2": str2 + " " + encoded_str2,
|
||||
}
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
||||
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": encoded_str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1 + " " + encoded_str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
||||
MockEncoder(),
|
||||
"a_namespace",
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1)},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1 + " " + encoded_str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1)},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb2() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace1": encoded_str1},
|
||||
{"test_namespace2": encoded_str2},
|
||||
{"test_namespace3": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.Embed(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace1": str1 + " " + encoded_str1},
|
||||
{"test_namespace2": str2 + " " + encoded_str2},
|
||||
{"test_namespace3": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.EmbedAndKeep(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
||||
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
||||
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1 + " " + encoded_str1,
|
||||
"test_namespace2": str1,
|
||||
},
|
||||
{
|
||||
"test_namespace": str2 + " " + encoded_str2,
|
||||
"test_namespace2": str2,
|
||||
},
|
||||
{
|
||||
"test_namespace": str3 + " " + encoded_str3,
|
||||
"test_namespace2": str3,
|
||||
},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
expected = [{"test_namespace": [str1, str2]}]
|
||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": [str1, encoded_str2]}]
|
||||
assert (
|
||||
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_list_features_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_dict_in_list_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_dict_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_list_of_tuples_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
class MockEncoder:
|
||||
def encode(self, to_encode: str) -> str:
|
||||
return "[encoded]" + to_encode
|
||||
|
||||
|
||||
class MockEncoderReturnsList:
|
||||
def encode(self, to_encode: Any) -> List:
|
||||
if isinstance(to_encode, str):
|
||||
return [1.0, 2.0]
|
||||
elif isinstance(to_encode, List):
|
||||
return [[1.0, 2.0] for _ in range(len(to_encode))]
|
||||
raise ValueError("Invalid input type for unit test")
|
Reference in New Issue
Block a user