From a11ad11d063e8f5553bd25b8bd74e629e1e31dd6 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 29 Aug 2023 03:59:01 -0400 Subject: [PATCH] fix all mypy errors --- .../langchain/chains/rl_chain/base.py | 50 ++++++++++--------- .../chains/rl_chain/pick_best_chain.py | 9 +++- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index d97dd255afe..22ff60a403e 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -161,6 +161,9 @@ TEvent = TypeVar("TEvent", bound=Event) class Policy(ABC): + def __init__(self, **kwargs: Any): + pass + @abstractmethod def predict(self, event: TEvent) -> Any: ... @@ -233,7 +236,7 @@ class SelectionScorer(ABC, BaseModel): class AutoSelectionScorer(SelectionScorer, BaseModel): - llm_chain: Union[LLMChain, None] = None + llm_chain: LLMChain prompt: Union[BasePromptTemplate, None] = None scoring_criteria_template_str: Optional[str] = None @@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain): - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - reset_model (bool): If set to True, the model starts training from scratch. Default is False. - vw_cmd (List[str], optional): Command line arguments for the VW model. - - policy (VwPolicy): Policy used by the chain. + - policy (Type[VwPolicy]): Policy used by the chain. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - metrics_step (int): Step for the metrics tracker. Default is -1. @@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain): output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - policy: Policy + active_policy: Policy auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -347,14 +350,17 @@ class RLChain(Generic[TEvent], Chain): reinforcement learning will be done in the RL chain \ unless update_with_delayed_score is called." ) - self.policy = policy( - model_repo=ModelRepository( - model_save_dir, with_history=True, reset=reset_model - ), - vw_cmd=vw_cmd or [], - feature_embedder=feature_embedder, - vw_logger=VwLogger(vw_logs), - ) + + if self.active_policy is None: + self.active_policy = policy( + model_repo=ModelRepository( + model_save_dir, with_history=True, reset=reset_model + ), + vw_cmd=vw_cmd or [], + feature_embedder=feature_embedder, + vw_logger=VwLogger(vw_logs), + ) + self.metrics = MetricsTracker(step=metrics_step) class Config: @@ -427,8 +433,8 @@ class RLChain(Generic[TEvent], Chain): if self.metrics: self.metrics.on_feedback(score) self._call_after_scoring_before_learning(event=event, score=score) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) def set_auto_embed(self, auto_embed: bool) -> None: """ @@ -447,7 +453,7 @@ class RLChain(Generic[TEvent], Chain): inputs = prepare_inputs_for_autoembed(inputs=inputs) event: TEvent = self._call_before_predict(inputs=inputs) - prediction = self.policy.predict(event=event) + prediction = self.active_policy.predict(event=event) if self.metrics: self.metrics.on_decision() @@ -484,8 +490,8 @@ class RLChain(Generic[TEvent], Chain): if self.metrics: self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) return {self.output_key: {"response": output, "selection_metadata": event}} @@ -493,7 +499,7 @@ class RLChain(Generic[TEvent], Chain): """ This function should be called to save the state of the learned policy model. """ - self.policy.save() + self.active_policy.save() @property def _chain_type(self) -> str: @@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool: def embed_string_type( item: Union[str, _Embed], model: Any, namespace: Optional[str] = None -) -> Dict[str, str]: +) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" join_char = "" keep_str = "" @@ -533,9 +539,9 @@ def embed_string_type( return {namespace: keep_str + join_char.join(map(str, encoded))} -def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: +def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: """Helper function to embed a dictionary item.""" - inner_dict: Dict[str, Union[str, List[str]]] = {} + inner_dict: Dict[str, Any] = {} for ns, embed_item in item.items(): if isinstance(embed_item, list): inner_dict[ns] = [] @@ -560,9 +566,7 @@ def embed_list_type( def embed( - to_embed: Union[ - Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict] - ], + to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> List[Dict[str, Union[str, List[str]]]]: diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index ca920522680..691e0a99ce1 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -54,9 +54,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]): to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) ) + action_embs = ( - base.embed(to_select_from, self.model, to_select_from_var_name) - if event.to_select_from + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from else None )