From 6a1102d4c094b48afe43818a079f5758dc186a98 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 28 Aug 2023 06:58:33 -0400 Subject: [PATCH] mypy fixes and formatting --- .../langchain/chains/rl_chain/base.py | 150 ++++++++++-------- .../langchain/chains/rl_chain/metrics.py | 12 +- .../chains/rl_chain/model_repository.py | 6 +- .../chains/rl_chain/pick_best_chain.py | 22 +-- .../langchain/chains/rl_chain/vw_logger.py | 4 +- .../unit_tests/chains/rl_chain/test_utils.py | 2 +- 6 files changed, 108 insertions(+), 88 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 28baf898d2c..d97dd255afe 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -3,7 +3,18 @@ from __future__ import annotations import logging import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -26,47 +37,47 @@ logger = logging.getLogger(__name__) class _BasedOn: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def BasedOn(anything): +def BasedOn(anything: Any) -> _BasedOn: return _BasedOn(anything) class _ToSelectFrom: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def ToSelectFrom(anything): +def ToSelectFrom(anything: Any) -> _ToSelectFrom: if not isinstance(anything, list): raise ValueError("ToSelectFrom must be a list to select from") return _ToSelectFrom(anything) class _Embed: - def __init__(self, value, keep=False): + def __init__(self, value: Any, keep: bool = False): self.value = value self.keep = keep - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def Embed(anything, keep=False): +def Embed(anything: Any, keep: bool = False) -> Any: if isinstance(anything, _ToSelectFrom): return ToSelectFrom(Embed(anything.value, keep=keep)) elif isinstance(anything, _BasedOn): @@ -80,7 +91,7 @@ def Embed(anything, keep=False): return _Embed(anything, keep=keep) -def EmbedAndKeep(anything): +def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]): +def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: to_select_from = { k: inputs[k].value for k in inputs.keys() @@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]): return based_on, to_select_from -def prepare_inputs_for_autoembed(inputs: Dict[str, Any]): +def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]: """ go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status @@ -134,29 +145,35 @@ class Selected(ABC): pass -class Event(ABC): - inputs: Dict[str, Any] - selected: Optional[Selected] +TSelected = TypeVar("TSelected", bound=Selected) - def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None): + +class Event(Generic[TSelected], ABC): + inputs: Dict[str, Any] + selected: Optional[TSelected] + + def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None): self.inputs = inputs self.selected = selected +TEvent = TypeVar("TEvent", bound=Event) + + class Policy(ABC): @abstractmethod - def predict(self, event: Event) -> Any: - pass + def predict(self, event: TEvent) -> Any: + ... @abstractmethod - def learn(self, event: Event): - pass + def learn(self, event: TEvent) -> None: + ... @abstractmethod - def log(self, event: Event): - pass + def log(self, event: TEvent) -> None: + ... - def save(self): + def save(self) -> None: pass @@ -164,11 +181,11 @@ class VwPolicy(Policy): def __init__( self, model_repo: ModelRepository, - vw_cmd: Sequence[str], + vw_cmd: List[str], feature_embedder: Embedder, vw_logger: VwLogger, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) self.model_repo = model_repo @@ -176,7 +193,7 @@ class VwPolicy(Policy): self.feature_embedder = feature_embedder self.vw_logger = vw_logger - def predict(self, event: Event) -> Any: + def predict(self, event: TEvent) -> Any: import vowpal_wabbit_next as vw text_parser = vw.TextFormatParser(self.workspace) @@ -184,7 +201,7 @@ class VwPolicy(Policy): parse_lines(text_parser, self.feature_embedder.format(event)) ) - def learn(self, event: Event): + def learn(self, event: TEvent) -> None: import vowpal_wabbit_next as vw vw_ex = self.feature_embedder.format(event) @@ -192,19 +209,19 @@ class VwPolicy(Policy): multi_ex = parse_lines(text_parser, vw_ex) self.workspace.learn_one(multi_ex) - def log(self, event: Event): + def log(self, event: TEvent) -> None: if self.vw_logger.logging_enabled(): vw_ex = self.feature_embedder.format(event) self.vw_logger.log(vw_ex) - def save(self): - self.model_repo.save() + def save(self) -> None: + self.model_repo.save(self.workspace) -class Embedder(ABC): +class Embedder(Generic[TEvent], ABC): @abstractmethod - def format(self, event: Event) -> str: - pass + def format(self, event: TEvent) -> str: + ... class SelectionScorer(ABC, BaseModel): @@ -212,7 +229,7 @@ class SelectionScorer(ABC, BaseModel): @abstractmethod def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: - pass + ... class AutoSelectionScorer(SelectionScorer, BaseModel): @@ -243,7 +260,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel): return chat_prompt @root_validator(pre=True) - def set_prompt_and_llm_chain(cls, values): + def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm = values.get("llm") prompt = values.get("prompt") scoring_criteria_template_str = values.get("scoring_criteria_template_str") @@ -275,7 +292,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel): ) -class RLChain(Chain): +class RLChain(Generic[TEvent], Chain): """ The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. @@ -305,7 +322,7 @@ class RLChain(Chain): output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - policy: Optional[Policy] + policy: Policy auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -314,14 +331,14 @@ class RLChain(Chain): def __init__( self, feature_embedder: Embedder, - model_save_dir="./", - reset_model=False, - vw_cmd=None, - policy=VwPolicy, + model_save_dir: str = "./", + reset_model: bool = False, + vw_cmd: Optional[List[str]] = None, + policy: Type[Policy] = VwPolicy, vw_logs: Optional[Union[str, os.PathLike]] = None, - metrics_step=-1, - *args, - **kwargs, + metrics_step: int = -1, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) if self.selection_scorer is None: @@ -374,29 +391,29 @@ class RLChain(Chain): ) @abstractmethod - def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: - pass + def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent: + ... @abstractmethod def _call_after_predict_before_llm( - self, inputs: Dict[str, Any], event: Event, prediction: Any - ) -> Tuple[Dict[str, Any], Event]: - pass + self, inputs: Dict[str, Any], event: TEvent, prediction: Any + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_llm_before_scoring( - self, llm_response: str, event: Event - ) -> Tuple[Dict[str, Any], Event]: - pass + self, llm_response: str, event: TEvent + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_scoring_before_learning( - self, event: Event, score: Optional[float] - ) -> Event: - pass + self, event: TEvent, score: Optional[float] + ) -> TEvent: + ... def update_with_delayed_score( - self, score: float, event: Event, force_score=False + self, score: float, event: TEvent, force_score: bool = False ) -> None: """ Updates the learned policy with the score provided. @@ -407,7 +424,8 @@ class RLChain(Chain): "The selection scorer is set, and force_score was not set to True. \ Please set force_score=True to use this function." ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) self._call_after_scoring_before_learning(event=event, score=score) self.policy.learn(event=event) self.policy.log(event=event) @@ -422,15 +440,16 @@ class RLChain(Chain): self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() if self.auto_embed: inputs = prepare_inputs_for_autoembed(inputs=inputs) - event = self._call_before_predict(inputs=inputs) + event: TEvent = self._call_before_predict(inputs=inputs) prediction = self.policy.predict(event=event) - self.metrics.on_decision() + if self.metrics: + self.metrics.on_decision() next_chain_inputs, event = self._call_after_predict_before_llm( inputs=inputs, event=event, prediction=prediction @@ -462,7 +481,8 @@ class RLChain(Chain): f"The selection scorer was not able to score, \ and the chain was not able to adjust to this response, error: {e}" ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) self.policy.learn(event=event) self.policy.log(event=event) @@ -515,7 +535,7 @@ def embed_string_type( def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a dictionary item.""" - inner_dict = {} + inner_dict: Dict[str, Union[str, List[str]]] = {} for ns, embed_item in item.items(): if isinstance(embed_item, list): inner_dict[ns] = [] @@ -530,7 +550,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: def embed_list_type( item: list, model: Any, namespace: Optional[str] = None ) -> List[Dict[str, Union[str, List[str]]]]: - ret_list = [] + ret_list: List[Dict[str, Union[str, List[str]]]] = [] for embed_item in item: if isinstance(embed_item, dict): ret_list.append(embed_dict_type(embed_item, model)) diff --git a/libs/langchain/langchain/chains/rl_chain/metrics.py b/libs/langchain/langchain/chains/rl_chain/metrics.py index b7ec949c9ea..4d6306f7760 100644 --- a/libs/langchain/langchain/chains/rl_chain/metrics.py +++ b/libs/langchain/langchain/chains/rl_chain/metrics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union if TYPE_CHECKING: import pandas as pd @@ -6,11 +6,11 @@ if TYPE_CHECKING: class MetricsTracker: def __init__(self, step: int): - self._history = [] - self._step = step - self._i = 0 - self._num = 0 - self._denom = 0 + self._history: List[Dict[str, Union[int, float]]] = [] + self._step: int = step + self._i: int = 0 + self._num: float = 0 + self._denom: float = 0 @property def score(self) -> float: diff --git a/libs/langchain/langchain/chains/rl_chain/model_repository.py b/libs/langchain/langchain/chains/rl_chain/model_repository.py index eea866d1cf3..87f162df0ab 100644 --- a/libs/langchain/langchain/chains/rl_chain/model_repository.py +++ b/libs/langchain/langchain/chains/rl_chain/model_repository.py @@ -4,7 +4,7 @@ import logging import os import shutil from pathlib import Path -from typing import TYPE_CHECKING, Sequence, Union +from typing import TYPE_CHECKING, List, Union if TYPE_CHECKING: import vowpal_wabbit_next as vw @@ -22,7 +22,7 @@ class ModelRepository: self.folder = Path(folder) self.model_path = self.folder / "latest.vw" self.with_history = with_history - if reset and self.has_history: + if reset and self.has_history(): logger.warning( "There is non empty history which is recommended to be cleaned up" ) @@ -44,7 +44,7 @@ class ModelRepository: if self.with_history: # write history shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw") - def load(self, commandline: Sequence[str]) -> "vw.Workspace": + def load(self, commandline: List[str]) -> "vw.Workspace": import vowpal_wabbit_next as vw model_data = None diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 6e1a1a5eff7..e60e685a0b5 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class PickBestFeatureEmbedder(base.Embedder): +class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]): """ Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy @@ -25,7 +25,7 @@ class PickBestFeatureEmbedder(base.Embedder): model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. """ # noqa E501 - def __init__(self, model: Optional[Any] = None, *args, **kwargs): + def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) if model is None: @@ -88,7 +88,7 @@ class PickBestFeatureEmbedder(base.Embedder): return example_string[:-1] -class PickBest(base.RLChain): +class PickBest(base.RLChain[PickBest.Event]): """ `PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. @@ -131,7 +131,7 @@ class PickBest(base.RLChain): self.probability = probability self.score = score - class Event(base.Event): + class Event(base.Event[PickBest.Selected]): def __init__( self, inputs: Dict[str, Any], @@ -146,8 +146,8 @@ class PickBest(base.RLChain): def __init__( self, feature_embedder: Optional[PickBestFeatureEmbedder] = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): vw_cmd = kwargs.get("vw_cmd", []) if not vw_cmd: @@ -170,7 +170,7 @@ class PickBest(base.RLChain): super().__init__(feature_embedder=feature_embedder, *args, **kwargs) - def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: + def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) if not actions: raise ValueError( @@ -198,7 +198,7 @@ class PickBest(base.RLChain): def _call_after_predict_before_llm( self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] - ) -> Tuple[Dict[str, Any], PickBest.Event]: + ) -> Tuple[Dict[str, Any], Event]: import numpy as np prob_sum = sum(prob for _, prob in prediction) @@ -218,8 +218,8 @@ class PickBest(base.RLChain): return next_chain_inputs, event def _call_after_llm_before_scoring( - self, llm_response: str, event: PickBest.Event - ) -> Tuple[Dict[str, Any], PickBest.Event]: + self, llm_response: str, event: Event + ) -> Tuple[Dict[str, Any], Event]: next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) @@ -232,7 +232,7 @@ class PickBest(base.RLChain): return next_chain_inputs, event def _call_after_scoring_before_learning( - self, event: PickBest.Event, score: Optional[float] + self, event: Event, score: Optional[float] ) -> Event: event.selected.score = score return event diff --git a/libs/langchain/langchain/chains/rl_chain/vw_logger.py b/libs/langchain/langchain/chains/rl_chain/vw_logger.py index 4fa47175395..e8d2e1541f1 100644 --- a/libs/langchain/langchain/chains/rl_chain/vw_logger.py +++ b/libs/langchain/langchain/chains/rl_chain/vw_logger.py @@ -9,10 +9,10 @@ class VwLogger: if self.path: self.path.parent.mkdir(parents=True, exist_ok=True) - def log(self, vw_ex: str): + def log(self, vw_ex: str) -> None: if self.path: with open(self.path, "a") as f: f.write(f"{vw_ex}\n\n") - def logging_enabled(self): + def logging_enabled(self) -> bool: return bool(self.path) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py index 6d54d20d921..625c37ee000 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py @@ -1,3 +1,3 @@ class MockEncoder: - def encode(self, to_encode): + def encode(self, to_encode: str) -> str: return "[encoded]" + to_encode