fix all mypy errors and some renaming and refactoring

This commit is contained in:
olgavrou 2023-08-29 05:19:19 -04:00
parent a11ad11d06
commit 0b8691c6e5
3 changed files with 86 additions and 69 deletions

View File

@ -295,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
)
class RLChain(Generic[TEvent], Chain):
class RLChain(Chain, Generic[TEvent]):
"""
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
@ -320,12 +320,24 @@ class RLChain(Generic[TEvent], Chain):
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
""" # noqa: E501
class _NoOpPolicy(Policy):
"""Placeholder policy that does nothing"""
def predict(self, event: TEvent) -> Any:
return None
def learn(self, event: TEvent) -> None:
pass
def log(self, event: TEvent) -> None:
pass
llm_chain: Chain
output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
active_policy: Policy
active_policy: Policy = _NoOpPolicy()
auto_embed: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
@ -351,7 +363,7 @@ class RLChain(Generic[TEvent], Chain):
unless update_with_delayed_score is called."
)
if self.active_policy is None:
if isinstance(self.active_policy, RLChain._NoOpPolicy):
self.active_policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model

View File

@ -17,7 +17,36 @@ logger = logging.getLogger(__name__)
SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
class PickBestSelected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
"""
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
@ -35,7 +64,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
self.model = model
def format(self, event: PickBest.Event) -> str:
def format(self, event: PickBestEvent) -> str:
"""
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
"""
@ -93,7 +122,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
return example_string[:-1]
class PickBest(base.RLChain[PickBest.Event]):
class PickBest(base.RLChain[PickBestEvent]):
"""
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
@ -121,33 +150,6 @@ class PickBest(base.RLChain[PickBest.Event]):
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
""" # noqa E501
class Selected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class Event(base.Event[PickBest.Selected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBest.Selected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
def __init__(
self,
*args: Any,
@ -176,7 +178,7 @@ class PickBest(base.RLChain[PickBest.Event]):
super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
@ -199,12 +201,15 @@ class PickBest(base.RLChain[PickBest.Event]):
to base the selected of ToSelectFrom on."
)
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], Event]:
self,
inputs: Dict[str, Any],
event: PickBestEvent,
prediction: List[Tuple[int, float]],
) -> Tuple[Dict[str, Any], PickBestEvent]:
import numpy as np
prob_sum = sum(prob for _, prob in prediction)
@ -214,7 +219,7 @@ class PickBest(base.RLChain[PickBest.Event]):
sampled_ap = prediction[sampled_index]
sampled_action = sampled_ap[0]
sampled_prob = sampled_ap[1]
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
event.selected = selected
# only one key, value pair in event.to_select_from
@ -224,8 +229,8 @@ class PickBest(base.RLChain[PickBest.Event]):
return next_chain_inputs, event
def _call_after_llm_before_scoring(
self, llm_response: str, event: Event
) -> Tuple[Dict[str, Any], Event]:
self, llm_response: str, event: PickBestEvent
) -> Tuple[Dict[str, Any], PickBestEvent]:
next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
@ -243,8 +248,8 @@ class PickBest(base.RLChain[PickBest.Event]):
return next_chain_inputs, event
def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float]
) -> Event:
self, event: PickBestEvent, score: Optional[float]
) -> PickBestEvent:
if event.selected:
event.selected.score = score
return event

View File

@ -11,7 +11,7 @@ encoded_text = "[ e n c o d e d ] "
def test_pickbest_textembedder_missing_context_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_action, based_on={}
)
with pytest.raises(ValueError):
@ -21,7 +21,7 @@ def test_pickbest_textembedder_missing_context_throws():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
with pytest.raises(ValueError):
@ -33,7 +33,7 @@ def test_pickbest_textembedder_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
)
vw_ex_str = feature_embedder.format(event)
@ -45,8 +45,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -63,8 +63,8 @@ def test_pickbest_textembedder_w_full_label_no_emb():
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
)
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -128,7 +128,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -141,8 +141,8 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -155,8 +155,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
}
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -219,8 +219,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
}
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -253,8 +253,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -290,8 +290,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -315,7 +315,7 @@ def test_raw_features_underscored():
expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
)
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -325,7 +325,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -335,7 +335,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)