mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
fix lock, imports, deps, test w deps, typo, formatting
This commit is contained in:
parent
e9423300d9
commit
1ae5a9c7a3
@ -2,25 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import vowpal_wabbit_next as vw
|
||||
from langchain.chains.rl_chain.vw_logger import VwLogger
|
||||
from langchain.chains.rl_chain.model_repository import ModelRepository
|
||||
from langchain.chains.rl_chain.metrics import MetricsTracker
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
|
||||
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.rl_chain.metrics import MetricsTracker
|
||||
from langchain.chains.rl_chain.model_repository import ModelRepository
|
||||
from langchain.chains.rl_chain.vw_logger import VwLogger
|
||||
from langchain.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -87,7 +84,9 @@ def EmbedAndKeep(anything):
|
||||
# helper functions
|
||||
|
||||
|
||||
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
|
||||
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||
|
||||
|
||||
@ -100,7 +99,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
|
||||
if not to_select_from:
|
||||
raise ValueError(
|
||||
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
|
||||
"No variables using 'ToSelectFrom' found in the inputs. \
|
||||
Please include at least one variable containing a list to select from."
|
||||
)
|
||||
|
||||
based_on = {
|
||||
@ -173,14 +173,17 @@ class VwPolicy(Policy):
|
||||
self.vw_logger = vw_logger
|
||||
|
||||
def predict(self, event: Event) -> Any:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
return self.workspace.predict_one(
|
||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||
)
|
||||
|
||||
def learn(self, event: Event):
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
multi_ex = parse_lines(text_parser, vw_ex)
|
||||
self.workspace.learn_one(multi_ex)
|
||||
@ -216,7 +219,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
@staticmethod
|
||||
def get_default_system_prompt() -> SystemMessagePromptTemplate:
|
||||
return SystemMessagePromptTemplate.from_template(
|
||||
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
|
||||
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
|
||||
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@ -23,5 +22,7 @@ class MetricsTracker:
|
||||
if self._step > 0 and self._i % self._step == 0:
|
||||
self._history.append({"step": self._i, "score": self.score})
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
import pandas as pd
|
||||
|
||||
return pd.DataFrame(self._history)
|
||||
|
@ -1,11 +1,10 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import datetime
|
||||
import vowpal_wabbit_next as vw
|
||||
from typing import Union, Sequence
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,14 +34,18 @@ class ModelRepository:
|
||||
def has_history(self) -> bool:
|
||||
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
|
||||
|
||||
def save(self, workspace: vw.Workspace) -> None:
|
||||
def save(self, workspace: "vw.Workspace") -> None:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
with open(self.model_path, "wb") as f:
|
||||
logger.info(f"storing rl_chain model in: {self.model_path}")
|
||||
f.write(workspace.serialize())
|
||||
if self.with_history: # write history
|
||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||
|
||||
def load(self, commandline: Sequence[str]) -> vw.Workspace:
|
||||
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
model_data = None
|
||||
if self.model_path.exists():
|
||||
with open(self.model_path, "rb") as f:
|
||||
|
@ -1,19 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import langchain.chains.rl_chain.base as base
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import langchain.chains.rl_chain.base as base
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
|
||||
@ -23,7 +19,7 @@ SENTINEL = object()
|
||||
class PickBestFeatureEmbedder(base.Embedder):
|
||||
"""
|
||||
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
|
||||
|
||||
|
||||
Attributes:
|
||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||
"""
|
||||
@ -32,6 +28,8 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
model = SentenceTransformer("bert-base-nli-mean-tokens")
|
||||
|
||||
self.model = model
|
||||
@ -67,7 +65,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
)
|
||||
|
||||
example_string = ""
|
||||
example_string += f"shared "
|
||||
example_string += "shared "
|
||||
for context_item in context_emb:
|
||||
for ns, based_on in context_item.items():
|
||||
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
|
||||
@ -190,6 +188,8 @@ class PickBest(base.RLChain):
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||
import numpy as np
|
||||
|
||||
prob_sum = sum(prob for _, prob in prediction)
|
||||
probabilities = [prob / prob_sum for _, prob in prediction]
|
||||
## sample from the pmf
|
||||
@ -237,7 +237,7 @@ class PickBest(base.RLChain):
|
||||
Attributes:
|
||||
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
|
||||
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- `response`: The response generated by the LLM (Language Model).
|
||||
|
500
libs/langchain/poetry.lock
generated
500
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -338,6 +338,7 @@ extended_testing = [
|
||||
"xmltodict",
|
||||
"faiss-cpu",
|
||||
"openapi-schema-pydantic",
|
||||
"vowpal-wabbit-next"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,13 +1,15 @@
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
from test_utils import MockEncoder
|
||||
import pytest
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
from langchain.chat_models import FakeListChatModel
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def setup():
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
@ -16,6 +18,7 @@ def setup():
|
||||
return llm, PROMPT
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_multiple_ToSelectFrom_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -28,6 +31,7 @@ def test_multiple_ToSelectFrom_throws():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_missing_basedOn_from_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -36,6 +40,7 @@ def test_missing_basedOn_from_throws():
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_ToSelectFrom_not_a_list_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -47,6 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
@ -68,6 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_update_with_delayed_score_force():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
@ -91,6 +98,7 @@ def test_update_with_delayed_score_force():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_update_with_delayed_score():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -108,6 +116,7 @@ def test_update_with_delayed_score():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_user_defined_scorer():
|
||||
llm, PROMPT = setup()
|
||||
|
||||
@ -129,6 +138,7 @@ def test_user_defined_scorer():
|
||||
assert selection_metadata.selected.score == 200.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_default_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -162,6 +172,7 @@ def test_default_embeddings():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_default_embeddings_off():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -187,6 +198,7 @@ def test_default_embeddings_off():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -221,6 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_default_no_scorer_specified():
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=[100])
|
||||
@ -235,6 +248,7 @@ def test_default_no_scorer_specified():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_explicitly_no_scorer():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -250,6 +264,7 @@ def test_explicitly_no_scorer():
|
||||
assert selection_metadata.selected.score == None
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_auto_scorer_with_user_defined_llm():
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=[300])
|
||||
@ -268,15 +283,14 @@ def test_auto_scorer_with_user_defined_llm():
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_calling_chain_w_reserved_inputs_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
with pytest.raises(ValueError):
|
||||
chain.run(
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
|
||||
["0", "1", "2"]
|
||||
),
|
||||
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -1,12 +1,13 @@
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import pytest
|
||||
import langchain.chains.rl_chain.base as rl_chain
|
||||
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
|
||||
|
||||
encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_context_throws():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_action = {"action": ["0", "1", "2"]}
|
||||
@ -17,6 +18,7 @@ def test_pickbest_textembedder_missing_context_throws():
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_missing_actions_throws():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
event = pick_best_chain.PickBest.Event(
|
||||
@ -26,6 +28,7 @@ def test_pickbest_textembedder_missing_actions_throws():
|
||||
feature_embedder.format(event)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_no_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
@ -37,6 +40,7 @@ def test_pickbest_textembedder_no_label_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
@ -52,6 +56,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": ["0", "1", "2"]}
|
||||
@ -69,6 +74,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
@ -92,6 +98,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "0"
|
||||
@ -115,6 +122,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
@ -127,6 +135,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
@ -140,6 +149,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
|
||||
@ -153,6 +163,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
@ -168,9 +179,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||
|
||||
named_actions = {
|
||||
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
|
||||
}
|
||||
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||
context = {
|
||||
"context1": rl_chain.Embed(ctx_str_1),
|
||||
"context2": rl_chain.Embed(ctx_str_2),
|
||||
@ -185,6 +194,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
@ -201,9 +211,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||
|
||||
named_actions = {
|
||||
"action1": rl_chain.EmbedAndKeep(
|
||||
[{"a": str1, "b": str1}, str2, str3]
|
||||
)
|
||||
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||
}
|
||||
context = {
|
||||
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
|
||||
@ -219,6 +227,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
@ -252,6 +261,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
|
||||
@ -288,6 +298,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
assert vw_ex_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_raw_features_underscored():
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
str1 = "this is a long string"
|
||||
|
@ -1,16 +1,18 @@
|
||||
import langchain.chains.rl_chain.base as base
|
||||
import pytest
|
||||
from test_utils import MockEncoder
|
||||
|
||||
import pytest
|
||||
import langchain.chains.rl_chain.base as base
|
||||
|
||||
encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_no_emb():
|
||||
expected = [{"a_namespace": "test"}]
|
||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_emb():
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
@ -25,6 +27,7 @@ def test_simple_context_str_w_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_context_str_w_nested_emb():
|
||||
# nested embeddings, innermost wins
|
||||
str1 = "test"
|
||||
@ -42,11 +45,13 @@ def test_simple_context_str_w_nested_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_no_emb():
|
||||
expected = [{"test_namespace": "test"}]
|
||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb():
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
@ -61,6 +66,7 @@ def test_context_w_namespace_w_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_emb2():
|
||||
str1 = "test"
|
||||
encoded_str1 = " ".join(char for char in str1)
|
||||
@ -75,6 +81,7 @@ def test_context_w_namespace_w_emb2():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_context_w_namespace_w_some_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -103,6 +110,7 @@ def test_context_w_namespace_w_some_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_no_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -111,6 +119,7 @@ def test_simple_action_strlist_no_emb():
|
||||
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -138,6 +147,7 @@ def test_simple_action_strlist_w_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_simple_action_strlist_w_some_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -170,6 +180,7 @@ def test_simple_action_strlist_w_some_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_no_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -192,6 +203,7 @@ def test_action_w_namespace_no_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -233,6 +245,7 @@ def test_action_w_namespace_w_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb2():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -278,6 +291,7 @@ def test_action_w_namespace_w_emb2():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_some_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -318,6 +332,7 @@ def test_action_w_namespace_w_some_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -368,6 +383,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_no_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -375,6 +391,7 @@ def test_one_namespace_w_list_of_features_no_emb():
|
||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_one_namespace_w_list_of_features_w_some_emb():
|
||||
str1 = "test1"
|
||||
str2 = "test2"
|
||||
@ -386,21 +403,25 @@ def test_one_namespace_w_list_of_features_w_some_emb():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_list_features_throws():
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_dict_in_list_throws():
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_nested_dict_throws():
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
def test_list_of_tuples_throws():
|
||||
with pytest.raises(ValueError):
|
||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
||||
|
Loading…
Reference in New Issue
Block a user