From b1f646c7e75552c8b9a87d21c55f0c97cf1f7e7e Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:21:08 +0800 Subject: [PATCH] [feat[ Support one-behind to reduce bubble time. Add profiling code (#6353) * support n_behind, add profiling * fix bugs * fix visualization * fix behind * fix loop issue * add profiling * fix update * update assert * remove assert --------- Co-authored-by: Tong Li --- .../coati/distributed/consumer.py | 171 +++++++++++++----- .../coati/distributed/grpo_consumer.py | 4 + .../ColossalChat/coati/distributed/launch.py | 6 + .../coati/distributed/producer.py | 27 ++- .../coati/distributed/profiling_utils.py | 37 ++++ applications/ColossalChat/profiling.sh | 13 ++ applications/ColossalChat/rl_example.py | 89 +++++---- applications/ColossalChat/visualization.py | 100 ++++++++++ 8 files changed, 365 insertions(+), 82 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/profiling_utils.py create mode 100755 applications/ColossalChat/profiling.sh create mode 100644 applications/ColossalChat/visualization.py diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a7abb1588..e360392e7 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -6,6 +6,7 @@ import ray import ray.util.collective as cc import torch import torch.distributed as dist +from coati.distributed.profiling_utils import CustomProfiler from tqdm import tqdm from transformers import AutoModelForCausalLM @@ -36,6 +37,8 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", + enable_profiling: bool = False, + n_behind: int = 0, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -49,6 +52,7 @@ class BaseConsumer: self.minibatch_size = minibatch_size self.save_interval = save_interval self.save_dir = save_dir + self.enable_profiling = enable_profiling assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size @@ -57,6 +61,7 @@ class BaseConsumer: self.device = get_current_device() self.lr_scheduler = None + self.n_behind = n_behind def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) @@ -94,6 +99,7 @@ class BaseConsumer: self.buffer = [] self.recv_cnt = 0 + self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling) def state_dict(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -101,6 +107,41 @@ class BaseConsumer: def step(self, step_idx: int, **kwargs) -> Optional[float]: raise NotImplementedError + def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]: + """ + Prepare a mini-batch from the effective group to raw group mapping. + This method is used to create a mini-batch for training. + """ + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size) + ] + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) + batch = post_recv(batch) + return batch, raw_mini_batches_metric_dict + + def calculate_effective_group_to_raw_group_mapping(self, step): + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + if self.n_behind == 0: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + else: + if self.buffer[buffer_idx][-1] <= step - self.n_behind: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + return effective_group_to_raw_group_mapping + def loop(self) -> None: print( f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" @@ -112,14 +153,53 @@ class BaseConsumer: disable=self.rank != 0, ) as pbar: for step in pbar: + torch.cuda.reset_peak_memory_stats() i = 0 + + self.profiler.enter(f"rollout_episode_{episode}_step_{step}") for _ in range(self.num_recv_per_update): + if self.n_behind > 0: + # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step + ) + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + self.profiler.log( + f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" + ) + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) + self.profiler.enter("step") + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len( + effective_group_to_raw_group_mapping + ) + effective_group_to_raw_group_mapping = ( + self.calculate_effective_group_to_raw_group_mapping(step=step) + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before + - self.dp_size * self.minibatch_size + ) + if loss is not None: + pbar.set_postfix({"loss": loss}) + i += 1 + # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") + self.profiler.enter(f"recv_broadcast_data_P{r}") raw_batch = ray_broadcast_tensor_dict( None, src=0, device=self.device, group_name=f"sync_data_{r}" ) + self.profiler.exit(f"recv_broadcast_data_P{r}") # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] @@ -154,6 +234,7 @@ class BaseConsumer: format_acc[group_idx], ans_acc[group_idx], response_len[group_idx], + step, ] ) if effective_group_mask is not None: @@ -161,56 +242,44 @@ class BaseConsumer: f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) # mapping the effective group to the raw group for indexing - effective_group_to_raw_group_mapping = {} - for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( - buffer_idx - ) + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step + ) print( f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" ) - while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: - # on each dp_rank, we use minibatch_size effective samples to form a batch - batches = [ - self.buffer[effective_group_to_raw_group_mapping[i]] - for i in range( - self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size + if self.n_behind == 0: + # If n_behind is 0, we start training after receiving data from producers. + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + self.profiler.log( + f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" ) - ] - # every dp_rank will receive a complete mini-batch, no need to sync within step() later - # each mini-batch use the first self.dp_size * minibatch_size effective samples - raw_mini_batches = self.buffer[ - : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 - ] # include the last effective sample - raw_mini_batches_metric_dict = { - "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], - "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], - "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], - "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], - } - batch = bind_batch([t[0] for t in batches]) - batch = post_recv(batch) - loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) - self.buffer = self.buffer[ - effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : - ] - # recalculate the effective group to raw group mapping - effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) - effective_group_to_raw_group_mapping = {} - for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( - buffer_idx - ) - assert ( - len(effective_group_to_raw_group_mapping) - == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size - ) - if loss is not None: - pbar.set_postfix({"loss": loss}) - i += 1 + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) + self.profiler.enter("step") + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len( + effective_group_to_raw_group_mapping + ) + effective_group_to_raw_group_mapping = ( + self.calculate_effective_group_to_raw_group_mapping(step=step) + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before + - self.dp_size * self.minibatch_size + ) + if loss is not None: + pbar.set_postfix({"loss": loss}) + i += 1 + if self.lr_scheduler is not None: self.lr_scheduler.step() if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: @@ -221,13 +290,16 @@ class BaseConsumer: if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") - if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: + if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( + episode != 0 or step >= self.n_behind + ): if self.pp_size > 1: print( f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" ) else: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + self.profiler.enter("sync_model") torch.cuda.empty_cache() state_dict = self.state_dict() if self.pp_size > 1: @@ -245,6 +317,13 @@ class BaseConsumer: ) del state_dict torch.cuda.empty_cache() + self.profiler.exit("sync_model") + self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + self.profiler.exit(f"rollout_episode_{episode}_step_{step}") + + def __del__(self): + if hasattr(self, "profiler"): + self.profiler.close() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 8d50734a9..f8ce1afde 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -38,6 +38,8 @@ class GRPOConsumer(BaseConsumer): project_name: str = None, run_name: str = None, wandb_group_name: str = None, + enable_profiling: bool = False, + n_behind: int = 0, ): print(f"Using GRPO config: {grpo_config}") if ( @@ -63,6 +65,8 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, + enable_profiling=enable_profiling, + n_behind=n_behind, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index dc8bf0057..a48246c87 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -57,6 +57,8 @@ def launch_distributed( eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, rollout_save_dir: str = "./rollout", + enable_profiling: bool = False, + n_behind: int = 0, ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -132,6 +134,8 @@ def launch_distributed( wandb_group_name=wandb_group_name, log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, ) producer_procs.append(producer) ray.get([p.setup.remote() for p in producer_procs]) @@ -171,6 +175,8 @@ def launch_distributed( project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + enable_profiling=enable_profiling, + n_behind=n_behind, ) consumer_procs.append(consumer) ray.get([p.setup.remote() for p in consumer_procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 854c2fcc2..2a3746391 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -9,6 +9,7 @@ import torch import tqdm import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo +from coati.distributed.profiling_utils import CustomProfiler from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from ray.util.collective import allreduce @@ -52,6 +53,8 @@ class BaseProducer: wandb_group_name: str = None, log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -62,6 +65,7 @@ class BaseProducer: assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size self.latest_eval_step = -1 + self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -75,6 +79,7 @@ class BaseProducer: self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 self.grpo_config = grpo_config + self.n_behind = n_behind reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -268,11 +273,14 @@ class BaseProducer: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.eval_mode = False self.latest_eval_step = self.consumer_global_step + self.profiler.enter("rollout") outputs = self.rollout(**batch) + self.profiler.exit("rollout") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) + self.profiler.enter("calculate_reward") if self.grpo_config["reward_fn_type"] == "code": test_cases = [] for prompt_id in range(bs): @@ -310,14 +318,19 @@ class BaseProducer: outputs.pop("gt_answer") if "test_cases" in outputs: outputs.pop("test_cases") + self.profiler.exit("calculate_reward") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) + self.profiler.enter("send_broadcast_data") ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) - if (i + 1) % self.num_microbatches == 0 and ( - episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 + self.profiler.exit("send_broadcast_data") + if ( + (i + 1) % self.num_microbatches == 0 + and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1) + and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches) ): if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False @@ -325,7 +338,7 @@ class BaseProducer: self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration torch.cuda.empty_cache() - + self.profiler.enter("sync_model") if self.consumer_pp_size > 1: for pp_idx in range(self.consumer_pp_size): print( @@ -347,6 +360,7 @@ class BaseProducer: if "consumer_global_step" in state_dict: self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) + self.profiler.exit("sync_model") del state_dict torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( @@ -364,6 +378,9 @@ class BaseProducer: "temperature" ] + ratio * 0.9 + def __del__(self): + self.profiler.close() + @ray.remote class SimpleProducer(BaseProducer): @@ -392,6 +409,8 @@ class SimpleProducer(BaseProducer): wandb_group_name: str = None, log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, ): super().__init__( producer_idx, @@ -415,6 +434,8 @@ class SimpleProducer(BaseProducer): wandb_group_name=wandb_group_name, log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) diff --git a/applications/ColossalChat/coati/distributed/profiling_utils.py b/applications/ColossalChat/coati/distributed/profiling_utils.py new file mode 100644 index 000000000..1c1169b50 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/profiling_utils.py @@ -0,0 +1,37 @@ +import os +import time + + +class CustomProfiler: + def __init__(self, name, disabled=True): + self.disabled = disabled + if not disabled: + self.name = name + self.pid = os.getpid() + self.file = open(f"{name}.prof", "w") + + def _log(self, message): + if self.disabled: + return + current_time = time.time() + self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + + def log(self, message): + if self.disabled: + return + current_time = time.time() + self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + + def enter(self, event_name): + self._log(f"Enter {event_name}") + + def exit(self, event_name): + self._log(f"Exit {event_name}") + + def close(self): + if self.disabled: + return + self.file.close() + print(f"Profiler data written to {self.name}.prof") diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh new file mode 100755 index 000000000..d9f3d9a93 --- /dev/null +++ b/applications/ColossalChat/profiling.sh @@ -0,0 +1,13 @@ +export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + +# 8K context length +# rm -rf *.prof +# MAX_NEW_TOKENS=$((8192-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png + +# 4K context length +rm -rf *.prof +MAX_NEW_TOKENS=$((4096-512)) +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt +python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index c60923e00..d7b7c2a5d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -67,6 +67,27 @@ if __name__ == "__main__": default=2, help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) parser.add_argument( "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" ) @@ -97,6 +118,13 @@ if __name__ == "__main__": parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") + parser.add_argument( + "-ptp", + "--producer-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) # GRPO parameters parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) @@ -117,6 +145,13 @@ if __name__ == "__main__": default=100, help="Interval for evaluation. Evaluate every ei training steps.", ) + parser.add_argument( + "-nb", + "--n-behind", + type=int, + default=0, + help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") @@ -128,32 +163,7 @@ if __name__ == "__main__": "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." ) parser.add_argument( - "-tp", - "--tensor-parallel-size", - type=int, - default=1, - help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-pp", - "--pipeline-parallel-size", - type=int, - default=1, - help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-zero", - "--zero-stage", - type=int, - default=0, - help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", - ) - parser.add_argument( - "-ptp", - "--producer-tensor-parallel-size", - type=int, - default=1, - help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) args = parser.parse_args() @@ -236,14 +246,25 @@ if __name__ == "__main__": tensor_parallel_size=args.producer_tensor_parallel_size, ) ) - generate_config.update( - dict( - max_tokens=args.max_new_tokens, # max new tokens - ignore_eos=True if args.reward_type == "think_answer_tags" else False, - include_stop_str_in_output=True, - stop=[""] if args.reward_type == "think_answer_tags" else None, + if args.enable_profiling: + # If profiling is enabled, we force model to generate to max_new_tokens + generate_config.update( + dict( + max_tokens=args.max_new_tokens, # max new tokens + ignore_eos=True, + include_stop_str_in_output=True, + stop=None, + ) + ) + else: + generate_config.update( + dict( + max_tokens=args.max_new_tokens, # max new tokens + ignore_eos=True if args.reward_type == "think_answer_tags" else False, + include_stop_str_in_output=True, + stop=[""] if args.reward_type == "think_answer_tags" else None, + ) ) - ) eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -353,4 +374,6 @@ if __name__ == "__main__": eval_generation_config=eval_generation_config, log_rollout_interval=20, rollout_save_dir=args.rollout_save_dir, + enable_profiling=args.enable_profiling, + n_behind=args.n_behind, ) diff --git a/applications/ColossalChat/visualization.py b/applications/ColossalChat/visualization.py new file mode 100644 index 000000000..afb729884 --- /dev/null +++ b/applications/ColossalChat/visualization.py @@ -0,0 +1,100 @@ +# Re-import required libraries due to kernel reset +import argparse +from collections import defaultdict + +import matplotlib.cm as cm +import matplotlib.pyplot as plt + +# Argument parser for command line arguments +parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.") +parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.") +args = parser.parse_args() + +# Raw log lines +log_lines = [] + +import glob + +files = glob.glob("*.prof") +for file in files: + with open(file, "r") as f: + log_lines += f.readlines() + +# Parse logs and collect function intervals grouped by actor +actors = defaultdict(lambda: defaultdict(list)) +current_entries = {} + +# First, collect all timestamps to find the minimum +all_timestamps = [] +parsed_lines = [] + +for line in log_lines: + if line.startswith("[Log]"): + continue + parts = line.split() + timestamp = float(parts[0]) + actor = parts[1] + action = parts[3] + func_name = parts[4] + parsed_lines.append((timestamp, actor, action, func_name)) + all_timestamps.append(timestamp) + +if not all_timestamps: + raise ValueError("No valid log entries found.") + +min_timestamp = min(all_timestamps) + +for timestamp, actor, action, func_name in parsed_lines: + rel_timestamp = timestamp - min_timestamp + key = (actor, func_name) + if action == "Enter": + current_entries[key] = rel_timestamp + elif action == "Exit": + start_time = current_entries.pop(key, None) + if start_time is not None: + actors[actor][func_name].append((start_time, rel_timestamp)) + +# Plotting setup +fig, ax = plt.subplots(figsize=(12, 6)) +colors = cm.get_cmap("tab10", len(actors)) + +actor_offsets = {} +base_offset = 0 +function_spacing = 0.9 + +yticks = [] +yticklabels = [] + +for idx, (actor, func_dict) in enumerate(actors.items()): + actor_offsets[actor] = base_offset + color = colors(idx) + for j, (func, intervals) in enumerate(func_dict.items()): + print(actor, func, intervals) + y_val = base_offset + j * function_spacing + yticks.append(y_val) + yticklabels.append(f"{actor}:{func}") + for start, end in intervals: + if end - start < 1: + end = start + 1 # Ensure all lines are at least 3 units long + ax.plot( + [start, end], + [y_val, y_val], + color=color, + linewidth=2, + label=actor if j == 0 else "", + ) + base_offset += len(func_dict) * function_spacing + 1 + +# Formatting +ax.set_yticks(yticks) +ax.set_yticklabels(yticklabels) +ax.set_xlabel("Time") +ax.set_title("Timeline per Actor") +# Remove duplicate labels in legend +handles, labels = ax.get_legend_handles_labels() +unique = dict(zip(labels, handles)) +ax.legend(unique.values(), unique.keys()) +plt.tight_layout() +plt.grid(True) +plt.savefig(args.visualization, dpi=600) # Increase dpi for higher resolution +print(f"Plot saved as {args.visualization}")