mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
fix all mypy errors and some renaming and refactoring
This commit is contained in:
parent
a11ad11d06
commit
0b8691c6e5
@ -295,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RLChain(Generic[TEvent], Chain):
|
||||
class RLChain(Chain, Generic[TEvent]):
|
||||
"""
|
||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||
|
||||
@ -320,12 +320,24 @@ class RLChain(Generic[TEvent], Chain):
|
||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||
""" # noqa: E501
|
||||
|
||||
class _NoOpPolicy(Policy):
|
||||
"""Placeholder policy that does nothing"""
|
||||
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
return None
|
||||
|
||||
def learn(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
def log(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
llm_chain: Chain
|
||||
|
||||
output_key: str = "result" #: :meta private:
|
||||
prompt: BasePromptTemplate
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
active_policy: Policy
|
||||
active_policy: Policy = _NoOpPolicy()
|
||||
auto_embed: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
@ -351,7 +363,7 @@ class RLChain(Generic[TEvent], Chain):
|
||||
unless update_with_delayed_score is called."
|
||||
)
|
||||
|
||||
if self.active_policy is None:
|
||||
if isinstance(self.active_policy, RLChain._NoOpPolicy):
|
||||
self.active_policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
|
@ -17,7 +17,36 @@ logger = logging.getLogger(__name__)
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||
class PickBestSelected(base.Selected):
|
||||
index: Optional[int]
|
||||
probability: Optional[float]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
probability: Optional[float] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.index = index
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
|
||||
class PickBestEvent(base.Event[PickBestSelected]):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[PickBestSelected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||
"""
|
||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||
|
||||
@ -35,7 +64,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||
|
||||
self.model = model
|
||||
|
||||
def format(self, event: PickBest.Event) -> str:
|
||||
def format(self, event: PickBestEvent) -> str:
|
||||
"""
|
||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||
"""
|
||||
@ -93,7 +122,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||
return example_string[:-1]
|
||||
|
||||
|
||||
class PickBest(base.RLChain[PickBest.Event]):
|
||||
class PickBest(base.RLChain[PickBestEvent]):
|
||||
"""
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
||||
@ -121,33 +150,6 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||
""" # noqa E501
|
||||
|
||||
class Selected(base.Selected):
|
||||
index: Optional[int]
|
||||
probability: Optional[float]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
probability: Optional[float] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.index = index
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
class Event(base.Event[PickBest.Selected]):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[PickBest.Selected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -176,7 +178,7 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
if not actions:
|
||||
raise ValueError(
|
||||
@ -199,12 +201,15 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
to base the selected of ToSelectFrom on."
|
||||
)
|
||||
|
||||
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
return event
|
||||
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
event: PickBestEvent,
|
||||
prediction: List[Tuple[int, float]],
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
import numpy as np
|
||||
|
||||
prob_sum = sum(prob for _, prob in prediction)
|
||||
@ -214,7 +219,7 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
sampled_ap = prediction[sampled_index]
|
||||
sampled_action = sampled_ap[0]
|
||||
sampled_prob = sampled_ap[1]
|
||||
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
|
||||
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
|
||||
event.selected = selected
|
||||
|
||||
# only one key, value pair in event.to_select_from
|
||||
@ -224,8 +229,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: Event
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
self, llm_response: str, event: PickBestEvent
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
# only one key, value pair in event.to_select_from
|
||||
value = next(iter(event.to_select_from.values()))
|
||||
@ -243,8 +248,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> Event:
|
||||
self, event: PickBestEvent, score: Optional[float]
|
||||
) -> PickBestEvent:
|
||||
if event.selected:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
@ -11,7 +11,7 @@ encoded_text = "[ e n c o d e d ] "
|
||||
def test_pickbest_textembedder_missing_context_throws():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -21,7 +21,7 @@ def test_pickbest_textembedder_missing_context_throws():
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -33,7 +33,7 @@ def test_pickbest_textembedder_no_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -45,8 +45,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -63,8 +63,8 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -128,7 +128,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -141,8 +141,8 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -155,8 +155,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -219,8 +219,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -253,8 +253,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -290,8 +290,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -315,7 +315,7 @@ def test_raw_features_underscored():
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -325,7 +325,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -335,7 +335,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
|
Loading…
Reference in New Issue
Block a user