fix all mypy errors

This commit is contained in:
olgavrou 2023-08-29 03:59:01 -04:00
parent dd6fff1c62
commit a11ad11d06
2 changed files with 34 additions and 25 deletions

View File

@ -161,6 +161,9 @@ TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC):
def __init__(self, **kwargs: Any):
pass
@abstractmethod
def predict(self, event: TEvent) -> Any:
...
@ -233,7 +236,7 @@ class SelectionScorer(ABC, BaseModel):
class AutoSelectionScorer(SelectionScorer, BaseModel):
llm_chain: Union[LLMChain, None] = None
llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain):
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
- vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (VwPolicy): Policy used by the chain.
- policy (Type[VwPolicy]): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1.
@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain):
output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
policy: Policy
active_policy: Policy
auto_embed: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
@ -347,7 +350,9 @@ class RLChain(Generic[TEvent], Chain):
reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called."
)
self.policy = policy(
if self.active_policy is None:
self.active_policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
@ -355,6 +360,7 @@ class RLChain(Generic[TEvent], Chain):
feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step)
class Config:
@ -427,8 +433,8 @@ class RLChain(Generic[TEvent], Chain):
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None:
"""
@ -447,7 +453,7 @@ class RLChain(Generic[TEvent], Chain):
inputs = prepare_inputs_for_autoembed(inputs=inputs)
event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
prediction = self.active_policy.predict(event=event)
if self.metrics:
self.metrics.on_decision()
@ -484,8 +490,8 @@ class RLChain(Generic[TEvent], Chain):
if self.metrics:
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event)
self.policy.log(event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}}
@ -493,7 +499,7 @@ class RLChain(Generic[TEvent], Chain):
"""
This function should be called to save the state of the learned policy model.
"""
self.policy.save()
self.active_policy.save()
@property
def _chain_type(self) -> str:
@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool:
def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]:
) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object."""
join_char = ""
keep_str = ""
@ -533,9 +539,9 @@ def embed_string_type(
return {namespace: keep_str + join_char.join(map(str, encoded))}
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item."""
inner_dict: Dict[str, Union[str, List[str]]] = {}
inner_dict: Dict[str, Any] = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
@ -560,9 +566,7 @@ def embed_list_type(
def embed(
to_embed: Union[
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:

View File

@ -54,11 +54,16 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)
action_embs = (
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
if not context_emb or not action_embs:
raise ValueError(