diff --git a/.gitignore b/.gitignore index 94d952295..f06e00b24 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,4 @@ applications/ColossalChat/stdin applications/ColossalChat/*.zip applications/ColossalChat/*.prof applications/ColossalChat/*.png +applications/ColossalChat/profiling_log/ diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 554650069..7a02c83d8 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -200,10 +200,15 @@ class BaseConsumer: } batch = bind_batch([t[0] for t in batches]) batch = post_recv(batch) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() self.profiler.enter("step") loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) total_step += 1 self.profiler.exit("step") + self.profiler.log( + f"step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + ) self.buffer = self.buffer[ effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : ] diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 79817b0bb..d4a3f1de0 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import calc_action_log_probs +from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -262,7 +262,7 @@ class GRPOConsumer(BaseConsumer): # Support training with PP. if self.policy_loss_fn.beta > 0: with torch.no_grad(): - torch.cuda.reset_peak_memory_stats() + # torch.cuda.reset_peak_memory_stats() reference_model_outputs = self.booster.execute_pipeline( iter( [ @@ -286,22 +286,32 @@ class GRPOConsumer(BaseConsumer): if self.booster.plugin.stage_manager.is_last_stage(): # breakpoint() - torch.cuda.reset_peak_memory_stats() - reference_action_log_probs = torch.zeros( - (input_ids_forward_micro_batch.size(0), num_action), - device=input_ids_forward_micro_batch.device, + # torch.cuda.empty_cache() + # current_memory = torch.cuda.memory_allocated() + # torch.cuda.reset_peak_memory_stats() + + # reference_action_log_probs = calc_action_log_probs( + # reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], + # input_ids_forward_micro_batch, + # num_action, + # self.plugin.shard_config, + # ) + + # self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB") + # torch.cuda.empty_cache() + # current_memory = torch.cuda.memory_allocated() + # torch.cuda.reset_peak_memory_stats() + reference_action_log_probs = memory_efficient_logprob( + reference_model_outputs["outputs"]["logits"], + input_ids_forward_micro_batch, + num_action, + shard_config=self.plugin.shard_config, ) - for i in range(reference_action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - reference_action_log_probs[i, :] += calc_action_log_probs( - reference_model_outputs["outputs"]["logits"][i : i + 1] - / self.generate_config["temperature"], - input_ids_forward_micro_batch[i : i + 1], - num_action, - self.plugin.shard_config, - )[0] + # self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB") + # if torch.allclose(reference_action_log_probs, me_reference_action_log_probs): + # self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs") + # else: + # self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs") # breakpoint() # torch.cuda.empty_cache() self.profiler.log( @@ -310,6 +320,7 @@ class GRPOConsumer(BaseConsumer): else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None + del reference_model_outputs else: reference_action_log_probs = None @@ -328,26 +339,43 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = torch.zeros( - (inputs["input_ids"].size(0), num_action), device=action_logits.device - ) # breakpoint() - torch.cuda.reset_peak_memory_stats() - for i in range(action_log_probs.size(0)): - # activation for log_softmax is too large if vocab size and sequence length are large - # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), - # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB - action_log_probs[i, :] += calc_action_log_probs( - action_logits[i : i + 1] / self.generate_config["temperature"], - inputs["input_ids"][i : i + 1], - num_action, - self.plugin.shard_config, - )[0] # torch.cuda.empty_cache() - self.profiler.log( - f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + # current_memory = torch.cuda.memory_allocated() + # torch.cuda.reset_peak_memory_stats() + # action_log_probs = calc_action_log_probs( + # action_logits / self.generate_config["temperature"], + # inputs["input_ids"], + # num_action, + # self.plugin.shard_config, + # ) + # # torch.cuda.empty_cache() + # self.profiler.log( + # f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB" + # ) + # torch.cuda.empty_cache() + # current_memory = torch.cuda.memory_allocated() + # torch.cuda.reset_peak_memory_stats() + action_log_probs = memory_efficient_logprob( + action_logits, + inputs["input_ids"], + num_action, + shard_config=self.plugin.shard_config, ) + # self.profiler.log( + # f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB" + # ) + # if torch.allclose(action_log_probs, me_action_log_probs): + # self.profiler.log("Memory efficient action log probs is same as normal action log probs") + # else: + # self.profiler.log("Memory efficient action log probs is different from normal action log probs") + # torch.cuda.empty_cache() # breakpoint() + # current_memory = torch.cuda.memory_allocated() + # torch.cuda.empty_cache() + # self.profiler.log( + # f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB" + # ) if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) @@ -463,7 +491,7 @@ class GRPOConsumer(BaseConsumer): if need_update: # breakpoint() # torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + # torch.cuda.reset_peak_memory_stats() self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index c3fa23189..bb623269d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -448,7 +448,9 @@ class DummyProducer(BaseProducer): "input_ids": torch.cat( [ torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device), - torch.ones( + torch.randint( + 0, + self.tokenizer.vocab_size, # assuming vocab_size is available in tokenizer (input_ids.size(0), self.num_generations, num_new_tokens), dtype=input_ids.dtype, ).to(device), diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 85a0ec59b..6c4e10743 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -146,6 +146,45 @@ def safe_append_to_jsonl_file(file_path, data): f.write(json_line + "\n") +def memory_efficient_logprob( + logits: torch.Tensor, + inputs: torch.Tensor, + num_action: int, + chunk_size: int = 2048, + shard_config: Any = None, + vocab_size: int = None, +) -> torch.Tensor: + """ + Calculate action log probs in a memory-efficient way by processing in chunks. + Args: + logits (torch.Tensor): Output tensor of Actor.forward.logits. + inputs (torch.LongTensor): Input sequences. + num_action (int): Number of actions. + chunk_size (int, optional): Size of each chunk to process. Default is 2048. + shard_config: Shard configuration for distributed computation. + vocab_size (int, optional): Vocabulary size. Default is None. + Returns: + torch.Tensor: Action log probs. + """ + action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype) + context_length = logits.size(1) - num_action + for i in range(action_log_probs.size(0)): + # loop over each sample in the micro-batch + for start in range(context_length, logits.size(1), chunk_size): + end = min(start + chunk_size, logits.size(1)) + # calculate log probs in chunks to save memory + log_probs = dist_log_prob( + inputs[i : i + 1, start - 1 : end], + logits[i : i + 1, start - 1 : end], + shard_config, + vocab_size, + logits.dtype, + ) # [1, chunk_size, 1] + log_probs = log_probs.squeeze(-1) + action_log_probs[i, start - context_length : end - context_length] += log_probs[0] + return action_log_probs + + class CustomProfiler: def __init__(self, name): self.name = name diff --git a/applications/ColossalChat/profile_grpo.py b/applications/ColossalChat/profile_grpo.py index ca4dc80d0..bde49303f 100644 --- a/applications/ColossalChat/profile_grpo.py +++ b/applications/ColossalChat/profile_grpo.py @@ -1,10 +1,8 @@ # Re-import required libraries due to kernel reset import argparse from collections import defaultdict -from datetime import datetime import matplotlib.cm as cm -import matplotlib.dates as mdates import matplotlib.pyplot as plt # Argument parser for command line arguments @@ -26,6 +24,10 @@ for file in files: actors = defaultdict(lambda: defaultdict(list)) current_entries = {} +# First, collect all timestamps to find the minimum +all_timestamps = [] +parsed_lines = [] + for line in log_lines: if line.startswith("[Log]"): continue @@ -34,13 +36,23 @@ for line in log_lines: actor = parts[1] action = parts[3] func_name = parts[4] + parsed_lines.append((timestamp, actor, action, func_name)) + all_timestamps.append(timestamp) + +if not all_timestamps: + raise ValueError("No valid log entries found.") + +min_timestamp = min(all_timestamps) + +for timestamp, actor, action, func_name in parsed_lines: + rel_timestamp = timestamp - min_timestamp key = (actor, func_name) if action == "Enter": - current_entries[key] = timestamp + current_entries[key] = rel_timestamp elif action == "Exit": start_time = current_entries.pop(key, None) if start_time is not None: - actors[actor][func_name].append((start_time, timestamp)) + actors[actor][func_name].append((start_time, rel_timestamp)) # Plotting setup fig, ax = plt.subplots(figsize=(12, 6)) @@ -57,21 +69,23 @@ for idx, (actor, func_dict) in enumerate(actors.items()): actor_offsets[actor] = base_offset color = colors(idx) for j, (func, intervals) in enumerate(func_dict.items()): + print(actor, func, intervals) y_val = base_offset + j * function_spacing yticks.append(y_val) yticklabels.append(f"{actor}:{func}") for start, end in intervals: + if end - start < 100: + end = start + 100 # Ensure minimum length of 100ms ax.plot( - [datetime.fromtimestamp(start), datetime.fromtimestamp(end)], + [start, end], [y_val, y_val], color=color, - linewidth=4, + linewidth=2, label=actor if j == 0 else "", ) base_offset += len(func_dict) * function_spacing + 1 # Formatting -ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S")) ax.set_yticks(yticks) ax.set_yticklabels(yticklabels) ax.set_xlabel("Time") diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh index 6771376ec..fe45f8a22 100755 --- a/applications/ColossalChat/profiling.sh +++ b/applications/ColossalChat/profiling.sh @@ -23,6 +23,17 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 # python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png # 16K -MAX_NEW_TOKENS=$((16384-512)) -python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt -python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png +# MAX_NEW_TOKENS=$((16384-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_s_2_ptp_2_16384_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmb2_pp_2_ptp_2_16384_GRPO_profile.png + +# 8K +# MAX_NEW_TOKENS=$((8192-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO-DUMMY-TEST -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png + + +# # 32K +MAX_NEW_TOKENS=$((32768-512)) +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -tMbs 2 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -tp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.txt +python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.png diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5bf701839..12741fdc7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -318,8 +318,10 @@ if __name__ == "__main__": plugin_config={ "tp_size": args.tensor_parallel_size, "pp_size": args.pipeline_parallel_size, - # "num_layers_per_stage": [12,12,2,2], - "num_layers_per_stage": [20, 8], + # "num_layers_per_stage": [18, 10], + # "num_layers_per_stage": [16, 11, 1], + # "num_layers_per_stage": [24, 4], + "num_layers_per_stage": [15, 13], "microbatch_size": max( 1, args.train_microbatch_size // args.pipeline_parallel_size ), # microbatch size should be set to train_microbatch_size // pp_size