mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
add DAPO support
This commit is contained in:
parent
9474316132
commit
c0c0da2f26
@ -112,7 +112,7 @@ class BaseConsumer:
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, **batch)
|
||||
loss = self.step(i, pbar, **batch)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
@ -181,7 +181,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
super().setup()
|
||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
labels = kwargs["input_ids"].clone()
|
||||
labels[kwargs["attention_mask"] == 0] = -100
|
||||
kwargs["labels"] = labels
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -144,7 +144,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
Step data from policy model:
|
||||
[{
|
||||
@ -212,6 +212,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||
# balance between efficiency and accuracy
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
||||
pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"})
|
||||
|
||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||
ctx = (
|
||||
@ -409,6 +410,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=2048,
|
||||
max_tokens=4096,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
@ -120,8 +120,8 @@ if __name__ == "__main__":
|
||||
"beta": 0.0, # no KL penalty
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": 1024 * 2,
|
||||
"cache_length": 256,
|
||||
"max_length": 4096,
|
||||
"cache_length": 512,
|
||||
"filter_truncated_response": True,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user