mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
@@ -33,7 +33,7 @@ class BaseConsumer:
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
@@ -46,11 +46,11 @@ class BaseConsumer:
|
||||
self.num_update_per_episode = num_update_per_episode
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
self.minibatch_size = minibatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
@@ -67,7 +67,7 @@ class BaseConsumer:
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
plugin_config["microbatch_size"] = self.minibatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
@@ -105,18 +105,26 @@ class BaseConsumer:
|
||||
)
|
||||
)
|
||||
)
|
||||
while len(self.buffer) >= self.dp_size * self.microbatch_size:
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = pad_batch(
|
||||
batches
|
||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, **batch)
|
||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||
self.buffer = (
|
||||
self.buffer[
|
||||
(self.dp_rank + 1) * self.minibatch_size
|
||||
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
+ self.buffer[self.dp_size * self.minibatch_size :]
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
@@ -154,7 +162,9 @@ class SimpleConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
minibatch_size=1,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@@ -168,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
minibatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@@ -181,7 +191,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
super().setup()
|
||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
labels = kwargs["input_ids"].clone()
|
||||
labels[kwargs["attention_mask"] == 0] = -100
|
||||
kwargs["labels"] = labels
|
||||
|
Reference in New Issue
Block a user