[feature] fit RL style generation (#6213)

* [feature] fit rl style generation

* [doc] add docstr

* [doc] add docstr
This commit is contained in:
Hongxin Liu
2025-02-21 17:28:19 +08:00
committed by GitHub
parent 43c9b5fb44
commit de282dd694
3 changed files with 140 additions and 49 deletions

View File

@@ -168,6 +168,11 @@ class SimpleConsumer(BaseConsumer):
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)