fix pp+tp, fix dataloader (#6280)

This commit is contained in:
YeAnbang 2025-04-28 17:10:00 +08:00 committed by GitHub
parent 28795f560c
commit 2ca1e3c630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 17 additions and 8 deletions

View File

@ -66,7 +66,11 @@ class BaseConsumer:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
if (
self.plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in self.plugin_config
and "microbatch_size" not in self.plugin_config
):
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)

View File

@ -49,6 +49,12 @@ class GRPOConsumer(BaseConsumer):
UserWarning,
)
minibatch_size = batch_size
if (
plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
super().__init__(
num_producers,
num_episodes,
@ -373,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
)
return loss, num_excessive_samples // self.num_generations
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
@ -468,10 +474,10 @@ class GRPOConsumer(BaseConsumer):
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0
self.total_sample_count = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
@ -507,7 +513,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar, num_excessive_samples // self.num_generations
return loss_scalar, num_excessive_samples // self.num_generations
else:
return None, num_excessive_samples // self.num_generations

View File

@ -33,7 +33,6 @@ def launch_distributed(
inference_batch_size: int,
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],

View File

@ -68,6 +68,7 @@ class BaseProducer:
seed=42,
),
num_workers=4,
drop_last=True,
)
self.device = get_current_device()
@ -116,7 +117,6 @@ class BaseProducer:
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
):

View File

@ -198,7 +198,6 @@ if __name__ == "__main__":
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={
"path": args.dataset,
"max_length": args.max_prompt_tokens,
@ -216,7 +215,8 @@ if __name__ == "__main__":
# currently not support tp/pp
# plugin_config={
# "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2,
# "pp_size": 2,
# "microbatch_size": max(1, args.train_microbatch_size // 2),
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp