add DAPO support

This commit is contained in:
YeAnbang 2025-04-15 18:28:35 +08:00
parent 9474316132
commit c0c0da2f26
3 changed files with 9 additions and 7 deletions

View File

@ -112,7 +112,7 @@ class BaseConsumer:
self.buffer = self.buffer[self.dp_size * self.microbatch_size :] self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches) batch = bind_batch(batches)
batch = post_recv(batch) batch = post_recv(batch)
loss = self.step(i, **batch) loss = self.step(i, pbar, **batch)
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
@ -181,7 +181,7 @@ class SimpleConsumer(BaseConsumer):
super().setup() super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone() labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100 labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels kwargs["labels"] = labels

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional, Any
import ray import ray
import torch import torch
@ -144,7 +144,7 @@ class GRPOConsumer(BaseConsumer):
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR") self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
""" """
Step data from policy model: Step data from policy model:
[{ [{
@ -212,6 +212,7 @@ class GRPOConsumer(BaseConsumer):
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy # balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"})
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = ( ctx = (
@ -409,6 +410,7 @@ class GRPOConsumer(BaseConsumer):
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0 self.effective_sample_count = 0
self.total_sample_count = 0 self.total_sample_count = 0

View File

@ -83,7 +83,7 @@ if __name__ == "__main__":
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update( generate_config.update(
dict( dict(
max_tokens=2048, max_tokens=4096,
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
@ -120,8 +120,8 @@ if __name__ == "__main__":
"beta": 0.0, # no KL penalty "beta": 0.0, # no KL penalty
"loss_variation": "token_level", "loss_variation": "token_level",
"soft_over_length_punishment": True, "soft_over_length_punishment": True,
"max_length": 1024 * 2, "max_length": 4096,
"cache_length": 256, "cache_length": 512,
"filter_truncated_response": True, "filter_truncated_response": True,
} }