mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
rename rl_chain_base to base and update paths and imports
This commit is contained in:
parent
b422dc035f
commit
a6f9dccc35
@ -1,5 +1,5 @@
|
||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||
from langchain.chains.rl_chain.rl_chain_base import (
|
||||
from langchain.chains.rl_chain.base import (
|
||||
Embed,
|
||||
BasedOn,
|
||||
ToSelectFrom,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import langchain.chains.rl_chain.rl_chain_base as base
|
||||
import langchain.chains.rl_chain.base as base
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
@ -1,4 +1,5 @@
|
||||
import langchain.chains.rl_chain as rl_chain
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
from test_utils import MockEncoder
|
||||
import pytest
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
@ -17,7 +18,7 @@ def setup():
|
||||
|
||||
def test_multiple_ToSelectFrom_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
@ -29,7 +30,7 @@ def test_multiple_ToSelectFrom_throws():
|
||||
|
||||
def test_missing_basedOn_from_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
@ -37,7 +38,7 @@ def test_missing_basedOn_from_throws():
|
||||
|
||||
def test_ToSelectFrom_not_a_list_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = {"actions": ["0", "1", "2"]}
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
@ -50,7 +51,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
@ -71,7 +72,7 @@ def test_update_with_delayed_score_force():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
@ -92,7 +93,7 @@ def test_update_with_delayed_score_force():
|
||||
|
||||
def test_update_with_delayed_score():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
@ -115,7 +116,7 @@ def test_user_defined_scorer():
|
||||
score = 200
|
||||
return score
|
||||
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
@ -130,8 +131,8 @@ def test_user_defined_scorer():
|
||||
|
||||
def test_default_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||
)
|
||||
|
||||
@ -163,8 +164,8 @@ def test_default_embeddings():
|
||||
|
||||
def test_default_embeddings_off():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
@ -188,8 +189,8 @@ def test_default_embeddings_off():
|
||||
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||
)
|
||||
|
||||
@ -223,7 +224,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
def test_default_no_scorer_specified():
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=[100])
|
||||
chain = rl_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
@ -236,7 +237,7 @@ def test_default_no_scorer_specified():
|
||||
|
||||
def test_explicitly_no_scorer():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||
)
|
||||
response = chain.run(
|
||||
@ -252,7 +253,7 @@ def test_explicitly_no_scorer():
|
||||
def test_auto_scorer_with_user_defined_llm():
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=[300])
|
||||
chain = rl_chain.PickBest.from_llm(
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
||||
@ -269,7 +270,7 @@ def test_auto_scorer_with_user_defined_llm():
|
||||
|
||||
def test_calling_chain_w_reserved_inputs_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import langchain.chains.rl_chain as rl_chain
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import pytest
|
||||
@ -7,9 +8,9 @@ encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
def test_pickbest_textembedder_missing_context_throws():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -17,8 +18,8 @@ def test_pickbest_textembedder_missing_context_throws():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_missing_actions_throws():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
event = rl_chain.PickBest.Event(
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -26,10 +27,10 @@ def test_pickbest_textembedder_missing_actions_throws():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_no_label_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -37,11 +38,11 @@ def test_pickbest_textembedder_no_label_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -52,13 +53,13 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -69,7 +70,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
@ -83,8 +84,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -92,7 +93,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
@ -106,8 +107,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -115,11 +116,11 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -127,12 +128,12 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -140,12 +141,12 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -153,7 +154,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -176,8 +177,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
|
||||
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -185,7 +186,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -210,8 +211,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
|
||||
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -219,7 +220,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -243,8 +244,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """
|
||||
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -252,7 +253,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
|
||||
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -279,8 +280,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """
|
||||
|
||||
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = rl_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -288,7 +289,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
|
||||
|
||||
def test_raw_features_underscored():
|
||||
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "this is a long string"
|
||||
str1_underscored = str1.replace(" ", "_")
|
||||
encoded_str1 = encoded_text + " ".join(char for char in str1)
|
||||
@ -303,7 +304,7 @@ def test_raw_features_underscored():
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -313,7 +314,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -323,7 +324,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """
|
||||
event = rl_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
|
@ -1,4 +1,4 @@
|
||||
import langchain.chains.rl_chain as base
|
||||
import langchain.chains.rl_chain.base as base
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import pytest
|
||||
|
Loading…
Reference in New Issue
Block a user