mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
no errors in pick best chain
This commit is contained in:
parent
6a1102d4c0
commit
dd6fff1c62
@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
|
|||||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||||
|
|
||||||
|
|
||||||
def configure_logger():
|
def configure_logger() -> None:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import langchain.chains.rl_chain.base as base
|
import langchain.chains.rl_chain.base as base
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs["vw_cmd"] = vw_cmd
|
kwargs["vw_cmd"] = vw_cmd
|
||||||
|
|
||||||
|
feature_embedder = kwargs.get("feature_embedder", None)
|
||||||
if not feature_embedder:
|
if not feature_embedder:
|
||||||
feature_embedder = PickBestFeatureEmbedder()
|
feature_embedder = PickBestFeatureEmbedder()
|
||||||
|
kwargs["feature_embedder"] = feature_embedder
|
||||||
|
|
||||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||||
@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
next_chain_inputs = event.inputs.copy()
|
next_chain_inputs = event.inputs.copy()
|
||||||
# only one key, value pair in event.to_select_from
|
# only one key, value pair in event.to_select_from
|
||||||
value = next(iter(event.to_select_from.values()))
|
value = next(iter(event.to_select_from.values()))
|
||||||
|
v = (
|
||||||
|
value[event.selected.index]
|
||||||
|
if event.selected
|
||||||
|
else event.to_select_from.values()
|
||||||
|
)
|
||||||
next_chain_inputs.update(
|
next_chain_inputs.update(
|
||||||
{
|
{
|
||||||
self.selected_based_on_input_key: str(event.based_on),
|
self.selected_based_on_input_key: str(event.based_on),
|
||||||
self.selected_input_key: value[event.selected.index],
|
self.selected_input_key: v,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return next_chain_inputs, event
|
return next_chain_inputs, event
|
||||||
@ -234,7 +240,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
def _call_after_scoring_before_learning(
|
def _call_after_scoring_before_learning(
|
||||||
self, event: Event, score: Optional[float]
|
self, event: Event, score: Optional[float]
|
||||||
) -> Event:
|
) -> Event:
|
||||||
event.selected.score = score
|
if event.selected:
|
||||||
|
event.selected.score = score
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
@ -249,34 +256,20 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|||||||
return "rl_chain_pick_best"
|
return "rl_chain_pick_best"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chain(
|
def from_llm(
|
||||||
cls,
|
cls: Type[PickBest],
|
||||||
llm_chain: Chain,
|
llm: BaseLanguageModel,
|
||||||
prompt: BasePromptTemplate,
|
prompt: BasePromptTemplate,
|
||||||
selection_scorer=SENTINEL,
|
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
) -> PickBest:
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
if selection_scorer is SENTINEL:
|
if selection_scorer is SENTINEL:
|
||||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||||
|
|
||||||
return PickBest(
|
return PickBest(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
selection_scorer=selection_scorer,
|
selection_scorer=selection_scorer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate,
|
|
||||||
selection_scorer=SENTINEL,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return PickBest.from_chain(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
prompt=prompt,
|
|
||||||
selection_scorer=selection_scorer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user