diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index fb4143f4655..c250815943d 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -343,6 +343,7 @@ class RLChain(Chain, Generic[TEvent]): selection_scorer: Union[SelectionScorer, None] active_policy: Policy = _NoOpPolicy() auto_embed: bool = False + selection_scorer_activated: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" metrics: Optional[MetricsTracker] = None @@ -400,6 +401,42 @@ class RLChain(Chain, Generic[TEvent]): """ return [self.output_key] + def update_with_delayed_score( + self, score: float, event: TEvent, force_score: bool = False + ) -> None: + """ + Updates the learned policy with the score provided. + Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call + """ # noqa: E501 + if self._can_use_selection_scorer() and not force_score: + raise RuntimeError( + "The selection scorer is set, and force_score was not set to True. \ + Please set force_score=True to use this function." + ) + if self.metrics: + self.metrics.on_feedback(score) + self._call_after_scoring_before_learning(event=event, score=score) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) + + def deactivate_selection_scorer(self) -> None: + """ + Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses. + """ # noqa: E501 + self.selection_scorer_activated = False + + def activate_selection_scorer(self) -> None: + """ + Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses. + """ # noqa: E501 + self.selection_scorer_activated = True + + def save_progress(self) -> None: + """ + This function should be called to save the state of the learned policy model. + """ # noqa: E501 + self.active_policy.save() + def _validate_inputs(self, inputs: Dict[str, Any]) -> None: super()._validate_inputs(inputs) if ( @@ -412,6 +449,12 @@ class RLChain(Chain, Generic[TEvent]): they are reserved for internal use during auto reward." ) + def _can_use_selection_scorer(self) -> bool: + """ + Returns whether the chain can use the selection scorer to score responses or not. + """ # noqa: E501 + return self.selection_scorer is not None and self.selection_scorer_activated + @abstractmethod def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent: ... @@ -434,30 +477,6 @@ class RLChain(Chain, Generic[TEvent]): ) -> TEvent: ... - def update_with_delayed_score( - self, score: float, event: TEvent, force_score: bool = False - ) -> None: - """ - Updates the learned policy with the score provided. - Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call - """ # noqa: E501 - if self.selection_scorer and not force_score: - raise RuntimeError( - "The selection scorer is set, and force_score was not set to True. \ - Please set force_score=True to use this function." - ) - if self.metrics: - self.metrics.on_feedback(score) - self._call_after_scoring_before_learning(event=event, score=score) - self.active_policy.learn(event=event) - self.active_policy.log(event=event) - - def set_auto_embed(self, auto_embed: bool) -> None: - """ - Sets whether the chain should auto embed the inputs or not. - """ - self.auto_embed = auto_embed - def _call( self, inputs: Dict[str, Any], @@ -494,8 +513,8 @@ class RLChain(Chain, Generic[TEvent]): score = None try: - if self.selection_scorer: - score = self.selection_scorer.score_response( + if self._can_use_selection_scorer(): + score = self.selection_scorer.score_response( # type: ignore inputs=next_chain_inputs, llm_response=output, event=event ) except Exception as e: @@ -511,12 +530,6 @@ class RLChain(Chain, Generic[TEvent]): return {self.output_key: {"response": output, "selection_metadata": event}} - def save_progress(self) -> None: - """ - This function should be called to save the state of the learned policy model. - """ - self.active_policy.save() - @property def _chain_type(self) -> str: return "llm_personalizer_chain" diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index 2af08840b5e..d7dee7fdf64 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -363,3 +363,41 @@ def test_calling_chain_w_reserved_inputs_throws() -> None: User=rl_chain.BasedOn("Context"), rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]), ) + + +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") +def test_activate_and_deactivate_scorer() -> None: + llm, PROMPT = setup() + scorer_llm = FakeListChatModel(responses=[300]) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + # chain llm used for both basic prompt and for scoring + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score == 300.0 + + chain.deactivate_selection_scorer() + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score is None + + chain.activate_selection_scorer() + response = chain.run( + User=pick_best_chain.base.BasedOn("Context"), + action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]), + ) + assert response["response"] == "hey" + selection_metadata = response["selection_metadata"] + assert selection_metadata.selected.score == 300.0