mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
mypy fixes and formatting
This commit is contained in:
parent
7725192a0d
commit
6a1102d4c0
@ -3,7 +3,18 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class _BasedOn:
|
class _BasedOn:
|
||||||
def __init__(self, value):
|
def __init__(self, value: Any):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def BasedOn(anything):
|
def BasedOn(anything: Any) -> _BasedOn:
|
||||||
return _BasedOn(anything)
|
return _BasedOn(anything)
|
||||||
|
|
||||||
|
|
||||||
class _ToSelectFrom:
|
class _ToSelectFrom:
|
||||||
def __init__(self, value):
|
def __init__(self, value: Any):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def ToSelectFrom(anything):
|
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
||||||
if not isinstance(anything, list):
|
if not isinstance(anything, list):
|
||||||
raise ValueError("ToSelectFrom must be a list to select from")
|
raise ValueError("ToSelectFrom must be a list to select from")
|
||||||
return _ToSelectFrom(anything)
|
return _ToSelectFrom(anything)
|
||||||
|
|
||||||
|
|
||||||
class _Embed:
|
class _Embed:
|
||||||
def __init__(self, value, keep=False):
|
def __init__(self, value: Any, keep: bool = False):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.keep = keep
|
self.keep = keep
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
def Embed(anything, keep=False):
|
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||||
if isinstance(anything, _ToSelectFrom):
|
if isinstance(anything, _ToSelectFrom):
|
||||||
return ToSelectFrom(Embed(anything.value, keep=keep))
|
return ToSelectFrom(Embed(anything.value, keep=keep))
|
||||||
elif isinstance(anything, _BasedOn):
|
elif isinstance(anything, _BasedOn):
|
||||||
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
|
|||||||
return _Embed(anything, keep=keep)
|
return _Embed(anything, keep=keep)
|
||||||
|
|
||||||
|
|
||||||
def EmbedAndKeep(anything):
|
def EmbedAndKeep(anything: Any) -> Any:
|
||||||
return Embed(anything, keep=True)
|
return Embed(anything, keep=True)
|
||||||
|
|
||||||
|
|
||||||
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
|
|||||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||||
|
|
||||||
|
|
||||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
|
||||||
to_select_from = {
|
to_select_from = {
|
||||||
k: inputs[k].value
|
k: inputs[k].value
|
||||||
for k in inputs.keys()
|
for k in inputs.keys()
|
||||||
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
|||||||
return based_on, to_select_from
|
return based_on, to_select_from
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
|
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||||
@ -134,29 +145,35 @@ class Selected(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Event(ABC):
|
TSelected = TypeVar("TSelected", bound=Selected)
|
||||||
inputs: Dict[str, Any]
|
|
||||||
selected: Optional[Selected]
|
|
||||||
|
|
||||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
|
|
||||||
|
class Event(Generic[TSelected], ABC):
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
selected: Optional[TSelected]
|
||||||
|
|
||||||
|
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.selected = selected
|
self.selected = selected
|
||||||
|
|
||||||
|
|
||||||
|
TEvent = TypeVar("TEvent", bound=Event)
|
||||||
|
|
||||||
|
|
||||||
class Policy(ABC):
|
class Policy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, event: Event) -> Any:
|
def predict(self, event: TEvent) -> Any:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def learn(self, event: Event):
|
def learn(self, event: TEvent) -> None:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log(self, event: Event):
|
def log(self, event: TEvent) -> None:
|
||||||
pass
|
...
|
||||||
|
|
||||||
def save(self):
|
def save(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -164,11 +181,11 @@ class VwPolicy(Policy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_repo: ModelRepository,
|
model_repo: ModelRepository,
|
||||||
vw_cmd: Sequence[str],
|
vw_cmd: List[str],
|
||||||
feature_embedder: Embedder,
|
feature_embedder: Embedder,
|
||||||
vw_logger: VwLogger,
|
vw_logger: VwLogger,
|
||||||
*args,
|
*args: Any,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.model_repo = model_repo
|
self.model_repo = model_repo
|
||||||
@ -176,7 +193,7 @@ class VwPolicy(Policy):
|
|||||||
self.feature_embedder = feature_embedder
|
self.feature_embedder = feature_embedder
|
||||||
self.vw_logger = vw_logger
|
self.vw_logger = vw_logger
|
||||||
|
|
||||||
def predict(self, event: Event) -> Any:
|
def predict(self, event: TEvent) -> Any:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
text_parser = vw.TextFormatParser(self.workspace)
|
text_parser = vw.TextFormatParser(self.workspace)
|
||||||
@ -184,7 +201,7 @@ class VwPolicy(Policy):
|
|||||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||||
)
|
)
|
||||||
|
|
||||||
def learn(self, event: Event):
|
def learn(self, event: TEvent) -> None:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
vw_ex = self.feature_embedder.format(event)
|
vw_ex = self.feature_embedder.format(event)
|
||||||
@ -192,19 +209,19 @@ class VwPolicy(Policy):
|
|||||||
multi_ex = parse_lines(text_parser, vw_ex)
|
multi_ex = parse_lines(text_parser, vw_ex)
|
||||||
self.workspace.learn_one(multi_ex)
|
self.workspace.learn_one(multi_ex)
|
||||||
|
|
||||||
def log(self, event: Event):
|
def log(self, event: TEvent) -> None:
|
||||||
if self.vw_logger.logging_enabled():
|
if self.vw_logger.logging_enabled():
|
||||||
vw_ex = self.feature_embedder.format(event)
|
vw_ex = self.feature_embedder.format(event)
|
||||||
self.vw_logger.log(vw_ex)
|
self.vw_logger.log(vw_ex)
|
||||||
|
|
||||||
def save(self):
|
def save(self) -> None:
|
||||||
self.model_repo.save()
|
self.model_repo.save(self.workspace)
|
||||||
|
|
||||||
|
|
||||||
class Embedder(ABC):
|
class Embedder(Generic[TEvent], ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format(self, event: Event) -> str:
|
def format(self, event: TEvent) -> str:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class SelectionScorer(ABC, BaseModel):
|
class SelectionScorer(ABC, BaseModel):
|
||||||
@ -212,7 +229,7 @@ class SelectionScorer(ABC, BaseModel):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||||
@ -243,7 +260,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
return chat_prompt
|
return chat_prompt
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_prompt_and_llm_chain(cls, values):
|
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
llm = values.get("llm")
|
llm = values.get("llm")
|
||||||
prompt = values.get("prompt")
|
prompt = values.get("prompt")
|
||||||
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
||||||
@ -275,7 +292,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RLChain(Chain):
|
class RLChain(Generic[TEvent], Chain):
|
||||||
"""
|
"""
|
||||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||||
|
|
||||||
@ -305,7 +322,7 @@ class RLChain(Chain):
|
|||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
selection_scorer: Union[SelectionScorer, None]
|
selection_scorer: Union[SelectionScorer, None]
|
||||||
policy: Optional[Policy]
|
policy: Policy
|
||||||
auto_embed: bool = True
|
auto_embed: bool = True
|
||||||
selected_input_key = "rl_chain_selected"
|
selected_input_key = "rl_chain_selected"
|
||||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||||
@ -314,14 +331,14 @@ class RLChain(Chain):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_embedder: Embedder,
|
feature_embedder: Embedder,
|
||||||
model_save_dir="./",
|
model_save_dir: str = "./",
|
||||||
reset_model=False,
|
reset_model: bool = False,
|
||||||
vw_cmd=None,
|
vw_cmd: Optional[List[str]] = None,
|
||||||
policy=VwPolicy,
|
policy: Type[Policy] = VwPolicy,
|
||||||
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
||||||
metrics_step=-1,
|
metrics_step: int = -1,
|
||||||
*args,
|
*args: Any,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.selection_scorer is None:
|
if self.selection_scorer is None:
|
||||||
@ -374,29 +391,29 @@ class RLChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: Any
|
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
) -> Tuple[Dict[str, Any], TEvent]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_llm_before_scoring(
|
def _call_after_llm_before_scoring(
|
||||||
self, llm_response: str, event: Event
|
self, llm_response: str, event: TEvent
|
||||||
) -> Tuple[Dict[str, Any], Event]:
|
) -> Tuple[Dict[str, Any], TEvent]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: Event, score: Optional[float]
|
self, event: TEvent, score: Optional[float]
|
||||||
) -> Event:
|
) -> TEvent:
|
||||||
pass
|
...
|
||||||
|
|
||||||
def update_with_delayed_score(
|
def update_with_delayed_score(
|
||||||
self, score: float, event: Event, force_score=False
|
self, score: float, event: TEvent, force_score: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the learned policy with the score provided.
|
Updates the learned policy with the score provided.
|
||||||
@ -407,6 +424,7 @@ class RLChain(Chain):
|
|||||||
"The selection scorer is set, and force_score was not set to True. \
|
"The selection scorer is set, and force_score was not set to True. \
|
||||||
Please set force_score=True to use this function."
|
Please set force_score=True to use this function."
|
||||||
)
|
)
|
||||||
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
self._call_after_scoring_before_learning(event=event, score=score)
|
self._call_after_scoring_before_learning(event=event, score=score)
|
||||||
self.policy.learn(event=event)
|
self.policy.learn(event=event)
|
||||||
@ -422,14 +440,15 @@ class RLChain(Chain):
|
|||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
|
||||||
if self.auto_embed:
|
if self.auto_embed:
|
||||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||||
|
|
||||||
event = self._call_before_predict(inputs=inputs)
|
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||||
prediction = self.policy.predict(event=event)
|
prediction = self.policy.predict(event=event)
|
||||||
|
if self.metrics:
|
||||||
self.metrics.on_decision()
|
self.metrics.on_decision()
|
||||||
|
|
||||||
next_chain_inputs, event = self._call_after_predict_before_llm(
|
next_chain_inputs, event = self._call_after_predict_before_llm(
|
||||||
@ -462,6 +481,7 @@ class RLChain(Chain):
|
|||||||
f"The selection scorer was not able to score, \
|
f"The selection scorer was not able to score, \
|
||||||
and the chain was not able to adjust to this response, error: {e}"
|
and the chain was not able to adjust to this response, error: {e}"
|
||||||
)
|
)
|
||||||
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||||
self.policy.learn(event=event)
|
self.policy.learn(event=event)
|
||||||
@ -515,7 +535,7 @@ def embed_string_type(
|
|||||||
|
|
||||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||||
"""Helper function to embed a dictionary item."""
|
"""Helper function to embed a dictionary item."""
|
||||||
inner_dict = {}
|
inner_dict: Dict[str, Union[str, List[str]]] = {}
|
||||||
for ns, embed_item in item.items():
|
for ns, embed_item in item.items():
|
||||||
if isinstance(embed_item, list):
|
if isinstance(embed_item, list):
|
||||||
inner_dict[ns] = []
|
inner_dict[ns] = []
|
||||||
@ -530,7 +550,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
|||||||
def embed_list_type(
|
def embed_list_type(
|
||||||
item: list, model: Any, namespace: Optional[str] = None
|
item: list, model: Any, namespace: Optional[str] = None
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
ret_list = []
|
ret_list: List[Dict[str, Union[str, List[str]]]] = []
|
||||||
for embed_item in item:
|
for embed_item in item:
|
||||||
if isinstance(embed_item, dict):
|
if isinstance(embed_item, dict):
|
||||||
ret_list.append(embed_dict_type(embed_item, model))
|
ret_list.append(embed_dict_type(embed_item, model))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -6,11 +6,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class MetricsTracker:
|
class MetricsTracker:
|
||||||
def __init__(self, step: int):
|
def __init__(self, step: int):
|
||||||
self._history = []
|
self._history: List[Dict[str, Union[int, float]]] = []
|
||||||
self._step = step
|
self._step: int = step
|
||||||
self._i = 0
|
self._i: int = 0
|
||||||
self._num = 0
|
self._num: float = 0
|
||||||
self._denom = 0
|
self._denom: float = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def score(self) -> float:
|
def score(self) -> float:
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Sequence, Union
|
from typing import TYPE_CHECKING, List, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
@ -22,7 +22,7 @@ class ModelRepository:
|
|||||||
self.folder = Path(folder)
|
self.folder = Path(folder)
|
||||||
self.model_path = self.folder / "latest.vw"
|
self.model_path = self.folder / "latest.vw"
|
||||||
self.with_history = with_history
|
self.with_history = with_history
|
||||||
if reset and self.has_history:
|
if reset and self.has_history():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"There is non empty history which is recommended to be cleaned up"
|
"There is non empty history which is recommended to be cleaned up"
|
||||||
)
|
)
|
||||||
@ -44,7 +44,7 @@ class ModelRepository:
|
|||||||
if self.with_history: # write history
|
if self.with_history: # write history
|
||||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||||
|
|
||||||
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
def load(self, commandline: List[str]) -> "vw.Workspace":
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
|
||||||
model_data = None
|
model_data = None
|
||||||
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
class PickBestFeatureEmbedder(base.Embedder):
|
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||||
"""
|
"""
|
||||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||||
""" # noqa E501
|
""" # noqa E501
|
||||||
|
|
||||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -88,7 +88,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
|||||||
return example_string[:-1]
|
return example_string[:-1]
|
||||||
|
|
||||||
|
|
||||||
class PickBest(base.RLChain):
|
class PickBest(base.RLChain[PickBest.Event]):
|
||||||
"""
|
"""
|
||||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ class PickBest(base.RLChain):
|
|||||||
self.probability = probability
|
self.probability = probability
|
||||||
self.score = score
|
self.score = score
|
||||||
|
|
||||||
class Event(base.Event):
|
class Event(base.Event[PickBest.Selected]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
@ -146,8 +146,8 @@ class PickBest(base.RLChain):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
||||||
*args,
|
*args: Any,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
vw_cmd = kwargs.get("vw_cmd", [])
|
vw_cmd = kwargs.get("vw_cmd", [])
|
||||||
if not vw_cmd:
|
if not vw_cmd:
|
||||||
@ -170,7 +170,7 @@ class PickBest(base.RLChain):
|
|||||||
|
|
||||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
||||||
|
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||||
if not actions:
|
if not actions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -198,7 +198,7 @@ class PickBest(base.RLChain):
|
|||||||
|
|
||||||
def _call_after_predict_before_llm(
|
def _call_after_predict_before_llm(
|
||||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
) -> Tuple[Dict[str, Any], Event]:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
prob_sum = sum(prob for _, prob in prediction)
|
prob_sum = sum(prob for _, prob in prediction)
|
||||||
@ -218,8 +218,8 @@ class PickBest(base.RLChain):
|
|||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_llm_before_scoring(
|
def _call_after_llm_before_scoring(
|
||||||
self, llm_response: str, event: PickBest.Event
|
self, llm_response: str, event: Event
|
||||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
) -> Tuple[Dict[str, Any], Event]:
|
||||||
next_chain_inputs = event.inputs.copy()
|
next_chain_inputs = event.inputs.copy()
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
value = next(iter(event.to_select_from.values()))
|
value = next(iter(event.to_select_from.values()))
|
||||||
@ -232,7 +232,7 @@ class PickBest(base.RLChain):
|
|||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
|
|
||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: PickBest.Event, score: Optional[float]
|
self, event: Event, score: Optional[float]
|
||||||
) -> Event:
|
) -> Event:
|
||||||
event.selected.score = score
|
event.selected.score = score
|
||||||
return event
|
return event
|
||||||
|
@ -9,10 +9,10 @@ class VwLogger:
|
|||||||
if self.path:
|
if self.path:
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def log(self, vw_ex: str):
|
def log(self, vw_ex: str) -> None:
|
||||||
if self.path:
|
if self.path:
|
||||||
with open(self.path, "a") as f:
|
with open(self.path, "a") as f:
|
||||||
f.write(f"{vw_ex}\n\n")
|
f.write(f"{vw_ex}\n\n")
|
||||||
|
|
||||||
def logging_enabled(self):
|
def logging_enabled(self) -> bool:
|
||||||
return bool(self.path)
|
return bool(self.path)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
class MockEncoder:
|
class MockEncoder:
|
||||||
def encode(self, to_encode):
|
def encode(self, to_encode: str) -> str:
|
||||||
return "[encoded]" + to_encode
|
return "[encoded]" + to_encode
|
||||||
|
Loading…
Reference in New Issue
Block a user