mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[feature] fit RL style generation (#6213)
* [feature] fit rl style generation * [doc] add docstr * [doc] add docstr
This commit is contained in:
@@ -168,6 +168,11 @@ class SimpleConsumer(BaseConsumer):
|
||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
labels = kwargs["input_ids"].clone()
|
||||
labels[kwargs["attention_mask"] == 0] = -100
|
||||
kwargs["labels"] = labels
|
||||
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
|
||||
|
Reference in New Issue
Block a user