mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 17:45:25 +00:00
fix all mypy errors
This commit is contained in:
parent
dd6fff1c62
commit
a11ad11d06
@ -161,6 +161,9 @@ TEvent = TypeVar("TEvent", bound=Event)
|
|||||||
|
|
||||||
|
|
||||||
class Policy(ABC):
|
class Policy(ABC):
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, event: TEvent) -> Any:
|
def predict(self, event: TEvent) -> Any:
|
||||||
...
|
...
|
||||||
@ -233,7 +236,7 @@ class SelectionScorer(ABC, BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||||
llm_chain: Union[LLMChain, None] = None
|
llm_chain: LLMChain
|
||||||
prompt: Union[BasePromptTemplate, None] = None
|
prompt: Union[BasePromptTemplate, None] = None
|
||||||
scoring_criteria_template_str: Optional[str] = None
|
scoring_criteria_template_str: Optional[str] = None
|
||||||
|
|
||||||
@ -309,7 +312,7 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||||
- policy (VwPolicy): Policy used by the chain.
|
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||||
|
|
||||||
@ -322,7 +325,7 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
selection_scorer: Union[SelectionScorer, None]
|
selection_scorer: Union[SelectionScorer, None]
|
||||||
policy: Policy
|
active_policy: Policy
|
||||||
auto_embed: bool = True
|
auto_embed: bool = True
|
||||||
selected_input_key = "rl_chain_selected"
|
selected_input_key = "rl_chain_selected"
|
||||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||||
@ -347,14 +350,17 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
reinforcement learning will be done in the RL chain \
|
reinforcement learning will be done in the RL chain \
|
||||||
unless update_with_delayed_score is called."
|
unless update_with_delayed_score is called."
|
||||||
)
|
)
|
||||||
self.policy = policy(
|
|
||||||
model_repo=ModelRepository(
|
if self.active_policy is None:
|
||||||
model_save_dir, with_history=True, reset=reset_model
|
self.active_policy = policy(
|
||||||
),
|
model_repo=ModelRepository(
|
||||||
vw_cmd=vw_cmd or [],
|
model_save_dir, with_history=True, reset=reset_model
|
||||||
feature_embedder=feature_embedder,
|
),
|
||||||
vw_logger=VwLogger(vw_logs),
|
vw_cmd=vw_cmd or [],
|
||||||
)
|
feature_embedder=feature_embedder,
|
||||||
|
vw_logger=VwLogger(vw_logs),
|
||||||
|
)
|
||||||
|
|
||||||
self.metrics = MetricsTracker(step=metrics_step)
|
self.metrics = MetricsTracker(step=metrics_step)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -427,8 +433,8 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
self._call_after_scoring_before_learning(event=event, score=score)
|
self._call_after_scoring_before_learning(event=event, score=score)
|
||||||
self.policy.learn(event=event)
|
self.active_policy.learn(event=event)
|
||||||
self.policy.log(event=event)
|
self.active_policy.log(event=event)
|
||||||
|
|
||||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||||
"""
|
"""
|
||||||
@ -447,7 +453,7 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
inputs = prepare_inputs_for_autoembed(inputs=inputs)
|
||||||
|
|
||||||
event: TEvent = self._call_before_predict(inputs=inputs)
|
event: TEvent = self._call_before_predict(inputs=inputs)
|
||||||
prediction = self.policy.predict(event=event)
|
prediction = self.active_policy.predict(event=event)
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.on_decision()
|
self.metrics.on_decision()
|
||||||
|
|
||||||
@ -484,8 +490,8 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||||
self.policy.learn(event=event)
|
self.active_policy.learn(event=event)
|
||||||
self.policy.log(event=event)
|
self.active_policy.log(event=event)
|
||||||
|
|
||||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||||
|
|
||||||
@ -493,7 +499,7 @@ class RLChain(Generic[TEvent], Chain):
|
|||||||
"""
|
"""
|
||||||
This function should be called to save the state of the learned policy model.
|
This function should be called to save the state of the learned policy model.
|
||||||
"""
|
"""
|
||||||
self.policy.save()
|
self.active_policy.save()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
@ -509,7 +515,7 @@ def is_stringtype_instance(item: Any) -> bool:
|
|||||||
|
|
||||||
def embed_string_type(
|
def embed_string_type(
|
||||||
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, Union[str, List[str]]]:
|
||||||
"""Helper function to embed a string or an _Embed object."""
|
"""Helper function to embed a string or an _Embed object."""
|
||||||
join_char = ""
|
join_char = ""
|
||||||
keep_str = ""
|
keep_str = ""
|
||||||
@ -533,9 +539,9 @@ def embed_string_type(
|
|||||||
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
||||||
|
|
||||||
|
|
||||||
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
|
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
|
||||||
"""Helper function to embed a dictionary item."""
|
"""Helper function to embed a dictionary item."""
|
||||||
inner_dict: Dict[str, Union[str, List[str]]] = {}
|
inner_dict: Dict[str, Any] = {}
|
||||||
for ns, embed_item in item.items():
|
for ns, embed_item in item.items():
|
||||||
if isinstance(embed_item, list):
|
if isinstance(embed_item, list):
|
||||||
inner_dict[ns] = []
|
inner_dict[ns] = []
|
||||||
@ -560,9 +566,7 @@ def embed_list_type(
|
|||||||
|
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
to_embed: Union[
|
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
|
||||||
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
|
|
||||||
],
|
|
||||||
model: Any,
|
model: Any,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||||
|
@ -54,9 +54,14 @@ class PickBestFeatureEmbedder(base.Embedder[PickBest.Event]):
|
|||||||
to_select_from_var_name, to_select_from = next(
|
to_select_from_var_name, to_select_from = next(
|
||||||
iter(event.to_select_from.items()), (None, None)
|
iter(event.to_select_from.items()), (None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
action_embs = (
|
action_embs = (
|
||||||
base.embed(to_select_from, self.model, to_select_from_var_name)
|
(
|
||||||
if event.to_select_from
|
base.embed(to_select_from, self.model, to_select_from_var_name)
|
||||||
|
if event.to_select_from
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if to_select_from
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user