From dd6fff1c6209f6b05cfac0f343654947d6bd9e2e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 28 Aug 2023 08:13:23 -0400 Subject: [PATCH] no errors in pick best chain --- .../langchain/chains/rl_chain/__init__.py | 2 +- .../chains/rl_chain/pick_best_chain.py | 47 ++++++++----------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/__init__.py b/libs/langchain/langchain/chains/rl_chain/__init__.py index e71de1da6cc..6d5cfc3e29c 100644 --- a/libs/langchain/langchain/chains/rl_chain/__init__.py +++ b/libs/langchain/langchain/chains/rl_chain/__init__.py @@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import ( from langchain.chains.rl_chain.pick_best_chain import PickBest -def configure_logger(): +def configure_logger() -> None: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ch = logging.StreamHandler() diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index e60e685a0b5..ca920522680 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type, Union import langchain.chains.rl_chain.base as base from langchain.base_language import BaseLanguageModel @@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]): def __init__( self, - feature_embedder: Optional[PickBestFeatureEmbedder] = None, *args: Any, **kwargs: Any, ): @@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]): raise ValueError( "If vw_cmd is specified, it must include --cb_explore_adf" ) - kwargs["vw_cmd"] = vw_cmd + + feature_embedder = kwargs.get("feature_embedder", None) if not feature_embedder: feature_embedder = PickBestFeatureEmbedder() + kwargs["feature_embedder"] = feature_embedder - super().__init__(feature_embedder=feature_embedder, *args, **kwargs) + super().__init__(*args, **kwargs) def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) @@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]): next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) + v = ( + value[event.selected.index] + if event.selected + else event.to_select_from.values() + ) next_chain_inputs.update( { self.selected_based_on_input_key: str(event.based_on), - self.selected_input_key: value[event.selected.index], + self.selected_input_key: v, } ) return next_chain_inputs, event @@ -234,7 +240,8 @@ class PickBest(base.RLChain[PickBest.Event]): def _call_after_scoring_before_learning( self, event: Event, score: Optional[float] ) -> Event: - event.selected.score = score + if event.selected: + event.selected.score = score return event def _call( @@ -249,34 +256,20 @@ class PickBest(base.RLChain[PickBest.Event]): return "rl_chain_pick_best" @classmethod - def from_chain( - cls, - llm_chain: Chain, + def from_llm( + cls: Type[PickBest], + llm: BaseLanguageModel, prompt: BasePromptTemplate, - selection_scorer=SENTINEL, + selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL, **kwargs: Any, - ): + ) -> PickBest: + llm_chain = LLMChain(llm=llm, prompt=prompt) if selection_scorer is SENTINEL: selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) + return PickBest( llm_chain=llm_chain, prompt=prompt, selection_scorer=selection_scorer, **kwargs, ) - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - prompt: BasePromptTemplate, - selection_scorer=SENTINEL, - **kwargs: Any, - ): - llm_chain = LLMChain(llm=llm, prompt=prompt) - return PickBest.from_chain( - llm_chain=llm_chain, - prompt=prompt, - selection_scorer=selection_scorer, - **kwargs, - )