mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -96,6 +96,7 @@ class OLTrainer(ABC):
|
||||
self.sample_buffer = sample_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
self.num_train_step = 0
|
||||
|
||||
@contextmanager
|
||||
def _fit_ctx(self) -> None:
|
||||
@@ -212,5 +213,6 @@ class OLTrainer(ABC):
|
||||
self._update_phase(update_step)
|
||||
# NOTE: this is for on-policy algorithms
|
||||
self.data_buffer.clear()
|
||||
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
|
||||
self._save_checkpoint(episode + 1)
|
||||
|
||||
if self.num_train_step > 0 and (self.num_train_step + 1) % (self.save_interval) == 0:
|
||||
self._save_checkpoint(self.num_train_step + 1)
|
||||
|
Reference in New Issue
Block a user