diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 8a2221b92..026056783 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -120,24 +120,85 @@ class BaseConsumer: raw_batch = unbind_batch( ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") ) - processed_batch = [ - self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch - ] - filtered_batch = [t for t in processed_batch if t is not None] + recv_effective_count = 0 + # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), + # we need to calculate the metrics before filtering here for logging + for group in raw_batch: + group_with_reward = self.calculate_group_reward(group) + group_reward_mean = group_with_reward["reward"].mean().cpu().item() + group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item() + group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item() + group_response_len = ( + ( + group_with_reward["response_idx"][:, 1] + - group_with_reward["response_idx"][:, 0] + + 1 + ) + .type(torch.float32) + .mean() + .cpu() + .item() + ) + filtered_group = self.prompt_level_filtering(group_with_reward) + recv_effective_count += 1 if filtered_group is not None else 0 + self.buffer.append( + [ + filtered_group, + group_reward_mean, + group_format_acc_mean, + group_ans_acc_mean, + group_response_len, + ] + ) if self.filter_range is not None: print( - f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}" + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}" ) + # mapping the effective group to the raw group for indexing + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) - self.buffer.extend(filtered_batch) - while len(self.buffer) >= self.dp_size * self.minibatch_size: - batches = self.buffer[ - self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + # on each dp_rank, we use minibatch_size effective samples to form a batch + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range( + self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size + ) ] - batch = bind_batch(batches) + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) batch = post_recv(batch) - loss = self.step(i, pbar, **batch) - self.buffer = self.buffer[self.dp_size * self.minibatch_size :] + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size + ) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index bc2bd45d1..369023977 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -72,12 +72,12 @@ class GRPOConsumer(BaseConsumer): self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) - self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_acc = torch.zeros(1, device=self.device) - self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) - self.accum_response_length = torch.zeros(1, device=self.device) + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] self.accum_count = 0 self.generate_config = generate_config self.grpo_config = grpo_config @@ -186,7 +186,11 @@ class GRPOConsumer(BaseConsumer): [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] - data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k} + self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"]) + self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"]) + self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"]) + self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"]) action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] @@ -430,11 +434,7 @@ class GRPOConsumer(BaseConsumer): self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_acc.add_(format_acc.data) - self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() @@ -452,21 +452,33 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): + raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward) + raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len( + self.raw_train_batch_format_acc + ) + raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc) + raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len( + self.raw_train_batch_response_len + ) + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] to_log_msg = [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", - f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", - f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", - f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Reward: {raw_batch_reward_mean:.4f}", + f"format Reward: {raw_batch_format_acc_mean:.4f}", + f"Acc Reward: {raw_batch_ans_acc_mean:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", - f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, - "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "metrics/reward": raw_batch_reward_mean, + "metrics/format_acc": raw_batch_format_acc_mean, + "metrics/ans_acc": raw_batch_ans_acc_mean, + "metrics/response_length": raw_batch_response_len_mean, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], @@ -478,12 +490,8 @@ class GRPOConsumer(BaseConsumer): if self.wandb_run is not None: self.wandb_run.log(metrics) self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_ans_acc.zero_() - self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() - self.accum_response_length.zero_() self.accum_count = 0 return loss_scalar else: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 6eeb5d379..220a47eb8 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,4 +1,5 @@ import copy +import os import uuid from typing import Any, Dict, Optional @@ -56,7 +57,7 @@ def launch_distributed( eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, - rollout_log_file: str = "./rollout_log.jsonl", + rollout_save_dir: str = "./rollout", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -74,6 +75,10 @@ def launch_distributed( run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) + rollout_log_file = os.path.join( + rollout_save_dir, + f"{project_name}_run_{wandb_group_name}.jsonl", + ) procs = [] for i in range(num_producers): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a97c2b210..53f166d66 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,6 +121,34 @@ if __name__ == "__main__": parser.add_argument( "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-ptp", + "--produce-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -178,7 +206,7 @@ if __name__ == "__main__": enforce_eager=True, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, - tensor_parallel_size=1, + tensor_parallel_size=args.produce_tensor_parallel_size, ) ) generate_config.update( @@ -228,7 +256,7 @@ if __name__ == "__main__": launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size), num_consumer_procs=args.num_trainers, num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size, @@ -247,17 +275,14 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "zero_stage": 2, - }, # for zero - # plugin_config={ - # "tp_size": 2, - # "pp_size": 2, - # "microbatch_size": max( - # 1, args.train_microbatch_size // 2 - # ), # microbatch size should be set to train_microbatch_size // pp_size - # "zero_stage": 0, - # "max_norm": 1.0, - # }, # for pp, tp + "tp_size": args.tensor_parallel_size, + "pp_size": args.pipeline_parallel_size, + "microbatch_size": max( + 1, args.train_microbatch_size // args.pipeline_parallel_size + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": args.zero_stage, + "max_norm": 1.0, + }, # for pp, tp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port, @@ -273,5 +298,5 @@ if __name__ == "__main__": eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, log_rollout_interval=20, - rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), + rollout_save_dir=args.rollout_save_dir, )