Merge pull request #6 from VowpalWabbit/cb_defaults

cb defaults and some fixes
This commit is contained in:
olgavrou 2023-08-29 08:47:28 -04:00 committed by GitHub
commit 5fb781dfde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 16 deletions

View File

@ -60,7 +60,7 @@ jobs:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
# - "3.11"
name: Python ${{ matrix.python-version }} extended tests
steps:
- uses: actions/checkout@v3

View File

@ -227,15 +227,17 @@ class Embedder(Generic[TEvent], ABC):
...
class SelectionScorer(ABC, BaseModel):
class SelectionScorer(Generic[TEvent], ABC, BaseModel):
"""Abstract method to grade the chosen selection or the response of the llm"""
@abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: TEvent
) -> float:
...
class AutoSelectionScorer(SelectionScorer, BaseModel):
class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
llm_chain: LLMChain
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
@ -254,7 +256,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
def get_default_prompt() -> ChatPromptTemplate:
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
as the most important attribute, rank how good or bad this text is: \
"{llm_response}".'
"{rl_chain_selected}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
chat_prompt = ChatPromptTemplate.from_messages(
@ -281,7 +283,9 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
return values
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: Event
) -> float:
ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
ranking = ranking.strip()
try:
@ -304,7 +308,7 @@ class RLChain(Chain, Generic[TEvent]):
- prompt (BasePromptTemplate): The template for the base prompt.
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
- auto_embed (bool): Determines if embedding should be automatic. Default is True.
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
Initialization Attributes:
@ -338,7 +342,7 @@ class RLChain(Chain, Generic[TEvent]):
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
active_policy: Policy = _NoOpPolicy()
auto_embed: bool = True
auto_embed: bool = False
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None
@ -492,7 +496,7 @@ class RLChain(Chain, Generic[TEvent]):
try:
if self.selection_scorer:
score = self.selection_scorer.score_response(
inputs=next_chain_inputs, llm_response=output
inputs=next_chain_inputs, llm_response=output, event=event
)
except Exception as e:
logger.info(
@ -553,7 +557,7 @@ def embed_string_type(
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
"""Helper function to embed a dictionary item."""
inner_dict: Dict[str, Any] = {}
inner_dict: Dict = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
@ -568,10 +572,17 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
ret_list: List[Dict[str, Union[str, List[str]]]] = []
ret_list: List = []
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))
elif isinstance(embed_item, list):
item_embedding = embed_list_type(embed_item, model, namespace)
# Get the first key from the first dictionary
first_key = next(iter(item_embedding[0]))
# Group the values under that key
grouping = {first_key: [item[first_key] for item in item_embedding]}
ret_list.append(grouping)
else:
ret_list.append(embed_string_type(embed_item, model, namespace))
return ret_list

View File

@ -161,7 +161,7 @@ class PickBest(base.RLChain[PickBestEvent]):
"--quiet",
"--interactions=::",
"--coin",
"--epsilon=0.2",
"--squarecb",
]
else:
if "--cb_explore_adf" not in vw_cmd:

View File

@ -140,7 +140,12 @@ def test_user_defined_scorer() -> None:
llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
def score_response(
self,
inputs: Dict[str, Any],
llm_response: str,
event: pick_best_chain.PickBestEvent,
) -> float:
score = 200
return score
@ -161,11 +166,11 @@ def test_user_defined_scorer() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings() -> None:
def test_auto_embeddings_on() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"
@ -194,6 +199,32 @@ def test_default_embeddings() -> None:
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_auto_embedder_is_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
str1 = "0"
str2 = "1"
str3 = "2"
ctx_str_1 = "context1"
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """ # noqa
actions = [str1, str2, str3]
response = chain.run(
User=pick_best_chain.base.BasedOn(ctx_str_1),
action=pick_best_chain.base.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off() -> None:
llm, PROMPT = setup()
@ -225,7 +256,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"