Merge pull request #3 from VowpalWabbit/fix_linting

Fix mypy errors
This commit is contained in:
olgavrou 2023-08-29 05:58:03 -04:00 committed by GitHub
commit f7fb083aba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 296 additions and 251 deletions

View File

@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
from langchain.chains.rl_chain.pick_best_chain import PickBest from langchain.chains.rl_chain.pick_best_chain import PickBest
def configure_logger(): def configure_logger() -> None:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()

View File

@ -3,7 +3,18 @@ from __future__ import annotations
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
class _BasedOn: class _BasedOn:
def __init__(self, value): def __init__(self, value: Any):
self.value = value self.value = value
def __str__(self): def __str__(self) -> str:
return str(self.value) return str(self.value)
__repr__ = __str__ __repr__ = __str__
def BasedOn(anything): def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything) return _BasedOn(anything)
class _ToSelectFrom: class _ToSelectFrom:
def __init__(self, value): def __init__(self, value: Any):
self.value = value self.value = value
def __str__(self): def __str__(self) -> str:
return str(self.value) return str(self.value)
__repr__ = __str__ __repr__ = __str__
def ToSelectFrom(anything): def ToSelectFrom(anything: Any) -> _ToSelectFrom:
if not isinstance(anything, list): if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from") raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything) return _ToSelectFrom(anything)
class _Embed: class _Embed:
def __init__(self, value, keep=False): def __init__(self, value: Any, keep: bool = False):
self.value = value self.value = value
self.keep = keep self.keep = keep
def __str__(self): def __str__(self) -> str:
return str(self.value) return str(self.value)
__repr__ = __str__ __repr__ = __str__
def Embed(anything, keep=False): def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom): if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep)) return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn): elif isinstance(anything, _BasedOn):
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
return _Embed(anything, keep=keep) return _Embed(anything, keep=keep)
def EmbedAndKeep(anything): def EmbedAndKeep(anything: Any) -> Any:
return Embed(anything, keep=True) return Embed(anything, keep=True)
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
return [parser.parse_line(line) for line in input_str.split("\n")] return [parser.parse_line(line) for line in input_str.split("\n")]
def get_based_on_and_to_select_from(inputs: Dict[str, Any]): def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = { to_select_from = {
k: inputs[k].value k: inputs[k].value
for k in inputs.keys() for k in inputs.keys()
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
return based_on, to_select_from return based_on, to_select_from
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]): def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
""" """
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
@ -134,29 +145,38 @@ class Selected(ABC):
pass pass
class Event(ABC): TSelected = TypeVar("TSelected", bound=Selected)
inputs: Dict[str, Any]
selected: Optional[Selected]
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
class Event(Generic[TSelected], ABC):
inputs: Dict[str, Any]
selected: Optional[TSelected]
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
self.inputs = inputs self.inputs = inputs
self.selected = selected self.selected = selected
TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC): class Policy(ABC):
@abstractmethod def __init__(self, **kwargs: Any):
def predict(self, event: Event) -> Any:
pass pass
@abstractmethod @abstractmethod
def learn(self, event: Event): def predict(self, event: TEvent) -> Any:
pass ...
@abstractmethod @abstractmethod
def log(self, event: Event): def learn(self, event: TEvent) -> None:
pass ...
def save(self): @abstractmethod
def log(self, event: TEvent) -> None:
...
def save(self) -> None:
pass pass
@ -164,11 +184,11 @@ class VwPolicy(Policy):
def __init__( def __init__(
self, self,
model_repo: ModelRepository, model_repo: ModelRepository,
vw_cmd: Sequence[str], vw_cmd: List[str],
feature_embedder: Embedder, feature_embedder: Embedder,
vw_logger: VwLogger, vw_logger: VwLogger,
*args, *args: Any,
**kwargs, **kwargs: Any,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.model_repo = model_repo self.model_repo = model_repo
@ -176,7 +196,7 @@ class VwPolicy(Policy):
self.feature_embedder = feature_embedder self.feature_embedder = feature_embedder
self.vw_logger = vw_logger self.vw_logger = vw_logger
def predict(self, event: Event) -> Any: def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace) text_parser = vw.TextFormatParser(self.workspace)
@ -184,7 +204,7 @@ class VwPolicy(Policy):
parse_lines(text_parser, self.feature_embedder.format(event)) parse_lines(text_parser, self.feature_embedder.format(event))
) )
def learn(self, event: Event): def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event) vw_ex = self.feature_embedder.format(event)
@ -192,19 +212,19 @@ class VwPolicy(Policy):
multi_ex = parse_lines(text_parser, vw_ex) multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex) self.workspace.learn_one(multi_ex)
def log(self, event: Event): def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled(): if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event) vw_ex = self.feature_embedder.format(event)
self.vw_logger.log(vw_ex) self.vw_logger.log(vw_ex)
def save(self): def save(self) -> None:
self.model_repo.save() self.model_repo.save(self.workspace)
class Embedder(ABC): class Embedder(Generic[TEvent], ABC):
@abstractmethod @abstractmethod
def format(self, event: Event) -> str: def format(self, event: TEvent) -> str:
pass ...
class SelectionScorer(ABC, BaseModel): class SelectionScorer(ABC, BaseModel):
@ -212,11 +232,11 @@ class SelectionScorer(ABC, BaseModel):
@abstractmethod @abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
pass ...
class AutoSelectionScorer(SelectionScorer, BaseModel): class AutoSelectionScorer(SelectionScorer, BaseModel):
llm_chain: Union[LLMChain, None] = None llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None scoring_criteria_template_str: Optional[str] = None
@ -243,7 +263,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
return chat_prompt return chat_prompt
@root_validator(pre=True) @root_validator(pre=True)
def set_prompt_and_llm_chain(cls, values): def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm = values.get("llm") llm = values.get("llm")
prompt = values.get("prompt") prompt = values.get("prompt")
scoring_criteria_template_str = values.get("scoring_criteria_template_str") scoring_criteria_template_str = values.get("scoring_criteria_template_str")
@ -275,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
) )
class RLChain(Chain): class RLChain(Chain, Generic[TEvent]):
""" """
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
@ -292,7 +312,7 @@ class RLChain(Chain):
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
- reset_model (bool): If set to True, the model starts training from scratch. Default is False. - reset_model (bool): If set to True, the model starts training from scratch. Default is False.
- vw_cmd (List[str], optional): Command line arguments for the VW model. - vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (VwPolicy): Policy used by the chain. - policy (Type[VwPolicy]): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1. - metrics_step (int): Step for the metrics tracker. Default is -1.
@ -300,12 +320,24 @@ class RLChain(Chain):
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called. The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
""" # noqa: E501 """ # noqa: E501
class _NoOpPolicy(Policy):
"""Placeholder policy that does nothing"""
def predict(self, event: TEvent) -> Any:
return None
def learn(self, event: TEvent) -> None:
pass
def log(self, event: TEvent) -> None:
pass
llm_chain: Chain llm_chain: Chain
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None] selection_scorer: Union[SelectionScorer, None]
policy: Optional[Policy] active_policy: Policy = _NoOpPolicy()
auto_embed: bool = True auto_embed: bool = True
selected_input_key = "rl_chain_selected" selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on" selected_based_on_input_key = "rl_chain_selected_based_on"
@ -314,14 +346,14 @@ class RLChain(Chain):
def __init__( def __init__(
self, self,
feature_embedder: Embedder, feature_embedder: Embedder,
model_save_dir="./", model_save_dir: str = "./",
reset_model=False, reset_model: bool = False,
vw_cmd=None, vw_cmd: Optional[List[str]] = None,
policy=VwPolicy, policy: Type[Policy] = VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None, vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step=-1, metrics_step: int = -1,
*args, *args: Any,
**kwargs, **kwargs: Any,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.selection_scorer is None: if self.selection_scorer is None:
@ -330,14 +362,17 @@ class RLChain(Chain):
reinforcement learning will be done in the RL chain \ reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called." unless update_with_delayed_score is called."
) )
self.policy = policy(
model_repo=ModelRepository( if isinstance(self.active_policy, RLChain._NoOpPolicy):
model_save_dir, with_history=True, reset=reset_model self.active_policy = policy(
), model_repo=ModelRepository(
vw_cmd=vw_cmd or [], model_save_dir, with_history=True, reset=reset_model
feature_embedder=feature_embedder, ),
vw_logger=VwLogger(vw_logs), vw_cmd=vw_cmd or [],
) feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step) self.metrics = MetricsTracker(step=metrics_step)
class Config: class Config:
@ -374,29 +409,29 @@ class RLChain(Chain):
) )
@abstractmethod @abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
pass ...
@abstractmethod @abstractmethod
def _call_after_predict_before_llm( def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: Any self, inputs: Dict[str, Any], event: TEvent, prediction: Any
) -> Tuple[Dict[str, Any], Event]: ) -> Tuple[Dict[str, Any], TEvent]:
pass ...
@abstractmethod @abstractmethod
def _call_after_llm_before_scoring( def _call_after_llm_before_scoring(
self, llm_response: str, event: Event self, llm_response: str, event: TEvent
) -> Tuple[Dict[str, Any], Event]: ) -> Tuple[Dict[str, Any], TEvent]:
pass ...
@abstractmethod @abstractmethod
def _call_after_scoring_before_learning( def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float] self, event: TEvent, score: Optional[float]
) -> Event: ) -> TEvent:
pass ...
def update_with_delayed_score( def update_with_delayed_score(
self, score: float, event: Event, force_score=False self, score: float, event: TEvent, force_score: bool = False
) -> None: ) -> None:
""" """
Updates the learned policy with the score provided. Updates the learned policy with the score provided.
@ -407,10 +442,11 @@ class RLChain(Chain):
"The selection scorer is set, and force_score was not set to True. \ "The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function." Please set force_score=True to use this function."
) )
self.metrics.on_feedback(score) if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score) self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event) self.active_policy.learn(event=event)
self.policy.log(event=event) self.active_policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None: def set_auto_embed(self, auto_embed: bool) -> None:
""" """
@ -422,15 +458,16 @@ class RLChain(Chain):
self, self,
inputs: Dict[str, Any], inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if self.auto_embed: if self.auto_embed:
inputs = prepare_inputs_for_autoembed(inputs=inputs) inputs = prepare_inputs_for_autoembed(inputs=inputs)
event = self._call_before_predict(inputs=inputs) event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event) prediction = self.active_policy.predict(event=event)
self.metrics.on_decision() if self.metrics:
self.metrics.on_decision()
next_chain_inputs, event = self._call_after_predict_before_llm( next_chain_inputs, event = self._call_after_predict_before_llm(
inputs=inputs, event=event, prediction=prediction inputs=inputs, event=event, prediction=prediction
@ -462,10 +499,11 @@ class RLChain(Chain):
f"The selection scorer was not able to score, \ f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}" and the chain was not able to adjust to this response, error: {e}"
) )
self.metrics.on_feedback(score) if self.metrics:
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event) event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event) self.active_policy.learn(event=event)
self.policy.log(event=event) self.active_policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}} return {self.output_key: {"response": output, "selection_metadata": event}}
@ -473,7 +511,7 @@ class RLChain(Chain):
""" """
This function should be called to save the state of the learned policy model. This function should be called to save the state of the learned policy model.
""" """
self.policy.save() self.active_policy.save()
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
@ -489,7 +527,7 @@ def is_stringtype_instance(item: Any) -> bool:
def embed_string_type( def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]: ) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object.""" """Helper function to embed a string or an _Embed object."""
join_char = "" join_char = ""
keep_str = "" keep_str = ""
@ -513,9 +551,9 @@ def embed_string_type(
return {namespace: keep_str + join_char.join(map(str, encoded))} return {namespace: keep_str + join_char.join(map(str, encoded))}
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item.""" """Helper function to embed a dictionary item."""
inner_dict = {} inner_dict: Dict[str, Any] = {}
for ns, embed_item in item.items(): for ns, embed_item in item.items():
if isinstance(embed_item, list): if isinstance(embed_item, list):
inner_dict[ns] = [] inner_dict[ns] = []
@ -530,7 +568,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
def embed_list_type( def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]: ) -> List[Dict[str, Union[str, List[str]]]]:
ret_list = [] ret_list: List[Dict[str, Union[str, List[str]]]] = []
for embed_item in item: for embed_item in item:
if isinstance(embed_item, dict): if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model)) ret_list.append(embed_dict_type(embed_item, model))
@ -540,9 +578,7 @@ def embed_list_type(
def embed( def embed(
to_embed: Union[ to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
model: Any, model: Any,
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]: ) -> List[Dict[str, Union[str, List[str]]]]:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
import pandas as pd import pandas as pd
@ -6,11 +6,11 @@ if TYPE_CHECKING:
class MetricsTracker: class MetricsTracker:
def __init__(self, step: int): def __init__(self, step: int):
self._history = [] self._history: List[Dict[str, Union[int, float]]] = []
self._step = step self._step: int = step
self._i = 0 self._i: int = 0
self._num = 0 self._num: float = 0
self._denom = 0 self._denom: float = 0
@property @property
def score(self) -> float: def score(self) -> float:

View File

@ -4,7 +4,7 @@ import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Sequence, Union from typing import TYPE_CHECKING, List, Union
if TYPE_CHECKING: if TYPE_CHECKING:
import vowpal_wabbit_next as vw import vowpal_wabbit_next as vw
@ -22,7 +22,7 @@ class ModelRepository:
self.folder = Path(folder) self.folder = Path(folder)
self.model_path = self.folder / "latest.vw" self.model_path = self.folder / "latest.vw"
self.with_history = with_history self.with_history = with_history
if reset and self.has_history: if reset and self.has_history():
logger.warning( logger.warning(
"There is non empty history which is recommended to be cleaned up" "There is non empty history which is recommended to be cleaned up"
) )
@ -44,7 +44,7 @@ class ModelRepository:
if self.with_history: # write history if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw") shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> "vw.Workspace": def load(self, commandline: List[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw import vowpal_wabbit_next as vw
model_data = None model_data = None

View File

@ -1,12 +1,11 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Type, Union
import langchain.chains.rl_chain.base as base import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
@ -17,7 +16,36 @@ logger = logging.getLogger(__name__)
SENTINEL = object() SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder): class PickBestSelected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
""" """
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
@ -25,7 +53,7 @@ class PickBestFeatureEmbedder(base.Embedder):
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
""" # noqa E501 """ # noqa E501
def __init__(self, model: Optional[Any] = None, *args, **kwargs): def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if model is None: if model is None:
@ -35,7 +63,7 @@ class PickBestFeatureEmbedder(base.Embedder):
self.model = model self.model = model
def format(self, event: PickBest.Event) -> str: def format(self, event: PickBestEvent) -> str:
""" """
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
""" """
@ -54,9 +82,14 @@ class PickBestFeatureEmbedder(base.Embedder):
to_select_from_var_name, to_select_from = next( to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None) iter(event.to_select_from.items()), (None, None)
) )
action_embs = ( action_embs = (
base.embed(to_select_from, self.model, to_select_from_var_name) (
if event.to_select_from base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None else None
) )
@ -88,7 +121,7 @@ class PickBestFeatureEmbedder(base.Embedder):
return example_string[:-1] return example_string[:-1]
class PickBest(base.RLChain): class PickBest(base.RLChain[PickBestEvent]):
""" """
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. `PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
@ -116,38 +149,10 @@ class PickBest(base.RLChain):
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized. feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
""" # noqa E501 """ # noqa E501
class Selected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class Event(base.Event):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBest.Selected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
def __init__( def __init__(
self, self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None, *args: Any,
*args, **kwargs: Any,
**kwargs,
): ):
vw_cmd = kwargs.get("vw_cmd", []) vw_cmd = kwargs.get("vw_cmd", [])
if not vw_cmd: if not vw_cmd:
@ -163,14 +168,16 @@ class PickBest(base.RLChain):
raise ValueError( raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf" "If vw_cmd is specified, it must include --cb_explore_adf"
) )
kwargs["vw_cmd"] = vw_cmd kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if not feature_embedder: if not feature_embedder:
feature_embedder = PickBestFeatureEmbedder() feature_embedder = PickBestFeatureEmbedder()
kwargs["feature_embedder"] = feature_embedder
super().__init__(feature_embedder=feature_embedder, *args, **kwargs) super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs) context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions: if not actions:
raise ValueError( raise ValueError(
@ -193,12 +200,15 @@ class PickBest(base.RLChain):
to base the selected of ToSelectFrom on." to base the selected of ToSelectFrom on."
) )
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context) event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event return event
def _call_after_predict_before_llm( def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] self,
) -> Tuple[Dict[str, Any], PickBest.Event]: inputs: Dict[str, Any],
event: PickBestEvent,
prediction: List[Tuple[int, float]],
) -> Tuple[Dict[str, Any], PickBestEvent]:
import numpy as np import numpy as np
prob_sum = sum(prob for _, prob in prediction) prob_sum = sum(prob for _, prob in prediction)
@ -208,7 +218,7 @@ class PickBest(base.RLChain):
sampled_ap = prediction[sampled_index] sampled_ap = prediction[sampled_index]
sampled_action = sampled_ap[0] sampled_action = sampled_ap[0]
sampled_prob = sampled_ap[1] sampled_prob = sampled_ap[1]
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob) selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
event.selected = selected event.selected = selected
# only one key, value pair in event.to_select_from # only one key, value pair in event.to_select_from
@ -218,23 +228,29 @@ class PickBest(base.RLChain):
return next_chain_inputs, event return next_chain_inputs, event
def _call_after_llm_before_scoring( def _call_after_llm_before_scoring(
self, llm_response: str, event: PickBest.Event self, llm_response: str, event: PickBestEvent
) -> Tuple[Dict[str, Any], PickBest.Event]: ) -> Tuple[Dict[str, Any], PickBestEvent]:
next_chain_inputs = event.inputs.copy() next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from # only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values())) value = next(iter(event.to_select_from.values()))
v = (
value[event.selected.index]
if event.selected
else event.to_select_from.values()
)
next_chain_inputs.update( next_chain_inputs.update(
{ {
self.selected_based_on_input_key: str(event.based_on), self.selected_based_on_input_key: str(event.based_on),
self.selected_input_key: value[event.selected.index], self.selected_input_key: v,
} }
) )
return next_chain_inputs, event return next_chain_inputs, event
def _call_after_scoring_before_learning( def _call_after_scoring_before_learning(
self, event: PickBest.Event, score: Optional[float] self, event: PickBestEvent, score: Optional[float]
) -> Event: ) -> PickBestEvent:
event.selected.score = score if event.selected:
event.selected.score = score
return event return event
def _call( def _call(
@ -249,34 +265,20 @@ class PickBest(base.RLChain):
return "rl_chain_pick_best" return "rl_chain_pick_best"
@classmethod @classmethod
def from_chain( def from_llm(
cls, cls: Type[PickBest],
llm_chain: Chain, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
selection_scorer=SENTINEL, selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
**kwargs: Any, **kwargs: Any,
): ) -> PickBest:
llm_chain = LLMChain(llm=llm, prompt=prompt)
if selection_scorer is SENTINEL: if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest( return PickBest(
llm_chain=llm_chain, llm_chain=llm_chain,
prompt=prompt, prompt=prompt,
selection_scorer=selection_scorer, selection_scorer=selection_scorer,
**kwargs, **kwargs,
) )
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)

View File

@ -9,10 +9,10 @@ class VwLogger:
if self.path: if self.path:
self.path.parent.mkdir(parents=True, exist_ok=True) self.path.parent.mkdir(parents=True, exist_ok=True)
def log(self, vw_ex: str): def log(self, vw_ex: str) -> None:
if self.path: if self.path:
with open(self.path, "a") as f: with open(self.path, "a") as f:
f.write(f"{vw_ex}\n\n") f.write(f"{vw_ex}\n\n")
def logging_enabled(self): def logging_enabled(self) -> bool:
return bool(self.path) return bool(self.path)

View File

@ -1,3 +1,5 @@
from typing import Any, Dict
import pytest import pytest
from test_utils import MockEncoder from test_utils import MockEncoder
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup(): def setup() -> tuple:
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm""" _PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE) PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -19,7 +21,7 @@ def setup():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws(): def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"] actions = ["0", "1", "2"]
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws(): def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"] actions = ["0", "1", "2"]
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws(): def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = {"actions": ["0", "1", "2"]} actions = {"actions": ["0", "1", "2"]}
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws(): def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that # this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"]) auto_val_llm = FakeListChatModel(responses=["3"])
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force(): def test_update_with_delayed_score_force() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that # this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"]) auto_val_llm = FakeListChatModel(responses=["3"])
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score(): def test_update_with_delayed_score() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None llm=llm, prompt=PROMPT, selection_scorer=None
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer(): def test_user_defined_scorer() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer): class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs, llm_response: str) -> float: def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
score = 200 score = 200
return score return score
@ -139,7 +141,7 @@ def test_user_defined_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings(): def test_default_embeddings() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -173,7 +175,7 @@ def test_default_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off(): def test_default_embeddings_off() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -199,7 +201,7 @@ def test_default_embeddings_off():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings(): def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified(): def test_default_no_scorer_specified() -> None:
_, PROMPT = setup() _, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100]) chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer(): def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None llm=llm, prompt=PROMPT, selection_scorer=None
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm(): def test_auto_scorer_with_user_defined_llm() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300]) scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws(): def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -8,10 +8,10 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws(): def test_pickbest_textembedder_missing_context_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]} named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_action, based_on={} inputs={}, to_select_from=named_action, based_on={}
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -19,9 +19,9 @@ def test_pickbest_textembedder_missing_context_throws():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws(): def test_pickbest_textembedder_missing_actions_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"} inputs={}, to_select_from={}, based_on={"context": "context"}
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -29,11 +29,11 @@ def test_pickbest_textembedder_missing_actions_throws():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb(): def test_pickbest_textembedder_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on={"context": "context"} inputs={}, to_select_from=named_actions, based_on={"context": "context"}
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -41,12 +41,12 @@ def test_pickbest_textembedder_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb(): def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, inputs={},
to_select_from=named_actions, to_select_from=named_actions,
based_on={"context": "context"}, based_on={"context": "context"},
@ -57,14 +57,14 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb(): def test_pickbest_textembedder_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
expected = ( expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """ """shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
) )
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, inputs={},
to_select_from=named_actions, to_select_from=named_actions,
based_on={"context": "context"}, based_on={"context": "context"},
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb(): def test_pickbest_textembedder_w_full_label_w_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
str2 = "1" str2 = "1"
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)} context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
str2 = "1" str2 = "1"
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -123,12 +123,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"} context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context inputs={}, to_select_from=named_actions, based_on=context
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -136,13 +136,13 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"} context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -150,13 +150,13 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"} context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
} }
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -219,8 +221,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
} }
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -253,8 +255,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501 expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -262,7 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -290,8 +292,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
} }
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected inputs={}, to_select_from=named_actions, based_on=context, selected=selected
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -299,7 +301,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored(): def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string" str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_") str1_underscored = str1.replace(" ", "_")
@ -315,7 +317,7 @@ def test_raw_features_underscored():
expected_no_embed = ( expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """ f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
) )
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context inputs={}, to_select_from=named_actions, based_on=context
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -325,7 +327,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.Embed([str1])} named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)} context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """ expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context inputs={}, to_select_from=named_actions, based_on=context
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)
@ -335,7 +337,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.EmbedAndKeep([str1])} named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)} context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501 expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context inputs={}, to_select_from=named_actions, based_on=context
) )
vw_ex_str = feature_embedder.format(event) vw_ex_str = feature_embedder.format(event)

View File

@ -1,3 +1,5 @@
from typing import List, Union
import pytest import pytest
from test_utils import MockEncoder from test_utils import MockEncoder
@ -7,13 +9,13 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb(): def test_simple_context_str_no_emb() -> None:
expected = [{"a_namespace": "test"}] expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb(): def test_simple_context_str_w_emb() -> None:
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
expected = [{"a_namespace": encoded_text + encoded_str1}] expected = [{"a_namespace": encoded_text + encoded_str1}]
@ -28,7 +30,7 @@ def test_simple_context_str_w_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb(): def test_simple_context_str_w_nested_emb() -> None:
# nested embeddings, innermost wins # nested embeddings, innermost wins
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
@ -46,13 +48,13 @@ def test_simple_context_str_w_nested_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb(): def test_context_w_namespace_no_emb() -> None:
expected = [{"test_namespace": "test"}] expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb(): def test_context_w_namespace_w_emb() -> None:
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}] expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -67,7 +69,7 @@ def test_context_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2(): def test_context_w_namespace_w_emb2() -> None:
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}] expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -82,7 +84,7 @@ def test_context_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb(): def test_context_w_namespace_w_some_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
encoded_str2 = " ".join(char for char in str2) encoded_str2 = " ".join(char for char in str2)
@ -111,16 +113,17 @@ def test_context_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb(): def test_simple_action_strlist_no_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb(): def test_simple_action_strlist_w_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -148,7 +151,7 @@ def test_simple_action_strlist_w_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb(): def test_simple_action_strlist_w_some_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -181,7 +184,7 @@ def test_simple_action_strlist_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb(): def test_action_w_namespace_no_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -204,7 +207,7 @@ def test_action_w_namespace_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb(): def test_action_w_namespace_w_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -246,7 +249,7 @@ def test_action_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2(): def test_action_w_namespace_w_emb2() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -292,7 +295,7 @@ def test_action_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb(): def test_action_w_namespace_w_some_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -333,7 +336,7 @@ def test_action_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
str3 = "test3" str3 = "test3"
@ -384,7 +387,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb(): def test_one_namespace_w_list_of_features_no_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
expected = [{"test_namespace": [str1, str2]}] expected = [{"test_namespace": [str1, str2]}]
@ -392,7 +395,7 @@ def test_one_namespace_w_list_of_features_no_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb(): def test_one_namespace_w_list_of_features_w_some_emb() -> None:
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
encoded_str2 = " ".join(char for char in str2) encoded_str2 = " ".join(char for char in str2)
@ -404,24 +407,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws(): def test_nested_list_features_throws() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws(): def test_dict_in_list_throws() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws(): def test_nested_dict_throws() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next") @pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws(): def test_list_of_tuples_throws() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())

View File

@ -1,3 +1,3 @@
class MockEncoder: class MockEncoder:
def encode(self, to_encode): def encode(self, to_encode: str) -> str:
return "[encoded]" + to_encode return "[encoded]" + to_encode