mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-15 16:12:37 +00:00
fix pp+tp, fix dataloader (#6280)
This commit is contained in:
parent
28795f560c
commit
2ca1e3c630
@ -66,7 +66,11 @@ class BaseConsumer:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
if (
|
||||
self.plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in self.plugin_config
|
||||
and "microbatch_size" not in self.plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = self.minibatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
|
@ -49,6 +49,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
UserWarning,
|
||||
)
|
||||
minibatch_size = batch_size
|
||||
if (
|
||||
plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@ -373,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss, num_excessive_samples // self.num_generations
|
||||
return loss
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
@ -468,10 +474,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
@ -507,7 +513,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
else:
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
|
||||
|
@ -33,7 +33,6 @@ def launch_distributed(
|
||||
inference_batch_size: int,
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_microbatch_size: int,
|
||||
train_minibatch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
|
@ -68,6 +68,7 @@ class BaseProducer:
|
||||
seed=42,
|
||||
),
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
)
|
||||
self.device = get_current_device()
|
||||
|
||||
@ -116,7 +117,6 @@ class BaseProducer:
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
):
|
||||
|
@ -198,7 +198,6 @@ if __name__ == "__main__":
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_minibatch_size=args.train_minibatch_size,
|
||||
train_microbatch_size=args.train_microbatch_size,
|
||||
dataset_config={
|
||||
"path": args.dataset,
|
||||
"max_length": args.max_prompt_tokens,
|
||||
@ -216,7 +215,8 @@ if __name__ == "__main__":
|
||||
# currently not support tp/pp
|
||||
# plugin_config={
|
||||
# "tp_size": 2,
|
||||
# "microbatch_size": args.train_microbatch_size // 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(1, args.train_microbatch_size // 2),
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp
|
||||
|
Loading…
Reference in New Issue
Block a user