Merge pull request #3 from VowpalWabbit/fix_linting

Fix mypy errors
This commit is contained in:
olgavrou 2023-08-29 05:58:03 -04:00 committed by GitHub
commit f7fb083aba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 296 additions and 251 deletions

View File

@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
from langchain.chains.rl_chain.pick_best_chain import PickBest
def configure_logger():
def configure_logger() -> None:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()

View File

@ -3,7 +3,18 @@ from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
class _BasedOn:
def __init__(self, value):
def __init__(self, value: Any):
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def BasedOn(anything):
def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything)
class _ToSelectFrom:
def __init__(self, value):
def __init__(self, value: Any):
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def ToSelectFrom(anything):
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
class _Embed:
def __init__(self, value, keep=False):
def __init__(self, value: Any, keep: bool = False):
self.value = value
self.keep = keep
def __str__(self):
def __str__(self) -> str:
return str(self.value)
__repr__ = __str__
def Embed(anything, keep=False):
def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
return _Embed(anything, keep=keep)
def EmbedAndKeep(anything):
def EmbedAndKeep(anything: Any) -> Any:
return Embed(anything, keep=True)
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
return [parser.parse_line(line) for line in input_str.split("\n")]
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
return based_on, to_select_from
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
@ -134,29 +145,38 @@ class Selected(ABC):
pass
class Event(ABC):
inputs: Dict[str, Any]
selected: Optional[Selected]
TSelected = TypeVar("TSelected", bound=Selected)
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
class Event(Generic[TSelected], ABC):
inputs: Dict[str, Any]
selected: Optional[TSelected]
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
self.inputs = inputs
self.selected = selected
TEvent = TypeVar("TEvent", bound=Event)
class Policy(ABC):
@abstractmethod
def predict(self, event: Event) -> Any:
def __init__(self, **kwargs: Any):
pass
@abstractmethod
def learn(self, event: Event):
pass
def predict(self, event: TEvent) -> Any:
...
@abstractmethod
def log(self, event: Event):
pass
def learn(self, event: TEvent) -> None:
...
def save(self):
@abstractmethod
def log(self, event: TEvent) -> None:
...
def save(self) -> None:
pass
@ -164,11 +184,11 @@ class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: Sequence[str],
vw_cmd: List[str],
feature_embedder: Embedder,
vw_logger: VwLogger,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.model_repo = model_repo
@ -176,7 +196,7 @@ class VwPolicy(Policy):
self.feature_embedder = feature_embedder
self.vw_logger = vw_logger
def predict(self, event: Event) -> Any:
def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace)
@ -184,7 +204,7 @@ class VwPolicy(Policy):
parse_lines(text_parser, self.feature_embedder.format(event))
)
def learn(self, event: Event):
def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event)
@ -192,19 +212,19 @@ class VwPolicy(Policy):
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
def log(self, event: Event):
def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event)
self.vw_logger.log(vw_ex)
def save(self):
self.model_repo.save()
def save(self) -> None:
self.model_repo.save(self.workspace)
class Embedder(ABC):
class Embedder(Generic[TEvent], ABC):
@abstractmethod
def format(self, event: Event) -> str:
pass
def format(self, event: TEvent) -> str:
...
class SelectionScorer(ABC, BaseModel):
@ -212,11 +232,11 @@ class SelectionScorer(ABC, BaseModel):
@abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
pass
...
class AutoSelectionScorer(SelectionScorer, BaseModel):
llm_chain: Union[LLMChain, None] = None
llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
@ -243,7 +263,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
return chat_prompt
@root_validator(pre=True)
def set_prompt_and_llm_chain(cls, values):
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm = values.get("llm")
prompt = values.get("prompt")
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
@ -275,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
)
class RLChain(Chain):
class RLChain(Chain, Generic[TEvent]):
"""
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
@ -292,7 +312,7 @@ class RLChain(Chain):
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
- vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (VwPolicy): Policy used by the chain.
- policy (Type[VwPolicy]): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1.
@ -300,12 +320,24 @@ class RLChain(Chain):
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
""" # noqa: E501
class _NoOpPolicy(Policy):
"""Placeholder policy that does nothing"""
def predict(self, event: TEvent) -> Any:
return None
def learn(self, event: TEvent) -> None:
pass
def log(self, event: TEvent) -> None:
pass
llm_chain: Chain
output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
policy: Optional[Policy]
active_policy: Policy = _NoOpPolicy()
auto_embed: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
@ -314,14 +346,14 @@ class RLChain(Chain):
def __init__(
self,
feature_embedder: Embedder,
model_save_dir="./",
reset_model=False,
vw_cmd=None,
policy=VwPolicy,
model_save_dir: str = "./",
reset_model: bool = False,
vw_cmd: Optional[List[str]] = None,
policy: Type[Policy] = VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step=-1,
*args,
**kwargs,
metrics_step: int = -1,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if self.selection_scorer is None:
@ -330,14 +362,17 @@ class RLChain(Chain):
reinforcement learning will be done in the RL chain \
unless update_with_delayed_score is called."
)
self.policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd or [],
feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
if isinstance(self.active_policy, RLChain._NoOpPolicy):
self.active_policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd or [],
feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step)
class Config:
@ -374,29 +409,29 @@ class RLChain(Chain):
)
@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
pass
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
...
@abstractmethod
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: Any
) -> Tuple[Dict[str, Any], Event]:
pass
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_llm_before_scoring(
self, llm_response: str, event: Event
) -> Tuple[Dict[str, Any], Event]:
pass
self, llm_response: str, event: TEvent
) -> Tuple[Dict[str, Any], TEvent]:
...
@abstractmethod
def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float]
) -> Event:
pass
self, event: TEvent, score: Optional[float]
) -> TEvent:
...
def update_with_delayed_score(
self, score: float, event: Event, force_score=False
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
@ -407,10 +442,11 @@ class RLChain(Chain):
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None:
"""
@ -422,15 +458,16 @@ class RLChain(Chain):
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if self.auto_embed:
inputs = prepare_inputs_for_autoembed(inputs=inputs)
event = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
self.metrics.on_decision()
event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.active_policy.predict(event=event)
if self.metrics:
self.metrics.on_decision()
next_chain_inputs, event = self._call_after_predict_before_llm(
inputs=inputs, event=event, prediction=prediction
@ -462,10 +499,11 @@ class RLChain(Chain):
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
if self.metrics:
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event)
self.policy.log(event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}}
@ -473,7 +511,7 @@ class RLChain(Chain):
"""
This function should be called to save the state of the learned policy model.
"""
self.policy.save()
self.active_policy.save()
@property
def _chain_type(self) -> str:
@ -489,7 +527,7 @@ def is_stringtype_instance(item: Any) -> bool:
def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]:
) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object."""
join_char = ""
keep_str = ""
@ -513,9 +551,9 @@ def embed_string_type(
return {namespace: keep_str + join_char.join(map(str, encoded))}
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item."""
inner_dict = {}
inner_dict: Dict[str, Any] = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
@ -530,7 +568,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
ret_list = []
ret_list: List[Dict[str, Union[str, List[str]]]] = []
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))
@ -540,9 +578,7 @@ def embed_list_type(
def embed(
to_embed: Union[
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union
if TYPE_CHECKING:
import pandas as pd
@ -6,11 +6,11 @@ if TYPE_CHECKING:
class MetricsTracker:
def __init__(self, step: int):
self._history = []
self._step = step
self._i = 0
self._num = 0
self._denom = 0
self._history: List[Dict[str, Union[int, float]]] = []
self._step: int = step
self._i: int = 0
self._num: float = 0
self._denom: float = 0
@property
def score(self) -> float:

View File

@ -4,7 +4,7 @@ import logging
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, Union
from typing import TYPE_CHECKING, List, Union
if TYPE_CHECKING:
import vowpal_wabbit_next as vw
@ -22,7 +22,7 @@ class ModelRepository:
self.folder = Path(folder)
self.model_path = self.folder / "latest.vw"
self.with_history = with_history
if reset and self.has_history:
if reset and self.has_history():
logger.warning(
"There is non empty history which is recommended to be cleaned up"
)
@ -44,7 +44,7 @@ class ModelRepository:
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
def load(self, commandline: List[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw
model_data = None

View File

@ -1,12 +1,11 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate
@ -17,7 +16,36 @@ logger = logging.getLogger(__name__)
SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder):
class PickBestSelected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
"""
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
@ -25,7 +53,7 @@ class PickBestFeatureEmbedder(base.Embedder):
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
""" # noqa E501
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
if model is None:
@ -35,7 +63,7 @@ class PickBestFeatureEmbedder(base.Embedder):
self.model = model
def format(self, event: PickBest.Event) -> str:
def format(self, event: PickBestEvent) -> str:
"""
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
"""
@ -54,9 +82,14 @@ class PickBestFeatureEmbedder(base.Embedder):
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)
action_embs = (
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
@ -88,7 +121,7 @@ class PickBestFeatureEmbedder(base.Embedder):
return example_string[:-1]
class PickBest(base.RLChain):
class PickBest(base.RLChain[PickBestEvent]):
"""
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
@ -116,38 +149,10 @@ class PickBest(base.RLChain):
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
""" # noqa E501
class Selected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class Event(base.Event):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBest.Selected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
def __init__(
self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
):
vw_cmd = kwargs.get("vw_cmd", [])
if not vw_cmd:
@ -163,14 +168,16 @@ class PickBest(base.RLChain):
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)
kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if not feature_embedder:
feature_embedder = PickBestFeatureEmbedder()
kwargs["feature_embedder"] = feature_embedder
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
@ -193,12 +200,15 @@ class PickBest(base.RLChain):
to base the selected of ToSelectFrom on."
)
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], PickBest.Event]:
self,
inputs: Dict[str, Any],
event: PickBestEvent,
prediction: List[Tuple[int, float]],
) -> Tuple[Dict[str, Any], PickBestEvent]:
import numpy as np
prob_sum = sum(prob for _, prob in prediction)
@ -208,7 +218,7 @@ class PickBest(base.RLChain):
sampled_ap = prediction[sampled_index]
sampled_action = sampled_ap[0]
sampled_prob = sampled_ap[1]
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
event.selected = selected
# only one key, value pair in event.to_select_from
@ -218,23 +228,29 @@ class PickBest(base.RLChain):
return next_chain_inputs, event
def _call_after_llm_before_scoring(
self, llm_response: str, event: PickBest.Event
) -> Tuple[Dict[str, Any], PickBest.Event]:
self, llm_response: str, event: PickBestEvent
) -> Tuple[Dict[str, Any], PickBestEvent]:
next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
v = (
value[event.selected.index]
if event.selected
else event.to_select_from.values()
)
next_chain_inputs.update(
{
self.selected_based_on_input_key: str(event.based_on),
self.selected_input_key: value[event.selected.index],
self.selected_input_key: v,
}
)
return next_chain_inputs, event
def _call_after_scoring_before_learning(
self, event: PickBest.Event, score: Optional[float]
) -> Event:
event.selected.score = score
self, event: PickBestEvent, score: Optional[float]
) -> PickBestEvent:
if event.selected:
event.selected.score = score
return event
def _call(
@ -249,34 +265,20 @@ class PickBest(base.RLChain):
return "rl_chain_pick_best"
@classmethod
def from_chain(
cls,
llm_chain: Chain,
def from_llm(
cls: Type[PickBest],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
**kwargs: Any,
):
) -> PickBest:
llm_chain = LLMChain(llm=llm, prompt=prompt)
if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)

View File

@ -9,10 +9,10 @@ class VwLogger:
if self.path:
self.path.parent.mkdir(parents=True, exist_ok=True)
def log(self, vw_ex: str):
def log(self, vw_ex: str) -> None:
if self.path:
with open(self.path, "a") as f:
f.write(f"{vw_ex}\n\n")
def logging_enabled(self):
def logging_enabled(self) -> bool:
return bool(self.path)

View File

@ -1,3 +1,5 @@
from typing import Any, Dict
import pytest
from test_utils import MockEncoder
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup():
def setup() -> tuple:
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -19,7 +21,7 @@ def setup():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws():
def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws():
def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws():
def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = {"actions": ["0", "1", "2"]}
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws():
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force():
def test_update_with_delayed_score_force() -> None:
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score():
def test_update_with_delayed_score() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer():
def test_user_defined_scorer() -> None:
llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs, llm_response: str) -> float:
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
score = 200
return score
@ -139,7 +141,7 @@ def test_user_defined_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings():
def test_default_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -173,7 +175,7 @@ def test_default_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off():
def test_default_embeddings_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -199,7 +201,7 @@ def test_default_embeddings_off():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings():
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified():
def test_default_no_scorer_specified() -> None:
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer():
def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm():
def test_auto_scorer_with_user_defined_llm() -> None:
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws():
def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError):

View File

@ -8,10 +8,10 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws():
def test_pickbest_textembedder_missing_context_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_action, based_on={}
)
with pytest.raises(ValueError):
@ -19,9 +19,9 @@ def test_pickbest_textembedder_missing_context_throws():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws():
def test_pickbest_textembedder_missing_actions_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
with pytest.raises(ValueError):
@ -29,11 +29,11 @@ def test_pickbest_textembedder_missing_actions_throws():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb():
def test_pickbest_textembedder_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
)
vw_ex_str = feature_embedder.format(event)
@ -41,12 +41,12 @@ def test_pickbest_textembedder_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb():
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -57,14 +57,14 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb():
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
)
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb():
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -123,12 +123,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -136,13 +136,13 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -150,13 +150,13 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
}
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -219,8 +221,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
}
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -253,8 +255,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -262,7 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -290,8 +292,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -299,7 +301,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored():
def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
@ -315,7 +317,7 @@ def test_raw_features_underscored():
expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
)
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -325,7 +327,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -335,7 +337,7 @@ def test_raw_features_underscored():
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
event = pick_best_chain.PickBest.Event(
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)

View File

@ -1,3 +1,5 @@
from typing import List, Union
import pytest
from test_utils import MockEncoder
@ -7,13 +9,13 @@ encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb():
def test_simple_context_str_no_emb() -> None:
expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb():
def test_simple_context_str_w_emb() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"a_namespace": encoded_text + encoded_str1}]
@ -28,7 +30,7 @@ def test_simple_context_str_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb():
def test_simple_context_str_w_nested_emb() -> None:
# nested embeddings, innermost wins
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
@ -46,13 +48,13 @@ def test_simple_context_str_w_nested_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb():
def test_context_w_namespace_no_emb() -> None:
expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb():
def test_context_w_namespace_w_emb() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -67,7 +69,7 @@ def test_context_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2():
def test_context_w_namespace_w_emb2() -> None:
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
@ -82,7 +84,7 @@ def test_context_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb():
def test_context_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
@ -111,16 +113,17 @@ def test_context_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb():
def test_simple_action_strlist_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb():
def test_simple_action_strlist_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -148,7 +151,7 @@ def test_simple_action_strlist_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb():
def test_simple_action_strlist_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -181,7 +184,7 @@ def test_simple_action_strlist_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb():
def test_action_w_namespace_no_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -204,7 +207,7 @@ def test_action_w_namespace_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb():
def test_action_w_namespace_w_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -246,7 +249,7 @@ def test_action_w_namespace_w_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2():
def test_action_w_namespace_w_emb2() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -292,7 +295,7 @@ def test_action_w_namespace_w_emb2():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb():
def test_action_w_namespace_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -333,7 +336,7 @@ def test_action_w_namespace_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
str1 = "test1"
str2 = "test2"
str3 = "test3"
@ -384,7 +387,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb():
def test_one_namespace_w_list_of_features_no_emb() -> None:
str1 = "test1"
str2 = "test2"
expected = [{"test_namespace": [str1, str2]}]
@ -392,7 +395,7 @@ def test_one_namespace_w_list_of_features_no_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb():
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
@ -404,24 +407,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws():
def test_nested_list_features_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws():
def test_dict_in_list_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws():
def test_nested_dict_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws():
def test_list_of_tuples_throws() -> None:
with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())

View File

@ -1,3 +1,3 @@
class MockEncoder:
def encode(self, to_encode):
def encode(self, to_encode: str) -> str:
return "[encoded]" + to_encode