mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
move everything into experimental
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
|
||||
from langchain_experimental.rl_chain.base import (
|
||||
AutoSelectionScorer,
|
||||
BasedOn,
|
||||
Embed,
|
||||
Embedder,
|
||||
Policy,
|
||||
SelectionScorer,
|
||||
ToSelectFrom,
|
||||
VwPolicy,
|
||||
embed,
|
||||
stringify_embedding,
|
||||
)
|
||||
from langchain_experimental.rl_chain.pick_best_chain import (
|
||||
PickBest,
|
||||
PickBestEvent,
|
||||
PickBestFeatureEmbedder,
|
||||
PickBestRandomPolicy,
|
||||
PickBestSelected,
|
||||
)
|
||||
|
||||
|
||||
def configure_logger() -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
ch.setFormatter(formatter)
|
||||
ch.setLevel(logging.INFO)
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
configure_logger()
|
||||
|
||||
__all__ = [
|
||||
"PickBest",
|
||||
"PickBestEvent",
|
||||
"PickBestSelected",
|
||||
"PickBestFeatureEmbedder",
|
||||
"PickBestRandomPolicy",
|
||||
"Embed",
|
||||
"BasedOn",
|
||||
"ToSelectFrom",
|
||||
"SelectionScorer",
|
||||
"AutoSelectionScorer",
|
||||
"Embedder",
|
||||
"Policy",
|
||||
"VwPolicy",
|
||||
"embed",
|
||||
"stringify_embedding",
|
||||
]
|
634
libs/experimental/langchain_experimental/rl_chain/base.py
Normal file
634
libs/experimental/langchain_experimental/rl_chain/base.py
Normal file
@@ -0,0 +1,634 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_experimental.rl_chain.metrics import (
|
||||
MetricsTrackerAverage,
|
||||
MetricsTrackerRollingWindow,
|
||||
)
|
||||
from langchain_experimental.rl_chain.model_repository import ModelRepository
|
||||
from langchain_experimental.rl_chain.vw_logger import VwLogger
|
||||
from langchain.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BasedOn:
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def BasedOn(anything: Any) -> _BasedOn:
|
||||
return _BasedOn(anything)
|
||||
|
||||
|
||||
class _ToSelectFrom:
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
||||
if not isinstance(anything, list):
|
||||
raise ValueError("ToSelectFrom must be a list to select from")
|
||||
return _ToSelectFrom(anything)
|
||||
|
||||
|
||||
class _Embed:
|
||||
def __init__(self, value: Any, keep: bool = False):
|
||||
self.value = value
|
||||
self.keep = keep
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||
if isinstance(anything, _ToSelectFrom):
|
||||
return ToSelectFrom(Embed(anything.value, keep=keep))
|
||||
elif isinstance(anything, _BasedOn):
|
||||
return BasedOn(Embed(anything.value, keep=keep))
|
||||
if isinstance(anything, list):
|
||||
return [Embed(v, keep=keep) for v in anything]
|
||||
elif isinstance(anything, dict):
|
||||
return {k: Embed(v, keep=keep) for k, v in anything.items()}
|
||||
elif isinstance(anything, _Embed):
|
||||
return anything
|
||||
return _Embed(anything, keep=keep)
|
||||
|
||||
|
||||
def EmbedAndKeep(anything: Any) -> Any:
|
||||
return Embed(anything, keep=True)
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def stringify_embedding(embedding: List) -> str:
|
||||
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
||||
|
||||
|
||||
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||
|
||||
|
||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
|
||||
to_select_from = {
|
||||
k: inputs[k].value
|
||||
for k in inputs.keys()
|
||||
if isinstance(inputs[k], _ToSelectFrom)
|
||||
}
|
||||
|
||||
if not to_select_from:
|
||||
raise ValueError(
|
||||
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa: E501
|
||||
)
|
||||
|
||||
based_on = {
|
||||
k: inputs[k].value if isinstance(inputs[k].value, list) else [inputs[k].value]
|
||||
for k in inputs.keys()
|
||||
if isinstance(inputs[k], _BasedOn)
|
||||
}
|
||||
|
||||
return based_on, to_select_from
|
||||
|
||||
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||
""" # noqa: E501
|
||||
|
||||
next_inputs = inputs.copy()
|
||||
for k, v in next_inputs.items():
|
||||
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
|
||||
if not isinstance(v.value, _Embed):
|
||||
next_inputs[k].value = EmbedAndKeep(v.value)
|
||||
return next_inputs
|
||||
|
||||
|
||||
# end helper functions
|
||||
|
||||
|
||||
class Selected(ABC):
|
||||
pass
|
||||
|
||||
|
||||
TSelected = TypeVar("TSelected", bound=Selected)
|
||||
|
||||
|
||||
class Event(Generic[TSelected], ABC):
|
||||
inputs: Dict[str, Any]
|
||||
selected: Optional[TSelected]
|
||||
|
||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
|
||||
self.inputs = inputs
|
||||
self.selected = selected
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=Event)
|
||||
|
||||
|
||||
class Policy(Generic[TEvent], ABC):
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class VwPolicy(Policy):
|
||||
def __init__(
|
||||
self,
|
||||
model_repo: ModelRepository,
|
||||
vw_cmd: List[str],
|
||||
feature_embedder: Embedder,
|
||||
vw_logger: VwLogger,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_repo = model_repo
|
||||
self.workspace = self.model_repo.load(vw_cmd)
|
||||
self.feature_embedder = feature_embedder
|
||||
self.vw_logger = vw_logger
|
||||
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
return self.workspace.predict_one(
|
||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||
)
|
||||
|
||||
def learn(self, event: TEvent) -> None:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
multi_ex = parse_lines(text_parser, vw_ex)
|
||||
self.workspace.learn_one(multi_ex)
|
||||
|
||||
def log(self, event: TEvent) -> None:
|
||||
if self.vw_logger.logging_enabled():
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
self.vw_logger.log(vw_ex)
|
||||
|
||||
def save(self) -> None:
|
||||
self.model_repo.save(self.workspace)
|
||||
|
||||
|
||||
class Embedder(Generic[TEvent], ABC):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format(self, event: TEvent) -> str:
|
||||
...
|
||||
|
||||
|
||||
class SelectionScorer(Generic[TEvent], ABC, BaseModel):
|
||||
"""Abstract method to grade the chosen selection or the response of the llm"""
|
||||
|
||||
@abstractmethod
|
||||
def score_response(
|
||||
self, inputs: Dict[str, Any], llm_response: str, event: TEvent
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
|
||||
llm_chain: LLMChain
|
||||
prompt: Union[BasePromptTemplate, None] = None
|
||||
scoring_criteria_template_str: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def get_default_system_prompt() -> SystemMessagePromptTemplate:
|
||||
return SystemMessagePromptTemplate.from_template(
|
||||
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
|
||||
You are a strict judge that is called on to rank a response based on \
|
||||
given criteria. You must respond with your ranking by providing a \
|
||||
single float within the range [0, 1], 0 being very bad \
|
||||
response and 1 being very good response."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_default_prompt() -> ChatPromptTemplate:
|
||||
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
|
||||
as the most important attribute, rank how good or bad this text is: \
|
||||
"{rl_chain_selected}".'
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
||||
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[default_system_prompt, human_message_prompt]
|
||||
)
|
||||
return chat_prompt
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
llm = values.get("llm")
|
||||
prompt = values.get("prompt")
|
||||
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
||||
if prompt is None and scoring_criteria_template_str is None:
|
||||
prompt = AutoSelectionScorer.get_default_prompt()
|
||||
elif prompt is None and scoring_criteria_template_str is not None:
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(
|
||||
scoring_criteria_template_str
|
||||
)
|
||||
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[default_system_prompt, human_message_prompt]
|
||||
)
|
||||
values["prompt"] = prompt
|
||||
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
|
||||
return values
|
||||
|
||||
def score_response(
|
||||
self, inputs: Dict[str, Any], llm_response: str, event: Event
|
||||
) -> float:
|
||||
ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
|
||||
ranking = ranking.strip()
|
||||
try:
|
||||
resp = float(ranking)
|
||||
return resp
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
class RLChain(Chain, Generic[TEvent]):
|
||||
"""
|
||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||
|
||||
Attributes:
|
||||
- llm_chain (Chain): Represents the underlying Language Model chain.
|
||||
- prompt (BasePromptTemplate): The template for the base prompt.
|
||||
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
|
||||
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
|
||||
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
|
||||
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
|
||||
|
||||
Initialization Attributes:
|
||||
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
|
||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
|
||||
- metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
|
||||
|
||||
Notes:
|
||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||
""" # noqa: E501
|
||||
|
||||
class _NoOpPolicy(Policy):
|
||||
"""Placeholder policy that does nothing"""
|
||||
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
return None
|
||||
|
||||
def learn(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
def log(self, event: TEvent) -> None:
|
||||
pass
|
||||
|
||||
llm_chain: Chain
|
||||
|
||||
output_key: str = "result" #: :meta private:
|
||||
prompt: BasePromptTemplate
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
active_policy: Policy = _NoOpPolicy()
|
||||
auto_embed: bool = False
|
||||
selection_scorer_activated: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
metrics: Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Embedder,
|
||||
model_save_dir: str = "./",
|
||||
reset_model: bool = False,
|
||||
vw_cmd: Optional[List[str]] = None,
|
||||
policy: Type[Policy] = VwPolicy,
|
||||
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
||||
metrics_step: int = -1,
|
||||
metrics_window_size: int = -1,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.selection_scorer is None:
|
||||
logger.warning(
|
||||
"No selection scorer provided, which means that no \
|
||||
reinforcement learning will be done in the RL chain \
|
||||
unless update_with_delayed_score is called."
|
||||
)
|
||||
|
||||
if isinstance(self.active_policy, RLChain._NoOpPolicy):
|
||||
self.active_policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
),
|
||||
vw_cmd=vw_cmd or [],
|
||||
feature_embedder=feature_embedder,
|
||||
vw_logger=VwLogger(vw_logs),
|
||||
)
|
||||
|
||||
if metrics_window_size > 0:
|
||||
self.metrics = MetricsTrackerRollingWindow(
|
||||
step=metrics_step, window_size=metrics_window_size
|
||||
)
|
||||
else:
|
||||
self.metrics = MetricsTrackerAverage(step=metrics_step)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
:meta private:
|
||||
"""
|
||||
return []
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
||||
""" # noqa: E501
|
||||
if self._can_use_selection_scorer() and not force_score:
|
||||
raise RuntimeError(
|
||||
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." # noqa: E501
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
event: TEvent = chain_response["selection_metadata"]
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
def deactivate_selection_scorer(self) -> None:
|
||||
"""
|
||||
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
|
||||
""" # noqa: E501
|
||||
self.selection_scorer_activated = False
|
||||
|
||||
def activate_selection_scorer(self) -> None:
|
||||
"""
|
||||
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
|
||||
""" # noqa: E501
|
||||
self.selection_scorer_activated = True
|
||||
|
||||
def save_progress(self) -> None:
|
||||
"""
|
||||
This function should be called to save the state of the learned policy model.
|
||||
""" # noqa: E501
|
||||
self.active_policy.save()
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
super()._validate_inputs(inputs)
|
||||
if (
|
||||
self.selected_input_key in inputs.keys()
|
||||
or self.selected_based_on_input_key in inputs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." # noqa: E501
|
||||
)
|
||||
|
||||
def _can_use_selection_scorer(self) -> bool:
|
||||
"""
|
||||
Returns whether the chain can use the selection scorer to score responses or not.
|
||||
""" # noqa: E501
|
||||
return self.selection_scorer is not None and self.selection_scorer_activated
|
||||
|
||||
@abstractmethod
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: TEvent
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: TEvent, score: Optional[float]
|
||||
) -> TEvent:
|
||||
...
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
|
||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||
prediction = self.active_policy.predict(event=event)
|
||||
if self.metrics:
|
||||
self.metrics.on_decision()
|
||||
|
||||
next_chain_inputs, event = self._call_after_predict_before_llm(
|
||||
inputs=inputs, event=event, prediction=prediction
|
||||
)
|
||||
|
||||
t = self.llm_chain.run(**next_chain_inputs, callbacks=_run_manager.get_child())
|
||||
_run_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
t = t.strip()
|
||||
|
||||
if self.verbose:
|
||||
_run_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
|
||||
output = t
|
||||
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
|
||||
next_chain_inputs, event = self._call_after_llm_before_scoring(
|
||||
llm_response=output, event=event
|
||||
)
|
||||
|
||||
score = None
|
||||
try:
|
||||
if self._can_use_selection_scorer():
|
||||
score = self.selection_scorer.score_response( # type: ignore
|
||||
inputs=next_chain_inputs, llm_response=output, event=event
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"The selection scorer was not able to score, \
|
||||
and the chain was not able to adjust to this response, error: {e}"
|
||||
)
|
||||
if self.metrics and score is not None:
|
||||
self.metrics.on_feedback(score)
|
||||
|
||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_personalizer_chain"
|
||||
|
||||
|
||||
def is_stringtype_instance(item: Any) -> bool:
|
||||
"""Helper function to check if an item is a string."""
|
||||
return isinstance(item, str) or (
|
||||
isinstance(item, _Embed) and isinstance(item.value, str)
|
||||
)
|
||||
|
||||
|
||||
def embed_string_type(
|
||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
"""Helper function to embed a string or an _Embed object."""
|
||||
keep_str = ""
|
||||
if isinstance(item, _Embed):
|
||||
encoded = stringify_embedding(model.encode(item.value))
|
||||
if item.keep:
|
||||
keep_str = item.value.replace(" ", "_") + " "
|
||||
elif isinstance(item, str):
|
||||
encoded = item.replace(" ", "_")
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(item)} for embedding")
|
||||
|
||||
if namespace is None:
|
||||
raise ValueError(
|
||||
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
|
||||
)
|
||||
|
||||
return {namespace: keep_str + encoded}
|
||||
|
||||
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||
"""Helper function to embed a dictionary item."""
|
||||
inner_dict: Dict = {}
|
||||
for ns, embed_item in item.items():
|
||||
if isinstance(embed_item, list):
|
||||
inner_dict[ns] = []
|
||||
for embed_list_item in embed_item:
|
||||
embedded = embed_string_type(embed_list_item, model, ns)
|
||||
inner_dict[ns].append(embedded[ns])
|
||||
else:
|
||||
inner_dict.update(embed_string_type(embed_item, model, ns))
|
||||
return inner_dict
|
||||
|
||||
|
||||
def embed_list_type(
|
||||
item: list, model: Any, namespace: Optional[str] = None
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ret_list: List = []
|
||||
for embed_item in item:
|
||||
if isinstance(embed_item, dict):
|
||||
ret_list.append(embed_dict_type(embed_item, model))
|
||||
elif isinstance(embed_item, list):
|
||||
item_embedding = embed_list_type(embed_item, model, namespace)
|
||||
# Get the first key from the first dictionary
|
||||
first_key = next(iter(item_embedding[0]))
|
||||
# Group the values under that key
|
||||
grouping = {first_key: [item[first_key] for item in item_embedding]}
|
||||
ret_list.append(grouping)
|
||||
else:
|
||||
ret_list.append(embed_string_type(embed_item, model, namespace))
|
||||
return ret_list
|
||||
|
||||
|
||||
def embed(
|
||||
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||
model: Any,
|
||||
namespace: Optional[str] = None,
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
"""
|
||||
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
|
||||
|
||||
Attributes:
|
||||
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
|
||||
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
|
||||
model: (Any, required) The model to use for embedding
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
|
||||
""" # noqa: E501
|
||||
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
|
||||
to_embed, str
|
||||
):
|
||||
return [embed_string_type(to_embed, model, namespace)]
|
||||
elif isinstance(to_embed, dict):
|
||||
return [embed_dict_type(to_embed, model)]
|
||||
elif isinstance(to_embed, list):
|
||||
return embed_list_type(to_embed, model, namespace)
|
||||
else:
|
||||
raise ValueError("Invalid input format for embedding")
|
66
libs/experimental/langchain_experimental/rl_chain/metrics.py
Normal file
66
libs/experimental/langchain_experimental/rl_chain/metrics.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class MetricsTrackerAverage:
|
||||
def __init__(self, step: int):
|
||||
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
|
||||
self.step: int = step
|
||||
self.i: int = 0
|
||||
self.num: float = 0
|
||||
self.denom: float = 0
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
return self.num / self.denom if self.denom > 0 else 0
|
||||
|
||||
def on_decision(self) -> None:
|
||||
self.denom += 1
|
||||
|
||||
def on_feedback(self, score: float) -> None:
|
||||
self.num += score or 0
|
||||
self.i += 1
|
||||
if self.step > 0 and self.i % self.step == 0:
|
||||
self.history.append({"step": self.i, "score": self.score})
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
import pandas as pd
|
||||
|
||||
return pd.DataFrame(self.history)
|
||||
|
||||
|
||||
class MetricsTrackerRollingWindow:
|
||||
def __init__(self, window_size: int, step: int):
|
||||
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
|
||||
self.step: int = step
|
||||
self.i: int = 0
|
||||
self.window_size: int = window_size
|
||||
self.queue: deque = deque()
|
||||
self.sum: float = 0.0
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
return self.sum / len(self.queue) if len(self.queue) > 0 else 0
|
||||
|
||||
def on_decision(self) -> None:
|
||||
pass
|
||||
|
||||
def on_feedback(self, value: float) -> None:
|
||||
self.sum += value
|
||||
self.queue.append(value)
|
||||
self.i += 1
|
||||
|
||||
if len(self.queue) > self.window_size:
|
||||
old_val = self.queue.popleft()
|
||||
self.sum -= old_val
|
||||
|
||||
if self.step > 0 and self.i % self.step == 0:
|
||||
self.history.append({"step": self.i, "score": self.sum / len(self.queue)})
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
import pandas as pd
|
||||
|
||||
return pd.DataFrame(self.history)
|
@@ -0,0 +1,57 @@
|
||||
import datetime
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRepository:
|
||||
def __init__(
|
||||
self,
|
||||
folder: Union[str, os.PathLike],
|
||||
with_history: bool = True,
|
||||
reset: bool = False,
|
||||
):
|
||||
self.folder = Path(folder)
|
||||
self.model_path = self.folder / "latest.vw"
|
||||
self.with_history = with_history
|
||||
if reset and self.has_history():
|
||||
logger.warning(
|
||||
"There is non empty history which is recommended to be cleaned up"
|
||||
)
|
||||
if self.model_path.exists():
|
||||
os.remove(self.model_path)
|
||||
|
||||
self.folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_tag(self) -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
def has_history(self) -> bool:
|
||||
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
|
||||
|
||||
def save(self, workspace: "vw.Workspace") -> None:
|
||||
with open(self.model_path, "wb") as f:
|
||||
logger.info(f"storing rl_chain model in: {self.model_path}")
|
||||
f.write(workspace.serialize())
|
||||
if self.with_history: # write history
|
||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||
|
||||
def load(self, commandline: List[str]) -> "vw.Workspace":
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
model_data = None
|
||||
if self.model_path.exists():
|
||||
with open(self.model_path, "rb") as f:
|
||||
model_data = f.read()
|
||||
if model_data:
|
||||
logger.info(f"rl_chain model is loaded from: {self.model_path}")
|
||||
return vw.Workspace(commandline, model_data=model_data)
|
||||
return vw.Workspace(commandline)
|
@@ -0,0 +1,412 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import langchain_experimental.rl_chain.base as base
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# sentinel object used to distinguish between
|
||||
# user didn't supply anything or user explicitly supplied None
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class PickBestSelected(base.Selected):
|
||||
index: Optional[int]
|
||||
probability: Optional[float]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
probability: Optional[float] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.index = index
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
|
||||
class PickBestEvent(base.Event[PickBestSelected]):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[PickBestSelected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||
"""
|
||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||
|
||||
Attributes:
|
||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||
""" # noqa E501
|
||||
|
||||
def __init__(
|
||||
self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer("all-mpnet-base-v2")
|
||||
|
||||
self.model = model
|
||||
self.auto_embed = auto_embed
|
||||
|
||||
@staticmethod
|
||||
def _str(embedding: List[float]) -> str:
|
||||
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
||||
|
||||
def get_label(self, event: PickBestEvent) -> tuple:
|
||||
cost = None
|
||||
if event.selected:
|
||||
chosen_action = event.selected.index
|
||||
cost = (
|
||||
-1.0 * event.selected.score
|
||||
if event.selected.score is not None
|
||||
else None
|
||||
)
|
||||
prob = event.selected.probability
|
||||
return chosen_action, cost, prob
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple:
|
||||
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
|
||||
to_select_from_var_name, to_select_from = next(
|
||||
iter(event.to_select_from.items()), (None, None)
|
||||
)
|
||||
|
||||
action_embs = (
|
||||
(
|
||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||
if event.to_select_from
|
||||
else None
|
||||
)
|
||||
if to_select_from
|
||||
else None
|
||||
)
|
||||
|
||||
if not context_emb or not action_embs:
|
||||
raise ValueError(
|
||||
"Context and to_select_from must be provided in the inputs dictionary"
|
||||
)
|
||||
return context_emb, action_embs
|
||||
|
||||
def get_indexed_dot_product(self, context_emb: List, action_embs: List) -> Dict:
|
||||
import numpy as np
|
||||
|
||||
unique_contexts = set()
|
||||
for context_item in context_emb:
|
||||
for ns, ee in context_item.items():
|
||||
if isinstance(ee, list):
|
||||
for ea in ee:
|
||||
unique_contexts.add(f"{ns}={ea}")
|
||||
else:
|
||||
unique_contexts.add(f"{ns}={ee}")
|
||||
|
||||
encoded_contexts = self.model.encode(list(unique_contexts))
|
||||
context_embeddings = dict(zip(unique_contexts, encoded_contexts))
|
||||
|
||||
unique_actions = set()
|
||||
for action in action_embs:
|
||||
for ns, e in action.items():
|
||||
if isinstance(e, list):
|
||||
for ea in e:
|
||||
unique_actions.add(f"{ns}={ea}")
|
||||
else:
|
||||
unique_actions.add(f"{ns}={e}")
|
||||
|
||||
encoded_actions = self.model.encode(list(unique_actions))
|
||||
action_embeddings = dict(zip(unique_actions, encoded_actions))
|
||||
|
||||
action_matrix = np.stack([v for k, v in action_embeddings.items()])
|
||||
context_matrix = np.stack([v for k, v in context_embeddings.items()])
|
||||
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
|
||||
|
||||
indexed_dot_product: Dict = {}
|
||||
|
||||
for i, context_key in enumerate(context_embeddings.keys()):
|
||||
indexed_dot_product[context_key] = {}
|
||||
for j, action_key in enumerate(action_embeddings.keys()):
|
||||
indexed_dot_product[context_key][action_key] = dot_product_matrix[i, j]
|
||||
|
||||
return indexed_dot_product
|
||||
|
||||
def format_auto_embed_on(self, event: PickBestEvent) -> str:
|
||||
chosen_action, cost, prob = self.get_label(event)
|
||||
context_emb, action_embs = self.get_context_and_action_embeddings(event)
|
||||
indexed_dot_product = self.get_indexed_dot_product(context_emb, action_embs)
|
||||
|
||||
action_lines = []
|
||||
for i, action in enumerate(action_embs):
|
||||
line_parts = []
|
||||
dot_prods = []
|
||||
if cost is not None and chosen_action == i:
|
||||
line_parts.append(f"{chosen_action}:{cost}:{prob}")
|
||||
for ns, action in action.items():
|
||||
line_parts.append(f"|{ns}")
|
||||
elements = action if isinstance(action, list) else [action]
|
||||
nsa = []
|
||||
for elem in elements:
|
||||
line_parts.append(f"{elem}")
|
||||
ns_a = f"{ns}={elem}"
|
||||
nsa.append(ns_a)
|
||||
for k, v in indexed_dot_product.items():
|
||||
dot_prods.append(v[ns_a])
|
||||
nsa_str = " ".join(nsa)
|
||||
line_parts.append(f"|# {nsa_str}")
|
||||
|
||||
line_parts.append(f"|dotprod {self._str(dot_prods)}")
|
||||
action_lines.append(" ".join(line_parts))
|
||||
|
||||
shared = []
|
||||
for item in context_emb:
|
||||
for ns, context in item.items():
|
||||
shared.append(f"|{ns}")
|
||||
elements = context if isinstance(context, list) else [context]
|
||||
nsc = []
|
||||
for elem in elements:
|
||||
shared.append(f"{elem}")
|
||||
nsc.append(f"{ns}={elem}")
|
||||
nsc_str = " ".join(nsc)
|
||||
shared.append(f"|@ {nsc_str}")
|
||||
|
||||
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
|
||||
|
||||
def format_auto_embed_off(self, event: PickBestEvent) -> str:
|
||||
"""
|
||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||
"""
|
||||
chosen_action, cost, prob = self.get_label(event)
|
||||
context_emb, action_embs = self.get_context_and_action_embeddings(event)
|
||||
|
||||
example_string = ""
|
||||
example_string += "shared "
|
||||
for context_item in context_emb:
|
||||
for ns, based_on in context_item.items():
|
||||
e = " ".join(based_on) if isinstance(based_on, list) else based_on
|
||||
example_string += f"|{ns} {e} "
|
||||
example_string += "\n"
|
||||
|
||||
for i, action in enumerate(action_embs):
|
||||
if cost is not None and chosen_action == i:
|
||||
example_string += f"{chosen_action}:{cost}:{prob} "
|
||||
for ns, action_embedding in action.items():
|
||||
e = (
|
||||
" ".join(action_embedding)
|
||||
if isinstance(action_embedding, list)
|
||||
else action_embedding
|
||||
)
|
||||
example_string += f"|{ns} {e} "
|
||||
example_string += "\n"
|
||||
# Strip the last newline
|
||||
return example_string[:-1]
|
||||
|
||||
def format(self, event: PickBestEvent) -> str:
|
||||
if self.auto_embed:
|
||||
return self.format_auto_embed_on(event)
|
||||
else:
|
||||
return self.format_auto_embed_off(event)
|
||||
|
||||
|
||||
class PickBestRandomPolicy(base.Policy[PickBestEvent]):
|
||||
def __init__(self, feature_embedder: base.Embedder, **kwargs: Any):
|
||||
self.feature_embedder = feature_embedder
|
||||
|
||||
def predict(self, event: PickBestEvent) -> List[Tuple[int, float]]:
|
||||
num_items = len(event.to_select_from)
|
||||
return [(i, 1.0 / num_items) for i in range(num_items)]
|
||||
|
||||
def learn(self, event: PickBestEvent) -> None:
|
||||
pass
|
||||
|
||||
def log(self, event: PickBestEvent) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PickBest(base.RLChain[PickBestEvent]):
|
||||
"""
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
||||
Each invocation of the chain's `run()` method should be equipped with a set of potential actions (`ToSelectFrom`) and will result in the selection of a specific action based on the `BasedOn` input. This chosen action then informs the LLM (Language Model) prompt for the subsequent response generation.
|
||||
|
||||
The standard operation flow of this Chain includes:
|
||||
1. The Chain is invoked with inputs containing the `BasedOn` criteria and a list of potential actions (`ToSelectFrom`).
|
||||
2. An action is selected based on the `BasedOn` input.
|
||||
3. The LLM is called with the dynamic prompt, producing a response.
|
||||
4. If a `selection_scorer` is provided, it is used to score the selection.
|
||||
5. The internal Vowpal Wabbit model is updated with the `BasedOn` input, the chosen `ToSelectFrom` action, and the resulting score from the scorer.
|
||||
6. The final response is returned.
|
||||
|
||||
Expected input dictionary format:
|
||||
- At least one variable encapsulated within `BasedOn` to serve as the selection criteria.
|
||||
- A single list variable within `ToSelectFrom`, representing potential actions for the VW model. This list can take the form of:
|
||||
- A list of strings, e.g., `action = ToSelectFrom(["action1", "action2", "action3"])`
|
||||
- A list of list of strings e.g. `action = ToSelectFrom([["action1", "another identifier of action1"], ["action2", "another identifier of action2"]])`
|
||||
- A list of dictionaries, where each dictionary represents an action with namespace names as keys and corresponding action strings as values. For instance, `action = ToSelectFrom([{"namespace1": ["action1", "another identifier of action1"], "namespace2": "action2"}, {"namespace1": "action3", "namespace2": "action4"}])`.
|
||||
|
||||
Extends:
|
||||
RLChain
|
||||
|
||||
Attributes:
|
||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||
""" # noqa E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
auto_embed = kwargs.get("auto_embed", False)
|
||||
|
||||
feature_embedder = kwargs.get("feature_embedder", None)
|
||||
if feature_embedder:
|
||||
if "auto_embed" in kwargs:
|
||||
logger.warning(
|
||||
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
|
||||
)
|
||||
# turning auto_embed off for cli setting below
|
||||
auto_embed = False
|
||||
else:
|
||||
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
|
||||
kwargs["feature_embedder"] = feature_embedder
|
||||
|
||||
vw_cmd = kwargs.get("vw_cmd", [])
|
||||
if vw_cmd:
|
||||
if "--cb_explore_adf" not in vw_cmd:
|
||||
raise ValueError(
|
||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||
)
|
||||
else:
|
||||
interactions = ["--interactions=::"]
|
||||
if auto_embed:
|
||||
interactions = [
|
||||
"--interactions=@#",
|
||||
"--ignore_linear=@",
|
||||
"--ignore_linear=#",
|
||||
]
|
||||
vw_cmd = interactions + [
|
||||
"--cb_explore_adf",
|
||||
"--coin",
|
||||
"--squarecb",
|
||||
"--quiet",
|
||||
]
|
||||
|
||||
kwargs["vw_cmd"] = vw_cmd
|
||||
logger.info(f"vw_cmd: {vw_cmd}")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
if not actions:
|
||||
raise ValueError(
|
||||
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa E501
|
||||
)
|
||||
|
||||
if len(list(actions.values())) > 1:
|
||||
raise ValueError(
|
||||
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from." # noqa E501
|
||||
)
|
||||
|
||||
if not context:
|
||||
raise ValueError(
|
||||
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." # noqa E501
|
||||
)
|
||||
|
||||
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
return event
|
||||
|
||||
def _call_after_predict_before_llm(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
event: PickBestEvent,
|
||||
prediction: List[Tuple[int, float]],
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
import numpy as np
|
||||
|
||||
prob_sum = sum(prob for _, prob in prediction)
|
||||
probabilities = [prob / prob_sum for _, prob in prediction]
|
||||
## sample from the pmf
|
||||
sampled_index = np.random.choice(len(prediction), p=probabilities)
|
||||
sampled_ap = prediction[sampled_index]
|
||||
sampled_action = sampled_ap[0]
|
||||
sampled_prob = sampled_ap[1]
|
||||
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
|
||||
event.selected = selected
|
||||
|
||||
# only one key, value pair in event.to_select_from
|
||||
key, value = next(iter(event.to_select_from.items()))
|
||||
next_chain_inputs = inputs.copy()
|
||||
next_chain_inputs.update({key: value[event.selected.index]})
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: PickBestEvent
|
||||
) -> Tuple[Dict[str, Any], PickBestEvent]:
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
# only one key, value pair in event.to_select_from
|
||||
value = next(iter(event.to_select_from.values()))
|
||||
v = (
|
||||
value[event.selected.index]
|
||||
if event.selected
|
||||
else event.to_select_from.values()
|
||||
)
|
||||
next_chain_inputs.update(
|
||||
{
|
||||
self.selected_based_on_input_key: str(event.based_on),
|
||||
self.selected_input_key: v,
|
||||
}
|
||||
)
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: PickBestEvent, score: Optional[float]
|
||||
) -> PickBestEvent:
|
||||
if event.selected:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return super()._call(run_manager=run_manager, inputs=inputs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "rl_chain_pick_best"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls: Type[PickBest],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
||||
**kwargs: Any,
|
||||
) -> PickBest:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
if selection_scorer is SENTINEL:
|
||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||
|
||||
return PickBest(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
@@ -0,0 +1,18 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class VwLogger:
|
||||
def __init__(self, path: Optional[Union[str, PathLike]]):
|
||||
self.path = Path(path) if path else None
|
||||
if self.path:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def log(self, vw_ex: str) -> None:
|
||||
if self.path:
|
||||
with open(self.path, "a") as f:
|
||||
f.write(f"{vw_ex}\n\n")
|
||||
|
||||
def logging_enabled(self) -> bool:
|
||||
return bool(self.path)
|
1101
libs/experimental/poetry.lock
generated
1101
libs/experimental/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,9 @@ langchain = ">=0.0.239"
|
||||
presidio-anonymizer = {version = "^2.2.33", optional = true}
|
||||
presidio-analyzer = {version = "^2.2.33", optional = true}
|
||||
faker = {version = "^19.3.1", optional = true}
|
||||
vowpal-wabbit-next = {version = "0.6.0", optional = true}
|
||||
sentence-transformers = {version = "^2", optional = true}
|
||||
pandas = {version = "^2.0.1", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
@@ -42,6 +45,8 @@ extended_testing = [
|
||||
"presidio-anonymizer",
|
||||
"presidio-analyzer",
|
||||
"faker",
|
||||
"vowpal-wabbit-next",
|
||||
"sentence-transformers",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@@ -0,0 +1,457 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def setup() -> tuple:
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
|
||||
llm = FakeListChatModel(responses=["hey"])
|
||||
return llm, PROMPT
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_multiple_ToSelectFrom_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
another_action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_missing_basedOn_from_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_ToSelectFrom_not_a_list_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = {"actions": ["0", "1", "2"]}
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
with pytest.raises(RuntimeError):
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_force() -> None:
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
auto_val_llm = FakeListChatModel(responses=["3"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
chain.update_with_delayed_score(
|
||||
chain_response=response, score=100, force_score=True
|
||||
)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_user_defined_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
|
||||
class CustomSelectionScorer(rl_chain.SelectionScorer):
|
||||
def score_response(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
llm_response: str,
|
||||
event: pick_best_chain.PickBestEvent,
|
||||
) -> float:
|
||||
score = 200
|
||||
return score
|
||||
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=CustomSelectionScorer(),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
actions = ["0", "1", "2"]
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 200.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_everything_embedded() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
||||
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_auto_embedder_is_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn(ctx_str_1),
|
||||
action=pick_best_chain.base.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_off() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
ctx_str_1 = "context1"
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
|
||||
|
||||
actions = [str1, str2, str3]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_w_embeddings_on() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
ctx_str_1 = "context1"
|
||||
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, str2]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(ctx_str_1),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=True, model=MockEncoderReturnsList()
|
||||
)
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||
|
||||
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
||||
|
||||
actions = [str1, rl_chain.Embed(str2)]
|
||||
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
|
||||
User2=rl_chain.BasedOn(ctx_str_2),
|
||||
action=rl_chain.ToSelectFrom(actions),
|
||||
)
|
||||
selection_metadata = response["selection_metadata"]
|
||||
vw_str = feature_embedder.format(selection_metadata)
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_no_scorer_specified() -> None:
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=["hey", "100"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=chain_llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_explicitly_no_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=None,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_auto_scorer_with_user_defined_llm() -> None:
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=["300"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_activate_and_deactivate_scorer() -> None:
|
||||
_, PROMPT = setup()
|
||||
llm = FakeListChatModel(responses=["hey1", "hey2", "hey3"])
|
||||
scorer_llm = FakeListChatModel(responses=["300", "400"])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
),
|
||||
)
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey1"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
chain.deactivate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey2"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
chain.activate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey3"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 400.0
|
@@ -0,0 +1,370 @@
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as rl_chain
|
||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_action, based_on={}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
expected = (
|
||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||
)
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={},
|
||||
to_select_from=named_actions,
|
||||
based_on={"context": "context"},
|
||||
selected=selected,
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
|
||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
context = {"context1": "context1", "context2": "context2"}
|
||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||
context = {
|
||||
"context1": rl_chain.Embed(ctx_str_1),
|
||||
"context2": rl_chain.Embed(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||
None
|
||||
):
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||
}
|
||||
context = {
|
||||
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.Embed(str1)},
|
||||
str2,
|
||||
rl_chain.Embed(str3),
|
||||
]
|
||||
}
|
||||
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
|
||||
str1 = "0"
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
||||
|
||||
named_actions = {
|
||||
"action1": [
|
||||
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
|
||||
str2,
|
||||
rl_chain.EmbedAndKeep(str3),
|
||||
]
|
||||
}
|
||||
context = {
|
||||
"context1": ctx_str_1,
|
||||
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
|
||||
}
|
||||
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501
|
||||
|
||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_raw_features_underscored() -> None:
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||
auto_embed=False, model=MockEncoder()
|
||||
)
|
||||
str1 = "this is a long string"
|
||||
str1_underscored = str1.replace(" ", "_")
|
||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||
|
||||
ctx_str = "this is a long context"
|
||||
ctx_str_underscored = ctx_str.replace(" ", "_")
|
||||
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))
|
||||
|
||||
# No embeddings
|
||||
named_actions = {"action": [str1]}
|
||||
context = {"context": ctx_str}
|
||||
expected_no_embed = (
|
||||
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
|
||||
)
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_no_embed
|
||||
|
||||
# Just embeddings
|
||||
named_actions = {"action": rl_chain.Embed([str1])}
|
||||
context = {"context": rl_chain.Embed(ctx_str)}
|
||||
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed
|
||||
|
||||
# Embeddings and raw features
|
||||
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
|
||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
|
||||
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501
|
||||
event = pick_best_chain.PickBestEvent(
|
||||
inputs={}, to_select_from=named_actions, based_on=context
|
||||
)
|
||||
vw_ex_str = feature_embedder.format(event)
|
||||
assert vw_ex_str == expected_embed_and_keep
|
@@ -0,0 +1,422 @@
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain_experimental.rl_chain.base as base
|
||||
|
||||
encoded_keyword = "[encoded]"
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_no_emb() -> None:
|
||||
expected = [{"a_namespace": "test"}]
|
||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
|
||||
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_nested_emb() -> None:
|
||||
# nested embeddings, innermost wins
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"a_namespace": encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
|
||||
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
|
||||
== expected2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_no_emb() -> None:
|
||||
expected = [{"test_namespace": "test"}]
|
||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb2() -> None:
|
||||
str1 = "test"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
expected = [{"test_namespace": encoded_str1}]
|
||||
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
|
||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1,
|
||||
"test_namespace2": str2 + " " + encoded_str2,
|
||||
}
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
||||
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": encoded_str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1 + " " + encoded_str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": encoded_str2},
|
||||
{"a_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"a_namespace": str1},
|
||||
{"a_namespace": str2 + " " + encoded_str2},
|
||||
{"a_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
||||
MockEncoder(),
|
||||
"a_namespace",
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2},
|
||||
{"test_namespace": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1)},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1 + " " + encoded_str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1)},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb2() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace1": encoded_str1},
|
||||
{"test_namespace2": encoded_str2},
|
||||
{"test_namespace3": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.Embed(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace1": str1 + " " + encoded_str1},
|
||||
{"test_namespace2": str2 + " " + encoded_str2},
|
||||
{"test_namespace3": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
base.EmbedAndKeep(
|
||||
[
|
||||
{"test_namespace1": str1},
|
||||
{"test_namespace2": str2},
|
||||
{"test_namespace3": str3},
|
||||
]
|
||||
),
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": encoded_str2},
|
||||
{"test_namespace": encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.Embed(str2)},
|
||||
{"test_namespace": base.Embed(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": str2 + " " + encoded_str2},
|
||||
{"test_namespace": str3 + " " + encoded_str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||
{"test_namespace": base.EmbedAndKeep(str3)},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
str3 = "test3"
|
||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
||||
expected = [
|
||||
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
||||
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
||||
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
expected_embed_and_keep = [
|
||||
{
|
||||
"test_namespace": str1 + " " + encoded_str1,
|
||||
"test_namespace2": str1,
|
||||
},
|
||||
{
|
||||
"test_namespace": str2 + " " + encoded_str2,
|
||||
"test_namespace2": str2,
|
||||
},
|
||||
{
|
||||
"test_namespace": str3 + " " + encoded_str3,
|
||||
"test_namespace2": str3,
|
||||
},
|
||||
]
|
||||
assert (
|
||||
base.embed(
|
||||
[
|
||||
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
||||
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
||||
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
|
||||
],
|
||||
MockEncoder(),
|
||||
)
|
||||
== expected_embed_and_keep
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
expected = [{"test_namespace": [str1, str2]}]
|
||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
||||
expected = [{"test_namespace": [str1, encoded_str2]}]
|
||||
assert (
|
||||
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_list_features_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_dict_in_list_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_dict_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_list_of_tuples_throws() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
15
libs/experimental/tests/unit_tests/rl_chain/test_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
class MockEncoder:
|
||||
def encode(self, to_encode: str) -> str:
|
||||
return "[encoded]" + to_encode
|
||||
|
||||
|
||||
class MockEncoderReturnsList:
|
||||
def encode(self, to_encode: Any) -> List:
|
||||
if isinstance(to_encode, str):
|
||||
return [1.0, 2.0]
|
||||
elif isinstance(to_encode, List):
|
||||
return [[1.0, 2.0] for _ in range(len(to_encode))]
|
||||
raise ValueError("Invalid input type for unit test")
|
Reference in New Issue
Block a user