mypy fixes and formatting

This commit is contained in:
olgavrou 2023-08-28 06:58:33 -04:00
parent 7725192a0d
commit 6a1102d4c0
6 changed files with 108 additions and 88 deletions

View File

@ -3,7 +3,18 @@ from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
class _BasedOn:
def __init__(self, value):
def __init__(self, value: Any):
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def BasedOn(anything):
def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything)
class _ToSelectFrom:
def __init__(self, value):
def __init__(self, value: Any):
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def ToSelectFrom(anything):
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
class _Embed:
def __init__(self, value, keep=False):
def __init__(self, value: Any, keep: bool = False):
self.value = value
self.keep = keep
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def Embed(anything, keep=False):
def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
return _Embed(anything, keep=keep)
def EmbedAndKeep(anything):
def EmbedAndKeep(anything: Any) -> Any:
return Embed(anything, keep=True)
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
return [parser.parse_line(line) for line in input_str.split("\n")]
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
return based_on, to_select_from
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
@ -134,29 +145,35 @@ class Selected(ABC):
pass
class Event(ABC):
inputs: Dict[str, Any]
selected: Optional[Selected]
TSelected = TypeVar("TSelected", bound=Selected)
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
class Event(Generic[TSelected], ABC):
inputs: Dict[str, Any]
selected: Optional[TSelected]
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
self.inputs = inputs
self.selected = selected
TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC):
@abstractmethod
def predict(self, event: Event) -> Any:
pass
def predict(self, event: TEvent) -> Any:
...
@abstractmethod
def learn(self, event: Event):
pass
def learn(self, event: TEvent) -> None:
...
@abstractmethod
def log(self, event: Event):
pass
def log(self, event: TEvent) -> None:
...
def save(self):
def save(self) -> None:
pass
@ -164,11 +181,11 @@ class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: Sequence[str],
vw_cmd: List[str],
feature_embedder: Embedder,
vw_logger: VwLogger,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.model_repo = model_repo
@ -176,7 +193,7 @@ class VwPolicy(Policy):
self.feature_embedder = feature_embedder
self.vw_logger = vw_logger
def predict(self, event: Event) -> Any:
def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace)
@ -184,7 +201,7 @@ class VwPolicy(Policy):
parse_lines(text_parser, self.feature_embedder.format(event))
)
def learn(self, event: Event):
def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event)
@ -192,19 +209,19 @@ class VwPolicy(Policy):
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
def log(self, event: Event):
def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event)
self.vw_logger.log(vw_ex)
def save(self):
self.model_repo.save()
def save(self) -> None:
self.model_repo.save(self.workspace)
class Embedder(ABC):
class Embedder(Generic[TEvent], ABC):
@abstractmethod
def format(self, event: Event) -> str:
pass
def format(self, event: TEvent) -> str:
...
class SelectionScorer(ABC, BaseModel):
@ -212,7 +229,7 @@ class SelectionScorer(ABC, BaseModel):
@abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
pass
...
class AutoSelectionScorer(SelectionScorer, BaseModel):
@ -243,7 +260,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
return chat_prompt
@root_validator(pre=True)
def set_prompt_and_llm_chain(cls, values):
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm = values.get("llm")
prompt = values.get("prompt")
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
@ -275,7 +292,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
)
class RLChain(Chain):
class RLChain(Generic[TEvent], Chain):
"""
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
@ -305,7 +322,7 @@ class RLChain(Chain):
output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
policy: Optional[Policy]
policy: Policy
auto_embed: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
@ -314,14 +331,14 @@ class RLChain(Chain):
def __init__(
self,
feature_embedder: Embedder,
model_save_dir="./",
reset_model=False,
vw_cmd=None,
policy=VwPolicy,
model_save_dir: str = "./",
reset_model: bool = False,
vw_cmd: Optional[List[str]] = None,
policy: Type[Policy] = VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step=-1,
*args,
**kwargs,
metrics_step: int = -1,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if self.selection_scorer is None:
@ -374,29 +391,29 @@ class RLChain(Chain):
)
@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
pass
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
...
@abstractmethod
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: Any
) -> Tuple[Dict[str, Any], Event]:
pass
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_llm_before_scoring(
self, llm_response: str, event: Event
) -> Tuple[Dict[str, Any], Event]:
pass
self, llm_response: str, event: TEvent
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float]
) -> Event:
pass
self, event: TEvent, score: Optional[float]
) -> TEvent:
...
def update_with_delayed_score(
self, score: float, event: Event, force_score=False
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
@ -407,7 +424,8 @@ class RLChain(Chain):
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)
@ -422,15 +440,16 @@ class RLChain(Chain):
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if self.auto_embed:
inputs = prepare_inputs_for_autoembed(inputs=inputs)
event = self._call_before_predict(inputs=inputs)
event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
self.metrics.on_decision()
if self.metrics:
self.metrics.on_decision()
next_chain_inputs, event = self._call_after_predict_before_llm(
inputs=inputs, event=event, prediction=prediction
@ -462,7 +481,8 @@ class RLChain(Chain):
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
if self.metrics:
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event)
self.policy.log(event=event)
@ -515,7 +535,7 @@ def embed_string_type(
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a dictionary item."""
inner_dict = {}
inner_dict: Dict[str, Union[str, List[str]]] = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
@ -530,7 +550,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
ret_list = []
ret_list: List[Dict[str, Union[str, List[str]]]] = []
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union
if TYPE_CHECKING:
import pandas as pd
@ -6,11 +6,11 @@ if TYPE_CHECKING:
class MetricsTracker:
def __init__(self, step: int):
self._history = []
self._step = step
self._i = 0
self._num = 0
self._denom = 0
self._history: List[Dict[str, Union[int, float]]] = []
self._step: int = step
self._i: int = 0
self._num: float = 0
self._denom: float = 0
@property
def score(self) -> float:

