fix lock, imports, deps, test w deps, typo, formatting

This commit is contained in:
olgavrou 2023-08-18 05:45:21 -04:00
parent e9423300d9
commit 1ae5a9c7a3
9 changed files with 153 additions and 505 deletions

View File

@ -2,25 +2,22 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import vowpal_wabbit_next as vw
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.prompts import ( from langchain.prompts import (
BasePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
) )
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -87,7 +84,9 @@ def EmbedAndKeep(anything):
# helper functions # helper functions
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]: def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
import vowpal_wabbit_next as vw
return [parser.parse_line(line) for line in input_str.split("\n")] return [parser.parse_line(line) for line in input_str.split("\n")]
@ -100,7 +99,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
if not to_select_from: if not to_select_from:
raise ValueError( raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." "No variables using 'ToSelectFrom' found in the inputs. \
Please include at least one variable containing a list to select from."
) )
based_on = { based_on = {
@ -173,14 +173,17 @@ class VwPolicy(Policy):
self.vw_logger = vw_logger self.vw_logger = vw_logger
def predict(self, event: Event) -> Any: def predict(self, event: Event) -> Any:
import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace) text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one( return self.workspace.predict_one(
parse_lines(text_parser, self.feature_embedder.format(event)) parse_lines(text_parser, self.feature_embedder.format(event))
) )
def learn(self, event: Event): def learn(self, event: Event):
vw_ex = self.feature_embedder.format(event) import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event)
text_parser = vw.TextFormatParser(self.workspace) text_parser = vw.TextFormatParser(self.workspace)
multi_ex = parse_lines(text_parser, vw_ex) multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex) self.workspace.learn_one(multi_ex)
@ -216,7 +219,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
@staticmethod @staticmethod
def get_default_system_prompt() -> SystemMessagePromptTemplate: def get_default_system_prompt() -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template( return SystemMessagePromptTemplate.from_template(
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\ "PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response." You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
) )

View File

@ -1,4 +1,3 @@
import pandas as pd
from typing import Optional from typing import Optional
@ -23,5 +22,7 @@ class MetricsTracker:
if self._step > 0 and self._i % self._step == 0: if self._step > 0 and self._i % self._step == 0:
self._history.append({"step": self._i, "score": self.score}) self._history.append({"step": self._i, "score": self.score})
def to_pandas(self) -> pd.DataFrame: def to_pandas(self) -> "pd.DataFrame":
import pandas as pd
return pd.DataFrame(self._history) return pd.DataFrame(self._history)

View File

@ -1,11 +1,10 @@
from pathlib import Path
import shutil
import datetime import datetime
import vowpal_wabbit_next as vw
from typing import Union, Sequence
import os
import glob import glob
import logging import logging
import os
import shutil
from pathlib import Path
from typing import Sequence, Union
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,14 +34,18 @@ class ModelRepository:
def has_history(self) -> bool: def has_history(self) -> bool:
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0 return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
def save(self, workspace: vw.Workspace) -> None: def save(self, workspace: "vw.Workspace") -> None:
import vowpal_wabbit_next as vw
with open(self.model_path, "wb") as f: with open(self.model_path, "wb") as f:
logger.info(f"storing rl_chain model in: {self.model_path}") logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize()) f.write(workspace.serialize())
if self.with_history: # write history if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw") shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> vw.Workspace: def load(self, commandline: Sequence[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw
model_data = None model_data = None
if self.model_path.exists(): if self.model_path.exists():
with open(self.model_path, "rb") as f: with open(self.model_path, "rb") as f:

View File

@ -1,19 +1,15 @@
from __future__ import annotations from __future__ import annotations
import langchain.chains.rl_chain.base as base import logging
from typing import Any, Dict, List, Optional, Tuple
import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None # sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
@ -23,7 +19,7 @@ SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder): class PickBestFeatureEmbedder(base.Embedder):
""" """
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
Attributes: Attributes:
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
""" """
@ -32,6 +28,8 @@ class PickBestFeatureEmbedder(base.Embedder):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if model is None: if model is None:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("bert-base-nli-mean-tokens") model = SentenceTransformer("bert-base-nli-mean-tokens")
self.model = model self.model = model
@ -67,7 +65,7 @@ class PickBestFeatureEmbedder(base.Embedder):
) )
example_string = "" example_string = ""
example_string += f"shared " example_string += "shared "
for context_item in context_emb: for context_item in context_emb:
for ns, based_on in context_item.items(): for ns, based_on in context_item.items():
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} " example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
@ -190,6 +188,8 @@ class PickBest(base.RLChain):
def _call_after_predict_before_llm( def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], PickBest.Event]: ) -> Tuple[Dict[str, Any], PickBest.Event]:
import numpy as np
prob_sum = sum(prob for _, prob in prediction) prob_sum = sum(prob for _, prob in prediction)
probabilities = [prob / prob_sum for _, prob in prediction] probabilities = [prob / prob_sum for _, prob in prediction]
## sample from the pmf ## sample from the pmf
@ -237,7 +237,7 @@ class PickBest(base.RLChain):
Attributes: Attributes:
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt. inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used. run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
Returns: Returns:
A dictionary containing: A dictionary containing:
- `response`: The response generated by the LLM (Language Model). - `response`: The response generated by the LLM (Language Model).

File diff suppressed because it is too large Load Diff

View File

@ -338,6 +338,7 @@ extended_testing = [
"xmltodict", "xmltodict",
"faiss-cpu", "faiss-cpu",
"openapi-schema-pydantic", "openapi-schema-pydantic",
"vowpal-wabbit-next"
] ]
[tool.ruff] [tool.ruff]

View File

@ -1,13 +1,15 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain.base as rl_chain
from test_utils import MockEncoder
import pytest import pytest
from langchain.prompts.prompt import PromptTemplate from test_utils import MockEncoder
import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
from langchain.chat_models import FakeListChatModel from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate
encoded_text = "[ e n c o d e d ] " encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def setup(): def setup():
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm""" _PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE) PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -16,6 +18,7 @@ def setup():
return llm, PROMPT return llm, PROMPT
@pytest.mark.requires("vowpal_wabbit_next")
def test_multiple_ToSelectFrom_throws(): def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -28,6 +31,7 @@ def test_multiple_ToSelectFrom_throws():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_missing_basedOn_from_throws(): def test_missing_basedOn_from_throws():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -36,6 +40,7 @@ def test_missing_basedOn_from_throws():
chain.run(action=rl_chain.ToSelectFrom(actions)) chain.run(action=rl_chain.ToSelectFrom(actions))
@pytest.mark.requires("vowpal_wabbit_next")
def test_ToSelectFrom_not_a_list_throws(): def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -47,6 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score_with_auto_validator_throws(): def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup() llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that # this LLM returns a number so that the auto validator will return that
@ -68,6 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
chain.update_with_delayed_score(event=selection_metadata, score=100) chain.update_with_delayed_score(event=selection_metadata, score=100)
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score_force(): def test_update_with_delayed_score_force():
llm, PROMPT = setup() llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that # this LLM returns a number so that the auto validator will return that
@ -91,6 +98,7 @@ def test_update_with_delayed_score_force():
assert selection_metadata.selected.score == 100.0 assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score(): def test_update_with_delayed_score():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -108,6 +116,7 @@ def test_update_with_delayed_score():
assert selection_metadata.selected.score == 100.0 assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_user_defined_scorer(): def test_user_defined_scorer():
llm, PROMPT = setup() llm, PROMPT = setup()
@ -129,6 +138,7 @@ def test_user_defined_scorer():
assert selection_metadata.selected.score == 200.0 assert selection_metadata.selected.score == 200.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings(): def test_default_embeddings():
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -162,6 +172,7 @@ def test_default_embeddings():
assert vw_str == expected assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings_off(): def test_default_embeddings_off():
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -187,6 +198,7 @@ def test_default_embeddings_off():
assert vw_str == expected assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings_mixed_w_explicit_user_embeddings(): def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup() llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -221,6 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
assert vw_str == expected assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_no_scorer_specified(): def test_default_no_scorer_specified():
_, PROMPT = setup() _, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100]) chain_llm = FakeListChatModel(responses=[100])
@ -235,6 +248,7 @@ def test_default_no_scorer_specified():
assert selection_metadata.selected.score == 100.0 assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_explicitly_no_scorer(): def test_explicitly_no_scorer():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm( chain = pick_best_chain.PickBest.from_llm(
@ -250,6 +264,7 @@ def test_explicitly_no_scorer():
assert selection_metadata.selected.score == None assert selection_metadata.selected.score == None
@pytest.mark.requires("vowpal_wabbit_next")
def test_auto_scorer_with_user_defined_llm(): def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup() llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300]) scorer_llm = FakeListChatModel(responses=[300])
@ -268,15 +283,14 @@ def test_auto_scorer_with_user_defined_llm():
assert selection_metadata.selected.score == 300.0 assert selection_metadata.selected.score == 300.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_calling_chain_w_reserved_inputs_throws(): def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup() llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError): with pytest.raises(ValueError):
chain.run( chain.run(
User=rl_chain.BasedOn("Context"), User=rl_chain.BasedOn("Context"),
rl_chain_selected_based_on=rl_chain.ToSelectFrom( rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
["0", "1", "2"]
),
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -1,12 +1,13 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain import pytest
import langchain.chains.rl_chain.base as rl_chain
from test_utils import MockEncoder from test_utils import MockEncoder
import pytest import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
encoded_text = "[ e n c o d e d ] " encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws(): def test_pickbest_textembedder_missing_context_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]} named_action = {"action": ["0", "1", "2"]}
@ -17,6 +18,7 @@ def test_pickbest_textembedder_missing_context_throws():
feature_embedder.format(event) feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws(): def test_pickbest_textembedder_missing_actions_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event( event = pick_best_chain.PickBest.Event(
@ -26,6 +28,7 @@ def test_pickbest_textembedder_missing_actions_throws():
feature_embedder.format(event) feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb(): def test_pickbest_textembedder_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
@ -37,6 +40,7 @@ def test_pickbest_textembedder_no_label_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb(): def test_pickbest_textembedder_w_label_no_score_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
@ -52,6 +56,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb(): def test_pickbest_textembedder_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]} named_actions = {"action1": ["0", "1", "2"]}
@ -69,6 +74,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb(): def test_pickbest_textembedder_w_full_label_w_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -92,6 +98,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0" str1 = "0"
@ -115,6 +122,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -127,6 +135,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -140,6 +149,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -153,6 +163,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -168,9 +179,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = { named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
}
context = { context = {
"context1": rl_chain.Embed(ctx_str_1), "context1": rl_chain.Embed(ctx_str_1),
"context2": rl_chain.Embed(ctx_str_2), "context2": rl_chain.Embed(ctx_str_2),
@ -185,6 +194,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -201,9 +211,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = { named_actions = {
"action1": rl_chain.EmbedAndKeep( "action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
[{"a": str1, "b": str1}, str2, str3]
)
} }
context = { context = {
"context1": rl_chain.EmbedAndKeep(ctx_str_1), "context1": rl_chain.EmbedAndKeep(ctx_str_1),
@ -219,6 +227,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -252,6 +261,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep(): def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -288,6 +298,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
assert vw_ex_str == expected assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored(): def test_raw_features_underscored():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string" str1 = "this is a long string"

View File

@ -1,16 +1,18 @@
import langchain.chains.rl_chain.base as base import pytest
from test_utils import MockEncoder from test_utils import MockEncoder
import pytest import langchain.chains.rl_chain.base as base
encoded_text = "[ e n c o d e d ] " encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb(): def test_simple_context_str_no_emb():
expected = [{"a_namespace": "test"}] expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb(): def test_simple_context_str_w_emb():
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
@ -25,6 +27,7 @@ def test_simple_context_str_w_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb(): def test_simple_context_str_w_nested_emb():
# nested embeddings, innermost wins # nested embeddings, innermost wins
str1 = "test" str1 = "test"
@ -42,11 +45,13 @@ def test_simple_context_str_w_nested_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb(): def test_context_w_namespace_no_emb():
expected = [{"test_namespace": "test"}] expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb(): def test_context_w_namespace_w_emb():
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
@ -61,6 +66,7 @@ def test_context_w_namespace_w_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2(): def test_context_w_namespace_w_emb2():
str1 = "test" str1 = "test"
encoded_str1 = " ".join(char for char in str1) encoded_str1 = " ".join(char for char in str1)
@ -75,6 +81,7 @@ def test_context_w_namespace_w_emb2():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb(): def test_context_w_namespace_w_some_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -103,6 +110,7 @@ def test_context_w_namespace_w_some_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb(): def test_simple_action_strlist_no_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -111,6 +119,7 @@ def test_simple_action_strlist_no_emb():
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb(): def test_simple_action_strlist_w_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -138,6 +147,7 @@ def test_simple_action_strlist_w_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb(): def test_simple_action_strlist_w_some_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -170,6 +180,7 @@ def test_simple_action_strlist_w_some_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb(): def test_action_w_namespace_no_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -192,6 +203,7 @@ def test_action_w_namespace_no_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb(): def test_action_w_namespace_w_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -233,6 +245,7 @@ def test_action_w_namespace_w_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2(): def test_action_w_namespace_w_emb2():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -278,6 +291,7 @@ def test_action_w_namespace_w_emb2():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb(): def test_action_w_namespace_w_some_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -318,6 +332,7 @@ def test_action_w_namespace_w_some_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -368,6 +383,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb(): def test_one_namespace_w_list_of_features_no_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -375,6 +391,7 @@ def test_one_namespace_w_list_of_features_no_emb():
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb(): def test_one_namespace_w_list_of_features_w_some_emb():
str1 = "test1" str1 = "test1"
str2 = "test2" str2 = "test2"
@ -386,21 +403,25 @@ def test_one_namespace_w_list_of_features_w_some_emb():
) )
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws(): def test_nested_list_features_throws():
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws(): def test_dict_in_list_throws():
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws(): def test_nested_dict_throws():
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws(): def test_list_of_tuples_throws():
with pytest.raises(ValueError): with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())