mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
fix lock, imports, deps, test w deps, typo, formatting
This commit is contained in:
parent
e9423300d9
commit
1ae5a9c7a3
@ -2,25 +2,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
import vowpal_wabbit_next as vw
|
|
||||||
from langchain.chains.rl_chain.vw_logger import VwLogger
|
|
||||||
from langchain.chains.rl_chain.model_repository import ModelRepository
|
|
||||||
from langchain.chains.rl_chain.metrics import MetricsTracker
|
|
||||||
from langchain.prompts import BasePromptTemplate
|
|
||||||
|
|
||||||
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.rl_chain.metrics import MetricsTracker
|
||||||
|
from langchain.chains.rl_chain.model_repository import ModelRepository
|
||||||
|
from langchain.chains.rl_chain.vw_logger import VwLogger
|
||||||
from langchain.prompts import (
|
from langchain.prompts import (
|
||||||
|
BasePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -87,7 +84,9 @@ def EmbedAndKeep(anything):
|
|||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
|
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
||||||
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +99,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
|||||||
|
|
||||||
if not to_select_from:
|
if not to_select_from:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
|
"No variables using 'ToSelectFrom' found in the inputs. \
|
||||||
|
Please include at least one variable containing a list to select from."
|
||||||
)
|
)
|
||||||
|
|
||||||
based_on = {
|
based_on = {
|
||||||
@ -173,14 +173,17 @@ class VwPolicy(Policy):
|
|||||||
self.vw_logger = vw_logger
|
self.vw_logger = vw_logger
|
||||||
|
|
||||||
def predict(self, event: Event) -> Any:
|
def predict(self, event: Event) -> Any:
|
||||||
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
text_parser = vw.TextFormatParser(self.workspace)
|
text_parser = vw.TextFormatParser(self.workspace)
|
||||||
return self.workspace.predict_one(
|
return self.workspace.predict_one(
|
||||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||||
)
|
)
|
||||||
|
|
||||||
def learn(self, event: Event):
|
def learn(self, event: Event):
|
||||||
vw_ex = self.feature_embedder.format(event)
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
|
vw_ex = self.feature_embedder.format(event)
|
||||||
text_parser = vw.TextFormatParser(self.workspace)
|
text_parser = vw.TextFormatParser(self.workspace)
|
||||||
multi_ex = parse_lines(text_parser, vw_ex)
|
multi_ex = parse_lines(text_parser, vw_ex)
|
||||||
self.workspace.learn_one(multi_ex)
|
self.workspace.learn_one(multi_ex)
|
||||||
@ -216,7 +219,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_system_prompt() -> SystemMessagePromptTemplate:
|
def get_default_system_prompt() -> SystemMessagePromptTemplate:
|
||||||
return SystemMessagePromptTemplate.from_template(
|
return SystemMessagePromptTemplate.from_template(
|
||||||
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
|
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
|
||||||
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
|
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import pandas as pd
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@ -23,5 +22,7 @@ class MetricsTracker:
|
|||||||
if self._step > 0 and self._i % self._step == 0:
|
if self._step > 0 and self._i % self._step == 0:
|
||||||
self._history.append({"step": self._i, "score": self.score})
|
self._history.append({"step": self._i, "score": self.score})
|
||||||
|
|
||||||
def to_pandas(self) -> pd.DataFrame:
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
return pd.DataFrame(self._history)
|
return pd.DataFrame(self._history)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import datetime
|
import datetime
|
||||||
import vowpal_wabbit_next as vw
|
|
||||||
from typing import Union, Sequence
|
|
||||||
import os
|
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,14 +34,18 @@ class ModelRepository:
|
|||||||
def has_history(self) -> bool:
|
def has_history(self) -> bool:
|
||||||
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
|
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
|
||||||
|
|
||||||
def save(self, workspace: vw.Workspace) -> None:
|
def save(self, workspace: "vw.Workspace") -> None:
|
||||||
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
with open(self.model_path, "wb") as f:
|
with open(self.model_path, "wb") as f:
|
||||||
logger.info(f"storing rl_chain model in: {self.model_path}")
|
logger.info(f"storing rl_chain model in: {self.model_path}")
|
||||||
f.write(workspace.serialize())
|
f.write(workspace.serialize())
|
||||||
if self.with_history: # write history
|
if self.with_history: # write history
|
||||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||||
|
|
||||||
def load(self, commandline: Sequence[str]) -> vw.Workspace:
|
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
||||||
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
model_data = None
|
model_data = None
|
||||||
if self.model_path.exists():
|
if self.model_path.exists():
|
||||||
with open(self.model_path, "rb") as f:
|
with open(self.model_path, "rb") as f:
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import langchain.chains.rl_chain.base as base
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import langchain.chains.rl_chain.base as base
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
|
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
|
||||||
@ -23,7 +19,7 @@ SENTINEL = object()
|
|||||||
class PickBestFeatureEmbedder(base.Embedder):
|
class PickBestFeatureEmbedder(base.Embedder):
|
||||||
"""
|
"""
|
||||||
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
|
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||||
"""
|
"""
|
||||||
@ -32,6 +28,8 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
model = SentenceTransformer("bert-base-nli-mean-tokens")
|
model = SentenceTransformer("bert-base-nli-mean-tokens")
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -67,7 +65,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
example_string = ""
|
example_string = ""
|
||||||
example_string += f"shared "
|
example_string += "shared "
|
||||||
for context_item in context_emb:
|
for context_item in context_emb:
|
||||||
for ns, based_on in context_item.items():
|
for ns, based_on in context_item.items():
|
||||||
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
|
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
|
||||||
@ -190,6 +188,8 @@ class PickBest(base.RLChain):
|
|||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
prob_sum = sum(prob for _, prob in prediction)
|
prob_sum = sum(prob for _, prob in prediction)
|
||||||
probabilities = [prob / prob_sum for _, prob in prediction]
|
probabilities = [prob / prob_sum for _, prob in prediction]
|
||||||
## sample from the pmf
|
## sample from the pmf
|
||||||
@ -237,7 +237,7 @@ class PickBest(base.RLChain):
|
|||||||
Attributes:
|
Attributes:
|
||||||
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
|
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
|
||||||
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
|
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing:
|
A dictionary containing:
|
||||||
- `response`: The response generated by the LLM (Language Model).
|
- `response`: The response generated by the LLM (Language Model).
|
||||||
|
500
libs/langchain/poetry.lock
generated
500
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -338,6 +338,7 @@ extended_testing = [
|
|||||||
"xmltodict",
|
"xmltodict",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"openapi-schema-pydantic",
|
"openapi-schema-pydantic",
|
||||||
|
"vowpal-wabbit-next"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
|
||||||
import langchain.chains.rl_chain.base as rl_chain
|
|
||||||
from test_utils import MockEncoder
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
|
import langchain.chains.rl_chain.base as rl_chain
|
||||||
|
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||||
from langchain.chat_models import FakeListChatModel
|
from langchain.chat_models import FakeListChatModel
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
encoded_text = "[ e n c o d e d ] "
|
encoded_text = "[ e n c o d e d ] "
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def setup():
|
def setup():
|
||||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||||
@ -16,6 +18,7 @@ def setup():
|
|||||||
return llm, PROMPT
|
return llm, PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_multiple_ToSelectFrom_throws():
|
def test_multiple_ToSelectFrom_throws():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
@ -28,6 +31,7 @@ def test_multiple_ToSelectFrom_throws():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_missing_basedOn_from_throws():
|
def test_missing_basedOn_from_throws():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
@ -36,6 +40,7 @@ def test_missing_basedOn_from_throws():
|
|||||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_ToSelectFrom_not_a_list_throws():
|
def test_ToSelectFrom_not_a_list_throws():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
@ -47,6 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
def test_update_with_delayed_score_with_auto_validator_throws():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
@ -68,6 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
|||||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_update_with_delayed_score_force():
|
def test_update_with_delayed_score_force():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
# this LLM returns a number so that the auto validator will return that
|
# this LLM returns a number so that the auto validator will return that
|
||||||
@ -91,6 +98,7 @@ def test_update_with_delayed_score_force():
|
|||||||
assert selection_metadata.selected.score == 100.0
|
assert selection_metadata.selected.score == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_update_with_delayed_score():
|
def test_update_with_delayed_score():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -108,6 +116,7 @@ def test_update_with_delayed_score():
|
|||||||
assert selection_metadata.selected.score == 100.0
|
assert selection_metadata.selected.score == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_user_defined_scorer():
|
def test_user_defined_scorer():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
|
|
||||||
@ -129,6 +138,7 @@ def test_user_defined_scorer():
|
|||||||
assert selection_metadata.selected.score == 200.0
|
assert selection_metadata.selected.score == 200.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_default_embeddings():
|
def test_default_embeddings():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
@ -162,6 +172,7 @@ def test_default_embeddings():
|
|||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_default_embeddings_off():
|
def test_default_embeddings_off():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
@ -187,6 +198,7 @@ def test_default_embeddings_off():
|
|||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
@ -221,6 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
|||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_default_no_scorer_specified():
|
def test_default_no_scorer_specified():
|
||||||
_, PROMPT = setup()
|
_, PROMPT = setup()
|
||||||
chain_llm = FakeListChatModel(responses=[100])
|
chain_llm = FakeListChatModel(responses=[100])
|
||||||
@ -235,6 +248,7 @@ def test_default_no_scorer_specified():
|
|||||||
assert selection_metadata.selected.score == 100.0
|
assert selection_metadata.selected.score == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_explicitly_no_scorer():
|
def test_explicitly_no_scorer():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
@ -250,6 +264,7 @@ def test_explicitly_no_scorer():
|
|||||||
assert selection_metadata.selected.score == None
|
assert selection_metadata.selected.score == None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_auto_scorer_with_user_defined_llm():
|
def test_auto_scorer_with_user_defined_llm():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
scorer_llm = FakeListChatModel(responses=[300])
|
scorer_llm = FakeListChatModel(responses=[300])
|
||||||
@ -268,15 +283,14 @@ def test_auto_scorer_with_user_defined_llm():
|
|||||||
assert selection_metadata.selected.score == 300.0
|
assert selection_metadata.selected.score == 300.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_calling_chain_w_reserved_inputs_throws():
|
def test_calling_chain_w_reserved_inputs_throws():
|
||||||
llm, PROMPT = setup()
|
llm, PROMPT = setup()
|
||||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chain.run(
|
chain.run(
|
||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
|
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||||
["0", "1", "2"]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
import pytest
|
||||||
import langchain.chains.rl_chain.base as rl_chain
|
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
import pytest
|
import langchain.chains.rl_chain.base as rl_chain
|
||||||
|
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||||
|
|
||||||
encoded_text = "[ e n c o d e d ] "
|
encoded_text = "[ e n c o d e d ] "
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_context_throws():
|
def test_pickbest_textembedder_missing_context_throws():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_action = {"action": ["0", "1", "2"]}
|
named_action = {"action": ["0", "1", "2"]}
|
||||||
@ -17,6 +18,7 @@ def test_pickbest_textembedder_missing_context_throws():
|
|||||||
feature_embedder.format(event)
|
feature_embedder.format(event)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_missing_actions_throws():
|
def test_pickbest_textembedder_missing_actions_throws():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
event = pick_best_chain.PickBest.Event(
|
event = pick_best_chain.PickBest.Event(
|
||||||
@ -26,6 +28,7 @@ def test_pickbest_textembedder_missing_actions_throws():
|
|||||||
feature_embedder.format(event)
|
feature_embedder.format(event)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_no_label_no_emb():
|
def test_pickbest_textembedder_no_label_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
@ -37,6 +40,7 @@ def test_pickbest_textembedder_no_label_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
@ -52,6 +56,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
def test_pickbest_textembedder_w_full_label_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": ["0", "1", "2"]}
|
named_actions = {"action1": ["0", "1", "2"]}
|
||||||
@ -69,6 +74,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
def test_pickbest_textembedder_w_full_label_w_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -92,6 +98,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
@ -115,6 +122,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
@ -127,6 +135,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
@ -140,6 +149,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||||
@ -153,6 +163,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
@ -168,9 +179,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
||||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||||
|
|
||||||
named_actions = {
|
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||||
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
|
|
||||||
}
|
|
||||||
context = {
|
context = {
|
||||||
"context1": rl_chain.Embed(ctx_str_1),
|
"context1": rl_chain.Embed(ctx_str_1),
|
||||||
"context2": rl_chain.Embed(ctx_str_2),
|
"context2": rl_chain.Embed(ctx_str_2),
|
||||||
@ -185,6 +194,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
@ -201,9 +211,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||||
|
|
||||||
named_actions = {
|
named_actions = {
|
||||||
"action1": rl_chain.EmbedAndKeep(
|
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||||
[{"a": str1, "b": str1}, str2, str3]
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
context = {
|
context = {
|
||||||
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
||||||
@ -219,6 +227,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
@ -252,6 +261,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
|
|
||||||
@ -288,6 +298,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
|||||||
assert vw_ex_str == expected
|
assert vw_ex_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_raw_features_underscored():
|
def test_raw_features_underscored():
|
||||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||||
str1 = "this is a long string"
|
str1 = "this is a long string"
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import langchain.chains.rl_chain.base as base
|
import pytest
|
||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
import pytest
|
import langchain.chains.rl_chain.base as base
|
||||||
|
|
||||||
encoded_text = "[ e n c o d e d ] "
|
encoded_text = "[ e n c o d e d ] "
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_no_emb():
|
def test_simple_context_str_no_emb():
|
||||||
expected = [{"a_namespace": "test"}]
|
expected = [{"a_namespace": "test"}]
|
||||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_emb():
|
def test_simple_context_str_w_emb():
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
@ -25,6 +27,7 @@ def test_simple_context_str_w_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_nested_emb():
|
def test_simple_context_str_w_nested_emb():
|
||||||
# nested embeddings, innermost wins
|
# nested embeddings, innermost wins
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
@ -42,11 +45,13 @@ def test_simple_context_str_w_nested_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_no_emb():
|
def test_context_w_namespace_no_emb():
|
||||||
expected = [{"test_namespace": "test"}]
|
expected = [{"test_namespace": "test"}]
|
||||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb():
|
def test_context_w_namespace_w_emb():
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
@ -61,6 +66,7 @@ def test_context_w_namespace_w_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb2():
|
def test_context_w_namespace_w_emb2():
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = " ".join(char for char in str1)
|
encoded_str1 = " ".join(char for char in str1)
|
||||||
@ -75,6 +81,7 @@ def test_context_w_namespace_w_emb2():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_some_emb():
|
def test_context_w_namespace_w_some_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -103,6 +110,7 @@ def test_context_w_namespace_w_some_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_no_emb():
|
def test_simple_action_strlist_no_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -111,6 +119,7 @@ def test_simple_action_strlist_no_emb():
|
|||||||
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
|
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_emb():
|
def test_simple_action_strlist_w_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -138,6 +147,7 @@ def test_simple_action_strlist_w_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_action_strlist_w_some_emb():
|
def test_simple_action_strlist_w_some_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -170,6 +180,7 @@ def test_simple_action_strlist_w_some_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_no_emb():
|
def test_action_w_namespace_no_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -192,6 +203,7 @@ def test_action_w_namespace_no_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb():
|
def test_action_w_namespace_w_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -233,6 +245,7 @@ def test_action_w_namespace_w_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb2():
|
def test_action_w_namespace_w_emb2():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -278,6 +291,7 @@ def test_action_w_namespace_w_emb2():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_some_emb():
|
def test_action_w_namespace_w_some_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -318,6 +332,7 @@ def test_action_w_namespace_w_some_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -368,6 +383,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_no_emb():
|
def test_one_namespace_w_list_of_features_no_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -375,6 +391,7 @@ def test_one_namespace_w_list_of_features_no_emb():
|
|||||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_w_some_emb():
|
def test_one_namespace_w_list_of_features_w_some_emb():
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
@ -386,21 +403,25 @@ def test_one_namespace_w_list_of_features_w_some_emb():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_list_features_throws():
|
def test_nested_list_features_throws():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_dict_in_list_throws():
|
def test_dict_in_list_throws():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_dict_throws():
|
def test_nested_dict_throws():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_list_of_tuples_throws():
|
def test_list_of_tuples_throws():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
||||||
|
Loading…
Reference in New Issue
Block a user