mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,8 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
# TODO(ver217): add prefetch
|
||||
self.items: List[BufferItem] = []
|
||||
self.rng_sequence = []
|
||||
self.ptr = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
@@ -40,6 +42,9 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
if samples_to_remove > 0:
|
||||
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
|
||||
self.items = self.items[samples_to_remove:]
|
||||
self.rng_sequence = [i for i in range(len(self.items))]
|
||||
random.shuffle(self.rng_sequence)
|
||||
self.ptr = 0
|
||||
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
@@ -52,7 +57,10 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
Returns:
|
||||
A batch of sampled experiences.
|
||||
"""
|
||||
items = random.sample(self.items, self.sample_batch_size)
|
||||
items = []
|
||||
for _ in range(self.sample_batch_size):
|
||||
self.ptr = (self.ptr + 1) % len(self.items)
|
||||
items.append(self.items[self.rng_sequence[self.ptr]])
|
||||
experience = make_experience_batch(items)
|
||||
if self.cpu_offload:
|
||||
experience.to_device(self.target_device)
|
||||
|
Reference in New Issue
Block a user