diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index c69b21dd1f6..4b5ac572f9c 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -316,7 +316,7 @@ class RLChain(Chain, Generic[TEvent]): - selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None. - policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt. - auto_embed (bool): Determines if embedding should be automatic. Default is False. - - metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None. + - metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None. Initialization Attributes: - feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs. @@ -325,7 +325,8 @@ class RLChain(Chain, Generic[TEvent]): - vw_cmd (List[str], optional): Command line arguments for the VW model. - policy (Type[VwPolicy]): Policy used by the chain. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - - metrics_step (int): Step for the metrics tracker. Default is -1. + - metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked. + - metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked. Notes: The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called. diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index e6a3007a560..791d12cdb46 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): context_matrix = np.stack([v for k, v in context_embeddings.items()]) dot_product_matrix = np.dot(context_matrix, action_matrix.T) - indexed_dot_product: Dict[Dict] = {} + indexed_dot_product: Dict = {} for i, context_key in enumerate(context_embeddings.keys()): indexed_dot_product[context_key] = {} @@ -258,6 +258,18 @@ class PickBest(base.RLChain[PickBestEvent]): ): auto_embed = kwargs.get("auto_embed", False) + feature_embedder = kwargs.get("feature_embedder", None) + if feature_embedder: + if "auto_embed" in kwargs: + logger.warning( + "auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501 + ) + # turning auto_embed off for cli setting below + auto_embed = False + else: + feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed) + kwargs["feature_embedder"] = feature_embedder + vw_cmd = kwargs.get("vw_cmd", []) if vw_cmd: if "--cb_explore_adf" not in vw_cmd: @@ -281,16 +293,6 @@ class PickBest(base.RLChain[PickBestEvent]): kwargs["vw_cmd"] = vw_cmd - feature_embedder = kwargs.get("feature_embedder", None) - if feature_embedder: - if "auto_embed" in kwargs: - logger.warning( - "auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501 - ) - else: - feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed) - kwargs["feature_embedder"] = feature_embedder - super().__init__(*args, **kwargs) def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: