[feat] Support DAPO (#6263)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

* add DAPO support

* remove format reward

* fix filtering, still buggy

* small fix

* add DAPO support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tested multi-node training; fix bind_batch bug

* fix conversation; support sleep mode

* support reusing excessive samples

* add dynamic batching control flag

* add dynamic batching control flag

* refactored

* fix logging

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-04-25 17:39:17 +08:00
committed by GitHub
parent b823c6eec7
commit 26d859f68e
10 changed files with 552 additions and 359 deletions

View File

@@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
class BaseConsumer:
@@ -33,7 +33,7 @@ class BaseConsumer:
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
):
@@ -46,11 +46,11 @@ class BaseConsumer:
self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.minibatch_size = minibatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.model_config = model_config
self.plugin_config = plugin_config
@@ -67,7 +67,7 @@ class BaseConsumer:
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
@@ -105,18 +105,26 @@ class BaseConsumer:
)
)
)
while len(self.buffer) >= self.dp_size * self.microbatch_size:
while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
loss, num_excessive_prompts = self.step(i, pbar, **batch)
self.buffer = (
self.buffer[
(self.dp_rank + 1) * self.minibatch_size
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
]
+ self.buffer[self.dp_size * self.minibatch_size :]
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
@@ -154,7 +162,9 @@ class SimpleConsumer(BaseConsumer):
batch_size,
model_config,
plugin_config,
microbatch_size=1,
minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
):
super().__init__(
num_producers,
@@ -168,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
batch_size,
model_config,
plugin_config,
microbatch_size,
minibatch_size,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -181,7 +191,7 @@ class SimpleConsumer(BaseConsumer):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels