mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
experimental[patch]: refactor rl chain structure (#25398)
can't have a class and function with same name but different capitalization in same file for api reference building
This commit is contained in:
parent
94c9cb7321
commit
414154fa59
@ -19,9 +19,8 @@ from langchain_experimental.rl_chain.base import (
|
|||||||
SelectionScorer,
|
SelectionScorer,
|
||||||
ToSelectFrom,
|
ToSelectFrom,
|
||||||
VwPolicy,
|
VwPolicy,
|
||||||
embed,
|
|
||||||
stringify_embedding,
|
|
||||||
)
|
)
|
||||||
|
from langchain_experimental.rl_chain.helpers import embed, stringify_embedding
|
||||||
from langchain_experimental.rl_chain.pick_best_chain import (
|
from langchain_experimental.rl_chain.pick_best_chain import (
|
||||||
PickBest,
|
PickBest,
|
||||||
PickBestEvent,
|
PickBestEvent,
|
||||||
|
@ -27,6 +27,7 @@ from langchain_core.prompts import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import BaseModel, root_validator
|
from langchain_experimental.pydantic_v1 import BaseModel, root_validator
|
||||||
|
from langchain_experimental.rl_chain.helpers import _Embed
|
||||||
from langchain_experimental.rl_chain.metrics import (
|
from langchain_experimental.rl_chain.metrics import (
|
||||||
MetricsTrackerAverage,
|
MetricsTrackerAverage,
|
||||||
MetricsTrackerRollingWindow,
|
MetricsTrackerRollingWindow,
|
||||||
@ -74,17 +75,6 @@ def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
|||||||
return _ToSelectFrom(anything)
|
return _ToSelectFrom(anything)
|
||||||
|
|
||||||
|
|
||||||
class _Embed:
|
|
||||||
def __init__(self, value: Any, keep: bool = False):
|
|
||||||
self.value = value
|
|
||||||
self.keep = keep
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.value)
|
|
||||||
|
|
||||||
__repr__ = __str__
|
|
||||||
|
|
||||||
|
|
||||||
def Embed(anything: Any, keep: bool = False) -> Any:
|
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||||
"""Wrap a value to indicate that it should be embedded."""
|
"""Wrap a value to indicate that it should be embedded."""
|
||||||
|
|
||||||
@ -110,12 +100,6 @@ def EmbedAndKeep(anything: Any) -> Any:
|
|||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
def stringify_embedding(embedding: List) -> str:
|
|
||||||
"""Convert an embedding to a string."""
|
|
||||||
|
|
||||||
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
|
||||||
|
|
||||||
|
|
||||||
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
||||||
"""Parse the input string into a list of examples."""
|
"""Parse the input string into a list of examples."""
|
||||||
|
|
||||||
@ -559,97 +543,3 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "llm_personalizer_chain"
|
return "llm_personalizer_chain"
|
||||||
|
|
||||||
|
|
||||||
def is_stringtype_instance(item: Any) -> bool:
|
|
||||||
"""Check if an item is a string."""
|
|
||||||
|
|
||||||
return isinstance(item, str) or (
|
|
||||||
isinstance(item, _Embed) and isinstance(item.value, str)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_string_type(
|
|
||||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
|
||||||
) -> Dict[str, Union[str, List[str]]]:
|
|
||||||
"""Embed a string or an _Embed object."""
|
|
||||||
|
|
||||||
keep_str = ""
|
|
||||||
if isinstance(item, _Embed):
|
|
||||||
encoded = stringify_embedding(model.encode(item.value))
|
|
||||||
if item.keep:
|
|
||||||
keep_str = item.value.replace(" ", "_") + " "
|
|
||||||
elif isinstance(item, str):
|
|
||||||
encoded = item.replace(" ", "_")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported type {type(item)} for embedding")
|
|
||||||
|
|
||||||
if namespace is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
|
|
||||||
)
|
|
||||||
|
|
||||||
return {namespace: keep_str + encoded}
|
|
||||||
|
|
||||||
|
|
||||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
|
||||||
"""Embed a dictionary item."""
|
|
||||||
inner_dict: Dict = {}
|
|
||||||
for ns, embed_item in item.items():
|
|
||||||
if isinstance(embed_item, list):
|
|
||||||
inner_dict[ns] = []
|
|
||||||
for embed_list_item in embed_item:
|
|
||||||
embedded = embed_string_type(embed_list_item, model, ns)
|
|
||||||
inner_dict[ns].append(embedded[ns])
|
|
||||||
else:
|
|
||||||
inner_dict.update(embed_string_type(embed_item, model, ns))
|
|
||||||
return inner_dict
|
|
||||||
|
|
||||||
|
|
||||||
def embed_list_type(
|
|
||||||
item: list, model: Any, namespace: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
|
||||||
"""Embed a list item."""
|
|
||||||
|
|
||||||
ret_list: List = []
|
|
||||||
for embed_item in item:
|
|
||||||
if isinstance(embed_item, dict):
|
|
||||||
ret_list.append(embed_dict_type(embed_item, model))
|
|
||||||
elif isinstance(embed_item, list):
|
|
||||||
item_embedding = embed_list_type(embed_item, model, namespace)
|
|
||||||
# Get the first key from the first dictionary
|
|
||||||
first_key = next(iter(item_embedding[0]))
|
|
||||||
# Group the values under that key
|
|
||||||
grouping = {first_key: [item[first_key] for item in item_embedding]}
|
|
||||||
ret_list.append(grouping)
|
|
||||||
else:
|
|
||||||
ret_list.append(embed_string_type(embed_item, model, namespace))
|
|
||||||
return ret_list
|
|
||||||
|
|
||||||
|
|
||||||
def embed(
|
|
||||||
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
|
||||||
model: Any,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
|
||||||
"""
|
|
||||||
Embed the actions or context using the SentenceTransformer model
|
|
||||||
(or a model that has an `encode` function).
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
|
|
||||||
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
|
|
||||||
model: (Any, required) The model to use for embedding
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
|
|
||||||
""" # noqa: E501
|
|
||||||
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
|
|
||||||
to_embed, str
|
|
||||||
):
|
|
||||||
return [embed_string_type(to_embed, model, namespace)]
|
|
||||||
elif isinstance(to_embed, dict):
|
|
||||||
return [embed_dict_type(to_embed, model)]
|
|
||||||
elif isinstance(to_embed, list):
|
|
||||||
return embed_list_type(to_embed, model, namespace)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid input format for embedding")
|
|
||||||
|
114
libs/experimental/langchain_experimental/rl_chain/helpers.py
Normal file
114
libs/experimental/langchain_experimental/rl_chain/helpers.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class _Embed:
|
||||||
|
def __init__(self, value: Any, keep: bool = False):
|
||||||
|
self.value = value
|
||||||
|
self.keep = keep
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
|
def stringify_embedding(embedding: List) -> str:
|
||||||
|
"""Convert an embedding to a string."""
|
||||||
|
|
||||||
|
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
||||||
|
|
||||||
|
|
||||||
|
def is_stringtype_instance(item: Any) -> bool:
|
||||||
|
"""Check if an item is a string."""
|
||||||
|
|
||||||
|
return isinstance(item, str) or (
|
||||||
|
isinstance(item, _Embed) and isinstance(item.value, str)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_string_type(
|
||||||
|
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||||
|
) -> Dict[str, Union[str, List[str]]]:
|
||||||
|
"""Embed a string or an _Embed object."""
|
||||||
|
|
||||||
|
keep_str = ""
|
||||||
|
if isinstance(item, _Embed):
|
||||||
|
encoded = stringify_embedding(model.encode(item.value))
|
||||||
|
if item.keep:
|
||||||
|
keep_str = item.value.replace(" ", "_") + " "
|
||||||
|
elif isinstance(item, str):
|
||||||
|
encoded = item.replace(" ", "_")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type {type(item)} for embedding")
|
||||||
|
|
||||||
|
if namespace is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
return {namespace: keep_str + encoded}
|
||||||
|
|
||||||
|
|
||||||
|
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||||
|
"""Embed a dictionary item."""
|
||||||
|
inner_dict: Dict = {}
|
||||||
|
for ns, embed_item in item.items():
|
||||||
|
if isinstance(embed_item, list):
|
||||||
|
inner_dict[ns] = []
|
||||||
|
for embed_list_item in embed_item:
|
||||||
|
embedded = embed_string_type(embed_list_item, model, ns)
|
||||||
|
inner_dict[ns].append(embedded[ns])
|
||||||
|
else:
|
||||||
|
inner_dict.update(embed_string_type(embed_item, model, ns))
|
||||||
|
return inner_dict
|
||||||
|
|
||||||
|
|
||||||
|
def embed_list_type(
|
||||||
|
item: list, model: Any, namespace: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
|
"""Embed a list item."""
|
||||||
|
|
||||||
|
ret_list: List = []
|
||||||
|
for embed_item in item:
|
||||||
|
if isinstance(embed_item, dict):
|
||||||
|
ret_list.append(embed_dict_type(embed_item, model))
|
||||||
|
elif isinstance(embed_item, list):
|
||||||
|
item_embedding = embed_list_type(embed_item, model, namespace)
|
||||||
|
# Get the first key from the first dictionary
|
||||||
|
first_key = next(iter(item_embedding[0]))
|
||||||
|
# Group the values under that key
|
||||||
|
grouping = {first_key: [item[first_key] for item in item_embedding]}
|
||||||
|
ret_list.append(grouping)
|
||||||
|
else:
|
||||||
|
ret_list.append(embed_string_type(embed_item, model, namespace))
|
||||||
|
return ret_list
|
||||||
|
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||||
|
model: Any,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
|
"""
|
||||||
|
Embed the actions or context using the SentenceTransformer model
|
||||||
|
(or a model that has an `encode` function).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
|
||||||
|
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
|
||||||
|
model: (Any, required) The model to use for embedding
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
|
||||||
|
""" # noqa: E501
|
||||||
|
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
|
||||||
|
to_embed, str
|
||||||
|
):
|
||||||
|
return [embed_string_type(to_embed, model, namespace)]
|
||||||
|
elif isinstance(to_embed, dict):
|
||||||
|
return [embed_dict_type(to_embed, model)]
|
||||||
|
elif isinstance(to_embed, list):
|
||||||
|
return embed_list_type(to_embed, model, namespace)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input format for embedding")
|
@ -9,6 +9,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as base
|
import langchain_experimental.rl_chain.base as base
|
||||||
|
from langchain_experimental.rl_chain.helpers import embed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -90,14 +91,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple:
|
def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple:
|
||||||
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
|
context_emb = embed(event.based_on, self.model) if event.based_on else None
|
||||||
to_select_from_var_name, to_select_from = next(
|
to_select_from_var_name, to_select_from = next(
|
||||||
iter(event.to_select_from.items()), (None, None)
|
iter(event.to_select_from.items()), (None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
action_embs = (
|
action_embs = (
|
||||||
(
|
(
|
||||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
embed(to_select_from, self.model, to_select_from_var_name)
|
||||||
if event.to_select_from
|
if event.to_select_from
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@ from langchain_core.prompts.prompt import PromptTemplate
|
|||||||
from test_utils import MockEncoder, MockEncoderReturnsList
|
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as rl_chain
|
import langchain_experimental.rl_chain.base as rl_chain
|
||||||
|
import langchain_experimental.rl_chain.helpers
|
||||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||||
|
|
||||||
encoded_keyword = "[encoded]"
|
encoded_keyword = "[encoded]"
|
||||||
@ -197,13 +198,21 @@ def test_everything_embedded() -> None:
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
|
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_1)
|
||||||
|
)
|
||||||
|
|
||||||
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
|
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa
|
||||||
|
|
||||||
@ -314,10 +323,14 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
|||||||
|
|
||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
[1.0, 2.0]
|
||||||
|
)
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
[1.0, 2.0]
|
||||||
|
)
|
||||||
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
|
||||||
|
|
||||||
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
|
||||||
|
@ -2,6 +2,7 @@ import pytest
|
|||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as rl_chain
|
import langchain_experimental.rl_chain.base as rl_chain
|
||||||
|
import langchain_experimental.rl_chain.helpers
|
||||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||||
|
|
||||||
encoded_keyword = "[encoded]"
|
encoded_keyword = "[encoded]"
|
||||||
@ -92,12 +93,20 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None:
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_1)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.Embed(ctx_str_1)}
|
context = {"context": rl_chain.Embed(ctx_str_1)}
|
||||||
@ -118,12 +127,20 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_1)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
|
||||||
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
|
||||||
@ -192,14 +209,24 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
list(encoded_keyword + ctx_str_1)
|
||||||
|
)
|
||||||
|
encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_2)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
|
||||||
context = {
|
context = {
|
||||||
@ -227,14 +254,24 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
|
encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
list(encoded_keyword + ctx_str_1)
|
||||||
|
)
|
||||||
|
encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_2)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {
|
named_actions = {
|
||||||
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
|
||||||
@ -262,12 +299,18 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_2)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {
|
named_actions = {
|
||||||
"action1": [
|
"action1": [
|
||||||
@ -296,12 +339,18 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
|
|||||||
str1 = "0"
|
str1 = "0"
|
||||||
str2 = "1"
|
str2 = "1"
|
||||||
str3 = "2"
|
str3 = "2"
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str_1 = "context1"
|
ctx_str_1 = "context1"
|
||||||
ctx_str_2 = "context2"
|
ctx_str_2 = "context2"
|
||||||
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
|
encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str_2)
|
||||||
|
)
|
||||||
|
|
||||||
named_actions = {
|
named_actions = {
|
||||||
"action1": [
|
"action1": [
|
||||||
@ -331,11 +380,15 @@ def test_raw_features_underscored() -> None:
|
|||||||
)
|
)
|
||||||
str1 = "this is a long string"
|
str1 = "this is a long string"
|
||||||
str1_underscored = str1.replace(" ", "_")
|
str1_underscored = str1.replace(" ", "_")
|
||||||
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
|
|
||||||
ctx_str = "this is a long context"
|
ctx_str = "this is a long context"
|
||||||
ctx_str_underscored = ctx_str.replace(" ", "_")
|
ctx_str_underscored = ctx_str.replace(" ", "_")
|
||||||
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))
|
encoded_ctx_str = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + ctx_str)
|
||||||
|
)
|
||||||
|
|
||||||
# No embeddings
|
# No embeddings
|
||||||
named_actions = {"action": [str1]}
|
named_actions = {"action": [str1]}
|
||||||
|
@ -4,6 +4,7 @@ import pytest
|
|||||||
from test_utils import MockEncoder
|
from test_utils import MockEncoder
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as base
|
import langchain_experimental.rl_chain.base as base
|
||||||
|
import langchain_experimental.rl_chain.helpers
|
||||||
|
|
||||||
encoded_keyword = "[encoded]"
|
encoded_keyword = "[encoded]"
|
||||||
|
|
||||||
@ -11,18 +12,32 @@ encoded_keyword = "[encoded]"
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_no_emb() -> None:
|
def test_simple_context_str_no_emb() -> None:
|
||||||
expected = [{"a_namespace": "test"}]
|
expected = [{"a_namespace": "test"}]
|
||||||
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
"test", MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_simple_context_str_w_emb() -> None:
|
def test_simple_context_str_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
expected = [{"a_namespace": encoded_str1}]
|
expected = [{"a_namespace": encoded_str1}]
|
||||||
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.Embed(str1), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.EmbedAndKeep(str1), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
== expected_embed_and_keep
|
== expected_embed_and_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,16 +46,22 @@ def test_simple_context_str_w_emb() -> None:
|
|||||||
def test_simple_context_str_w_nested_emb() -> None:
|
def test_simple_context_str_w_nested_emb() -> None:
|
||||||
# nested embeddings, innermost wins
|
# nested embeddings, innermost wins
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
expected = [{"a_namespace": encoded_str1}]
|
expected = [{"a_namespace": encoded_str1}]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
== expected
|
== expected
|
||||||
)
|
)
|
||||||
|
|
||||||
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
== expected2
|
== expected2
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,18 +69,32 @@ def test_simple_context_str_w_nested_emb() -> None:
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_no_emb() -> None:
|
def test_context_w_namespace_no_emb() -> None:
|
||||||
expected = [{"test_namespace": "test"}]
|
expected = [{"test_namespace": "test"}]
|
||||||
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": "test"}, MockEncoder()
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb() -> None:
|
def test_context_w_namespace_w_emb() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
expected = [{"test_namespace": encoded_str1}]
|
expected = [{"test_namespace": encoded_str1}]
|
||||||
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": base.Embed(str1)}, MockEncoder()
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||||
assert (
|
assert (
|
||||||
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder()
|
||||||
|
)
|
||||||
== expected_embed_and_keep
|
== expected_embed_and_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,12 +102,21 @@ def test_context_w_namespace_w_emb() -> None:
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_context_w_namespace_w_emb2() -> None:
|
def test_context_w_namespace_w_emb2() -> None:
|
||||||
str1 = "test"
|
str1 = "test"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str1)
|
||||||
|
)
|
||||||
expected = [{"test_namespace": encoded_str1}]
|
expected = [{"test_namespace": encoded_str1}]
|
||||||
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.Embed({"test_namespace": str1}), MockEncoder()
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.EmbedAndKeep({"test_namespace": str1}), MockEncoder()
|
||||||
|
)
|
||||||
== expected_embed_and_keep
|
== expected_embed_and_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -81,10 +125,12 @@ def test_context_w_namespace_w_emb2() -> None:
|
|||||||
def test_context_w_namespace_w_some_emb() -> None:
|
def test_context_w_namespace_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
||||||
)
|
)
|
||||||
== expected
|
== expected
|
||||||
@ -96,7 +142,7 @@ def test_context_w_namespace_w_some_emb() -> None:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
||||||
MockEncoder(),
|
MockEncoder(),
|
||||||
)
|
)
|
||||||
@ -110,8 +156,17 @@ def test_simple_action_strlist_no_emb() -> None:
|
|||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
||||||
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
to_embed: List[Union[str, langchain_experimental.rl_chain.helpers._Embed]] = [
|
||||||
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
str1,
|
||||||
|
str2,
|
||||||
|
str3,
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
to_embed, MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
@ -119,16 +174,24 @@ def test_simple_action_strlist_w_emb() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"a_namespace": encoded_str1},
|
{"a_namespace": encoded_str1},
|
||||||
{"a_namespace": encoded_str2},
|
{"a_namespace": encoded_str2},
|
||||||
{"a_namespace": encoded_str3},
|
{"a_namespace": encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
== expected
|
== expected
|
||||||
)
|
)
|
||||||
expected_embed_and_keep = [
|
expected_embed_and_keep = [
|
||||||
@ -137,7 +200,9 @@ def test_simple_action_strlist_w_emb() -> None:
|
|||||||
{"a_namespace": str3 + " " + encoded_str3},
|
{"a_namespace": str3 + " " + encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace"
|
||||||
|
)
|
||||||
== expected_embed_and_keep
|
== expected_embed_and_keep
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,15 +212,19 @@ def test_simple_action_strlist_w_some_emb() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"a_namespace": str1},
|
{"a_namespace": str1},
|
||||||
{"a_namespace": encoded_str2},
|
{"a_namespace": encoded_str2},
|
||||||
{"a_namespace": encoded_str3},
|
{"a_namespace": encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
||||||
)
|
)
|
||||||
== expected
|
== expected
|
||||||
@ -166,7 +235,7 @@ def test_simple_action_strlist_w_some_emb() -> None:
|
|||||||
{"a_namespace": str3 + " " + encoded_str3},
|
{"a_namespace": str3 + " " + encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
||||||
MockEncoder(),
|
MockEncoder(),
|
||||||
"a_namespace",
|
"a_namespace",
|
||||||
@ -186,7 +255,7 @@ def test_action_w_namespace_no_emb() -> None:
|
|||||||
{"test_namespace": str3},
|
{"test_namespace": str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": str1},
|
{"test_namespace": str1},
|
||||||
{"test_namespace": str2},
|
{"test_namespace": str2},
|
||||||
@ -203,16 +272,22 @@ def test_action_w_namespace_w_emb() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"test_namespace": encoded_str1},
|
{"test_namespace": encoded_str1},
|
||||||
{"test_namespace": encoded_str2},
|
{"test_namespace": encoded_str2},
|
||||||
{"test_namespace": encoded_str3},
|
{"test_namespace": encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": base.Embed(str1)},
|
{"test_namespace": base.Embed(str1)},
|
||||||
{"test_namespace": base.Embed(str2)},
|
{"test_namespace": base.Embed(str2)},
|
||||||
@ -228,7 +303,7 @@ def test_action_w_namespace_w_emb() -> None:
|
|||||||
{"test_namespace": str3 + " " + encoded_str3},
|
{"test_namespace": str3 + " " + encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": base.EmbedAndKeep(str1)},
|
{"test_namespace": base.EmbedAndKeep(str1)},
|
||||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||||
@ -245,16 +320,22 @@ def test_action_w_namespace_w_emb2() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"test_namespace1": encoded_str1},
|
{"test_namespace1": encoded_str1},
|
||||||
{"test_namespace2": encoded_str2},
|
{"test_namespace2": encoded_str2},
|
||||||
{"test_namespace3": encoded_str3},
|
{"test_namespace3": encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
base.Embed(
|
base.Embed(
|
||||||
[
|
[
|
||||||
{"test_namespace1": str1},
|
{"test_namespace1": str1},
|
||||||
@ -272,7 +353,7 @@ def test_action_w_namespace_w_emb2() -> None:
|
|||||||
{"test_namespace3": str3 + " " + encoded_str3},
|
{"test_namespace3": str3 + " " + encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
base.EmbedAndKeep(
|
base.EmbedAndKeep(
|
||||||
[
|
[
|
||||||
{"test_namespace1": str1},
|
{"test_namespace1": str1},
|
||||||
@ -291,15 +372,19 @@ def test_action_w_namespace_w_some_emb() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"test_namespace": str1},
|
{"test_namespace": str1},
|
||||||
{"test_namespace": encoded_str2},
|
{"test_namespace": encoded_str2},
|
||||||
{"test_namespace": encoded_str3},
|
{"test_namespace": encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": str1},
|
{"test_namespace": str1},
|
||||||
{"test_namespace": base.Embed(str2)},
|
{"test_namespace": base.Embed(str2)},
|
||||||
@ -315,7 +400,7 @@ def test_action_w_namespace_w_some_emb() -> None:
|
|||||||
{"test_namespace": str3 + " " + encoded_str3},
|
{"test_namespace": str3 + " " + encoded_str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": str1},
|
{"test_namespace": str1},
|
||||||
{"test_namespace": base.EmbedAndKeep(str2)},
|
{"test_namespace": base.EmbedAndKeep(str2)},
|
||||||
@ -332,16 +417,22 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
str3 = "test3"
|
str3 = "test3"
|
||||||
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
list(encoded_keyword + str1)
|
||||||
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
)
|
||||||
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
|
encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str3)
|
||||||
|
)
|
||||||
expected = [
|
expected = [
|
||||||
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
||||||
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
||||||
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
||||||
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
||||||
@ -366,7 +457,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
assert (
|
assert (
|
||||||
base.embed(
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
[
|
[
|
||||||
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
||||||
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
||||||
@ -383,17 +474,26 @@ def test_one_namespace_w_list_of_features_no_emb() -> None:
|
|||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
expected = [{"test_namespace": [str1, str2]}]
|
expected = [{"test_namespace": [str1, str2]}]
|
||||||
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
assert (
|
||||||
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": [str1, str2]}, MockEncoder()
|
||||||
|
)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
||||||
str1 = "test1"
|
str1 = "test1"
|
||||||
str2 = "test2"
|
str2 = "test2"
|
||||||
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding(
|
||||||
|
list(encoded_keyword + str2)
|
||||||
|
)
|
||||||
expected = [{"test_namespace": [str1, encoded_str2]}]
|
expected = [{"test_namespace": [str1, encoded_str2]}]
|
||||||
assert (
|
assert (
|
||||||
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": [str1, base.Embed(str2)]}, MockEncoder()
|
||||||
|
)
|
||||||
== expected
|
== expected
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -401,22 +501,30 @@ def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
|||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_list_features_throws() -> None:
|
def test_nested_list_features_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_dict_in_list_throws() -> None:
|
def test_dict_in_list_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_nested_dict_throws() -> None:
|
def test_nested_dict_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": {"a": {"b": 1}}}, MockEncoder()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next")
|
@pytest.mark.requires("vowpal_wabbit_next")
|
||||||
def test_list_of_tuples_throws() -> None:
|
def test_list_of_tuples_throws() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|
langchain_experimental.rl_chain.helpers.embed(
|
||||||
|
{"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user