experimental: docstrings update (#18048)

Added missed docstrings. Formatted docsctrings to the consistent format.
This commit is contained in:
Leonid Ganeline
2024-02-23 18:24:16 -08:00
committed by GitHub
parent 56b955fc31
commit 3f6bf852ea
61 changed files with 316 additions and 102 deletions

View File

@@ -51,6 +51,8 @@ class _BasedOn:
def BasedOn(anything: Any) -> _BasedOn:
"""Wrap a value to indicate that it should be based on."""
return _BasedOn(anything)
@@ -65,6 +67,8 @@ class _ToSelectFrom:
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
"""Wrap a value to indicate that it should be selected from."""
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
@@ -82,6 +86,8 @@ class _Embed:
def Embed(anything: Any, keep: bool = False) -> Any:
"""Wrap a value to indicate that it should be embedded."""
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
@@ -96,6 +102,8 @@ def Embed(anything: Any, keep: bool = False) -> Any:
def EmbedAndKeep(anything: Any) -> Any:
"""Wrap a value to indicate that it should be embedded and kept."""
return Embed(anything, keep=True)
@@ -103,14 +111,19 @@ def EmbedAndKeep(anything: Any) -> Any:
def stringify_embedding(embedding: List) -> str:
"""Convert an embedding to a string."""
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
"""Parse the input string into a list of examples."""
return [parser.parse_line(line) for line in input_str.split("\n")]
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
"""Get the BasedOn and ToSelectFrom from the inputs."""
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
@@ -132,8 +145,9 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
"""Prepare the inputs for auto embedding.
Go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
""" # noqa: E501
@@ -149,6 +163,8 @@ def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
class Selected(ABC):
"""Abstract class to represent the selected item."""
pass
@@ -156,6 +172,8 @@ TSelected = TypeVar("TSelected", bound=Selected)
class Event(Generic[TSelected], ABC):
"""Abstract class to represent an event."""
inputs: Dict[str, Any]
selected: Optional[TSelected]
@@ -168,6 +186,8 @@ TEvent = TypeVar("TEvent", bound=Event)
class Policy(Generic[TEvent], ABC):
"""Abstract class to represent a policy."""
def __init__(self, **kwargs: Any):
pass
@@ -188,6 +208,8 @@ class Policy(Generic[TEvent], ABC):
class VwPolicy(Policy):
"""Vowpal Wabbit policy."""
def __init__(
self,
model_repo: ModelRepository,
@@ -229,6 +251,8 @@ class VwPolicy(Policy):
class Embedder(Generic[TEvent], ABC):
"""Abstract class to represent an embedder."""
def __init__(self, *args: Any, **kwargs: Any):
pass
@@ -238,7 +262,7 @@ class Embedder(Generic[TEvent], ABC):
class SelectionScorer(Generic[TEvent], ABC, BaseModel):
"""Abstract method to grade the chosen selection or the response of the llm"""
"""Abstract class to grade the chosen selection or the response of the llm."""
@abstractmethod
def score_response(
@@ -248,6 +272,8 @@ class SelectionScorer(Generic[TEvent], ABC, BaseModel):
class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
"""Auto selection scorer."""
llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
@@ -308,8 +334,8 @@ class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
class RLChain(Chain, Generic[TEvent]):
"""
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
"""Chain that leverages the Vowpal Wabbit (VW) model as a learned policy
for reinforcement learning.
Attributes:
- llm_chain (Chain): Represents the underlying Language Model chain.
@@ -547,7 +573,8 @@ class RLChain(Chain, Generic[TEvent]):
def is_stringtype_instance(item: Any) -> bool:
"""Helper function to check if an item is a string."""
"""Check if an item is a string."""
return isinstance(item, str) or (
isinstance(item, _Embed) and isinstance(item.value, str)
)
@@ -556,7 +583,8 @@ def is_stringtype_instance(item: Any) -> bool:
def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object."""
"""Embed a string or an _Embed object."""
keep_str = ""
if isinstance(item, _Embed):
encoded = stringify_embedding(model.encode(item.value))
@@ -576,7 +604,7 @@ def embed_string_type(
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item."""
"""Embed a dictionary item."""
inner_dict: Dict = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
@@ -592,6 +620,8 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
"""Embed a list item."""
ret_list: List = []
for embed_item in item:
if isinstance(embed_item, dict):
@@ -614,7 +644,8 @@ def embed(
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:
"""
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
Embed the actions or context using the SentenceTransformer model
(or a model that has an `encode` function).
Attributes:
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.

View File

@@ -6,6 +6,8 @@ if TYPE_CHECKING:
class MetricsTrackerAverage:
"""Metrics Tracker Average."""
def __init__(self, step: int):
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
self.step: int = step
@@ -33,6 +35,8 @@ class MetricsTrackerAverage:
class MetricsTrackerRollingWindow:
"""Metrics Tracker Rolling Window."""
def __init__(self, window_size: int, step: int):
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
self.step: int = step

View File

@@ -13,6 +13,8 @@ logger = logging.getLogger(__name__)
class ModelRepository:
"""Model Repository."""
def __init__(
self,
folder: Union[str, os.PathLike],

View File

@@ -18,6 +18,8 @@ SENTINEL = object()
class PickBestSelected(base.Selected):
"""Selected class for PickBest chain."""
index: Optional[int]
probability: Optional[float]
score: Optional[float]
@@ -34,6 +36,8 @@ class PickBestSelected(base.Selected):
class PickBestEvent(base.Event[PickBestSelected]):
"""Event class for PickBest chain."""
def __init__(
self,
inputs: Dict[str, Any],
@@ -47,8 +51,8 @@ class PickBestEvent(base.Event[PickBestSelected]):
class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
"""
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
"""Embed the `BasedOn` and `ToSelectFrom` inputs into a format that can be used
by the learning policy.
Attributes:
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
@@ -225,6 +229,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
class PickBestRandomPolicy(base.Policy[PickBestEvent]):
"""Random policy for PickBest chain."""
def __init__(self, feature_embedder: base.Embedder, **kwargs: Any):
self.feature_embedder = feature_embedder
@@ -240,8 +246,8 @@ class PickBestRandomPolicy(base.Policy[PickBestEvent]):
class PickBest(base.RLChain[PickBestEvent]):
"""
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
"""Chain that leverages the Vowpal Wabbit (VW) model for reinforcement learning
with a context, with the goal of modifying the prompt before the LLM call.
Each invocation of the chain's `run()` method should be equipped with a set of potential actions (`ToSelectFrom`) and will result in the selection of a specific action based on the `BasedOn` input. This chosen action then informs the LLM (Language Model) prompt for the subsequent response generation.

View File

@@ -4,6 +4,8 @@ from typing import Optional, Union
class VwLogger:
"""Vowpal Wabbit custom logger."""
def __init__(self, path: Optional[Union[str, PathLike]]):
self.path = Path(path) if path else None
if self.path: