diff --git a/.gitignore b/.gitignore index 16f764c1b..533450a7c 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ coverage.xml applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb +applications/ColossalChat/model diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7..c6ae7be2d 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -54,7 +54,6 @@ class BaseConsumer: self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() self.lr_scheduler = None @@ -95,7 +94,6 @@ class BaseConsumer: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4174f9651..a282439cb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer): use_wandb=True, generate_config=None, training_config={}, + project_name=None, ): super().__init__( num_producers, @@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer): self.accum_count = 0 self.generate_config = generate_config self.training_config = training_config + self.project_name = project_name # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -94,9 +96,7 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: - name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -107,10 +107,19 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() + if self.use_wandb and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + ): + # Initialize wandb. + name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: """ @@ -168,6 +177,7 @@ class GRPOConsumer(BaseConsumer): ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size @@ -186,112 +196,210 @@ class GRPOConsumer(BaseConsumer): advantages_forward_micro_batch = advantages[ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size ] - policy_model_logits = self.policy_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - with torch.no_grad(): - reference_model_logits = self.reference_model( + if self.plugin.pp_size > 1: + # Support training with PP. + + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "reference_action_log_probs": reference_action_log_probs, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } + + kl = [] + + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=True, + ) + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + mean_kl.append(kl) + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + else: + + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask_forward_micro_batch, - loss_mask=loss_mask_forward_micro_batch, - ) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) - # Calculate accumulate value. - mean_kl.append(kl.data) - mean_loss.append(loss.data) - - reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) - advantages = all_reduce_mean(advantages.mean(), self.plugin) - response_length = all_reduce_mean(response_length.mean(), self.plugin) - self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) - self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) - self.accum_count += 1 + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_reward.add_(reward.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() - loss_scalar = self.accum_loss.item() - if self.rank == 0: - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } - ) - self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() - self.accum_kl.zero_() - self.accum_advantages.zero_() - self.accum_response_length.zero_() + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + loss_scalar = self.accum_loss.item() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "\nReward:", + self.accum_reward.item() / self.accum_count, + "\nFormat Reward:", + self.accum_format_reward.item() / self.accum_count, + "\nAcc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", + self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, + ) + self.wandb_run.log( + { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() - self.accum_count = 0 - return loss_scalar + self.accum_count = 0 + return loss_scalar def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ba5d3a9d4..699d90a8c 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -47,6 +47,7 @@ def launch_distributed( master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", + project_name: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -108,6 +109,7 @@ def launch_distributed( "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 4a4a4c340..f87f12ed2 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,13 +10,44 @@ if __name__ == "__main__": parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-g", "--num-generations", type=int, default=8) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) - parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument( + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", + ) + parser.add_argument( + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=1, + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", + ) + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", + ) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -29,11 +60,7 @@ if __name__ == "__main__": ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict( - path=args.model, - # use_flash_attention_2=True, - # use_cache=False - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -91,9 +118,17 @@ if __name__ == "__main__": generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + # plugin_config={}, # for zero + plugin_config={ + "pp_size": 2, + "tp_size": 1, + "microbatch_size": args.train_microbatch_size // 2, + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29505, + master_port=29506, core_algo=args.algo, + project_name=args.project, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702..74349091b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe..27571309e 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ class Qwen2PipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: