From d961a5f7251f377856fc2def74dc1de98269fe49 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 1 Apr 2025 11:24:09 +0800 Subject: [PATCH] support PP training --- .../coati/distributed/consumer.py | 1 - .../coati/distributed/grpo_consumer.py | 166 +++++++++++------- applications/ColossalChat/rl_example.py | 16 +- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/modeling/qwen2.py | 1 + 5 files changed, 121 insertions(+), 69 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7..4e1cd1f31 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -54,7 +54,6 @@ class BaseConsumer: self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() self.lr_scheduler = None diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4174f9651..d05709feb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -96,7 +96,7 @@ class GRPOConsumer(BaseConsumer): self.global_step = 0 if use_wandb and self.rank == 0: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -168,72 +168,120 @@ class GRPOConsumer(BaseConsumer): ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): - input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] - if loss_mask is not None - else None - ) - advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - policy_model_logits = self.policy_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.plugin.pp_size > 1: + # Support training with PP. + data_iter = iter([data]) with torch.no_grad(): - reference_model_logits = self.reference_model( + reference_model_outputs = self.booster.execute_pipeline( + data_iter, + self.reference_model, + criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = torch.zeros( + (old_action_log_probs.size(0), old_action_log_probs.size(1)) + ) + + data["reference_action_log_probs"] = reference_action_log_probs + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + pass + + outputs = self.booster.execute_pipeline( + data_iter, + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + else: + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask_forward_micro_batch, - loss_mask=loss_mask_forward_micro_batch, - ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) - # Calculate accumulate value. - mean_kl.append(kl.data) - mean_loss.append(loss.data) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bb719a13c..2b6faaa4a 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -31,7 +31,13 @@ if __name__ == "__main__": default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Number of samples per device. PP micro batchsize when PP is activated.", + ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -45,11 +51,7 @@ if __name__ == "__main__": ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict( - path=args.model, - # use_flash_attention_2=True, - # use_cache=False - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -107,7 +109,7 @@ if __name__ == "__main__": generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0}, inference_backend=args.backend, master_addr="localhost", master_port=29505, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702..74349091b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe..27571309e 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ class Qwen2PipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: