mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
mypy fixes and formatting
This commit is contained in:
parent
7725192a0d
commit
6a1102d4c0
@ -3,7 +3,18 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
@ -26,47 +37,47 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BasedOn:
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def BasedOn(anything):
|
||||
def BasedOn(anything: Any) -> _BasedOn:
|
||||
return _BasedOn(anything)
|
||||
|
||||
|
||||
class _ToSelectFrom:
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def ToSelectFrom(anything):
|
||||
def ToSelectFrom(anything: Any) -> _ToSelectFrom:
|
||||
if not isinstance(anything, list):
|
||||
raise ValueError("ToSelectFrom must be a list to select from")
|
||||
return _ToSelectFrom(anything)
|
||||
|
||||
|
||||
class _Embed:
|
||||
def __init__(self, value, keep=False):
|
||||
def __init__(self, value: Any, keep: bool = False):
|
||||
self.value = value
|
||||
self.keep = keep
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def Embed(anything, keep=False):
|
||||
def Embed(anything: Any, keep: bool = False) -> Any:
|
||||
if isinstance(anything, _ToSelectFrom):
|
||||
return ToSelectFrom(Embed(anything.value, keep=keep))
|
||||
elif isinstance(anything, _BasedOn):
|
||||
@ -80,7 +91,7 @@ def Embed(anything, keep=False):
|
||||
return _Embed(anything, keep=keep)
|
||||
|
||||
|
||||
def EmbedAndKeep(anything):
|
||||
def EmbedAndKeep(anything: Any) -> Any:
|
||||
return Embed(anything, keep=True)
|
||||
|
||||
|
||||
@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp
|
||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||
|
||||
|
||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
|
||||
to_select_from = {
|
||||
k: inputs[k].value
|
||||
for k in inputs.keys()
|
||||
@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
return based_on, to_select_from
|
||||
|
||||
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||
@ -134,29 +145,35 @@ class Selected(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Event(ABC):
|
||||
inputs: Dict[str, Any]
|
||||
selected: Optional[Selected]
|
||||
TSelected = TypeVar("TSelected", bound=Selected)
|
||||
|
||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
|
||||
|
||||
class Event(Generic[TSelected], ABC):
|
||||
inputs: Dict[str, Any]
|
||||
selected: Optional[TSelected]
|
||||
|
||||
def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
|
||||
self.inputs = inputs
|
||||
self.selected = selected
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=Event)
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, event: Event) -> Any:
|
||||
pass
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, event: Event):
|
||||
pass
|
||||
def learn(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log(self, event: Event):
|
||||
pass
|
||||
def log(self, event: TEvent) -> None:
|
||||
...
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -164,11 +181,11 @@ class VwPolicy(Policy):
|
||||
def __init__(
|
||||
self,
|
||||
model_repo: ModelRepository,
|
||||
vw_cmd: Sequence[str],
|
||||
vw_cmd: List[str],
|
||||
feature_embedder: Embedder,
|
||||
vw_logger: VwLogger,
|
||||
*args,
|
||||
**kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_repo = model_repo
|
||||
@ -176,7 +193,7 @@ class VwPolicy(Policy):
|
||||
self.feature_embedder = feature_embedder
|
||||
self.vw_logger = vw_logger
|
||||
|
||||
def predict(self, event: Event) -> Any:
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
text_parser = vw.TextFormatParser(self.workspace)
|
||||
@ -184,7 +201,7 @@ class VwPolicy(Policy):
|
||||
parse_lines(text_parser, self.feature_embedder.format(event))
|
||||
)
|
||||
|
||||
def learn(self, event: Event):
|
||||
def learn(self, event: TEvent) -> None:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
@ -192,19 +209,19 @@ class VwPolicy(Policy):
|
||||
multi_ex = parse_lines(text_parser, vw_ex)
|
||||
self.workspace.learn_one(multi_ex)
|
||||
|
||||
def log(self, event: Event):
|
||||
def log(self, event: TEvent) -> None:
|
||||
if self.vw_logger.logging_enabled():
|
||||
vw_ex = self.feature_embedder.format(event)
|
||||
self.vw_logger.log(vw_ex)
|
||||
|
||||
def save(self):
|
||||
self.model_repo.save()
|
||||
def save(self) -> None:
|
||||
self.model_repo.save(self.workspace)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
class Embedder(Generic[TEvent], ABC):
|
||||
@abstractmethod
|
||||
def format(self, event: Event) -> str:
|
||||
pass
|
||||
def format(self, event: TEvent) -> str:
|
||||
...
|
||||
|
||||
|
||||
class SelectionScorer(ABC, BaseModel):
|
||||
@ -212,7 +229,7 @@ class SelectionScorer(ABC, BaseModel):
|
||||
|
||||
@abstractmethod
|
||||
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
@ -243,7 +260,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
return chat_prompt
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_prompt_and_llm_chain(cls, values):
|
||||
def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
llm = values.get("llm")
|
||||
prompt = values.get("prompt")
|
||||
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
|
||||
@ -275,7 +292,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RLChain(Chain):
|
||||
class RLChain(Generic[TEvent], Chain):
|
||||
"""
|
||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||
|
||||
@ -305,7 +322,7 @@ class RLChain(Chain):
|
||||
output_key: str = "result" #: :meta private:
|
||||
prompt: BasePromptTemplate
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
policy: Optional[Policy]
|
||||
policy: Policy
|
||||
auto_embed: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
@ -314,14 +331,14 @@ class RLChain(Chain):
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Embedder,
|
||||
model_save_dir="./",
|
||||
reset_model=False,
|
||||
vw_cmd=None,
|
||||
policy=VwPolicy,
|
||||
model_save_dir: str = "./",
|
||||
reset_model: bool = False,
|
||||
vw_cmd: Optional[List[str]] = None,
|
||||
policy: Type[Policy] = VwPolicy,
|
||||
vw_logs: Optional[Union[str, os.PathLike]] = None,
|
||||
metrics_step=-1,
|
||||
*args,
|
||||
**kwargs,
|
||||
metrics_step: int = -1,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.selection_scorer is None:
|
||||
@ -374,29 +391,29 @@ class RLChain(Chain):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||
pass
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: Any
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
pass
|
||||
self, inputs: Dict[str, Any], event: TEvent, prediction: Any
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: Event
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
pass
|
||||
self, llm_response: str, event: TEvent
|
||||
) -> Tuple[Dict[str, Any], TEvent]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> Event:
|
||||
pass
|
||||
self, event: TEvent, score: Optional[float]
|
||||
) -> TEvent:
|
||||
...
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, event: Event, force_score=False
|
||||
self, score: float, event: TEvent, force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
@ -407,6 +424,7 @@ class RLChain(Chain):
|
||||
"The selection scorer is set, and force_score was not set to True. \
|
||||
Please set force_score=True to use this function."
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.policy.learn(event=event)
|
||||
@ -422,14 +440,15 @@ class RLChain(Chain):
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
|
||||
if self.auto_embed:
|
||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||
|
||||
event = self._call_before_predict(inputs=inputs)
|
||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||
prediction = self.policy.predict(event=event)
|
||||
if self.metrics:
|
||||
self.metrics.on_decision()
|
||||
|
||||
next_chain_inputs, event = self._call_after_predict_before_llm(
|
||||
@ -462,6 +481,7 @@ class RLChain(Chain):
|
||||
f"The selection scorer was not able to score, \
|
||||
and the chain was not able to adjust to this response, error: {e}"
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||
self.policy.learn(event=event)
|
||||
@ -515,7 +535,7 @@ def embed_string_type(
|
||||
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||
"""Helper function to embed a dictionary item."""
|
||||
inner_dict = {}
|
||||
inner_dict: Dict[str, Union[str, List[str]]] = {}
|
||||
for ns, embed_item in item.items():
|
||||
if isinstance(embed_item, list):
|
||||
inner_dict[ns] = []
|
||||
@ -530,7 +550,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||
def embed_list_type(
|
||||
item: list, model: Any, namespace: Optional[str] = None
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ret_list = []
|
||||
ret_list: List[Dict[str, Union[str, List[str]]]] = []
|
||||
for embed_item in item:
|
||||
if isinstance(embed_item, dict):
|
||||
ret_list.append(embed_dict_type(embed_item, model))
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
@ -6,11 +6,11 @@ if TYPE_CHECKING:
|
||||
|
||||
class MetricsTracker:
|
||||
def __init__(self, step: int):
|
||||
self._history = []
|
||||
self._step = step
|
||||
self._i = 0
|
||||
self._num = 0
|
||||
self._denom = 0
|
||||
self._history: List[Dict[str, Union[int, float]]] = []
|
||||
self._step: int = step
|
||||
self._i: int = 0
|
||||
self._num: float = 0
|
||||
self._denom: float = 0
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Sequence, Union
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
@ -22,7 +22,7 @@ class ModelRepository:
|
||||
self.folder = Path(folder)
|
||||
self.model_path = self.folder / "latest.vw"
|
||||
self.with_history = with_history
|
||||
if reset and self.has_history:
|
||||
if reset and self.has_history():
|
||||
logger.warning(
|
||||
"There is non empty history which is recommended to be cleaned up"
|
||||
)
|
||||
@ -44,7 +44,7 @@ class ModelRepository:
|
||||
if self.with_history: # write history
|
||||
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
|
||||
|
||||
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
|
||||
def load(self, commandline: List[str]) -> "vw.Workspace":
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
model_data = None
|
||||
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder):
|
||||
class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||
"""
|
||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||
|
||||
@ -25,7 +25,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||
""" # noqa E501
|
||||
|
||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
||||
def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if model is None:
|
||||
@ -88,7 +88,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
return example_string[:-1]
|
||||
|
||||
|
||||
class PickBest(base.RLChain):
|
||||
class PickBest(base.RLChain[PickBest.Event]):
|
||||
"""
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
||||
@ -131,7 +131,7 @@ class PickBest(base.RLChain):
|
||||
self.probability = probability
|
||||
self.score = score
|
||||
|
||||
class Event(base.Event):
|
||||
class Event(base.Event[PickBest.Selected]):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
@ -146,8 +146,8 @@ class PickBest(base.RLChain):
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
vw_cmd = kwargs.get("vw_cmd", [])
|
||||
if not vw_cmd:
|
||||
@ -170,7 +170,7 @@ class PickBest(base.RLChain):
|
||||
|
||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
if not actions:
|
||||
raise ValueError(
|
||||
@ -198,7 +198,7 @@ class PickBest(base.RLChain):
|
||||
|
||||
def _call_after_predict_before_llm(
|
||||
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
|
||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
import numpy as np
|
||||
|
||||
prob_sum = sum(prob for _, prob in prediction)
|
||||
@ -218,8 +218,8 @@ class PickBest(base.RLChain):
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: PickBest.Event
|
||||
) -> Tuple[Dict[str, Any], PickBest.Event]:
|
||||
self, llm_response: str, event: Event
|
||||
) -> Tuple[Dict[str, Any], Event]:
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
# only one key, value pair in event.to_select_from
|
||||
value = next(iter(event.to_select_from.values()))
|
||||
@ -232,7 +232,7 @@ class PickBest(base.RLChain):
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: PickBest.Event, score: Optional[float]
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> Event:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
@ -9,10 +9,10 @@ class VwLogger:
|
||||
if self.path:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def log(self, vw_ex: str):
|
||||
def log(self, vw_ex: str) -> None:
|
||||
if self.path:
|
||||
with open(self.path, "a") as f:
|
||||
f.write(f"{vw_ex}\n\n")
|
||||
|
||||
def logging_enabled(self):
|
||||
def logging_enabled(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
@ -1,3 +1,3 @@
|
||||
class MockEncoder:
|
||||
def encode(self, to_encode):
|
||||
def encode(self, to_encode: str) -> str:
|
||||
return "[encoded]" + to_encode
|
||||
|
Loading…
Reference in New Issue
Block a user