mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-04 06:03:31 +00:00
commit
f7fb083aba
@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
|
|||||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||||
|
|
||||||
|
|
||||||
def configure_logger():
|
def configure_logger() -> None:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
|
@ -3,7 +3,18 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class _BasedOn:
|
class _BasedOn:
|
||||||
def __init__(self, value):
|
def __init__(self, value: Any):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def BasedOn(anything):
|
def BasedOn(anything: Any) -> _BasedOn:
|
||||||
return _BasedOn(anything)
|
return _BasedOn(anything)
|
||||||
|
|
||||||
|
|
||||||
class _ToSelectFrom:
|
class _ToSelectFrom:
|
||||||
def __init__(self, value):
|
def __init__(self, value: Any):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def ToSelectFrom(anything):
|
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
||||||
if not isinstance(anything, list):
|
if not isinstance(anything, list):
|
||||||
raise ValueError("ToSelectFrom must be a list to select from")
|
raise ValueError("ToSelectFrom must be a list to select from")
|
||||||
return _ToSelectFrom(anything)
|
return _ToSelectFrom(anything)
|
||||||
|
|
||||||
|
|
||||||
class _Embed:
|
class _Embed:
|
||||||
def __init__(self, value, keep=False):
|
def __init__(self, value: Any, keep: bool = False):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.keep = keep
|
self.keep = keep
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def Embed(anything, keep=False):
|
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||||
if isinstance(anything, _ToSelectFrom):
|
if isinstance(anything, _ToSelectFrom):
|
||||||
return ToSelectFrom(Embed(anything.value, keep=keep))
|
return ToSelectFrom(Embed(anything.value, keep=keep))
|
||||||
elif isinstance(anything, _BasedOn):
|
elif isinstance(anything, _BasedOn):
|
||||||
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
|
|||||||
return _Embed(anything, keep=keep)
|
return _Embed(anything, keep=keep)
|
||||||
|
|
||||||
|
|
||||||
def EmbedAndKeep(anything):
|
def EmbedAndKeep(anything: Any) -> Any:
|
||||||
return Embed(anything, keep=True)
|
return Embed(anything, keep=True)
|
||||||
|
|
||||||
|
|
||||||
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
|
|||||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||||
|
|
||||||
|
|
||||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
|
||||||
to_select_from = {
|
to_select_from = {
|
||||||
k: inputs[k].value
|
k: inputs[k].value
|
||||||
for k in inputs.keys()
|
for k in inputs.keys()
|
||||||
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
|||||||
return based_on, to_select_from
|
return based_on, to_select_from
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
|
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||||
@ -134,29 +145,38 @@ class Selected(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Event(ABC):
|
TSelected = TypeVar("TSelected", bound=Selected)
|
||||||
inputs: Dict[str, Any]
|
|
||||||
selected: Optional[Selected]
|
|
||||||
|
|
||||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
|
|
||||||
|
class Event(Generic[TSelected], ABC):
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
selected: Optional[TSelected]
|
||||||
|
|
||||||
|
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.selected = selected
|
self.selected = selected
|
||||||
|
|
||||||
|
|
||||||
|
TEvent = TypeVar("TEvent", bound=Event)
|
||||||
|
|
||||||
|
|
||||||
class Policy(ABC):
|
class Policy(ABC):
|
||||||
@abstractmethod
|
def __init__(self, **kwargs: Any):
|
||||||
def predict(self, event: Event) -> Any:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def learn(self, event: Event):
|
def predict(self, event: TEvent) -> Any:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log(self, event: Event):
|
def learn(self, event: TEvent) -> None:
|
||||||
pass
|
...
|
||||||
|
|
||||||
def save(self):
|
@abstractmethod
|
||||||
|
def log(self, event: TEvent) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -164,11 +184,11 @@ class VwPolicy(Policy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_repo: ModelRepository,
|
model_repo: ModelRepository,
|
||||||
vw_cmd: Sequence[str],
|
vw_cmd: List[str],
|
||||||
feature_embedder: Embedder,
|
feature_embedder: Embedder,
|
||||||
vw_logger: VwLogger,
|
vw_logger: VwLogger,
|
||||||
*args,
|
*args: Any,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.model_repo = model_repo
|
self.model_repo = model_repo
|
||||||
@ -176,7 +196,7 @@ class VwPolicy(Policy):
|
|||||||
self.feature_embedder = feature_embedder
|
self.feature_embedder = feature_embedder
|
||||||
self.vw_logger = vw_logger
|
self.vw_logger = vw_logger
|
||||||
|
|
||||||
def predict(self, event: Event) -> Any:
|
def predict(self, event: TEvent) -> Any:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
text_parser = vw.TextFormatParser(self.workspace)
|
text_parser = vw.TextFormatParser(self.workspace)
|
||||||
@ -184,7 +204,7 @@ class VwPolicy(Policy):
|
|||||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||||
)
|
)
|
||||||
|
|
||||||
def learn(self, event: Event):
|
def learn(self, event: TEvent) -> None:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
vw_ex = self.feature_embedder.format(event)
|
vw_ex = self.feature_embedder.format(event)
|
||||||
@ -192,19 +212,19 @@ class VwPolicy(Policy):
|
|||||||
multi_ex = parse_lines(text_parser, vw_ex)
|
multi_ex = parse_lines(text_parser, vw_ex)
|
||||||
self.workspace.learn_one(multi_ex)
|
self.workspace.learn_one(multi_ex)
|
||||||
|
|
||||||
def log(self, event: Event):
|
def log(self, event: TEvent) -> None:
|
||||||
if self.vw_logger.logging_enabled():
|
if self.vw_logger.logging_enabled():
|
||||||
vw_ex = self.feature_embedder.format(event)
|
vw_ex = self.feature_embedder.format(event)
|
||||||
self.vw_logger.log(vw_ex)
|
self.vw_logger.log(vw_ex)
|
||||||
|
|
||||||
def save(self):
|
def save(self) -> None:
|
||||||
self.model_repo.save()
|
self.model_repo.save(self.workspace)
|
||||||
|
|
||||||
|
|
||||||
class Embedder(ABC):
|
class Embedder(Generic[TEvent], ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format(self, event: Event) -> str:
|
def format(self, event: TEvent) -> str:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class SelectionScorer(ABC, BaseModel):
|
class SelectionScorer(ABC, BaseModel):
|
||||||
@ -212,11 +232,11 @@ class SelectionScorer(ABC, BaseModel):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||||
llm_chain: Union[LLMChain, None] = None
|
llm_chain: LLMChain
|
||||||
prompt: Union[BasePromptTemplate, None] = None
|
prompt: Union[BasePromptTemplate, None] = None
|
||||||
scoring_criteria_template_str: Optional[str] = None
|
scoring_criteria_template_str: Optional[str] = None
|
||||||
|
|
||||||
@ -243,7 +263,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
return chat_prompt
|
return chat_prompt
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_prompt_and_llm_chain(cls, values):
|
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
llm = values.get("llm")
|
llm = values.get("llm")
|
||||||
prompt = values.get("prompt")
|
prompt = values.get("prompt")
|
||||||
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
||||||
@ -275,7 +295,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RLChain(Chain):
|
class RLChain(Chain, Generic[TEvent]):
|
||||||
"""
|
"""
|
||||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||||
|
|
||||||
@ -292,7 +312,7 @@ class RLChain(Chain):
|
|||||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||||
- policy (VwPolicy): Policy used by the chain.
|
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||||
|
|
||||||
@ -300,12 +320,24 @@ class RLChain(Chain):
|
|||||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class _NoOpPolicy(Policy):
|
||||||
|
"""Placeholder policy that does nothing"""
|
||||||
|
|
||||||
|
def predict(self, event: TEvent) -> Any:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def learn(self, event: TEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log(self, event: TEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
llm_chain: Chain
|
llm_chain: Chain
|
||||||
|
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
selection_scorer: Union[SelectionScorer, None]
|
selection_scorer: Union[SelectionScorer, None]
|
||||||
policy: Optional[Policy]
|
active_policy: Policy = _NoOpPolicy()
|
||||||
auto_embed: bool = True
|
auto_embed: bool = True
|
||||||
selected_input_key = "rl_chain_selected"
|
selected_input_key = "rl_chain_selected"
|
||||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||||
@ -314,14 +346,14 @@ class RLChain(Chain):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_embedder: Embedder,
|
feature_embedder: Embedder,
|
||||||
model_save_dir="./",
|
model_save_dir: str = "./",
|
||||||
reset_model=False,
|
reset_model: bool = False,
|
||||||
vw_cmd=None,
|
vw_cmd: Optional[List[str]] = None,
|
||||||
policy=VwPolicy,
|
policy: Type[Policy] = VwPolicy,
|
||||||
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
||||||
metrics_step=-1,
|
metrics_step: int = -1,
|
||||||
*args,
|
*args: Any,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.selection_scorer is None:
|
if self.selection_scorer is None:
|
||||||
@ -330,14 +362,17 @@ class RLChain(Chain):
|
|||||||
reinforcement learning will be done in the RL chain \
|
reinforcement learning will be done in the RL chain \
|
||||||
unless update_with_delayed_score is called."
|
unless update_with_delayed_score is called."
|
||||||
)
|
)
|
||||||
self.policy = policy(
|
|
||||||
model_repo=ModelRepository(
|
if isinstance(self.active_policy, RLChain._NoOpPolicy):
|
||||||
model_save_dir, with_history=True, reset=reset_model
|
self.active_policy = policy(
|
||||||
),
|
model_repo=ModelRepository(
|
||||||
vw_cmd=vw_cmd or [],
|
model_save_dir, with_history=True, reset=reset_model
|
||||||
feature_embedder=feature_embedder,
|
),
|
||||||
vw_logger=VwLogger(vw_logs),
|
vw_cmd=vw_cmd or [],
|
||||||
)
|
feature_embedder=feature_embedder,
|
||||||
|
vw_logger=VwLogger(vw_logs),
|
||||||
|
)
|
||||||
|
|
||||||
self.metrics = MetricsTracker(step=metrics_step)
|
self.metrics = MetricsTracker(step=metrics_step)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -374,29 +409,29 @@ class RLChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: Any
|
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
) -> Tuple[Dict[str, Any], TEvent]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_llm_before_scoring(
|
def _call_after_llm_before_scoring(
|
||||||
self, llm_response: str, event: Event
|
self, llm_response: str, event: TEvent
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
) -> Tuple[Dict[str, Any], TEvent]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: Event, score: Optional[float]
|
self, event: TEvent, score: Optional[float]
|
||||||
) -> Event:
|
) -> TEvent:
|
||||||
pass
|
...
|
||||||
|
|
||||||
def update_with_delayed_score(
|
def update_with_delayed_score(
|
||||||
self, score: float, event: Event, force_score=False
|
self, score: float, event: TEvent, force_score: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the learned policy with the score provided.
|
Updates the learned policy with the score provided.
|
||||||
@ -407,10 +442,11 @@ class RLChain(Chain):
|
|||||||
"The selection scorer is set, and force_score was not set to True. \
|
"The selection scorer is set, and force_score was not set to True. \
|
||||||
Please set force_score=True to use this function."
|
Please set force_score=True to use this function."
|
||||||
)
|
)
|
||||||
self.metrics.on_feedback(score)
|
if self.metrics:
|
||||||
|
self.metrics.on_feedback(score)
|
||||||
self._call_after_scoring_before_learning(event=event, score=score)
|
self._call_after_scoring_before_learning(event=event, score=score)
|
||||||
self.policy.learn(event=event)
|
self.active_policy.learn(event=event)
|
||||||
self.policy.log(event=event)
|
self.active_policy.log(event=event)
|
||||||
|
|
||||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||||
"""
|
"""
|
||||||
@ -422,15 +458,16 @@ class RLChain(Chain):
|
|||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
|
||||||
if self.auto_embed:
|
if self.auto_embed:
|
||||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||||
|
|
||||||
event = self._call_before_predict(inputs=inputs)
|
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||||
prediction = self.policy.predict(event=event)
|
prediction = self.active_policy.predict(event=event)
|
||||||
self.metrics.on_decision()
|
if self.metrics:
|
||||||
|
self.metrics.on_decision()
|
||||||
|
|
||||||
next_chain_inputs, event = self._call_after_predict_before_llm(
|
next_chain_inputs, event = self._call_after_predict_before_llm(
|
||||||
inputs=inputs, event=event, prediction=prediction
|
inputs=inputs, event=event, prediction=prediction
|
||||||
@ -462,10 +499,11 @@ class RLChain(Chain):
|
|||||||
f"The selection scorer was not able to score, \
|
f"The selection scorer was not able to score, \
|
||||||
and the chain was not able to adjust to this response, error: {e}"
|
and the chain was not able to adjust to this response, error: {e}"
|
||||||
)
|
)
|
||||||
self.metrics.on_feedback(score)
|
if self.metrics:
|
||||||
|
self.metrics.on_feedback(score)
|
||||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||||
self.policy.learn(event=event)
|
self.active_policy.learn(event=event)
|
||||||
self.policy.log(event=event)
|
self.active_policy.log(event=event)
|
||||||
|
|
||||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||||
|
|
||||||
@ -473,7 +511,7 @@ class RLChain(Chain):
|
|||||||
"""
|
"""
|
||||||
This function should be called to save the state of the learned policy model.
|
This function should be called to save the state of the learned policy model.
|
||||||
"""
|
"""
|
||||||
self.policy.save()
|
self.active_policy.save()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
@ -489,7 +527,7 @@ def is_stringtype_instance(item: Any) -> bool:
|
|||||||
|
|
||||||
def embed_string_type(
|
def embed_string_type(
|
||||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Union[str, List[str]]]:
|
||||||
"""Helper function to embed a string or an _Embed object."""
|
"""Helper function to embed a string or an _Embed object."""
|
||||||
join_char = ""
|
join_char = ""
|
||||||
keep_str = ""
|
keep_str = ""
|
||||||
@ -513,9 +551,9 @@ def embed_string_type(
|
|||||||
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
||||||
|
|
||||||
|
|
||||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||||
"""Helper function to embed a dictionary item."""
|
"""Helper function to embed a dictionary item."""
|
||||||
inner_dict = {}
|
inner_dict: Dict[str, Any] = {}
|
||||||
for ns, embed_item in item.items():
|
for ns, embed_item in item.items():
|
||||||
if isinstance(embed_item, list):
|
if isinstance(embed_item, list):
|
||||||
inner_dict[ns] = []
|
inner_dict[ns] = []
|
||||||
@ -530,7 +568,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
|||||||
def embed_list_type(
|
def embed_list_type(
|
||||||
item: list, model: Any, namespace: Optional[str] = None
|
item: list, model: Any, namespace: Optional[str] = None
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
ret_list = []
|
ret_list: List[Dict[str, Union[str, List[str]]]] = []
|
||||||
for embed_item in item:
|
for embed_item in item:
|
||||||
if isinstance(embed_item, dict):
|
if isinstance(embed_item, dict):
|
||||||
ret_list.append(embed_dict_type(embed_item, model))
|
ret_list.append(embed_dict_type(embed_item, model))
|
||||||
@ -540,9 +578,7 @@ def embed_list_type(
|
|||||||
|
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
to_embed: Union[
|
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||||
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
|
|
||||||
],
|
|
||||||
model: Any,
|
model: Any,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -6,11 +6,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class MetricsTracker:
|
class MetricsTracker:
|
||||||
def __init__(self, step: int):
|
def __init__(self, step: int):
|
||||||
self._history = []
|
self._history: List[Dict[str, Union[int, float]]] = []
|
||||||
self._step = step
|
self._step: int = step
|
||||||
self._i = 0
|
self._i: int = 0
|
||||||
self._num = 0
|
self._num: float = 0
|
||||||
self._denom = 0
|
self._denom: float = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def score(self) -> float:
|
def score(self) -> float:
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Sequence, Union
|
from typing import TYPE_CHECKING, List, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
@ -22,7 +22,7 @@ class ModelRepository:
|
|||||||
self.folder = Path(folder)
|
self.folder = Path(folder)
|
||||||
self.model_path = self.folder / "latest.vw"
|
self.model_path = self.folder / "latest.vw"
|
||||||
self.with_history = with_history
|
self.with_history = with_history
|
||||||
if reset and self.has_history:
|
if reset and self.has_history():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"There is non empty history which is recommended to be cleaned up"
|
"There is non empty history which is recommended to be cleaned up"
|
||||||
)
|
)
|
||||||
@ -44,7 +44,7 @@ class ModelRepository:
|
|||||||
if self.with_history: # write history
|
if self.with_history: # write history
|
||||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||||
|
|
||||||
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
def load(self, commandline: List[str]) -> "vw.Workspace":
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
model_data = None
|
model_data = None
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import langchain.chains.rl_chain.base as base
|
import langchain.chains.rl_chain.base as base
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
|
|
||||||
@ -17,7 +16,36 @@ logger = logging.getLogger(__name__)
|
|||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
class PickBestFeatureEmbedder(base.Embedder):
|
class PickBestSelected(base.Selected):
|
||||||
|
index: Optional[int]
|
||||||
|
probability: Optional[float]
|
||||||
|
score: Optional[float]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
probability: Optional[float] = None,
|
||||||
|
score: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.index = index
|
||||||
|
self.probability = probability
|
||||||
|
self.score = score
|
||||||
|
|
||||||
|
|
||||||
|
class PickBestEvent(base.Event[PickBestSelected]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
to_select_from: Dict[str, Any],
|
||||||
|
based_on: Dict[str, Any],
|
||||||
|
selected: Optional[PickBestSelected] = None,
|
||||||
|
):
|
||||||
|
super().__init__(inputs=inputs, selected=selected)
|
||||||
|
self.to_select_from = to_select_from
|
||||||
|
self.based_on = based_on
|
||||||
|
|
||||||
|
|
||||||
|
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||||
"""
|
"""
|
||||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||||
|
|
||||||
@ -25,7 +53,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||||
""" # noqa E501
|
""" # noqa E501
|
||||||
|
|
||||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -35,7 +63,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def format(self, event: PickBest.Event) -> str:
|
def format(self, event: PickBestEvent) -> str:
|
||||||
"""
|
"""
|
||||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||||
"""
|
"""
|
||||||
@ -54,9 +82,14 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
to_select_from_var_name, to_select_from = next(
|
to_select_from_var_name, to_select_from = next(
|
||||||
iter(event.to_select_from.items()), (None, None)
|
iter(event.to_select_from.items()), (None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
action_embs = (
|
action_embs = (
|
||||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
(
|
||||||
if event.to_select_from
|
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||||
|
if event.to_select_from
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if to_select_from
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,7 +121,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
return example_string[:-1]
|
return example_string[:-1]
|
||||||
|
|
||||||
|
|
||||||
class PickBest(base.RLChain):
|
class PickBest(base.RLChain[PickBestEvent]):
|
||||||
"""
|
"""
|
||||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||||
|
|
||||||
@ -116,38 +149,10 @@ class PickBest(base.RLChain):
|
|||||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||||
""" # noqa E501
|
""" # noqa E501
|
||||||
|
|
||||||
class Selected(base.Selected):
|
|
||||||
index: Optional[int]
|
|
||||||
probability: Optional[float]
|
|
||||||
score: Optional[float]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
index: Optional[int] = None,
|
|
||||||
probability: Optional[float] = None,
|
|
||||||
score: Optional[float] = None,
|
|
||||||
):
|
|
||||||
self.index = index
|
|
||||||
self.probability = probability
|
|
||||||
self.score = score
|
|
||||||
|
|
||||||
class Event(base.Event):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
to_select_from: Dict[str, Any],
|
|
||||||
based_on: Dict[str, Any],
|
|
||||||
selected: Optional[PickBest.Selected] = None,
|
|
||||||
):
|
|
||||||
super().__init__(inputs=inputs, selected=selected)
|
|
||||||
self.to_select_from = to_select_from
|
|
||||||
self.based_on = based_on
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
*args: Any,
|
||||||
*args,
|
**kwargs: Any,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
vw_cmd = kwargs.get("vw_cmd", [])
|
vw_cmd = kwargs.get("vw_cmd", [])
|
||||||
if not vw_cmd:
|
if not vw_cmd:
|
||||||
@ -163,14 +168,16 @@ class PickBest(base.RLChain):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs["vw_cmd"] = vw_cmd
|
kwargs["vw_cmd"] = vw_cmd
|
||||||
|
|
||||||
|
feature_embedder = kwargs.get("feature_embedder", None)
|
||||||
if not feature_embedder:
|
if not feature_embedder:
|
||||||
feature_embedder = PickBestFeatureEmbedder()
|
feature_embedder = PickBestFeatureEmbedder()
|
||||||
|
kwargs["feature_embedder"] = feature_embedder
|
||||||
|
|
||||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||||
if not actions:
|
if not actions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -193,12 +200,15 @@ class PickBest(base.RLChain):
|
|||||||
to base the selected of ToSelectFrom on."
|
to base the selected of ToSelectFrom on."
|
||||||
)
|
)
|
||||||
|
|
||||||
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
|
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
self,
|
||||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
inputs: Dict[str, Any],
|
||||||
|
event: PickBestEvent,
|
||||||
|
prediction: List[Tuple[int, float]],
|
||||||
|
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
prob_sum = sum(prob for _, prob in prediction)
|
prob_sum = sum(prob for _, prob in prediction)
|
||||||
@ -208,7 +218,7 @@ class PickBest(base.RLChain):
|
|||||||
sampled_ap = prediction[sampled_index]
|
sampled_ap = prediction[sampled_index]
|
||||||
sampled_action = sampled_ap[0]
|
sampled_action = sampled_ap[0]
|
||||||
sampled_prob = sampled_ap[1]
|
sampled_prob = sampled_ap[1]
|
||||||
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
|
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
|
||||||
event.selected = selected
|
event.selected = selected
|
||||||
|
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
@ -218,23 +228,29 @@ class PickBest(base.RLChain):
|
|||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_llm_before_scoring(
|
def _call_after_llm_before_scoring(
|
||||||
self, llm_response: str, event: PickBest.Event
|
self, llm_response: str, event: PickBestEvent
|
||||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||||
next_chain_inputs = event.inputs.copy()
|
next_chain_inputs = event.inputs.copy()
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
value = next(iter(event.to_select_from.values()))
|
value = next(iter(event.to_select_from.values()))
|
||||||
|
v = (
|
||||||
|
value[event.selected.index]
|
||||||
|
if event.selected
|
||||||
|
else event.to_select_from.values()
|
||||||
|
)
|
||||||
next_chain_inputs.update(
|
next_chain_inputs.update(
|
||||||
{
|
{
|
||||||
self.selected_based_on_input_key: str(event.based_on),
|
self.selected_based_on_input_key: str(event.based_on),
|
||||||
self.selected_input_key: value[event.selected.index],
|
self.selected_input_key: v,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: PickBest.Event, score: Optional[float]
|
self, event: PickBestEvent, score: Optional[float]
|
||||||
) -> Event:
|
) -> PickBestEvent:
|
||||||
event.selected.score = score
|
if event.selected:
|
||||||
|
event.selected.score = score
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
@ -249,34 +265,20 @@ class PickBest(base.RLChain):
|
|||||||
return "rl_chain_pick_best"
|
return "rl_chain_pick_best"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chain(
|
def from_llm(
|
||||||
cls,
|
cls: Type[PickBest],
|
||||||
llm_chain: Chain,
|
llm: BaseLanguageModel,
|
||||||
prompt: BasePromptTemplate,
|
prompt: BasePromptTemplate,
|
||||||
selection_scorer=SENTINEL,
|
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
) -> PickBest:
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
if selection_scorer is SENTINEL:
|
if selection_scorer is SENTINEL:
|
||||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||||
|
|
||||||
return PickBest(
|
return PickBest(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
selection_scorer=selection_scorer,
|
selection_scorer=selection_scorer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate,
|
|
||||||
selection_scorer=SENTINEL,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return PickBest.from_chain(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
prompt=prompt,
|
|
||||||
selection_scorer=selection_scorer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
@ -9,10 +9,10 @@ class VwLogger:
|
|||||||
if self.path:
|
if self.path:
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def log(self, vw_ex: str):
|
def log(self, vw_ex: str) -> None:
|
||||||
if self.path:
|
if self.path:
|
||||||
with open(self.path, "a") as f:
|
with open(self.path, "a") as f:
|
||||||
f.write(f"{vw_ex}\n\n")
|
f.write(f"{vw_ex}\n\n")
|
||||||
|
|
||||||
def logging_enabled(self):
|
def logging_enabled(self) -> bool:
|
||||||
return bool(self.path)
|
return bool(self.path)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
@ -10,7 +12,7 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def setup():
|
def setup() -> tuple:
|
||||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||||
|
|
||||||
@ -19,7 +21,7 @@ def setup():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_multiple_ToSelectFrom_throws():
|
def test_multiple_ToSelectFrom_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_missing_basedOn_from_throws():
|
def test_missing_basedOn_from_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_ToSelectFrom_not_a_list_throws():
|
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
actions = {"actions": ["0", "1", "2"]}
|
actions = {"actions": ["0", "1", "2"]}
|
||||||
@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||||
@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score_force():
|
def test_update_with_delayed_score_force() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||||
@ -99,7 +101,7 @@ def test_update_with_delayed_score_force():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_update_with_delayed_score():
|
def test_update_with_delayed_score() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||||
@ -117,11 +119,11 @@ def test_update_with_delayed_score():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_user_defined_scorer():
|
def test_user_defined_scorer() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
|
|
||||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||||
def score_response(self, inputs, llm_response: str) -> float:
|
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||||
score = 200
|
score = 200
|
||||||
return score
|
return score
|
||||||
|
|
||||||
@ -139,7 +141,7 @@ def test_user_defined_scorer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings():
|
def test_default_embeddings() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -173,7 +175,7 @@ def test_default_embeddings():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_off():
|
def test_default_embeddings_off() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -199,7 +201,7 @@ def test_default_embeddings_off():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_no_scorer_specified():
|
def test_default_no_scorer_specified() -> None:
|
||||||
_, PROMPT = setup()
|
_, PROMPT = setup()
|
||||||
chain_llm = FakeListChatModel(responses=[100])
|
chain_llm = FakeListChatModel(responses=[100])
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
|
||||||
@ -249,7 +251,7 @@ def test_default_no_scorer_specified():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_explicitly_no_scorer():
|
def test_explicitly_no_scorer() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, selection_scorer=None
|
llm=llm, prompt=PROMPT, selection_scorer=None
|
||||||
@ -265,7 +267,7 @@ def test_explicitly_no_scorer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_auto_scorer_with_user_defined_llm():
|
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
scorer_llm = FakeListChatModel(responses=[300])
|
scorer_llm = FakeListChatModel(responses=[300])
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_calling_chain_w_reserved_inputs_throws():
|
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -8,10 +8,10 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_context_throws():
|
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_action = {"action": ["0", "1", "2"]}
|
named_action = {"action": ["0", "1", "2"]}
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_action, based_on={}
|
inputs={}, to_select_from=named_action, based_on={}
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -19,9 +19,9 @@ def test_pickbest_textembedder_missing_context_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_actions_throws():
|
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -29,11 +29,11 @@ def test_pickbest_textembedder_missing_actions_throws():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_no_label_no_emb():
|
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -41,12 +41,12 @@ def test_pickbest_textembedder_no_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={},
|
inputs={},
|
||||||
to_select_from=named_actions,
|
to_select_from=named_actions,
|
||||||
based_on={"context": "context"},
|
based_on={"context": "context"},
|
||||||
@ -57,14 +57,14 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = (
|
expected = (
|
||||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||||
)
|
)
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={},
|
inputs={},
|
||||||
to_select_from=named_actions,
|
to_select_from=named_actions,
|
||||||
based_on={"context": "context"},
|
based_on={"context": "context"},
|
||||||
@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
|||||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
|||||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -123,12 +123,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -136,13 +136,13 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -150,13 +150,13 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -219,8 +221,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -253,8 +255,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
|||||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -262,7 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -290,8 +292,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
|||||||
}
|
}
|
||||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||||
|
|
||||||
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -299,7 +301,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_raw_features_underscored():
|
def test_raw_features_underscored() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "this is a long string"
|
str1 = "this is a long string"
|
||||||
str1_underscored = str1.replace(" ", "_")
|
str1_underscored = str1.replace(" ", "_")
|
||||||
@ -315,7 +317,7 @@ def test_raw_features_underscored():
|
|||||||
expected_no_embed = (
|
expected_no_embed = (
|
||||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||||
)
|
)
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -325,7 +327,7 @@ def test_raw_features_underscored():
|
|||||||
named_actions = {"action": rl_chain.Embed([str1])}
|
named_actions = {"action": rl_chain.Embed([str1])}
|
||||||
context = {"context": rl_chain.Embed(ctx_str)}
|
context = {"context": rl_chain.Embed(ctx_str)}
|
||||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
@ -335,7 +337,7 @@ def test_raw_features_underscored():
|
|||||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_actions, based_on=context
|
inputs={}, to_select_from=named_actions, based_on=context
|
||||||
)
|
)
|
||||||
vw_ex_str = feature_embedder.format(event)
|
vw_ex_str = feature_embedder.format(event)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
@ -7,13 +9,13 @@ encoded_text = "[ e n c o d e d ] "
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_no_emb():
|
def test_simple_context_str_no_emb() -> None:
|
||||||
expected = [{"a_namespace": "test"}]
|
expected = [{"a_namespace": "test"}]
|
||||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_emb():
|
def test_simple_context_str_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"a_namespace": encoded_text + encoded_str1}]
|
expected = [{"a_namespace": encoded_text + encoded_str1}]
|
||||||
@ -28,7 +30,7 @@ def test_simple_context_str_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_nested_emb():
|
def test_simple_context_str_w_nested_emb() -> None:
|
||||||
# nested embeddings, innermost wins
|
# nested embeddings, innermost wins
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
@ -46,13 +48,13 @@ def test_simple_context_str_w_nested_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_no_emb():
|
def test_context_w_namespace_no_emb() -> None:
|
||||||
expected = [{"test_namespace": "test"}]
|
expected = [{"test_namespace": "test"}]
|
||||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb():
|
def test_context_w_namespace_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||||
@ -67,7 +69,7 @@ def test_context_w_namespace_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb2():
|
def test_context_w_namespace_w_emb2() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
expected = [{"test_namespace": encoded_text + encoded_str1}]
|
||||||
@ -82,7 +84,7 @@ def test_context_w_namespace_w_emb2():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_some_emb():
|
def test_context_w_namespace_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = " ".join(char for char in str2)
|
encoded_str2 = " ".join(char for char in str2)
|
||||||
@ -111,16 +113,17 @@ def test_context_w_namespace_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_no_emb():
|
def test_simple_action_strlist_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||||
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
|
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
||||||
|
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_emb():
|
def test_simple_action_strlist_w_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -148,7 +151,7 @@ def test_simple_action_strlist_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_some_emb():
|
def test_simple_action_strlist_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -181,7 +184,7 @@ def test_simple_action_strlist_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_no_emb():
|
def test_action_w_namespace_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -204,7 +207,7 @@ def test_action_w_namespace_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb():
|
def test_action_w_namespace_w_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -246,7 +249,7 @@ def test_action_w_namespace_w_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb2():
|
def test_action_w_namespace_w_emb2() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -292,7 +295,7 @@ def test_action_w_namespace_w_emb2():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_some_emb():
|
def test_action_w_namespace_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -333,7 +336,7 @@ def test_action_w_namespace_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
@ -384,7 +387,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_no_emb():
|
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
expected = [{"test_namespace": [str1, str2]}]
|
expected = [{"test_namespace": [str1, str2]}]
|
||||||
@ -392,7 +395,7 @@ def test_one_namespace_w_list_of_features_no_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_w_some_emb():
|
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = " ".join(char for char in str2)
|
encoded_str2 = " ".join(char for char in str2)
|
||||||
@ -404,24 +407,24 @@ def test_one_namespace_w_list_of_features_w_some_emb():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_list_features_throws():
|
def test_nested_list_features_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_dict_in_list_throws():
|
def test_dict_in_list_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_dict_throws():
|
def test_nested_dict_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_list_of_tuples_throws():
|
def test_list_of_tuples_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
class MockEncoder:
|
class MockEncoder:
|
||||||
def encode(self, to_encode):
|
def encode(self, to_encode: str) -> str:
|
||||||
return "[encoded]" + to_encode
|
return "[encoded]" + to_encode
|
||||||
|
Loading…
Reference in New Issue
Block a user