diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a7abb1588..98aa3a3c4 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -6,6 +6,7 @@ import ray import ray.util.collective as cc import torch import torch.distributed as dist +from coati.distributed.profiling_utils import CustomProfiler from tqdm import tqdm from transformers import AutoModelForCausalLM @@ -36,6 +37,8 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", + enable_profiling: bool = False, + n_behind: int = 0, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -49,6 +52,7 @@ class BaseConsumer: self.minibatch_size = minibatch_size self.save_interval = save_interval self.save_dir = save_dir + self.enable_profiling = enable_profiling assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size @@ -57,6 +61,7 @@ class BaseConsumer: self.device = get_current_device() self.lr_scheduler = None + self.n_behind = n_behind def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) @@ -94,6 +99,7 @@ class BaseConsumer: self.buffer = [] self.recv_cnt = 0 + self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling) def state_dict(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -112,14 +118,17 @@ class BaseConsumer: disable=self.rank != 0, ) as pbar: for step in pbar: + torch.cuda.reset_peak_memory_stats() i = 0 for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") + self.profiler.enter(f"recv_broadcast_data_P{r}") raw_batch = ray_broadcast_tensor_dict( None, src=0, device=self.device, group_name=f"sync_data_{r}" ) + self.profiler.exit(f"recv_broadcast_data_P{r}") # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] @@ -192,7 +201,9 @@ class BaseConsumer: } batch = bind_batch([t[0] for t in batches]) batch = post_recv(batch) + self.profiler.enter("step") loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") self.buffer = self.buffer[ effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : ] @@ -221,13 +232,16 @@ class BaseConsumer: if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") - if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: + if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( + episode != 0 or step >= self.n_behind + ): if self.pp_size > 1: print( f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" ) else: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + self.profiler.enter("sync_model") torch.cuda.empty_cache() state_dict = self.state_dict() if self.pp_size > 1: @@ -245,6 +259,12 @@ class BaseConsumer: ) del state_dict torch.cuda.empty_cache() + self.profiler.exit("sync_model") + self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + + def __del__(self): + if hasattr(self, "profiler"): + self.profiler.close() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5e8f329eb..5d4b95fda 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -38,6 +38,8 @@ class GRPOConsumer(BaseConsumer): project_name: str = None, run_name: str = None, wandb_group_name: str = None, + enable_profiling: bool = False, + n_behind: int = 0, ): print(f"Using GRPO config: {grpo_config}") if ( @@ -63,6 +65,8 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, + enable_profiling=enable_profiling, + n_behind=n_behind, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index dc8bf0057..a48246c87 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -57,6 +57,8 @@ def launch_distributed( eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, rollout_save_dir: str = "./rollout", + enable_profiling: bool = False, + n_behind: int = 0, ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -132,6 +134,8 @@ def launch_distributed( wandb_group_name=wandb_group_name, log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, ) producer_procs.append(producer) ray.get([p.setup.remote() for p in producer_procs]) @@ -171,6 +175,8 @@ def launch_distributed( project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + enable_profiling=enable_profiling, + n_behind=n_behind, ) consumer_procs.append(consumer) ray.get([p.setup.remote() for p in consumer_procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 854c2fcc2..1b23c463d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -9,6 +9,7 @@ import torch import tqdm import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from ray.util.collective import allreduce @@ -52,6 +53,8 @@ class BaseProducer: wandb_group_name: str = None, log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -62,6 +65,7 @@ class BaseProducer: assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size self.latest_eval_step = -1 + self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -75,6 +79,7 @@ class BaseProducer: self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 self.grpo_config = grpo_config + self.n_behind = n_behind reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -268,11 +273,14 @@ class BaseProducer: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.eval_mode = False self.latest_eval_step = self.consumer_global_step + self.profiler.enter("rollout") outputs = self.rollout(**batch) + self.profiler.exit("rollout") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) + self.profiler.enter("calculate_reward") if self.grpo_config["reward_fn_type"] == "code": test_cases = [] for prompt_id in range(bs): @@ -310,20 +318,26 @@ class BaseProducer: outputs.pop("gt_answer") if "test_cases" in outputs: outputs.pop("test_cases") + self.profiler.exit("calculate_reward") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) + self.profiler.enter("send_broadcast_data") ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) - if (i + 1) % self.num_microbatches == 0 and ( - episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 + self.profiler.exit("send_broadcast_data") + if ( + (i + 1) % self.num_microbatches == 0 + and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1) + and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches) ): if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration + self.profiler.enter("sync_model") torch.cuda.empty_cache() if self.consumer_pp_size > 1: @@ -349,6 +363,7 @@ class BaseProducer: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() + self.profiler.exit("sync_model") if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): @@ -364,6 +379,9 @@ class BaseProducer: "temperature" ] + ratio * 0.9 + def __del__(self): + self.profiler.close() + @ray.remote class SimpleProducer(BaseProducer): @@ -392,6 +410,8 @@ class SimpleProducer(BaseProducer): wandb_group_name: str = None, log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, ): super().__init__( producer_idx, @@ -415,6 +435,8 @@ class SimpleProducer(BaseProducer): wandb_group_name=wandb_group_name, log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) diff --git a/applications/ColossalChat/coati/distributed/profiling_utils.py b/applications/ColossalChat/coati/distributed/profiling_utils.py new file mode 100644 index 000000000..4c4621a48 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/profiling_utils.py @@ -0,0 +1,33 @@ +class CustomProfiler: + def __init__(self, name, disabled=True): + self.disabled = disabled + if not disabled: + self.name = name + self.pid = os.getpid() + self.file = open(f"{name}.prof", "w") + + def _log(self, message): + if self.disabled: + return + current_time = time.time() + self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + + def log(self, message): + if self.disabled: + return + current_time = time.time() + self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + + def enter(self, event_name): + self._log(f"Enter {event_name}") + + def exit(self, event_name): + self._log(f"Exit {event_name}") + + def close(self): + if self.disabled: + return + self.file.close() + print(f"Profiler data written to {self.name}.prof") diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh new file mode 100755 index 000000000..1123f97f2 --- /dev/null +++ b/applications/ColossalChat/profiling.sh @@ -0,0 +1,7 @@ +export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + +# 8K context length +rm -rf *.prof +MAX_NEW_TOKENS=$((8192-512)) +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt +python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index c60923e00..da381f8a7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -67,6 +67,27 @@ if __name__ == "__main__": default=2, help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) parser.add_argument( "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" ) @@ -97,6 +118,13 @@ if __name__ == "__main__": parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") + parser.add_argument( + "-ptp", + "--producer-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) # GRPO parameters parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) @@ -117,6 +145,13 @@ if __name__ == "__main__": default=100, help="Interval for evaluation. Evaluate every ei training steps.", ) + parser.add_argument( + "-nb", + "--n-behind", + type=int, + default=0, + help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") @@ -128,32 +163,7 @@ if __name__ == "__main__": "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." ) parser.add_argument( - "-tp", - "--tensor-parallel-size", - type=int, - default=1, - help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-pp", - "--pipeline-parallel-size", - type=int, - default=1, - help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-zero", - "--zero-stage", - type=int, - default=0, - help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-ptp", - "--producer-tensor-parallel-size", - type=int, - default=1, - help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) args = parser.parse_args() @@ -353,4 +363,6 @@ if __name__ == "__main__": eval_generation_config=eval_generation_config, log_rollout_interval=20, rollout_save_dir=args.rollout_save_dir, + enable_profiling=args.enable_profiling, + n_behind=args.n_behind, ) diff --git a/applications/ColossalChat/visualization.py b/applications/ColossalChat/visualization.py new file mode 100644 index 000000000..bde49303f --- /dev/null +++ b/applications/ColossalChat/visualization.py @@ -0,0 +1,100 @@ +# Re-import required libraries due to kernel reset +import argparse +from collections import defaultdict + +import matplotlib.cm as cm +import matplotlib.pyplot as plt + +# Argument parser for command line arguments +parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.") +parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.") +args = parser.parse_args() + +# Raw log lines +log_lines = [] + +import glob + +files = glob.glob("*.prof") +for file in files: + with open(file, "r") as f: + log_lines += f.readlines() + +# Parse logs and collect function intervals grouped by actor +actors = defaultdict(lambda: defaultdict(list)) +current_entries = {} + +# First, collect all timestamps to find the minimum +all_timestamps = [] +parsed_lines = [] + +for line in log_lines: + if line.startswith("[Log]"): + continue + parts = line.split() + timestamp = float(parts[0]) + actor = parts[1] + action = parts[3] + func_name = parts[4] + parsed_lines.append((timestamp, actor, action, func_name)) + all_timestamps.append(timestamp) + +if not all_timestamps: + raise ValueError("No valid log entries found.") + +min_timestamp = min(all_timestamps) + +for timestamp, actor, action, func_name in parsed_lines: + rel_timestamp = timestamp - min_timestamp + key = (actor, func_name) + if action == "Enter": + current_entries[key] = rel_timestamp + elif action == "Exit": + start_time = current_entries.pop(key, None) + if start_time is not None: + actors[actor][func_name].append((start_time, rel_timestamp)) + +# Plotting setup +fig, ax = plt.subplots(figsize=(12, 6)) +colors = cm.get_cmap("tab10", len(actors)) + +actor_offsets = {} +base_offset = 0 +function_spacing = 0.9 + +yticks = [] +yticklabels = [] + +for idx, (actor, func_dict) in enumerate(actors.items()): + actor_offsets[actor] = base_offset + color = colors(idx) + for j, (func, intervals) in enumerate(func_dict.items()): + print(actor, func, intervals) + y_val = base_offset + j * function_spacing + yticks.append(y_val) + yticklabels.append(f"{actor}:{func}") + for start, end in intervals: + if end - start < 100: + end = start + 100 # Ensure minimum length of 100ms + ax.plot( + [start, end], + [y_val, y_val], + color=color, + linewidth=2, + label=actor if j == 0 else "", + ) + base_offset += len(func_dict) * function_spacing + 1 + +# Formatting +ax.set_yticks(yticks) +ax.set_yticklabels(yticklabels) +ax.set_xlabel("Time") +ax.set_title("Timeline per Actor") +# Remove duplicate labels in legend +handles, labels = ax.get_legend_handles_labels() +unique = dict(zip(labels, handles)) +ax.legend(unique.values(), unique.keys()) +plt.tight_layout() +plt.grid(True) +plt.savefig(args.visualization) +print(f"Plot saved as {args.visualization}")