mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
commit
f7fb083aba
@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
|
||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||
|
||||
|
||||
def configure_logger():
|
||||
def configure_logger() -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
|
@ -3,7 +3,18 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BasedOn:
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def BasedOn(anything):
|
||||
def BasedOn(anything: Any) -> _BasedOn:
|
||||
return _BasedOn(anything)
|
||||
|
||||
|
||||
class _ToSelectFrom:
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def ToSelectFrom(anything):
|
||||
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
||||
if not isinstance(anything, list):
|
||||
raise ValueError("ToSelectFrom must be a list to select from")
|
||||
return _ToSelectFrom(anything)
|
||||
|
||||
|
||||
class _Embed:
|
||||
def __init__(self, value, keep=False):
|
||||
def __init__(self, value: Any, keep: bool = False):
|
||||
self.value = value
|
||||
self.keep = keep
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def Embed(anything, keep=False):
|
||||
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||
if isinstance(anything, _ToSelectFrom):
|
||||
return ToSelectFrom(Embed(anything.value, keep=keep))
|
||||
elif isinstance(anything, _BasedOn):
|
||||
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
|
||||
return _Embed(anything, keep=keep)
|
||||
|
||||
|
||||
def EmbedAndKeep(anything):
|
||||
def EmbedAndKeep(anything: Any) -> Any:
|
||||
return Embed(anything, keep=True)
|
||||
|
||||
|
||||
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
|
||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||
|
||||
|
||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
|
||||
to_select_from = {
|
||||
k: inputs[k].value
|
||||
for k in inputs.keys()
|
||||
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
return based_on, to_select_from
|
||||
|
||||
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||
@ -134,29 +145,38 @@ class Selected(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Event(ABC):
|
||||
inputs: Dict[str, Any]
|
||||
selected: Optional[Selected]
|
||||
TSelected = TypeVar("TSelected", bound=Selected)
|
||||
|
||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
|
||||
|
||||
class Event(Generic[TSelected], ABC):
|
||||
inputs: Dict[str, Any]
|
||||
selected: Optional[TSelected]
|
||||
|
||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
|
||||
self.inputs = inputs
|
||||
self.selected = selected
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=Event)
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, event: Event) -> Any:
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, event: Event):
|
||||
pass
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log(self, event: Event):
|
||||
pass
|
||||
def learn(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
def save(self):
|
||||
@abstractmethod
|
||||
def log(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -164,11 +184,11 @@ class VwPolicy(Policy):
|
||||
def __init__(
|
||||
self,
|
||||
model_repo: ModelRepository,
|
||||
vw_cmd: Sequence[str],
|
||||
vw_cmd: List[str],
|
||||
feature_embedder: Embedder,
|
||||
vw_logger: VwLogger,
|
||||
*args,
|
||||
**kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_repo = model_repo
|
||||
@ -176,7 +196,7 @@ class VwPolicy(Policy):
|
||||
self.feature_embedder = feature_embedder
|
||||
self.vw_logger = vw_logger
|
||||
|
||||
def predict(self, event: Event) -> Any:
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
@ -184,7 +204,7 @@ class VwPolicy(Policy):
|
||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||
)
|
||||
|
||||
def learn(self, event: Event):
|
||||
def learn(self, event: TEvent) -> None:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
@ -192,19 +212,19 @@ class VwPolicy(Policy):
|
||||
multi_ex = parse_lines(text_parser, vw_ex)
|
||||
self.workspace.learn_one(multi_ex)
|
||||
|
||||
def log(self, event: Event):
|
||||
def log(self, event: TEvent) -> None:
|
||||
if self.vw_logger.logging_enabled():
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
self.vw_logger.log(vw_ex)
|
||||
|
||||
def save(self):
|
||||
self.model_repo.save()
|
||||
def save(self) -> None:
|
||||
self.model_repo.save(self.workspace)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
class Embedder(Generic[TEvent], ABC):
|
||||
@abstractmethod
|
||||
def format(self, event: Event) -> str:
|
||||
pass
|
||||
def format(self, event: TEvent) -> str:
|
||||
...
|
||||
|
||||
|
||||
class SelectionScorer(ABC, BaseModel):
|
||||
@ -212,11 +232,11 @@ class SelectionScorer(ABC, BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
llm_chain: Union[LLMChain, None] = None
|
||||
llm_chain: LLMChain
|
||||
prompt: Union[BasePromptTemplate, None] = None
|
||||
scoring_criteria_template_str: Optional[str] = None
|
||||
|
||||
@ -243,7 +263,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
return chat_prompt
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_prompt_and_llm_chain(cls, values):
|
||||
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
llm = values.get("llm")
|
||||
prompt = values.get("prompt")
|
||||
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
||||
@ -275,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RLChain(Chain):
|
||||
class RLChain(Chain, Generic[TEvent]):
|
||||
"""
|
||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||
|
||||
@ -292,7 +312,7 @@ class RLChain(Chain):
|
||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||
- policy (VwPolicy): Policy used by the chain.
|
||||
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||
|
||||
@ -300,12 +320,24 @@ class RLChain(Chain):
|
||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||
""" # noqa: E501
|
||||
|
||||
class _NoOpPolicy(Policy):
|
||||
"""Placeholder policy that does nothing"""
|
||||
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
return None
|
||||
|
||||
def learn(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
def log(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
llm_chain: Chain
|
||||
|
||||
output_key: str = "result" #: :meta private:
|
||||
prompt: BasePromptTemplate
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
policy: Optional[Policy]
|
||||
active_policy: Policy = _NoOpPolicy()
|
||||
auto_embed: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
@ -314,14 +346,14 @@ class RLChain(Chain):
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Embedder,
|
||||
model_save_dir="./",
|
||||
reset_model=False,
|
||||
vw_cmd=None,
|
||||
policy=VwPolicy,
|
||||
model_save_dir: str = "./",
|
||||
reset_model: bool = False,
|
||||
vw_cmd: Optional[List[str]] = None,
|
||||
policy: Type[Policy] = VwPolicy,
|
||||
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
||||
metrics_step=-1,
|
||||
*args,
|
||||
**kwargs,
|
||||
metrics_step: int = -1,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.selection_scorer is None:
|
||||
@ -330,14 +362,17 @@ class RLChain(Chain):
|
||||
reinforcement learning will be done in the RL chain \
|
||||
unless update_with_delayed_score is called."
|
||||
)
|
||||
self.policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
),
|
||||
vw_cmd=vw_cmd or [],
|
||||
feature_embedder=feature_embedder,
|
||||
vw_logger=VwLogger(vw_logs),
|
||||
)
|
||||
|
||||
if isinstance(self.active_policy, RLChain._NoOpPolicy):
|
||||
self.active_policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
),
|
||||
vw_cmd=vw_cmd or [],
|
||||
feature_embedder=feature_embedder,
|
||||
vw_logger=VwLogger(vw_logs),
|
||||
)
|
||||
|
||||
self.metrics = MetricsTracker(step=metrics_step)
|
||||
|
||||
class Config:
|
||||
@ -374,29 +409,29 @@ class RLChain(Chain):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||
pass
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: Any
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
pass
|
||||
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: Event
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
pass
|
||||
self, llm_response: str, event: TEvent
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> Event:
|
||||
pass
|
||||
self, event: TEvent, score: Optional[float]
|
||||
) -> TEvent:
|
||||
...
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, event: Event, force_score=False
|
||||
self, score: float, event: TEvent, force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
@ -407,10 +442,11 @@ class RLChain(Chain):
|
||||
"The selection scorer is set, and force_score was not set to True. \
|
||||
Please set force_score=True to use this function."
|
||||
)
|
||||
self.metrics.on_feedback(score)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.policy.learn(event=event)
|
||||
self.policy.log(event=event)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||
"""
|
||||
@ -422,15 +458,16 @@ class RLChain(Chain):
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
|
||||
if self.auto_embed:
|
||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||
|
||||
event = self._call_before_predict(inputs=inputs)
|
||||
prediction = self.policy.predict(event=event)
|
||||
self.metrics.on_decision()
|
||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||
prediction = self.active_policy.predict(event=event)
|
||||
if self.metrics:
|
||||
self.metrics.on_decision()
|
||||
|
||||
next_chain_inputs, event = self._call_after_predict_before_llm(
|
||||
inputs=inputs, event=event, prediction=prediction
|
||||
@ -462,10 +499,11 @@ class RLChain(Chain):
|
||||
f"The selection scorer was not able to score, \
|
||||
and the chain was not able to adjust to this response, error: {e}"
|
||||
)
|
||||
self.metrics.on_feedback(score)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||
self.policy.learn(event=event)
|
||||
self.policy.log(event=event)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||
|
||||
@ -473,7 +511,7 @@ class RLChain(Chain):
|
||||
"""
|
||||
This function should be called to save the state of the learned policy model.
|
||||
"""
|
||||
self.policy.save()
|
||||
self.active_policy.save()
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@ -489,7 +527,7 @@ def is_stringtype_instance(item: Any) -> bool:
|
||||
|
||||
def embed_string_type(
|
||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
"""Helper function to embed a string or an _Embed object."""
|
||||
join_char = ""
|
||||
keep_str = ""
|
||||
@ -513,9 +551,9 @@ def embed_string_type(
|
||||
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
||||
|
||||
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||
"""Helper function to embed a dictionary item."""
|
||||
inner_dict = {}
|
||||
inner_dict: Dict[str, Any] = {}
|
||||
for ns, embed_item in item.items():
|
||||
if isinstance(embed_item, list):
|
||||
inner_dict[ns] = []
|
||||
@ -530,7 +568,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||
def embed_list_type(
|
||||
item: list, model: Any, namespace: Optional[str] = None
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ret_list = []
|
||||
ret_list: List[Dict[str, Union[str, List[str]]]] = []
|
||||
for embed_item in item:
|
||||
if isinstance(embed_item, dict):
|
||||
ret_list.append(embed_dict_type(embed_item, model))
|
||||
@ -540,9 +578,7 @@ def embed_list_type(
|
||||
|
||||
|
||||
def embed(
|
||||
to_embed: Union[
|
||||
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
|
||||
],
|
||||
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||
model: Any,
|
||||
namespace: Optional[str] = None,
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
@ -6,11 +6,11 @@ if TYPE_CHECKING:
|
||||
|
||||
class MetricsTracker:
|
||||
def __init__(self, step: int):
|
||||
self._history = []
|
||||
self._step = step
|
||||
self._i = 0
|
||||
self._num = 0
|
||||
self._denom = 0
|
||||
self._history: List[Dict[str, Union[int, float]]] = []
|
||||
self._step: int = step
|
||||
self._i: int = 0
|
||||
self._num: float = 0
|
||||
self._denom: float = 0
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Sequence, Union
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
@ -22,7 +22,7 @@ class ModelRepository:
|
||||
self.folder = Path(folder)
|
||||
self.model_path = self.folder / "latest.vw"
|
||||
self.with_history = with_history
|
||||
if reset and self.has_history:
|
||||
if reset and self.has_history():
|
||||
logger.warning(
|
||||
"There is non empty history which is recommended to be cleaned up"
|
||||
)
|
||||
@ -44,7 +44,7 @@ class ModelRepository:
|
||||
if self.with_history: # write history
|
||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||
|
||||
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
||||
def load(self, commandline: List[str]) -> "vw.Workspace":
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
model_data = None
|
||||
|
@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import langchain.chains.rl_chain.base as base
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
|
||||
@ -17,7 +16,36 @@ logger = logging.getLogger(__name__)
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder):
|
||||
class PickBestSelected(base.Selected):
|
||||
index: Optional[int]
|
||||
probability: Optional[float]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
probability: Optional[float] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.index = index
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
|
||||
class PickBestEvent(base.Event[PickBestSelected]):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[PickBestSelected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||
"""
|
||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||
|
||||
@ -25,7 +53,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||
""" # noqa E501
|
||||
|
||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
||||
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if model is None:
|
||||
@ -35,7 +63,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
|
||||
self.model = model
|
||||
|
||||
def format(self, event: PickBest.Event) -> str:
|
||||
def format(self, event: PickBestEvent) -> str:
|
||||
"""
|
||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||
"""
|
||||
@ -54,9 +82,14 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
to_select_from_var_name, to_select_from = next(
|
||||
iter(event.to_select_from.items()), (None, None)
|
||||
)
|
||||
|
||||
action_embs = (
|
||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||
if event.to_select_from
|
||||
(
|
||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||
if event.to_select_from
|
||||
else None
|
||||
)
|
||||
if to_select_from
|
||||
else None
|
||||
)
|
||||
|
||||
@ -88,7 +121,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
return example_string[:-1]
|
||||
|
||||
|
||||
class PickBest(base.RLChain):
|
||||
class PickBest(base.RLChain[PickBestEvent]):
|
||||
"""
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
||||
@ -116,38 +149,10 @@ class PickBest(base.RLChain):
|
||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||
""" # noqa E501
|
||||
|
||||
class Selected(base.Selected):
|
||||
index: Optional[int]
|
||||
probability: Optional[float]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
probability: Optional[float] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.index = index
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
class Event(base.Event):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[PickBest.Selected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
vw_cmd = kwargs.get("vw_cmd", [])
|
||||
if not vw_cmd:
|
||||
@ -163,14 +168,16 @@ class PickBest(base.RLChain):
|
||||
raise ValueError(
|
||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||
)
|
||||
|
||||
kwargs["vw_cmd"] = vw_cmd
|
||||
|
||||
feature_embedder = kwargs.get("feature_embedder", None)
|
||||
if not feature_embedder:
|
||||
feature_embedder = PickBestFeatureEmbedder()
|
||||
kwargs["feature_embedder"] = feature_embedder
|
||||
|
||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
if not actions:
|
||||
raise ValueError(
|
||||
@ -193,12 +200,15 @@ class PickBest(base.RLChain):
|
||||
to base the selected of ToSelectFrom on."
|
||||
)
|
||||
|
||||
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
return event
|
||||
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
event: PickBestEvent,
|
||||
prediction: List[Tuple[int, float]],
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
import numpy as np
|
||||
|
||||
prob_sum = sum(prob for _, prob in prediction)
|
||||
@ -208,7 +218,7 @@ class PickBest(base.RLChain):
|
||||
sampled_ap = prediction[sampled_index]
|
||||
sampled_action = sampled_ap[0]
|
||||
sampled_prob = sampled_ap[1]
|
||||
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
|
||||
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
|
||||
event.selected = selected
|
||||
|
||||
# only one key, value pair in event.to_select_from
|
||||
@ -218,23 +228,29 @@ class PickBest(base.RLChain):
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: PickBest.Event
|
||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||
self, llm_response: str, event: PickBestEvent
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
# only one key, value pair in event.to_select_from
|
||||
value = next(iter(event.to_select_from.values()))
|
||||
v = (
|
||||
value[event.selected.index]
|
||||
if event.selected
|
||||
else event.to_select_from.values()
|
||||
)
|
||||
next_chain_inputs.update(
|
||||
{
|
||||
self.selected_based_on_input_key: str(event.based_on),
|
||||
self.selected_input_key: value[event.selected.index],
|
||||
self.selected_input_key: v,
|
||||
}
|
||||
)
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: PickBest.Event, score: Optional[float]
|
||||
) -> Event:
|
||||
event.selected.score = score
|
||||
self, event: PickBestEvent, score: Optional[float]
|
||||
) -> PickBestEvent:
|
||||
if event.selected:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
||||
def _call(
|
||||
@ -249,34 +265,20 @@ class PickBest(base.RLChain):
|
||||
return "rl_chain_pick_best"
|
||||
|
||||
@classmethod
|
||||
def from_chain(
|
||||
cls,
|
||||
llm_chain: Chain,
|
||||
def from_llm(
|
||||
cls: Type[PickBest],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> PickBest:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
if selection_scorer is SENTINEL:
|
||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||
|
||||
return PickBest(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return PickBest.from_chain(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -9,10 +9,10 @@ class VwLogger:
|
||||
if self.path:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def log(self, vw_ex: str):
|
||||
def log(self, vw_ex: str) -> None:
|
||||
if self.path:
|
||||
with open(self.path, "a") as f:
|
||||
f.write(f"{vw_ex}\n\n")
|
||||
|
||||
def logging_enabled(self):
|
||||
def logging_enabled(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def setup():
|
||||
def setup() -> tuple:
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
|
||||
@ -19,7 +21,7 @@ def setup():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_multiple_ToSelectFrom_throws():
|
||||
def test_multiple_ToSelectFrom_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = ["0", "1", "2"]
|
||||
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_missing_basedOn_from_throws():
|
||||
def test_missing_basedOn_from_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = ["0", "1", "2"]
|
||||
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_ToSelectFrom_not_a_list_throws():
|
||||
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
actions = {"actions": ["0", "1", "2"]}
|
||||
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_force():
|
||||
def test_update_with_delayed_score_force() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score():
|
||||
def test_update_with_delayed_score() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_user_defined_scorer():
|
||||
def test_user_defined_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
|
||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||
def score_response(self, inputs, llm_response: str) -> float:
|
||||
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||
score = 200
|
||||
return score
|
||||
|
||||
@ -139,7 +141,7 @@ def test_user_defined_scorer():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings():
|
||||
def test_default_embeddings() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -173,7 +175,7 @@ def test_default_embeddings():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_off():
|
||||
def test_default_embeddings_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -199,7 +201,7 @@ def test_default_embeddings_off():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_no_scorer_specified():
|
||||
def test_default_no_scorer_specified() -> None:
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=[100])
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
||||
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_explicitly_no_scorer():
|
||||
def test_explicitly_no_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_auto_scorer_with_user_defined_llm():
|
||||
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=[300])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_calling_chain_w_reserved_inputs_throws():
|
||||
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -8,10 +8,10 @@ encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_context_throws():
|
||||
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -19,9 +19,9 @@ def test_pickbest_textembedder_missing_context_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws():
|
||||
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -29,11 +29,11 @@ def test_pickbest_textembedder_missing_actions_throws():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_no_label_no_emb():
|
||||
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -41,12 +41,12 @@ def test_pickbest_textembedder_no_label_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -57,14 +57,14 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -123,12 +123,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -136,13 +136,13 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -150,13 +150,13 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||
None
|
||||
):
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
@ -219,8 +221,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
@ -253,8 +255,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -262,7 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
str1 = "0"
|
||||
@ -290,8 +292,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -299,7 +301,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_raw_features_underscored():
|
||||
def test_raw_features_underscored() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "this is a long string"
|
||||
str1_underscored = str1.replace(" ", "_")
|
||||
@ -315,7 +317,7 @@ def test_raw_features_underscored():
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -325,7 +327,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
@ -335,7 +337,7 @@ def test_raw_features_underscored():
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
@ -7,13 +9,13 @@ encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_no_emb():
|
||||
def test_simple_context_str_no_emb() -> None:
|
||||
expected = [{"a_namespace": "test"}]
|
||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_emb():
|
||||
def test_simple_context_str_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
expected = [{"a_namespace": encoded_text + encoded_str1}]
|
||||
@ -28,7 +30,7 @@ def test_simple_context_str_w_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_nested_emb():
|
||||
def test_simple_context_str_w_nested_emb() -> None:
|
||||
# nested embeddings, innermost wins
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
@ -46,13 +48,13 @@ def test_simple_context_str_w_nested_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_no_emb():
|
||||
def test_context_w_namespace_no_emb() -> None:
|
||||
expected = [{"test_namespace": "test"}]
|
||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb():
|
||||
def test_context_w_namespace_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||
@ -67,7 +69,7 @@ def test_context_w_namespace_w_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb2():
|
||||
def test_context_w_namespace_w_emb2() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||
@ -82,7 +84,7 @@ def test_context_w_namespace_w_emb2():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_some_emb():
|
||||
def test_context_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = " ".join(char for char in str2)
|
||||
@ -111,16 +113,17 @@ def test_context_w_namespace_w_some_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_no_emb():
|
||||
def test_simple_action_strlist_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
|
||||
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
||||
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_emb():
|
||||
def test_simple_action_strlist_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -148,7 +151,7 @@ def test_simple_action_strlist_w_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_some_emb():
|
||||
def test_simple_action_strlist_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -181,7 +184,7 @@ def test_simple_action_strlist_w_some_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_no_emb():
|
||||
def test_action_w_namespace_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -204,7 +207,7 @@ def test_action_w_namespace_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb():
|
||||
def test_action_w_namespace_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -246,7 +249,7 @@ def test_action_w_namespace_w_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb2():
|
||||
def test_action_w_namespace_w_emb2() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -292,7 +295,7 @@ def test_action_w_namespace_w_emb2():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_some_emb():
|
||||
def test_action_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -333,7 +336,7 @@ def test_action_w_namespace_w_some_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
@ -384,7 +387,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_no_emb():
|
||||
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
expected = [{"test_namespace": [str1, str2]}]
|
||||
@ -392,7 +395,7 @@ def test_one_namespace_w_list_of_features_no_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_w_some_emb():
|
||||
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = " ".join(char for char in str2)
|
||||
@ -404,24 +407,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_list_features_throws():
|
||||
def test_nested_list_features_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_dict_in_list_throws():
|
||||
def test_dict_in_list_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_dict_throws():
|
||||
def test_nested_dict_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_list_of_tuples_throws():
|
||||
def test_list_of_tuples_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
||||
|
@ -1,3 +1,3 @@
|
||||
class MockEncoder:
|
||||
def encode(self, to_encode):
|
||||
def encode(self, to_encode: str) -> str:
|
||||
return "[encoded]" + to_encode
|
||||
|
Loading…
Reference in New Issue
Block a user