mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 09:30:15 +00:00
fix all mypy errors
This commit is contained in:
parent
dd6fff1c62
commit
a11ad11d06
@ -161,6 +161,9 @@ TEvent = TypeVar("TEvent", bound=Event)
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, event: TEvent) -> Any:
|
||||
...
|
||||
@ -233,7 +236,7 @@ class SelectionScorer(ABC, BaseModel):
|
||||
|
||||
|
||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
llm_chain: Union[LLMChain, None] = None
|
||||
llm_chain: LLMChain
|
||||
prompt: Union[BasePromptTemplate, None] = None
|
||||
scoring_criteria_template_str: Optional[str] = None
|
||||
|
||||
@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain):
|
||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||
- policy (VwPolicy): Policy used by the chain.
|
||||
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||
|
||||
@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain):
|
||||
output_key: str = "result" #: :meta private:
|
||||
prompt: BasePromptTemplate
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
policy: Policy
|
||||
active_policy: Policy
|
||||
auto_embed: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
@ -347,14 +350,17 @@ class RLChain(Generic[TEvent], Chain):
|
||||
reinforcement learning will be done in the RL chain \
|
||||
unless update_with_delayed_score is called."
|
||||
)
|
||||
self.policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
),
|
||||
vw_cmd=vw_cmd or [],
|
||||
feature_embedder=feature_embedder,
|
||||
vw_logger=VwLogger(vw_logs),
|
||||
)
|
||||
|
||||
if self.active_policy is None:
|
||||
self.active_policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
model_save_dir, with_history=True, reset=reset_model
|
||||
),
|
||||
vw_cmd=vw_cmd or [],
|
||||
feature_embedder=feature_embedder,
|
||||
vw_logger=VwLogger(vw_logs),
|
||||
)
|
||||
|
||||
self.metrics = MetricsTracker(step=metrics_step)
|
||||
|
||||
class Config:
|
||||
@ -427,8 +433,8 @@ class RLChain(Generic[TEvent], Chain):
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.policy.learn(event=event)
|
||||
self.policy.log(event=event)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||
"""
|
||||
@ -447,7 +453,7 @@ class RLChain(Generic[TEvent], Chain):
|
||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||
|
||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||
prediction = self.policy.predict(event=event)
|
||||
prediction = self.active_policy.predict(event=event)
|
||||
if self.metrics:
|
||||
self.metrics.on_decision()
|
||||
|
||||
@ -484,8 +490,8 @@ class RLChain(Generic[TEvent], Chain):
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||
self.policy.learn(event=event)
|
||||
self.policy.log(event=event)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||
|
||||
@ -493,7 +499,7 @@ class RLChain(Generic[TEvent], Chain):
|
||||
"""
|
||||
This function should be called to save the state of the learned policy model.
|
||||
"""
|
||||
self.policy.save()
|
||||
self.active_policy.save()
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool:
|
||||
|
||||
def embed_string_type(
|
||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
"""Helper function to embed a string or an _Embed object."""
|
||||
join_char = ""
|
||||
keep_str = ""
|
||||
@ -533,9 +539,9 @@ def embed_string_type(
|
||||
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
||||
|
||||
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||
"""Helper function to embed a dictionary item."""
|
||||
inner_dict: Dict[str, Union[str, List[str]]] = {}
|
||||
inner_dict: Dict[str, Any] = {}
|
||||
for ns, embed_item in item.items():
|
||||
if isinstance(embed_item, list):
|
||||
inner_dict[ns] = []
|
||||
@ -560,9 +566,7 @@ def embed_list_type(
|
||||
|
||||
|
||||
def embed(
|
||||
to_embed: Union[
|
||||
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
|
||||
],
|
||||
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||
model: Any,
|
||||
namespace: Optional[str] = None,
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
|
@ -54,9 +54,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
||||
to_select_from_var_name, to_select_from = next(
|
||||
iter(event.to_select_from.items()), (None, None)
|
||||
)
|
||||
|
||||
action_embs = (
|
||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||
if event.to_select_from
|
||||
(
|
||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||
if event.to_select_from
|
||||
else None
|
||||
)
|
||||
if to_select_from
|
||||
else None
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user