mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
Merge pull request #10 from VowpalWabbit/dot_prods_auto_embed
Dot prods auto embed
This commit is contained in:
commit
b59e2b5afa
@ -15,6 +15,7 @@ from langchain.chains.rl_chain.base import (
|
|||||||
from langchain.chains.rl_chain.pick_best_chain import (
|
from langchain.chains.rl_chain.pick_best_chain import (
|
||||||
PickBest,
|
PickBest,
|
||||||
PickBestEvent,
|
PickBestEvent,
|
||||||
|
PickBestFeatureEmbedder,
|
||||||
PickBestSelected,
|
PickBestSelected,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ __all__ = [
|
|||||||
"PickBest",
|
"PickBest",
|
||||||
"PickBestEvent",
|
"PickBestEvent",
|
||||||
"PickBestSelected",
|
"PickBestSelected",
|
||||||
|
"PickBestFeatureEmbedder",
|
||||||
"Embed",
|
"Embed",
|
||||||
"BasedOn",
|
"BasedOn",
|
||||||
"ToSelectFrom",
|
"ToSelectFrom",
|
||||||
|
@ -118,8 +118,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]
|
|||||||
|
|
||||||
if not to_select_from:
|
if not to_select_from:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No variables using 'ToSelectFrom' found in the inputs. \
|
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa: E501
|
||||||
Please include at least one variable containing a list to select from."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
based_on = {
|
based_on = {
|
||||||
@ -229,6 +228,9 @@ class VwPolicy(Policy):
|
|||||||
|
|
||||||
|
|
||||||
class Embedder(Generic[TEvent], ABC):
|
class Embedder(Generic[TEvent], ABC):
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format(self, event: TEvent) -> str:
|
def format(self, event: TEvent) -> str:
|
||||||
...
|
...
|
||||||
@ -300,9 +302,7 @@ class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
|
|||||||
return resp
|
return resp
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The auto selection scorer did not manage to score the response, \
|
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" # noqa: E501
|
||||||
there is always the option to try again or tweak the reward prompt.\
|
|
||||||
Error: {e}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -316,7 +316,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
|
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
|
||||||
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
|
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
|
||||||
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
|
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
|
||||||
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
|
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
|
||||||
|
|
||||||
Initialization Attributes:
|
Initialization Attributes:
|
||||||
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
|
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
|
||||||
@ -325,7 +325,8 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||||
- policy (Type[VwPolicy]): Policy used by the chain.
|
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
- metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
|
||||||
|
- metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||||
@ -423,8 +424,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if self._can_use_selection_scorer() and not force_score:
|
if self._can_use_selection_scorer() and not force_score:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The selection scorer is set, and force_score was not set to True. \
|
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." # noqa: E501
|
||||||
Please set force_score=True to use this function."
|
|
||||||
)
|
)
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
@ -458,9 +458,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
or self.selected_based_on_input_key in inputs.keys()
|
or self.selected_based_on_input_key in inputs.keys()
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The rl chain does not accept '{self.selected_input_key}' \
|
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." # noqa: E501
|
||||||
or '{self.selected_based_on_input_key}' as input keys, \
|
|
||||||
they are reserved for internal use during auto reward."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _can_use_selection_scorer(self) -> bool:
|
def _can_use_selection_scorer(self) -> bool:
|
||||||
@ -498,9 +496,6 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
|
||||||
if self.auto_embed:
|
|
||||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
|
||||||
|
|
||||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||||
prediction = self.active_policy.predict(event=event)
|
prediction = self.active_policy.predict(event=event)
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
@ -573,8 +568,7 @@ def embed_string_type(
|
|||||||
|
|
||||||
if namespace is None:
|
if namespace is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The default namespace must be \
|
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
|
||||||
provided when embedding a string or _Embed object."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {namespace: keep_str + encoded}
|
return {namespace: keep_str + encoded}
|
||||||
|
@ -53,21 +53,24 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|||||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||||
""" # noqa E501
|
""" # noqa E501
|
||||||
|
|
||||||
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
|
def __init__(
|
||||||
|
self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
model = SentenceTransformer("bert-base-nli-mean-tokens")
|
model = SentenceTransformer("all-mpnet-base-v2")
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.auto_embed = auto_embed
|
||||||
|
|
||||||
def format(self, event: PickBestEvent) -> str:
|
@staticmethod
|
||||||
"""
|
def _str(embedding: List[float]) -> str:
|
||||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
||||||
"""
|
|
||||||
|
|
||||||
|
def get_label(self, event: PickBestEvent) -> tuple:
|
||||||
cost = None
|
cost = None
|
||||||
if event.selected:
|
if event.selected:
|
||||||
chosen_action = event.selected.index
|
chosen_action = event.selected.index
|
||||||
@ -77,7 +80,11 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
prob = event.selected.probability
|
prob = event.selected.probability
|
||||||
|
return chosen_action, cost, prob
|
||||||
|
else:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple:
|
||||||
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
|
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
|
||||||
to_select_from_var_name, to_select_from = next(
|
to_select_from_var_name, to_select_from = next(
|
||||||
iter(event.to_select_from.items()), (None, None)
|
iter(event.to_select_from.items()), (None, None)
|
||||||
@ -97,6 +104,95 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Context and to_select_from must be provided in the inputs dictionary"
|
"Context and to_select_from must be provided in the inputs dictionary"
|
||||||
)
|
)
|
||||||
|
return context_emb, action_embs
|
||||||
|
|
||||||
|
def get_indexed_dot_product(self, context_emb: List, action_embs: List) -> Dict:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
unique_contexts = set()
|
||||||
|
for context_item in context_emb:
|
||||||
|
for ns, ee in context_item.items():
|
||||||
|
if isinstance(ee, list):
|
||||||
|
for ea in ee:
|
||||||
|
unique_contexts.add(f"{ns}={ea}")
|
||||||
|
else:
|
||||||
|
unique_contexts.add(f"{ns}={ee}")
|
||||||
|
|
||||||
|
encoded_contexts = self.model.encode(list(unique_contexts))
|
||||||
|
context_embeddings = dict(zip(unique_contexts, encoded_contexts))
|
||||||
|
|
||||||
|
unique_actions = set()
|
||||||
|
for action in action_embs:
|
||||||
|
for ns, e in action.items():
|
||||||
|
if isinstance(e, list):
|
||||||
|
for ea in e:
|
||||||
|
unique_actions.add(f"{ns}={ea}")
|
||||||
|
else:
|
||||||
|
unique_actions.add(f"{ns}={e}")
|
||||||
|
|
||||||
|
encoded_actions = self.model.encode(list(unique_actions))
|
||||||
|
action_embeddings = dict(zip(unique_actions, encoded_actions))
|
||||||
|
|
||||||
|
action_matrix = np.stack([v for k, v in action_embeddings.items()])
|
||||||
|
context_matrix = np.stack([v for k, v in context_embeddings.items()])
|
||||||
|
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
|
||||||
|
|
||||||
|
indexed_dot_product: Dict = {}
|
||||||
|
|
||||||
|
for i, context_key in enumerate(context_embeddings.keys()):
|
||||||
|
indexed_dot_product[context_key] = {}
|
||||||
|
for j, action_key in enumerate(action_embeddings.keys()):
|
||||||
|
indexed_dot_product[context_key][action_key] = dot_product_matrix[i, j]
|
||||||
|
|
||||||
|
return indexed_dot_product
|
||||||
|
|
||||||
|
def format_auto_embed_on(self, event: PickBestEvent) -> str:
|
||||||
|
chosen_action, cost, prob = self.get_label(event)
|
||||||
|
context_emb, action_embs = self.get_context_and_action_embeddings(event)
|
||||||
|
indexed_dot_product = self.get_indexed_dot_product(context_emb, action_embs)
|
||||||
|
|
||||||
|
action_lines = []
|
||||||
|
for i, action in enumerate(action_embs):
|
||||||
|
line_parts = []
|
||||||
|
dot_prods = []
|
||||||
|
if cost is not None and chosen_action == i:
|
||||||
|
line_parts.append(f"{chosen_action}:{cost}:{prob}")
|
||||||
|
for ns, action in action.items():
|
||||||
|
line_parts.append(f"|{ns}")
|
||||||
|
elements = action if isinstance(action, list) else [action]
|
||||||
|
nsa = []
|
||||||
|
for elem in elements:
|
||||||
|
line_parts.append(f"{elem}")
|
||||||
|
ns_a = f"{ns}={elem}"
|
||||||
|
nsa.append(ns_a)
|
||||||
|
for k, v in indexed_dot_product.items():
|
||||||
|
dot_prods.append(v[ns_a])
|
||||||
|
nsa_str = " ".join(nsa)
|
||||||
|
line_parts.append(f"|# {nsa_str}")
|
||||||
|
|
||||||
|
line_parts.append(f"|dotprod {self._str(dot_prods)}")
|
||||||
|
action_lines.append(" ".join(line_parts))
|
||||||
|
|
||||||
|
shared = []
|
||||||
|
for item in context_emb:
|
||||||
|
for ns, context in item.items():
|
||||||
|
shared.append(f"|{ns}")
|
||||||
|
elements = context if isinstance(context, list) else [context]
|
||||||
|
nsc = []
|
||||||
|
for elem in elements:
|
||||||
|
shared.append(f"{elem}")
|
||||||
|
nsc.append(f"{ns}={elem}")
|
||||||
|
nsc_str = " ".join(nsc)
|
||||||
|
shared.append(f"|@ {nsc_str}")
|
||||||
|
|
||||||
|
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
|
||||||
|
|
||||||
|
def format_auto_embed_off(self, event: PickBestEvent) -> str:
|
||||||
|
"""
|
||||||
|
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||||
|
"""
|
||||||
|
chosen_action, cost, prob = self.get_label(event)
|
||||||
|
context_emb, action_embs = self.get_context_and_action_embeddings(event)
|
||||||
|
|
||||||
example_string = ""
|
example_string = ""
|
||||||
example_string += "shared "
|
example_string += "shared "
|
||||||
@ -120,6 +216,12 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|||||||
# Strip the last newline
|
# Strip the last newline
|
||||||
return example_string[:-1]
|
return example_string[:-1]
|
||||||
|
|
||||||
|
def format(self, event: PickBestEvent) -> str:
|
||||||
|
if self.auto_embed:
|
||||||
|
return self.format_auto_embed_on(event)
|
||||||
|
else:
|
||||||
|
return self.format_auto_embed_off(event)
|
||||||
|
|
||||||
|
|
||||||
class PickBest(base.RLChain[PickBestEvent]):
|
class PickBest(base.RLChain[PickBestEvent]):
|
||||||
"""
|
"""
|
||||||
@ -154,26 +256,42 @@ class PickBest(base.RLChain[PickBestEvent]):
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
vw_cmd = kwargs.get("vw_cmd", [])
|
auto_embed = kwargs.get("auto_embed", False)
|
||||||
if not vw_cmd:
|
|
||||||
vw_cmd = [
|
feature_embedder = kwargs.get("feature_embedder", None)
|
||||||
"--cb_explore_adf",
|
if feature_embedder:
|
||||||
"--quiet",
|
if "auto_embed" in kwargs:
|
||||||
"--interactions=::",
|
logger.warning(
|
||||||
"--coin",
|
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
|
||||||
"--squarecb",
|
)
|
||||||
]
|
# turning auto_embed off for cli setting below
|
||||||
|
auto_embed = False
|
||||||
else:
|
else:
|
||||||
|
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
|
||||||
|
kwargs["feature_embedder"] = feature_embedder
|
||||||
|
|
||||||
|
vw_cmd = kwargs.get("vw_cmd", [])
|
||||||
|
if vw_cmd:
|
||||||
if "--cb_explore_adf" not in vw_cmd:
|
if "--cb_explore_adf" not in vw_cmd:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||||
)
|
)
|
||||||
kwargs["vw_cmd"] = vw_cmd
|
else:
|
||||||
|
interactions = ["--interactions=::"]
|
||||||
|
if auto_embed:
|
||||||
|
interactions = [
|
||||||
|
"--interactions=@#",
|
||||||
|
"--ignore_linear=@",
|
||||||
|
"--ignore_linear=#",
|
||||||
|
]
|
||||||
|
vw_cmd = interactions + [
|
||||||
|
"--cb_explore_adf",
|
||||||
|
"--coin",
|
||||||
|
"--squarecb",
|
||||||
|
"--quiet",
|
||||||
|
]
|
||||||
|
|
||||||
feature_embedder = kwargs.get("feature_embedder", None)
|
kwargs["vw_cmd"] = vw_cmd
|
||||||
if not feature_embedder:
|
|
||||||
feature_embedder = PickBestFeatureEmbedder()
|
|
||||||
kwargs["feature_embedder"] = feature_embedder
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -181,23 +299,17 @@ class PickBest(base.RLChain[PickBestEvent]):
|
|||||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||||
if not actions:
|
if not actions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No variables using 'ToSelectFrom' found in the inputs. \
|
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa E501
|
||||||
Please include at least one variable containing \
|
|
||||||
a list to select from."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(list(actions.values())) > 1:
|
if len(list(actions.values())) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one variable using 'ToSelectFrom' can be provided in the inputs \
|
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from." # noqa E501
|
||||||
for the PickBest chain. Please provide only one variable \
|
|
||||||
containing a list to select from."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not context:
|
if not context:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No variables using 'BasedOn' found in the inputs. \
|
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." # noqa E501
|
||||||
Please include at least one variable containing information \
|
|
||||||
to base the selected of ToSelectFrom on."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||||
|
|
||||||
import langchain.chains.rl_chain.base as rl_chain
|
import langchain.chains.rl_chain.base as rl_chain
|
||||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||||
@ -26,7 +26,9 @@ def test_multiple_ToSelectFrom_throws() -> None:
|
|||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -43,7 +45,9 @@ def test_missing_basedOn_from_throws() -> None:
|
|||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -56,7 +60,9 @@ def test_ToSelectFrom_not_a_list_throws() -> None:
|
|||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = {"actions": ["0", "1", "2"]}
|
actions = {"actions": ["0", "1", "2"]}
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -75,7 +81,9 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
@ -98,7 +106,9 @@ def test_update_with_delayed_score_force() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
@ -121,7 +131,9 @@ def test_update_with_delayed_score() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=None,
|
selection_scorer=None,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
@ -153,7 +165,9 @@ def test_user_defined_scorer() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=CustomSelectionScorer(),
|
selection_scorer=CustomSelectionScorer(),
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actions = ["0", "1", "2"]
|
actions = ["0", "1", "2"]
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
@ -166,11 +180,13 @@ def test_user_defined_scorer() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_auto_embeddings_on() -> None:
|
def test_everything_embedded() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||||
)
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -189,8 +205,8 @@ def test_auto_embeddings_on() -> None:
|
|||||||
actions = [str1, str2, str3]
|
actions = [str1, str2, str3]
|
||||||
|
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=rl_chain.BasedOn(ctx_str_1),
|
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
||||||
action=rl_chain.ToSelectFrom(actions),
|
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"]
|
selection_metadata = response["selection_metadata"]
|
||||||
vw_str = feature_embedder.format(selection_metadata)
|
vw_str = feature_embedder.format(selection_metadata)
|
||||||
@ -200,7 +216,9 @@ def test_auto_embeddings_on() -> None:
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_auto_embedder_is_off() -> None:
|
def test_default_auto_embedder_is_off() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
|
||||||
)
|
)
|
||||||
@ -224,9 +242,11 @@ def test_default_auto_embedder_is_off() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_off() -> None:
|
def test_default_w_embeddings_off() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
|
||||||
)
|
)
|
||||||
@ -250,29 +270,54 @@ def test_default_embeddings_off() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
def test_default_w_embeddings_on() -> None:
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=True, model=MockEncoderReturnsList()
|
||||||
|
)
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||||
)
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
ctx_str_1 = "context1"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
|
||||||
|
|
||||||
|
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
|
||||||
|
|
||||||
|
actions = [str1, str2]
|
||||||
|
|
||||||
|
response = chain.run(
|
||||||
|
User=rl_chain.BasedOn(ctx_str_1),
|
||||||
|
action=rl_chain.ToSelectFrom(actions),
|
||||||
|
)
|
||||||
|
selection_metadata = response["selection_metadata"]
|
||||||
|
vw_str = feature_embedder.format(selection_metadata)
|
||||||
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
|
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
||||||
|
llm, PROMPT = setup()
|
||||||
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=True, model=MockEncoderReturnsList()
|
||||||
|
)
|
||||||
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
|
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
str1 = "0"
|
||||||
|
str2 = "1"
|
||||||
|
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
|
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
|
||||||
|
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||||
|
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
||||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
|
||||||
|
|
||||||
expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """ # noqa
|
actions = [str1, rl_chain.Embed(str2)]
|
||||||
|
|
||||||
actions = [str1, str2, rl_chain.Embed(str3)]
|
|
||||||
|
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
|
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
|
||||||
@ -291,7 +336,9 @@ def test_default_no_scorer_specified() -> None:
|
|||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=chain_llm,
|
llm=chain_llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
@ -310,7 +357,9 @@ def test_explicitly_no_scorer() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=None,
|
selection_scorer=None,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
@ -330,7 +379,9 @@ def test_auto_scorer_with_user_defined_llm() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
@ -348,7 +399,9 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
|
|||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chain.run(
|
chain.run(
|
||||||
@ -371,7 +424,9 @@ def test_activate_and_deactivate_scorer() -> None:
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=PROMPT,
|
prompt=PROMPT,
|
||||||
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=pick_best_chain.base.BasedOn("Context"),
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
|
@ -9,7 +9,9 @@ encoded_keyword = "[encoded]"
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_context_throws() -> None:
|
def test_pickbest_textembedder_missing_context_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_action = {"action": ["0", "1", "2"]}
|
named_action = {"action": ["0", "1", "2"]}
|
||||||
event = pick_best_chain.PickBestEvent(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from=named_action, based_on={}
|
inputs={}, to_select_from=named_action, based_on={}
|
||||||
@ -20,7 +22,9 @@ def test_pickbest_textembedder_missing_context_throws() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
def test_pickbest_textembedder_missing_actions_throws() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
event = pick_best_chain.PickBestEvent(
|
event = pick_best_chain.PickBestEvent(
|
||||||
inputs={}, to_select_from={}, based_on={"context": "context"}
|
inputs={}, to_select_from={}, based_on={"context": "context"}
|
||||||
)
|
)
|
||||||
@ -30,7 +34,9 @@ def test_pickbest_textembedder_missing_actions_throws() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
def test_pickbest_textembedder_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
event = pick_best_chain.PickBestEvent(
|
event = pick_best_chain.PickBestEvent(
|
||||||
@ -42,7 +48,9 @@ def test_pickbest_textembedder_no_label_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
|
||||||
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
|
||||||
@ -58,7 +66,9 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
expected = (
|
expected = (
|
||||||
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
|
||||||
@ -76,7 +86,9 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
@ -100,7 +112,9 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
@ -124,7 +138,9 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
@ -137,7 +153,9 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
@ -151,7 +169,9 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
context = {"context1": "context1", "context2": "context2"}
|
context = {"context1": "context1", "context2": "context2"}
|
||||||
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
|
||||||
@ -165,7 +185,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -198,7 +220,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None
|
|||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
|
||||||
None
|
None
|
||||||
):
|
):
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -231,7 +255,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -263,7 +289,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
@ -298,7 +326,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
|
|||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_raw_features_underscored() -> None:
|
def test_raw_features_underscored() -> None:
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
|
||||||
|
auto_embed=False, model=MockEncoder()
|
||||||
|
)
|
||||||
str1 = "this is a long string"
|
str1 = "this is a long string"
|
||||||
str1_underscored = str1.replace(" ", "_")
|
str1_underscored = str1.replace(" ", "_")
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
||||||
|
@ -1,3 +1,15 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
||||||
class MockEncoder:
|
class MockEncoder:
|
||||||
def encode(self, to_encode: str) -> str:
|
def encode(self, to_encode: str) -> str:
|
||||||
return "[encoded]" + to_encode
|
return "[encoded]" + to_encode
|
||||||
|
|
||||||
|
|
||||||
|
class MockEncoderReturnsList:
|
||||||
|
def encode(self, to_encode: Any) -> List:
|
||||||
|
if isinstance(to_encode, str):
|
||||||
|
return [1.0, 2.0]
|
||||||
|
elif isinstance(to_encode, List):
|
||||||
|
return [[1.0, 2.0] for _ in range(len(to_encode))]
|
||||||
|
raise ValueError("Invalid input type for unit test")
|
||||||
|
Loading…
Reference in New Issue
Block a user