fix all mypy errors

This commit is contained in:
olgavrou 2023-08-29 03:59:01 -04:00
parent dd6fff1c62
commit a11ad11d06
2 changed files with 34 additions and 25 deletions

View File

@ -161,6 +161,9 @@ TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC): class Policy(ABC):
def __init__(self, **kwargs: Any):
pass
@abstractmethod @abstractmethod
def predict(self, event: TEvent) -> Any: def predict(self, event: TEvent) -> Any:
... ...
@ -233,7 +236,7 @@ class SelectionScorer(ABC, BaseModel):
class AutoSelectionScorer(SelectionScorer, BaseModel): class AutoSelectionScorer(SelectionScorer, BaseModel):
llm_chain: Union[LLMChain, None] = None llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None scoring_criteria_template_str: Optional[str] = None
@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain):
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
- reset_model (bool): If set to True, the model starts training from scratch. Default is False. - reset_model (bool): If set to True, the model starts training from scratch. Default is False.
- vw_cmd (List[str], optional): Command line arguments for the VW model. - vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (VwPolicy): Policy used by the chain. - policy (Type[VwPolicy]): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1. - metrics_step (int): Step for the metrics tracker. Default is -1.
@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain):
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None] selection_scorer: Union[SelectionScorer, None]
policy: Policy active_policy: Policy
auto_embed: bool = True auto_embed: bool = True
selected_input_key = "rl_chain_selected" selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on" selected_based_on_input_key = "rl_chain_selected_based_on"
@ -347,14 +350,17 @@ class RLChain(Generic[TEvent], Chain):
reinforcement learning will be done in the RL chain \ reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called." unless update_with_delayed_score is called."
) )
self.policy = policy(
model_repo=ModelRepository( if self.active_policy is None:
model_save_dir, with_history=True, reset=reset_model self.active_policy = policy(
), model_repo=ModelRepository(
vw_cmd=vw_cmd or [], model_save_dir, with_history=True, reset=reset_model
feature_embedder=feature_embedder, ),
vw_logger=VwLogger(vw_logs), vw_cmd=vw_cmd or [],
) feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step) self.metrics = MetricsTracker(step=metrics_step)
class Config: class Config:
@ -427,8 +433,8 @@ class RLChain(Generic[TEvent], Chain):
if self.metrics: if self.metrics:
self.metrics.on_feedback(score) self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score) self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event) self.active_policy.learn(event=event)
self.policy.log(event=event) self.active_policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None: def set_auto_embed(self, auto_embed: bool) -> None:
""" """
@ -447,7 +453,7 @@ class RLChain(Generic[TEvent], Chain):
inputs = prepare_inputs_for_autoembed(inputs=inputs) inputs = prepare_inputs_for_autoembed(inputs=inputs)
event: TEvent = self._call_before_predict(inputs=inputs) event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event) prediction = self.active_policy.predict(event=event)
if self.metrics: if self.metrics:
self.metrics.on_decision() self.metrics.on_decision()
@ -484,8 +490,8 @@ class RLChain(Generic[TEvent], Chain):
if self.metrics: if self.metrics:
self.metrics.on_feedback(score) self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event) event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event) self.active_policy.learn(event=event)
self.policy.log(event=event) self.active_policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}} return {self.output_key: {"response": output, "selection_metadata": event}}
@ -493,7 +499,7 @@ class RLChain(Generic[TEvent], Chain):
""" """
This function should be called to save the state of the learned policy model. This function should be called to save the state of the learned policy model.
""" """
self.policy.save() self.active_policy.save()
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool:
def embed_string_type( def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]: ) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object.""" """Helper function to embed a string or an _Embed object."""
join_char = "" join_char = ""
keep_str = "" keep_str = ""
@ -533,9 +539,9 @@ def embed_string_type(
return {namespace: keep_str + join_char.join(map(str, encoded))} return {namespace: keep_str + join_char.join(map(str, encoded))}
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item.""" """Helper function to embed a dictionary item."""
inner_dict: Dict[str, Union[str, List[str]]] = {} inner_dict: Dict[str, Any] = {}
for ns, embed_item in item.items(): for ns, embed_item in item.items():
if isinstance(embed_item, list): if isinstance(embed_item, list):
inner_dict[ns] = [] inner_dict[ns] = []
@ -560,9 +566,7 @@ def embed_list_type(
def embed( def embed(
to_embed: Union[ to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
model: Any, model: Any,
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]: ) -> List[Dict[str, Union[str, List[str]]]]:

View File

@ -54,9 +54,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
to_select_from_var_name, to_select_from = next( to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None) iter(event.to_select_from.items()), (None, None)
) )
action_embs = ( action_embs = (
base.embed(to_select_from, self.model, to_select_from_var_name) (
if event.to_select_from base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None else None
) )