no errors in pick best chain

This commit is contained in:
olgavrou 2023-08-28 08:13:23 -04:00
parent 6a1102d4c0
commit dd6fff1c62
2 changed files with 21 additions and 28 deletions

View File

@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
from langchain.chains.rl_chain.pick_best_chain import PickBest from langchain.chains.rl_chain.pick_best_chain import PickBest
def configure_logger(): def configure_logger() -> None:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Type, Union
import langchain.chains.rl_chain.base as base import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]):
def __init__( def __init__(
self, self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
): ):
@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]):
raise ValueError( raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf" "If vw_cmd is specified, it must include --cb_explore_adf"
) )
kwargs["vw_cmd"] = vw_cmd kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if not feature_embedder: if not feature_embedder:
feature_embedder = PickBestFeatureEmbedder() feature_embedder = PickBestFeatureEmbedder()
kwargs["feature_embedder"] = feature_embedder
super().__init__(feature_embedder=feature_embedder, *args, **kwargs) super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs) context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]):
next_chain_inputs = event.inputs.copy() next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from # only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values())) value = next(iter(event.to_select_from.values()))
v = (
value[event.selected.index]
if event.selected
else event.to_select_from.values()
)
next_chain_inputs.update( next_chain_inputs.update(
{ {
self.selected_based_on_input_key: str(event.based_on), self.selected_based_on_input_key: str(event.based_on),
self.selected_input_key: value[event.selected.index], self.selected_input_key: v,
} }
) )
return next_chain_inputs, event return next_chain_inputs, event
@ -234,6 +240,7 @@ class PickBest(base.RLChain[PickBest.Event]):
def _call_after_scoring_before_learning( def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float] self, event: Event, score: Optional[float]
) -> Event: ) -> Event:
if event.selected:
event.selected.score = score event.selected.score = score
return event return event
@ -249,34 +256,20 @@ class PickBest(base.RLChain[PickBest.Event]):
return "rl_chain_pick_best" return "rl_chain_pick_best"
@classmethod @classmethod
def from_chain( def from_llm(
cls, cls: Type[PickBest],
llm_chain: Chain, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
selection_scorer=SENTINEL, selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
**kwargs: Any, **kwargs: Any,
): ) -> PickBest:
llm_chain = LLMChain(llm=llm, prompt=prompt)
if selection_scorer is SENTINEL: if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest( return PickBest(
llm_chain=llm_chain, llm_chain=llm_chain,
prompt=prompt, prompt=prompt,
selection_scorer=selection_scorer, selection_scorer=selection_scorer,
**kwargs, **kwargs,
) )
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)