add DAPO support

This commit is contained in:
YeAnbang 2025-04-15 18:28:35 +08:00
parent 9474316132
commit c0c0da2f26
3 changed files with 9 additions and 7 deletions

View File

@ -112,7 +112,7 @@ class BaseConsumer:
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
loss = self.step(i, pbar, **batch)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
@ -181,7 +181,7 @@ class SimpleConsumer(BaseConsumer):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels

View File

@ -1,7 +1,7 @@
import json
import os
from contextlib import nullcontext
from typing import Optional
from typing import Optional, Any
import ray
import torch
@ -144,7 +144,7 @@ class GRPOConsumer(BaseConsumer):
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]:
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
@ -212,6 +212,7 @@ class GRPOConsumer(BaseConsumer):
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"})
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
@ -409,6 +410,7 @@ class GRPOConsumer(BaseConsumer):
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0
self.total_sample_count = 0

View File

@ -83,7 +83,7 @@ if __name__ == "__main__":
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update(
dict(
max_tokens=2048,
max_tokens=4096,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
@ -120,8 +120,8 @@ if __name__ == "__main__":
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 1024 * 2,
"cache_length": 256,
"max_length": 4096,
"cache_length": 512,
"filter_truncated_response": True,
}