mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
fix all mypy errors and some renaming and refactoring
This commit is contained in:
parent
a11ad11d06
commit
0b8691c6e5
@ -295,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RLChain(Generic[TEvent], Chain):
|
class RLChain(Chain, Generic[TEvent]):
|
||||||
"""
|
"""
|
||||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||||
|
|
||||||
@ -320,12 +320,24 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class _NoOpPolicy(Policy):
|
||||||
|
"""Placeholder policy that does nothing"""
|
||||||
|
|
||||||
|
def predict(self, event: TEvent) -> Any:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def learn(self, event: TEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log(self, event: TEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
llm_chain: Chain
|
llm_chain: Chain
|
||||||
|
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
selection_scorer: Union[SelectionScorer, None]
|
selection_scorer: Union[SelectionScorer, None]
|
||||||
active_policy: Policy
|
active_policy: Policy = _NoOpPolicy()
|
||||||
auto_embed: bool = True
|
auto_embed: bool = True
|
||||||
selected_input_key = "rl_chain_selected"
|
selected_input_key = "rl_chain_selected"
|
||||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||||
@ -351,7 +363,7 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
unless update_with_delayed_score is called."
|
unless update_with_delayed_score is called."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.active_policy is None:
|
if isinstance(self.active_policy, RLChain._NoOpPolicy):
|
||||||
self.active_policy = policy(
|
self.active_policy = policy(
|
||||||
model_repo=ModelRepository(
|
model_repo=ModelRepository(
|
||||||
model_save_dir, with_history=True, reset=reset_model
|
model_save_dir, with_history=True, reset=reset_model
|
||||||
|
@ -17,7 +17,36 @@ logger = logging.getLogger(__name__)
|
|||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
class PickBestSelected(base.Selected):
|
||||||
|
index: Optional[int]
|
||||||
|
probability: Optional[float]
|
||||||
|
score: Optional[float]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
probability: Optional[float] = None,
|
||||||
|
score: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.index = index
|
||||||
|
self.probability = probability
|
||||||
|
self.score = score
|
||||||
|
|
||||||
|
|
||||||
|
class PickBestEvent(base.Event[PickBestSelected]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
to_select_from: Dict[str, Any],
|
||||||
|
based_on: Dict[str, Any],
|
||||||
|
selected: Optional[PickBestSelected] = None,
|
||||||
|
):
|
||||||
|
super().__init__(inputs=inputs, selected=selected)
|
||||||
|
self.to_select_from = to_select_from
|
||||||
|
self.based_on = based_on
|
||||||
|
|
||||||
|
|
||||||
|
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||||
"""
|
"""
|
||||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||||
|
|
||||||
@ -35,7 +64,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def format(self, event: PickBest.Event) -> str:
|
def format(self, event: PickBestEvent) -> str:
|
||||||
"""
|
"""
|
||||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||||
"""
|
"""
|
||||||
@ -93,7 +122,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
|||||||
return example_string[:-1]
|
return example_string[:-1]
|
||||||
|
|
||||||
|
|
||||||
class PickBest(base.RLChain[PickBest.Event]):
|
class PickBest(base.RLChain[PickBestEvent]):
|
||||||
"""
|
"""
|
||||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||||
|
|
||||||
@ -121,33 +150,6 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||||
""" # noqa E501
|
""" # noqa E501
|
||||||
|
|
||||||
class Selected(base.Selected):
|
|
||||||
index: Optional[int]
|
|
||||||
probability: Optional[float]
|
|
||||||
score: Optional[float]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
index: Optional[int] = None,
|
|
||||||
probability: Optional[float] = None,
|
|
||||||
score: Optional[float] = None,
|
|
||||||
):
|
|
||||||
self.index = index
|
|
||||||
self.probability = probability
|
|
||||||
self.score = score
|
|
||||||
|
|
||||||
class Event(base.Event[PickBest.Selected]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
to_select_from: Dict[str, Any],
|
|
||||||
based_on: Dict[str, Any],
|
|
||||||
selected: Optional[PickBest.Selected] = None,
|
|
||||||
):
|
|
||||||
super().__init__(inputs=inputs, selected=selected)
|
|
||||||
self.to_select_from = to_select_from
|
|
||||||
self.based_on = based_on
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -176,7 +178,7 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||||
if not actions:
|
if not actions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -199,12 +201,15 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
to base the selected of ToSelectFrom on."
|
to base the selected of ToSelectFrom on."
|
||||||
)
|
)
|
||||||
|
|
||||||
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
|
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
self,
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
inputs: Dict[str, Any],
|
||||||
|
event: PickBestEvent,
|
||||||
|
prediction: List[Tuple[int, float]],
|
||||||
|
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
prob_sum = sum(prob for _, prob in prediction)
|
prob_sum = sum(prob for _, prob in prediction)
|
||||||
@ -214,7 +219,7 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
sampled_ap = prediction[sampled_index]
|
sampled_ap = prediction[sampled_index]
|
||||||
sampled_action = sampled_ap[0]
|
sampled_action = sampled_ap[0]
|
||||||
sampled_prob = sampled_ap[1]
|
sampled_prob = sampled_ap[1]
|
||||||
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
|
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
|
||||||
event.selected = selected
|
event.selected = selected
|
||||||
|
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
@ -224,8 +229,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_llm_before_scoring(
|
def _call_after_llm_before_scoring(
|
||||||
self, llm_response: str, event: Event
|
self, llm_response: str, event: PickBestEvent
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||||
next_chain_inputs = event.inputs.copy()
|
next_chain_inputs = event.inputs.copy()
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
value = next(iter(event.to_select_from.values()))
|
value = next(iter(event.to_select_from.values()))
|
||||||
@ -243,8 +248,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: Event, score: Optional[float]
|
self, event: PickBestEvent, score: Optional[float]
|
||||||
) -> Event:
|
) -> PickBestEvent:
|
||||||
if event.selected:
|
if event.selected:
|
||||||
event.selected.score = score
|
event.selected.score = score
|
||||||
return event
|
return event
|
||||||
|
@ -11,7 +11,7 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
def test_pickbest_textembedder_missing_context_throws():
|
def test_pickbest_textembedder_missing_context_throws():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_action = {"action": ["0", "1", "2"]}
|
named_action = {"action": ["0", "1", "2"]}
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_action, based_on={}
|
inputs={}, to_select_from=named_action, based_on={}
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -21,7 +21,7 @@ def test_pickbest_textembedder_missing_context_throws():
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_actions_throws():
|
def test_pickbest_textembedder_missing_actions_throws():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -33,7 +33,7 @@ def test_pickbest_textembedder_no_label_no_emb():
|
|||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -45,8 +45,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
|||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={},
|
inputs={},
|
||||||
to_select_from=named_actions,
|
to_select_from=named_actions,
|
||||||
based_on={"context": "context"},
|
based_on={"context": "context"},
|
||||||
@ -63,8 +63,8 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
|||||||
expected = (
|
expected = (
|
||||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||||
)
|
)
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={},
|
inputs={},
|
||||||
to_select_from=named_actions,
|
to_select_from=named_actions,
|
||||||
based_on={"context": "context"},
|
based_on={"context": "context"},
|
||||||
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
|||||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
|||||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -128,7 +128,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
|||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -141,8 +141,8 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
|||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -155,8 +155,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
|||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -219,8 +219,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -253,8 +253,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
|||||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -290,8 +290,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -315,7 +315,7 @@ def test_raw_features_underscored():
|
|||||||
expected_no_embed = (
|
expected_no_embed = (
|
||||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||||
)
|
)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -325,7 +325,7 @@ def test_raw_features_underscored():
|
|||||||
named_actions = {"action": rl_chain.Embed([str1])}
|
named_actions = {"action": rl_chain.Embed([str1])}
|
||||||
context = {"context": rl_chain.Embed(ctx_str)}
|
context = {"context": rl_chain.Embed(ctx_str)}
|
||||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -335,7 +335,7 @@ def test_raw_features_underscored():
|
|||||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
|
Loading…
Reference in New Issue
Block a user