diff --git a/.gitignore b/.gitignore index e603f5015..94d952295 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,5 @@ applications/ColossalChat/*.txt applications/ColossalChat/*.db applications/ColossalChat/stdin applications/ColossalChat/*.zip +applications/ColossalChat/*.prof +applications/ColossalChat/*.png diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a1eac136a..554650069 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,4 +1,5 @@ import os +import time from contextlib import nullcontext from typing import Any, Dict, Optional @@ -79,6 +80,8 @@ class BaseConsumer: self.tp_size = dist.get_world_size(self.plugin.tp_group) self.pp_size = dist.get_world_size(self.plugin.pp_group) + torch.cuda.reset_peak_memory_stats() + # Init Hybrid ray process group for i in range(self.num_producers): cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") @@ -106,6 +109,8 @@ class BaseConsumer: print( f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" ) + start_time = time.time() + total_step = 0 for episode in range(self.num_episodes): with tqdm( range(self.num_update_per_episode), @@ -197,6 +202,7 @@ class BaseConsumer: batch = post_recv(batch) self.profiler.enter("step") loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + total_step += 1 self.profiler.exit("step") self.buffer = self.buffer[ effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : @@ -252,6 +258,8 @@ class BaseConsumer: del state_dict torch.cuda.empty_cache() self.profiler.exit("sync_model") + print(f"[T{self.rank}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + print(f"Average running time per step: {(time.time() - start_time) / total_step:.2f} seconds") def __del__(self): if hasattr(self, "profiler"): diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 2b846b669..79817b0bb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -3,13 +3,13 @@ from typing import Any, Optional import ray import torch +import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer -import wandb from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam @@ -258,9 +258,11 @@ class GRPOConsumer(BaseConsumer): ] if self.plugin.pp_size > 1: + # torch.cuda.empty_cache() # Support training with PP. if self.policy_loss_fn.beta > 0: with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() reference_model_outputs = self.booster.execute_pipeline( iter( [ @@ -278,14 +280,32 @@ class GRPOConsumer(BaseConsumer): return_loss=False, return_outputs=True, ) + self.profiler.log( + f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + ) if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, + # breakpoint() + torch.cuda.reset_peak_memory_stats() + reference_action_log_probs = torch.zeros( + (input_ids_forward_micro_batch.size(0), num_action), + device=input_ids_forward_micro_batch.device, + ) + for i in range(reference_action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + reference_action_log_probs[i, :] += calc_action_log_probs( + reference_model_outputs["outputs"]["logits"][i : i + 1] + / self.generate_config["temperature"], + input_ids_forward_micro_batch[i : i + 1], + num_action, + self.plugin.shard_config, + )[0] + # breakpoint() + # torch.cuda.empty_cache() + self.profiler.log( + f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" ) else: # Dummy reference logprobs for data iterator. @@ -308,12 +328,26 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - inputs["input_ids"], - num_action, - self.plugin.shard_config, + action_log_probs = torch.zeros( + (inputs["input_ids"].size(0), num_action), device=action_logits.device ) + # breakpoint() + torch.cuda.reset_peak_memory_stats() + for i in range(action_log_probs.size(0)): + # activation for log_softmax is too large if vocab size and sequence length are large + # e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), + # this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB + action_log_probs[i, :] += calc_action_log_probs( + action_logits[i : i + 1] / self.generate_config["temperature"], + inputs["input_ids"][i : i + 1], + num_action, + self.plugin.shard_config, + )[0] + # torch.cuda.empty_cache() + self.profiler.log( + f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + ) + # breakpoint() if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) @@ -347,6 +381,9 @@ class GRPOConsumer(BaseConsumer): return_loss=True, return_outputs=False, ) + self.profiler.log( + f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + ) loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): @@ -373,6 +410,8 @@ class GRPOConsumer(BaseConsumer): input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits + + # torch.cuda.empty_cache() reference_action_log_probs = calc_action_log_probs( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, @@ -422,6 +461,9 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages.add_(advantages.data) self.accum_count += 1 if need_update: + # breakpoint() + # torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 @@ -429,6 +471,9 @@ class GRPOConsumer(BaseConsumer): sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 + self.profiler.log( + f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" + ) loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index b666d083a..e4cf33fb6 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -7,9 +7,15 @@ import ray from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer -from .producer import SimpleProducer +from .producer import DummyProducer, SimpleProducer -ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "GRPO-DUMMY-TEST": GRPOConsumer} +PRODUCER_MAP = { + "Simple": SimpleProducer, + "GRPO": SimpleProducer, + "DAPO": SimpleProducer, + "GRPO-DUMMY-TEST": DummyProducer, +} def get_jsonl_size_fast(path: str) -> int: @@ -62,6 +68,7 @@ def launch_distributed( raise NotImplementedError(f"{core_algo} is not supported yet.") else: core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) + core_producer = PRODUCER_MAP.get(core_algo, SimpleProducer) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 @@ -81,7 +88,7 @@ def launch_distributed( procs = [] for i in range(num_producers): - producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( + producer = core_producer.options(num_gpus=num_proc_per_producer).remote( producer_idx=i, num_producers=num_producers, num_consumer_procs=num_consumer_procs, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 830ebcba6..c3fa23189 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -7,6 +7,7 @@ import ray import ray.util.collective as cc import torch import tqdm +import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward @@ -15,7 +16,6 @@ from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer -import wandb from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict @@ -76,6 +76,7 @@ class BaseProducer: self.latest_rollout_log_step = -1 self.grpo_config = grpo_config self.profiler = CustomProfiler(f"P{self.producer_idx}") + torch.cuda.reset_peak_memory_stats() reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -372,11 +373,126 @@ class BaseProducer: self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 + print(f"[P{self.producer_idx}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") def __del__(self): self.profiler.close() +@ray.remote # (runtime_env={ "nsight": "default"}) +class DummyProducer(BaseProducer): + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + ): + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config, + microbatch_size, + backend, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + ) + self.num_generations = num_generations + self.max_length = generate_config.get("max_tokens", generate_config.get("max_length", 512)) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) + + @torch.no_grad() + def rollout(self, input_ids, attention_mask, **kwargs): + # generate dummy rollouts + device = get_current_device() + # self.profiler.log(f"{input_ids.size()}, {attention_mask.size()}, {self.max_length}") + num_new_tokens = self.max_length + rollouts = { + "input_ids": torch.cat( + [ + torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device), + torch.ones( + (input_ids.size(0), self.num_generations, num_new_tokens), + dtype=input_ids.dtype, + ).to(device), + ], + dim=-1, + ).to(device), + "attention_mask": torch.cat( + [ + torch.repeat_interleave(attention_mask.unsqueeze(1), self.num_generations, dim=1).to(device), + torch.ones( + (input_ids.size(0), self.num_generations, num_new_tokens), + dtype=attention_mask.dtype, + ).to(device), + ], + dim=-1, + ), + "action_log_probs": torch.zeros( + (input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.float32 + ).to(device), + "action_mask": torch.ones((input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.bool).to( + device + ), + "response_idx": torch.tensor( + [[[input_ids.size(-1), input_ids.size(-1) + num_new_tokens]] * self.num_generations] * input_ids.size(0) + ) + .to(device) + .to(torch.int), + } + if "gt_answer" in kwargs: + rollouts["gt_answer"] = kwargs["gt_answer"] + if "test_cases" in kwargs: + rollouts["test_cases"] = kwargs["test_cases"] + return rollouts + + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) + + @ray.remote # (runtime_env={ "nsight": "default"}) class SimpleProducer(BaseProducer): def __init__( diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 6d4169ee0..85a0ec59b 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -152,16 +152,21 @@ class CustomProfiler: self.pid = os.getpid() self.file = open(f"{name}.prof", "w") - def log(self, message): + def _log(self, message): current_time = time.time() self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n") self.file.flush() + def log(self, message): + current_time = time.time() + self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + def enter(self, event_name): - self.log(f"Enter {event_name}") + self._log(f"Enter {event_name}") def exit(self, event_name): - self.log(f"Exit {event_name}") + self._log(f"Exit {event_name}") def close(self): self.file.close() diff --git a/applications/ColossalChat/profile_grpo.py b/applications/ColossalChat/profile_grpo.py new file mode 100644 index 000000000..ca4dc80d0 --- /dev/null +++ b/applications/ColossalChat/profile_grpo.py @@ -0,0 +1,86 @@ +# Re-import required libraries due to kernel reset +import argparse +from collections import defaultdict +from datetime import datetime + +import matplotlib.cm as cm +import matplotlib.dates as mdates +import matplotlib.pyplot as plt + +# Argument parser for command line arguments +parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.") +parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.") +args = parser.parse_args() + +# Raw log lines +log_lines = [] + +import glob + +files = glob.glob("*.prof") +for file in files: + with open(file, "r") as f: + log_lines += f.readlines() + +# Parse logs and collect function intervals grouped by actor +actors = defaultdict(lambda: defaultdict(list)) +current_entries = {} + +for line in log_lines: + if line.startswith("[Log]"): + continue + parts = line.split() + timestamp = float(parts[0]) + actor = parts[1] + action = parts[3] + func_name = parts[4] + key = (actor, func_name) + if action == "Enter": + current_entries[key] = timestamp + elif action == "Exit": + start_time = current_entries.pop(key, None) + if start_time is not None: + actors[actor][func_name].append((start_time, timestamp)) + +# Plotting setup +fig, ax = plt.subplots(figsize=(12, 6)) +colors = cm.get_cmap("tab10", len(actors)) + +actor_offsets = {} +base_offset = 0 +function_spacing = 0.9 + +yticks = [] +yticklabels = [] + +for idx, (actor, func_dict) in enumerate(actors.items()): + actor_offsets[actor] = base_offset + color = colors(idx) + for j, (func, intervals) in enumerate(func_dict.items()): + y_val = base_offset + j * function_spacing + yticks.append(y_val) + yticklabels.append(f"{actor}:{func}") + for start, end in intervals: + ax.plot( + [datetime.fromtimestamp(start), datetime.fromtimestamp(end)], + [y_val, y_val], + color=color, + linewidth=4, + label=actor if j == 0 else "", + ) + base_offset += len(func_dict) * function_spacing + 1 + +# Formatting +ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S")) +ax.set_yticks(yticks) +ax.set_yticklabels(yticklabels) +ax.set_xlabel("Time") +ax.set_title("Timeline per Actor") +# Remove duplicate labels in legend +handles, labels = ax.get_legend_handles_labels() +unique = dict(zip(labels, handles)) +ax.legend(unique.values(), unique.keys()) +plt.tight_layout() +plt.grid(True) +plt.savefig(args.visualization) +print(f"Plot saved as {args.visualization}") diff --git a/applications/ColossalChat/profile_utils.py b/applications/ColossalChat/profile_utils.py deleted file mode 100644 index a0a2ee4da..000000000 --- a/applications/ColossalChat/profile_utils.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import defaultdict - -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt - - -def parse_logs(log_file_path): - logs_by_actor = defaultdict(list) - - with open(log_file_path, "r") as f: - for line in f: - parts = line.strip().split() - if len(parts) < 5: - continue - timestamp = float(parts[0]) - actor = parts[1] - event = parts[3] - function_name = " ".join(parts[4:]) - logs_by_actor[actor].append((timestamp, event, function_name)) - - return logs_by_actor - - -def build_intervals(logs_by_actor): - actor_intervals = defaultdict(list) - - for actor, events in logs_by_actor.items(): - func_stack = {} - for timestamp, event, func in events: - (actor, func) - if event == "Enter": - func_stack[func] = timestamp - elif event == "Exit" and func in func_stack: - start = func_stack.pop(func) - actor_intervals[actor].append((func, start, timestamp)) - - return actor_intervals - - -def plot_actor_timelines(actor_intervals): - fig, ax = plt.subplots(figsize=(12, 6)) - ytick_labels = [] - yticks = [] - color_map = plt.get_cmap("tab10") - color_lookup = {} - - y = 0 - for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())): - color_lookup[actor] = color_map(idx % 10) - for func, start, end in intervals: - ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor) - ax.text(start, y + 0.1, func, fontsize=8, color="black") - yticks.append(y) - ytick_labels.append(actor) - y += 1 - - ax.set_yticks(yticks) - ax.set_yticklabels(ytick_labels) - ax.set_xlabel("Unix Timestamp") - ax.set_title("Ray Actor Function Timeline") - ax.grid(True, axis="x", linestyle="--", alpha=0.6) - - # Unique legend - handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup] - ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1)) - - plt.tight_layout() - plt.show() - - -# ==== Usage ==== -# Replace with your actual log file path -import glob - -files = glob.glob("*.prof") -logs = {} -for file in files: - print(f"Processing file: {file}") - logs_by_actor = parse_logs(log_file_path) - logs.update(logs_by_actor) -actor_intervals = build_intervals(logs) -plot_actor_timelines(actor_intervals) diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh new file mode 100755 index 000000000..6771376ec --- /dev/null +++ b/applications/ColossalChat/profiling.sh @@ -0,0 +1,28 @@ +# profile under different setups +export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 +# zero2 ibs32 tbs32 +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.png + +# # zero2 ibs64 tbs32 +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.png + +# # zero2 ibs96 tbs32 +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 24 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.png + +# 4K +# MAX_NEW_TOKENS=$((4096-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mpt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png + +# # 32K +# MAX_NEW_TOKENS=$((32768-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 4 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png + +# 16K +MAX_NEW_TOKENS=$((16384-512)) +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt +python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 1a2db4063..5bf701839 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -93,7 +93,7 @@ if __name__ == "__main__": parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") # GRPO parameters - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "GRPO-DUMMY-TEST"]) parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") parser.add_argument( @@ -172,7 +172,7 @@ if __name__ == "__main__": namespace="ray-example", runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -243,7 +243,7 @@ if __name__ == "__main__": else: raise ValueError(f"Unsupported backend: {args.backend}") - if args.algo == "GRPO": + if "GRPO" in args.algo: # Default Settings grpo_config = { "lr": args.learning_rate, @@ -318,6 +318,8 @@ if __name__ == "__main__": plugin_config={ "tp_size": args.tensor_parallel_size, "pp_size": args.pipeline_parallel_size, + # "num_layers_per_stage": [12,12,2,2], + "num_layers_per_stage": [20, 8], "microbatch_size": max( 1, args.train_microbatch_size // args.pipeline_parallel_size ), # microbatch size should be set to train_microbatch_size // pp_size diff --git a/applications/ColossalChat/test_profiling.sh b/applications/ColossalChat/test_profiling.sh new file mode 100755 index 000000000..8cae99497 --- /dev/null +++ b/applications/ColossalChat/test_profiling.sh @@ -0,0 +1,4 @@ +MAX_NEW_TOKENS=$((4096-512)) +export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt +python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.png