RL Chain with VowpalWabbit (#10242)

- Description: This PR adds a new chain `rl_chain.PickBest` for learned
prompt variable injection, detailed description and usage can be found
in the example notebook added. It essentially adds a
[VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) layer
before the llm call in order to learn or personalize prompt variable
selections.

Most of the code is to make the API simple and provide lots of defaults
and data wrangling that is needed to use Vowpal Wabbit, so that the user
of the chain doesn't have to worry about it.

- Dependencies:
[vowpal-wabbit-next](https://pypi.org/project/vowpal-wabbit-next/),
     - sentence-transformers (already a dep)
     - numpy (already a dep)
  - tagging @ataymano who contributed to this chain
  - Tag maintainer: @baskaryan
  - Twitter handle: @olgavrou


Added example notebook and unit tests
This commit is contained in:
olgavrou
2023-10-06 04:07:22 +03:00
committed by GitHub
parent 3a299b9680
commit 3b07c0cf3d
4 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,459 @@
from typing import Any, Dict
import pytest
from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate
from test_utils import MockEncoder, MockEncoderReturnsList
import langchain_experimental.rl_chain.base as rl_chain
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
encoded_keyword = "[encoded]"
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup() -> tuple:
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
llm = FakeListChatModel(responses=["hey"])
return llm, PROMPT
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
another_action=rl_chain.ToSelectFrom(actions),
)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(action=rl_chain.ToSelectFrom(actions))
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 3.0 # type: ignore
with pytest.raises(RuntimeError):
chain.update_with_delayed_score(
chain_response=response, score=100 # type: ignore
)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 3.0 # type: ignore
chain.update_with_delayed_score(
chain_response=response, score=100, force_score=True # type: ignore
)
assert selection_metadata.selected.score == 100.0 # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score is None # type: ignore
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
assert selection_metadata.selected.score == 100.0 # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer() -> None:
llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(
self,
inputs: Dict[str, Any],
llm_response: str,
event: pick_best_chain.PickBestEvent,
) -> float:
score = 200
return score
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=CustomSelectionScorer(),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 200.0 # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_everything_embedded() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
actions = [str1, str2, str3]
response = chain.run(
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
)
selection_metadata = response["selection_metadata"] # type: ignore
vw_str = feature_embedder.format(selection_metadata) # type: ignore
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_auto_embedder_is_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
str1 = "0"
str2 = "1"
str3 = "2"
ctx_str_1 = "context1"
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
actions = [str1, str2, str3]
response = chain.run(
User=pick_best_chain.base.BasedOn(ctx_str_1),
action=pick_best_chain.base.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"] # type: ignore
vw_str = feature_embedder.format(selection_metadata) # type: ignore
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_w_embeddings_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
str1 = "0"
str2 = "1"
str3 = "2"
ctx_str_1 = "context1"
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
actions = [str1, str2, str3]
response = chain.run(
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"] # type: ignore
vw_str = feature_embedder.format(selection_metadata) # type: ignore
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_w_embeddings_on() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=True, model=MockEncoderReturnsList()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"
str2 = "1"
ctx_str_1 = "context1"
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
actions = [str1, str2]
response = chain.run(
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"] # type: ignore
vw_str = feature_embedder.format(selection_metadata) # type: ignore
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=True, model=MockEncoderReturnsList()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"
str2 = "1"
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
actions = [str1, rl_chain.Embed(str2)]
response = chain.run(
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
User2=rl_chain.BasedOn(ctx_str_2),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"] # type: ignore
vw_str = feature_embedder.format(selection_metadata) # type: ignore
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified() -> None:
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=["hey", "100"])
chain = pick_best_chain.PickBest.from_llm(
llm=chain_llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 100.0 # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score is None # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm() -> None:
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=["300"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 300.0 # type: ignore
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_activate_and_deactivate_scorer() -> None:
_, PROMPT = setup()
llm = FakeListChatModel(responses=["hey1", "hey2", "hey3"])
scorer_llm = FakeListChatModel(responses=["300", "400"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey1" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 300.0 # type: ignore
chain.deactivate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey2" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score is None # type: ignore
chain.activate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey3" # type: ignore
selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 400.0 # type: ignore

View File

@@ -0,0 +1,370 @@
import pytest
from test_utils import MockEncoder
import langchain_experimental.rl_chain.base as rl_chain
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
encoded_keyword = "[encoded]"
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_action, based_on={}
)
with pytest.raises(ValueError):
feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
with pytest.raises(ValueError):
feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
selected=selected,
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
)
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
selected=selected,
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
context = {
"context1": rl_chain.Embed(ctx_str_1),
"context2": rl_chain.Embed(ctx_str_2),
}
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
named_actions = {
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
}
context = {
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
named_actions = {
"action1": [
{"a": str1, "b": rl_chain.Embed(str1)},
str2,
rl_chain.Embed(str3),
]
}
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
named_actions = {
"action1": [
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
str2,
rl_chain.EmbedAndKeep(str3),
]
}
context = {
"context1": ctx_str_1,
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
ctx_str = "this is a long context"
ctx_str_underscored = ctx_str.replace(" ", "_")
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))
# No embeddings
named_actions = {"action": [str1]}
context = {"context": ctx_str}
expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_no_embed
# Just embeddings
named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_embed
# Embeddings and raw features
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_embed_and_keep

View File

@@ -0,0 +1,422 @@
from typing import List, Union
import pytest
from test_utils import MockEncoder
import langchain_experimental.rl_chain.base as base
encoded_keyword = "[encoded]"
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb() -> None:
expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb() -> None:
str1 = "test"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
expected = [{"a_namespace": encoded_str1}]
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
assert (
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb() -> None:
# nested embeddings, innermost wins
str1 = "test"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
expected = [{"a_namespace": encoded_str1}]
assert (
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
== expected
)
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
assert (
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
== expected2
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb() -> None:
expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb() -> None:
str1 = "test"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
expected = [{"test_namespace": encoded_str1}]
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
assert (
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2() -> None:
str1 = "test"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
expected = [{"test_namespace": encoded_str1}]
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
assert (
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
assert (
base.embed(
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
)
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1,
"test_namespace2": str2 + " " + encoded_str2,
}
]
assert (
base.embed(
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
MockEncoder(),
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"a_namespace": encoded_str1},
{"a_namespace": encoded_str2},
{"a_namespace": encoded_str3},
]
assert (
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
== expected
)
expected_embed_and_keep = [
{"a_namespace": str1 + " " + encoded_str1},
{"a_namespace": str2 + " " + encoded_str2},
{"a_namespace": str3 + " " + encoded_str3},
]
assert (
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"a_namespace": str1},
{"a_namespace": encoded_str2},
{"a_namespace": encoded_str3},
]
assert (
base.embed(
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
)
== expected
)
expected_embed_and_keep = [
{"a_namespace": str1},
{"a_namespace": str2 + " " + encoded_str2},
{"a_namespace": str3 + " " + encoded_str3},
]
assert (
base.embed(
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
MockEncoder(),
"a_namespace",
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
expected = [
{"test_namespace": str1},
{"test_namespace": str2},
{"test_namespace": str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": str2},
{"test_namespace": str3},
],
MockEncoder(),
)
== expected
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"test_namespace": encoded_str1},
{"test_namespace": encoded_str2},
{"test_namespace": encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": base.Embed(str1)},
{"test_namespace": base.Embed(str2)},
{"test_namespace": base.Embed(str3)},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace": str1 + " " + encoded_str1},
{"test_namespace": str2 + " " + encoded_str2},
{"test_namespace": str3 + " " + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": base.EmbedAndKeep(str1)},
{"test_namespace": base.EmbedAndKeep(str2)},
{"test_namespace": base.EmbedAndKeep(str3)},
],
MockEncoder(),
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"test_namespace1": encoded_str1},
{"test_namespace2": encoded_str2},
{"test_namespace3": encoded_str3},
]
assert (
base.embed(
base.Embed(
[
{"test_namespace1": str1},
{"test_namespace2": str2},
{"test_namespace3": str3},
]
),
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace1": str1 + " " + encoded_str1},
{"test_namespace2": str2 + " " + encoded_str2},
{"test_namespace3": str3 + " " + encoded_str3},
]
assert (
base.embed(
base.EmbedAndKeep(
[
{"test_namespace1": str1},
{"test_namespace2": str2},
{"test_namespace3": str3},
]
),
MockEncoder(),
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"test_namespace": str1},
{"test_namespace": encoded_str2},
{"test_namespace": encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": base.Embed(str2)},
{"test_namespace": base.Embed(str3)},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace": str1},
{"test_namespace": str2 + " " + encoded_str2},
{"test_namespace": str3 + " " + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": base.EmbedAndKeep(str2)},
{"test_namespace": base.EmbedAndKeep(str3)},
],
MockEncoder(),
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
expected = [
{"test_namespace": encoded_str1, "test_namespace2": str1},
{"test_namespace": encoded_str2, "test_namespace2": str2},
{"test_namespace": encoded_str3, "test_namespace2": str3},
]
assert (
base.embed(
[
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1 + " " + encoded_str1,
"test_namespace2": str1,
},
{
"test_namespace": str2 + " " + encoded_str2,
"test_namespace2": str2,
},
{
"test_namespace": str3 + " " + encoded_str3,
"test_namespace2": str3,
},
]
assert (
base.embed(
[
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
],
MockEncoder(),
)
== expected_embed_and_keep
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb() -> None:
str1 = "test1"
str2 = "test2"
expected = [{"test_namespace": [str1, str2]}]
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
expected = [{"test_namespace": [str1, encoded_str2]}]
assert (
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
== expected
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())

View File

@@ -0,0 +1,15 @@
from typing import Any, List
class MockEncoder:
def encode(self, to_encode: str) -> str:
return "[encoded]" + to_encode
class MockEncoderReturnsList:
def encode(self, to_encode: Any) -> List:
if isinstance(to_encode, str):
return [1.0, 2.0]
elif isinstance(to_encode, List):
return [[1.0, 2.0] for _ in range(len(to_encode))]
raise ValueError("Invalid input type for unit test")