View File

@ -4,7 +4,7 @@ import logging
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, Union
from typing import TYPE_CHECKING, List, Union
if TYPE_CHECKING:
import vowpal_wabbit_next as vw
@ -22,7 +22,7 @@ class ModelRepository:
self.folder = Path(folder)
self.model_path = self.folder / "latest.vw"
self.with_history = with_history
if reset and self.has_history:
if reset and self.has_history():
logger.warning(
"There is non empty history which is recommended to be cleaned up"
)
@ -44,7 +44,7 @@ class ModelRepository:
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
def load(self, commandline: List[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw
model_data = None

View File

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder):
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
"""
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
@ -25,7 +25,7 @@ class PickBestFeatureEmbedder(base.Embedder):
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
""" # noqa E501
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
if model is None:
@ -88,7 +88,7 @@ class PickBestFeatureEmbedder(base.Embedder):
return example_string[:-1]
class PickBest(base.RLChain):
class PickBest(base.RLChain[PickBest.Event]):
"""
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
@ -131,7 +131,7 @@ class PickBest(base.RLChain):
self.probability = probability
self.score = score
class Event(base.Event):
class Event(base.Event[PickBest.Selected]):
def __init__(
self,
inputs: Dict[str, Any],
@ -146,8 +146,8 @@ class PickBest(base.RLChain):
def __init__(
self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
vw_cmd = kwargs.get("vw_cmd", [])
if not vw_cmd:
@ -170,7 +170,7 @@ class PickBest(base.RLChain):
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
@ -198,7 +198,7 @@ class PickBest(base.RLChain):
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], PickBest.Event]:
) -> Tuple[Dict[str, Any], Event]:
import numpy as np
prob_sum = sum(prob for _, prob in prediction)
@ -218,8 +218,8 @@ class PickBest(base.RLChain):
return next_chain_inputs, event
def _call_after_llm_before_scoring(
self, llm_response: str, event: PickBest.Event
) -> Tuple[Dict[str, Any], PickBest.Event]:
self, llm_response: str, event: Event
) -> Tuple[Dict[str, Any], Event]:
next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
@ -232,7 +232,7 @@ class PickBest(base.RLChain):
return next_chain_inputs, event
def _call_after_scoring_before_learning(
self, event: PickBest.Event, score: Optional[float]
self, event: Event, score: Optional[float]
) -> Event:
event.selected.score = score
return event

View File

@ -9,10 +9,10 @@ class VwLogger:
if self.path:
self.path.parent.mkdir(parents=True, exist_ok=True)
def log(self, vw_ex: str):
def log(self, vw_ex: str) -> None:
if self.path:
with open(self.path, "a") as f:
f.write(f"{vw_ex}\n\n")
def logging_enabled(self):
def logging_enabled(self) -> bool:
return bool(self.path)

View File

@ -1,3 +1,3 @@
class MockEncoder:
def encode(self, to_encode):
def encode(self, to_encode: str) -> str:
return "[encoded]" + to_encode