From 414154fa59d73e62241140a43996cadc01e1ff7a Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 14 Aug 2024 10:09:43 -0700 Subject: [PATCH] experimental[patch]: refactor rl chain structure (#25398) can't have a class and function with same name but different capitalization in same file for api reference building --- .../rl_chain/__init__.py | 3 +- .../langchain_experimental/rl_chain/base.py | 112 +-------- .../rl_chain/helpers.py | 114 +++++++++ .../rl_chain/pick_best_chain.py | 5 +- .../rl_chain/test_pick_best_chain_call.py | 25 +- .../rl_chain/test_pick_best_text_embedder.py | 105 ++++++--- .../rl_chain/test_rl_chain_base_embedder.py | 218 +++++++++++++----- 7 files changed, 380 insertions(+), 202 deletions(-) create mode 100644 libs/experimental/langchain_experimental/rl_chain/helpers.py diff --git a/libs/experimental/langchain_experimental/rl_chain/__init__.py b/libs/experimental/langchain_experimental/rl_chain/__init__.py index bc595101f77..eac581f9676 100644 --- a/libs/experimental/langchain_experimental/rl_chain/__init__.py +++ b/libs/experimental/langchain_experimental/rl_chain/__init__.py @@ -19,9 +19,8 @@ from langchain_experimental.rl_chain.base import ( SelectionScorer, ToSelectFrom, VwPolicy, - embed, - stringify_embedding, ) +from langchain_experimental.rl_chain.helpers import embed, stringify_embedding from langchain_experimental.rl_chain.pick_best_chain import ( PickBest, PickBestEvent, diff --git a/libs/experimental/langchain_experimental/rl_chain/base.py b/libs/experimental/langchain_experimental/rl_chain/base.py index 7b15f00fc1f..ca11686da7e 100644 --- a/libs/experimental/langchain_experimental/rl_chain/base.py +++ b/libs/experimental/langchain_experimental/rl_chain/base.py @@ -27,6 +27,7 @@ from langchain_core.prompts import ( ) from langchain_experimental.pydantic_v1 import BaseModel, root_validator +from langchain_experimental.rl_chain.helpers import _Embed from langchain_experimental.rl_chain.metrics import ( MetricsTrackerAverage, MetricsTrackerRollingWindow, @@ -74,17 +75,6 @@ def ToSelectFrom(anything: Any) -> _ToSelectFrom: return _ToSelectFrom(anything) -class _Embed: - def __init__(self, value: Any, keep: bool = False): - self.value = value - self.keep = keep - - def __str__(self) -> str: - return str(self.value) - - __repr__ = __str__ - - def Embed(anything: Any, keep: bool = False) -> Any: """Wrap a value to indicate that it should be embedded.""" @@ -110,12 +100,6 @@ def EmbedAndKeep(anything: Any) -> Any: # helper functions -def stringify_embedding(embedding: List) -> str: - """Convert an embedding to a string.""" - - return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) - - def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]: """Parse the input string into a list of examples.""" @@ -559,97 +543,3 @@ class RLChain(Chain, Generic[TEvent]): @property def _chain_type(self) -> str: return "llm_personalizer_chain" - - -def is_stringtype_instance(item: Any) -> bool: - """Check if an item is a string.""" - - return isinstance(item, str) or ( - isinstance(item, _Embed) and isinstance(item.value, str) - ) - - -def embed_string_type( - item: Union[str, _Embed], model: Any, namespace: Optional[str] = None -) -> Dict[str, Union[str, List[str]]]: - """Embed a string or an _Embed object.""" - - keep_str = "" - if isinstance(item, _Embed): - encoded = stringify_embedding(model.encode(item.value)) - if item.keep: - keep_str = item.value.replace(" ", "_") + " " - elif isinstance(item, str): - encoded = item.replace(" ", "_") - else: - raise ValueError(f"Unsupported type {type(item)} for embedding") - - if namespace is None: - raise ValueError( - "The default namespace must be provided when embedding a string or _Embed object." # noqa: E501 - ) - - return {namespace: keep_str + encoded} - - -def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: - """Embed a dictionary item.""" - inner_dict: Dict = {} - for ns, embed_item in item.items(): - if isinstance(embed_item, list): - inner_dict[ns] = [] - for embed_list_item in embed_item: - embedded = embed_string_type(embed_list_item, model, ns) - inner_dict[ns].append(embedded[ns]) - else: - inner_dict.update(embed_string_type(embed_item, model, ns)) - return inner_dict - - -def embed_list_type( - item: list, model: Any, namespace: Optional[str] = None -) -> List[Dict[str, Union[str, List[str]]]]: - """Embed a list item.""" - - ret_list: List = [] - for embed_item in item: - if isinstance(embed_item, dict): - ret_list.append(embed_dict_type(embed_item, model)) - elif isinstance(embed_item, list): - item_embedding = embed_list_type(embed_item, model, namespace) - # Get the first key from the first dictionary - first_key = next(iter(item_embedding[0])) - # Group the values under that key - grouping = {first_key: [item[first_key] for item in item_embedding]} - ret_list.append(grouping) - else: - ret_list.append(embed_string_type(embed_item, model, namespace)) - return ret_list - - -def embed( - to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], - model: Any, - namespace: Optional[str] = None, -) -> List[Dict[str, Union[str, List[str]]]]: - """ - Embed the actions or context using the SentenceTransformer model - (or a model that has an `encode` function). - - Attributes: - to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries. - namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided. - model: (Any, required) The model to use for embedding - Returns: - List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value - """ # noqa: E501 - if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance( - to_embed, str - ): - return [embed_string_type(to_embed, model, namespace)] - elif isinstance(to_embed, dict): - return [embed_dict_type(to_embed, model)] - elif isinstance(to_embed, list): - return embed_list_type(to_embed, model, namespace) - else: - raise ValueError("Invalid input format for embedding") diff --git a/libs/experimental/langchain_experimental/rl_chain/helpers.py b/libs/experimental/langchain_experimental/rl_chain/helpers.py new file mode 100644 index 00000000000..e4e221089e6 --- /dev/null +++ b/libs/experimental/langchain_experimental/rl_chain/helpers.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + + +class _Embed: + def __init__(self, value: Any, keep: bool = False): + self.value = value + self.keep = keep + + def __str__(self) -> str: + return str(self.value) + + __repr__ = __str__ + + +def stringify_embedding(embedding: List) -> str: + """Convert an embedding to a string.""" + + return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) + + +def is_stringtype_instance(item: Any) -> bool: + """Check if an item is a string.""" + + return isinstance(item, str) or ( + isinstance(item, _Embed) and isinstance(item.value, str) + ) + + +def embed_string_type( + item: Union[str, _Embed], model: Any, namespace: Optional[str] = None +) -> Dict[str, Union[str, List[str]]]: + """Embed a string or an _Embed object.""" + + keep_str = "" + if isinstance(item, _Embed): + encoded = stringify_embedding(model.encode(item.value)) + if item.keep: + keep_str = item.value.replace(" ", "_") + " " + elif isinstance(item, str): + encoded = item.replace(" ", "_") + else: + raise ValueError(f"Unsupported type {type(item)} for embedding") + + if namespace is None: + raise ValueError( + "The default namespace must be provided when embedding a string or _Embed object." # noqa: E501 + ) + + return {namespace: keep_str + encoded} + + +def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: + """Embed a dictionary item.""" + inner_dict: Dict = {} + for ns, embed_item in item.items(): + if isinstance(embed_item, list): + inner_dict[ns] = [] + for embed_list_item in embed_item: + embedded = embed_string_type(embed_list_item, model, ns) + inner_dict[ns].append(embedded[ns]) + else: + inner_dict.update(embed_string_type(embed_item, model, ns)) + return inner_dict + + +def embed_list_type( + item: list, model: Any, namespace: Optional[str] = None +) -> List[Dict[str, Union[str, List[str]]]]: + """Embed a list item.""" + + ret_list: List = [] + for embed_item in item: + if isinstance(embed_item, dict): + ret_list.append(embed_dict_type(embed_item, model)) + elif isinstance(embed_item, list): + item_embedding = embed_list_type(embed_item, model, namespace) + # Get the first key from the first dictionary + first_key = next(iter(item_embedding[0])) + # Group the values under that key + grouping = {first_key: [item[first_key] for item in item_embedding]} + ret_list.append(grouping) + else: + ret_list.append(embed_string_type(embed_item, model, namespace)) + return ret_list + + +def embed( + to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], + model: Any, + namespace: Optional[str] = None, +) -> List[Dict[str, Union[str, List[str]]]]: + """ + Embed the actions or context using the SentenceTransformer model + (or a model that has an `encode` function). + + Attributes: + to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries. + namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided. + model: (Any, required) The model to use for embedding + Returns: + List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value + """ # noqa: E501 + if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance( + to_embed, str + ): + return [embed_string_type(to_embed, model, namespace)] + elif isinstance(to_embed, dict): + return [embed_dict_type(to_embed, model)] + elif isinstance(to_embed, list): + return embed_list_type(to_embed, model, namespace) + else: + raise ValueError("Invalid input format for embedding") diff --git a/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py b/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py index 7df96d0a4ec..73ccdadcb4b 100644 --- a/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py +++ b/libs/experimental/langchain_experimental/rl_chain/pick_best_chain.py @@ -9,6 +9,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.prompts import BasePromptTemplate import langchain_experimental.rl_chain.base as base +from langchain_experimental.rl_chain.helpers import embed logger = logging.getLogger(__name__) @@ -90,14 +91,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): return None, None, None def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple: - context_emb = base.embed(event.based_on, self.model) if event.based_on else None + context_emb = embed(event.based_on, self.model) if event.based_on else None to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) ) action_embs = ( ( - base.embed(to_select_from, self.model, to_select_from_var_name) + embed(to_select_from, self.model, to_select_from_var_name) if event.to_select_from else None ) diff --git a/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_chain_call.py b/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_chain_call.py index d3373b6039c..97aed9d2284 100644 --- a/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_chain_call.py +++ b/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_chain_call.py @@ -6,6 +6,7 @@ from langchain_core.prompts.prompt import PromptTemplate from test_utils import MockEncoder, MockEncoderReturnsList import langchain_experimental.rl_chain.base as rl_chain +import langchain_experimental.rl_chain.helpers import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain encoded_keyword = "[encoded]" @@ -197,13 +198,21 @@ def test_everything_embedded() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" - encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_1) + ) expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa @@ -314,10 +323,14 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: str1 = "0" str2 = "1" - encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0]) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + [1.0, 2.0] + ) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0]) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + [1.0, 2.0] + ) dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0] expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa diff --git a/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_text_embedder.py b/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_text_embedder.py index fafa77e9f49..e080f83dfc8 100644 --- a/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_text_embedder.py +++ b/libs/experimental/tests/unit_tests/rl_chain/test_pick_best_text_embedder.py @@ -2,6 +2,7 @@ import pytest from test_utils import MockEncoder import langchain_experimental.rl_chain.base as rl_chain +import langchain_experimental.rl_chain.helpers import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain encoded_keyword = "[encoded]" @@ -92,12 +93,20 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" - encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_1) + ) named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} context = {"context": rl_chain.Embed(ctx_str_1)} @@ -118,12 +127,20 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" - encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_1) + ) named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} @@ -192,14 +209,24 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) - encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_1) + ) + encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_2) + ) named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} context = { @@ -227,14 +254,24 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) - encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + encoded_ctx_str_1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_1) + ) + encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_2) + ) named_actions = { "action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) @@ -262,12 +299,18 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_2) + ) named_actions = { "action1": [ @@ -296,12 +339,18 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) - encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) + encoded_ctx_str_2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str_2) + ) named_actions = { "action1": [ @@ -331,11 +380,15 @@ def test_raw_features_underscored() -> None: ) str1 = "this is a long string" str1_underscored = str1.replace(" ", "_") - encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) ctx_str = "this is a long context" ctx_str_underscored = ctx_str.replace(" ", "_") - encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str)) + encoded_ctx_str = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + ctx_str) + ) # No embeddings named_actions = {"action": [str1]} diff --git a/libs/experimental/tests/unit_tests/rl_chain/test_rl_chain_base_embedder.py b/libs/experimental/tests/unit_tests/rl_chain/test_rl_chain_base_embedder.py index 7e8b23857f2..8e8465b0b88 100644 --- a/libs/experimental/tests/unit_tests/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/experimental/tests/unit_tests/rl_chain/test_rl_chain_base_embedder.py @@ -4,6 +4,7 @@ import pytest from test_utils import MockEncoder import langchain_experimental.rl_chain.base as base +import langchain_experimental.rl_chain.helpers encoded_keyword = "[encoded]" @@ -11,18 +12,32 @@ encoded_keyword = "[encoded]" @pytest.mark.requires("vowpal_wabbit_next") def test_simple_context_str_no_emb() -> None: expected = [{"a_namespace": "test"}] - assert base.embed("test", MockEncoder(), "a_namespace") == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + "test", MockEncoder(), "a_namespace" + ) + == expected + ) @pytest.mark.requires("vowpal_wabbit_next") def test_simple_context_str_w_emb() -> None: str1 = "test" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) expected = [{"a_namespace": encoded_str1}] - assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + base.Embed(str1), MockEncoder(), "a_namespace" + ) + == expected + ) expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}] assert ( - base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace") + langchain_experimental.rl_chain.helpers.embed( + base.EmbedAndKeep(str1), MockEncoder(), "a_namespace" + ) == expected_embed_and_keep ) @@ -31,16 +46,22 @@ def test_simple_context_str_w_emb() -> None: def test_simple_context_str_w_nested_emb() -> None: # nested embeddings, innermost wins str1 = "test" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) expected = [{"a_namespace": encoded_str1}] assert ( - base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace") + langchain_experimental.rl_chain.helpers.embed( + base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace" + ) == expected ) expected2 = [{"a_namespace": str1 + " " + encoded_str1}] assert ( - base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace") + langchain_experimental.rl_chain.helpers.embed( + base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace" + ) == expected2 ) @@ -48,18 +69,32 @@ def test_simple_context_str_w_nested_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_context_w_namespace_no_emb() -> None: expected = [{"test_namespace": "test"}] - assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": "test"}, MockEncoder() + ) + == expected + ) @pytest.mark.requires("vowpal_wabbit_next") def test_context_w_namespace_w_emb() -> None: str1 = "test" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) expected = [{"test_namespace": encoded_str1}] - assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": base.Embed(str1)}, MockEncoder() + ) + == expected + ) expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] assert ( - base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder() + ) == expected_embed_and_keep ) @@ -67,12 +102,21 @@ def test_context_w_namespace_w_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_context_w_namespace_w_emb2() -> None: str1 = "test" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) expected = [{"test_namespace": encoded_str1}] - assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + base.Embed({"test_namespace": str1}), MockEncoder() + ) + == expected + ) expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] assert ( - base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + base.EmbedAndKeep({"test_namespace": str1}), MockEncoder() + ) == expected_embed_and_keep ) @@ -81,10 +125,12 @@ def test_context_w_namespace_w_emb2() -> None: def test_context_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( {"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder() ) == expected @@ -96,7 +142,7 @@ def test_context_w_namespace_w_some_emb() -> None: } ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( {"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)}, MockEncoder(), ) @@ -110,8 +156,17 @@ def test_simple_action_strlist_no_emb() -> None: str2 = "test2" str3 = "test3" expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] - to_embed: List[Union[str, base._Embed]] = [str1, str2, str3] - assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected + to_embed: List[Union[str, langchain_experimental.rl_chain.helpers._Embed]] = [ + str1, + str2, + str3, + ] + assert ( + langchain_experimental.rl_chain.helpers.embed( + to_embed, MockEncoder(), "a_namespace" + ) + == expected + ) @pytest.mark.requires("vowpal_wabbit_next") @@ -119,16 +174,24 @@ def test_simple_action_strlist_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"a_namespace": encoded_str1}, {"a_namespace": encoded_str2}, {"a_namespace": encoded_str3}, ] assert ( - base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace") + langchain_experimental.rl_chain.helpers.embed( + base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace" + ) == expected ) expected_embed_and_keep = [ @@ -137,7 +200,9 @@ def test_simple_action_strlist_w_emb() -> None: {"a_namespace": str3 + " " + encoded_str3}, ] assert ( - base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace") + langchain_experimental.rl_chain.helpers.embed( + base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace" + ) == expected_embed_and_keep ) @@ -147,15 +212,19 @@ def test_simple_action_strlist_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"a_namespace": str1}, {"a_namespace": encoded_str2}, {"a_namespace": encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace" ) == expected @@ -166,7 +235,7 @@ def test_simple_action_strlist_w_some_emb() -> None: {"a_namespace": str3 + " " + encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)], MockEncoder(), "a_namespace", @@ -186,7 +255,7 @@ def test_action_w_namespace_no_emb() -> None: {"test_namespace": str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": str1}, {"test_namespace": str2}, @@ -203,16 +272,22 @@ def test_action_w_namespace_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"test_namespace": encoded_str1}, {"test_namespace": encoded_str2}, {"test_namespace": encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": base.Embed(str1)}, {"test_namespace": base.Embed(str2)}, @@ -228,7 +303,7 @@ def test_action_w_namespace_w_emb() -> None: {"test_namespace": str3 + " " + encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": base.EmbedAndKeep(str1)}, {"test_namespace": base.EmbedAndKeep(str2)}, @@ -245,16 +320,22 @@ def test_action_w_namespace_w_emb2() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"test_namespace1": encoded_str1}, {"test_namespace2": encoded_str2}, {"test_namespace3": encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( base.Embed( [ {"test_namespace1": str1}, @@ -272,7 +353,7 @@ def test_action_w_namespace_w_emb2() -> None: {"test_namespace3": str3 + " " + encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( base.EmbedAndKeep( [ {"test_namespace1": str1}, @@ -291,15 +372,19 @@ def test_action_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"test_namespace": str1}, {"test_namespace": encoded_str2}, {"test_namespace": encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": str1}, {"test_namespace": base.Embed(str2)}, @@ -315,7 +400,7 @@ def test_action_w_namespace_w_some_emb() -> None: {"test_namespace": str3 + " " + encoded_str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": str1}, {"test_namespace": base.EmbedAndKeep(str2)}, @@ -332,16 +417,22 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) - encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) + encoded_str1 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str1) + ) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) + encoded_str3 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str3) + ) expected = [ {"test_namespace": encoded_str1, "test_namespace2": str1}, {"test_namespace": encoded_str2, "test_namespace2": str2}, {"test_namespace": encoded_str3, "test_namespace2": str3}, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": base.Embed(str1), "test_namespace2": str1}, {"test_namespace": base.Embed(str2), "test_namespace2": str2}, @@ -366,7 +457,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: }, ] assert ( - base.embed( + langchain_experimental.rl_chain.helpers.embed( [ {"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1}, {"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2}, @@ -383,17 +474,26 @@ def test_one_namespace_w_list_of_features_no_emb() -> None: str1 = "test1" str2 = "test2" expected = [{"test_namespace": [str1, str2]}] - assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected + assert ( + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": [str1, str2]}, MockEncoder() + ) + == expected + ) @pytest.mark.requires("vowpal_wabbit_next") def test_one_namespace_w_list_of_features_w_some_emb() -> None: str1 = "test1" str2 = "test2" - encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str2 = langchain_experimental.rl_chain.helpers.stringify_embedding( + list(encoded_keyword + str2) + ) expected = [{"test_namespace": [str1, encoded_str2]}] assert ( - base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": [str1, base.Embed(str2)]}, MockEncoder() + ) == expected ) @@ -401,22 +501,30 @@ def test_one_namespace_w_list_of_features_w_some_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_nested_list_features_throws() -> None: with pytest.raises(ValueError): - base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": [[1, 2], [3, 4]]}, MockEncoder() + ) @pytest.mark.requires("vowpal_wabbit_next") def test_dict_in_list_throws() -> None: with pytest.raises(ValueError): - base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder() + ) @pytest.mark.requires("vowpal_wabbit_next") def test_nested_dict_throws() -> None: with pytest.raises(ValueError): - base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": {"a": {"b": 1}}}, MockEncoder() + ) @pytest.mark.requires("vowpal_wabbit_next") def test_list_of_tuples_throws() -> None: with pytest.raises(ValueError): - base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) + langchain_experimental.rl_chain.helpers.embed( + {"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder() + )