add profiling

This commit is contained in:
YeAnbang 2025-06-13 18:00:31 +08:00
parent 43a0e99ae1
commit ff1689b69a
11 changed files with 325 additions and 104 deletions

2
.gitignore vendored
View File

@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png

View File

@ -1,4 +1,5 @@
import os
import time
from contextlib import nullcontext
from typing import Any, Dict, Optional
@ -79,6 +80,8 @@ class BaseConsumer:
self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group)
torch.cuda.reset_peak_memory_stats()
# Init Hybrid ray process group
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
@ -106,6 +109,8 @@ class BaseConsumer:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
start_time = time.time()
total_step = 0
for episode in range(self.num_episodes):
with tqdm(
range(self.num_update_per_episode),
@ -197,6 +202,7 @@ class BaseConsumer:
batch = post_recv(batch)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
total_step += 1
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
@ -252,6 +258,8 @@ class BaseConsumer:
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
print(f"[T{self.rank}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
print(f"Average running time per step: {(time.time() - start_time) / total_step:.2f} seconds")
def __del__(self):
if hasattr(self, "profiler"):

View File

@ -3,13 +3,13 @@ from typing import Any, Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@ -258,9 +258,11 @@ class GRPOConsumer(BaseConsumer):
]
if self.plugin.pp_size > 1:
# torch.cuda.empty_cache()
# Support training with PP.
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
@ -278,14 +280,32 @@ class GRPOConsumer(BaseConsumer):
return_loss=False,
return_outputs=True,
)
self.profiler.log(
f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
# breakpoint()
torch.cuda.reset_peak_memory_stats()
reference_action_log_probs = torch.zeros(
(input_ids_forward_micro_batch.size(0), num_action),
device=input_ids_forward_micro_batch.device,
)
for i in range(reference_action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
reference_action_log_probs[i, :] += calc_action_log_probs(
reference_model_outputs["outputs"]["logits"][i : i + 1]
/ self.generate_config["temperature"],
input_ids_forward_micro_batch[i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# breakpoint()
# torch.cuda.empty_cache()
self.profiler.log(
f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
else:
# Dummy reference logprobs for data iterator.
@ -308,12 +328,26 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
self.plugin.shard_config,
action_log_probs = torch.zeros(
(inputs["input_ids"].size(0), num_action), device=action_logits.device
)
# breakpoint()
torch.cuda.reset_peak_memory_stats()
for i in range(action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
action_log_probs[i, :] += calc_action_log_probs(
action_logits[i : i + 1] / self.generate_config["temperature"],
inputs["input_ids"][i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# torch.cuda.empty_cache()
self.profiler.log(
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
# breakpoint()
if "reference_action_log_probs" in inputs:
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@ -347,6 +381,9 @@ class GRPOConsumer(BaseConsumer):
return_loss=True,
return_outputs=False,
)
self.profiler.log(
f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
@ -373,6 +410,8 @@ class GRPOConsumer(BaseConsumer):
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
# torch.cuda.empty_cache()
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
@ -422,6 +461,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.add_(advantages.data)
self.accum_count += 1
if need_update:
# breakpoint()
# torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
@ -429,6 +471,9 @@ class GRPOConsumer(BaseConsumer):
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
self.effective_sample_count = 0
self.profiler.log(
f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0

View File

@ -7,9 +7,15 @@ import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
from .producer import DummyProducer, SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "GRPO-DUMMY-TEST": GRPOConsumer}
PRODUCER_MAP = {
"Simple": SimpleProducer,
"GRPO": SimpleProducer,
"DAPO": SimpleProducer,
"GRPO-DUMMY-TEST": DummyProducer,
}
def get_jsonl_size_fast(path: str) -> int:
@ -62,6 +68,7 @@ def launch_distributed(
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
core_producer = PRODUCER_MAP.get(core_algo, SimpleProducer)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@ -81,7 +88,7 @@ def launch_distributed(
procs = []
for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
producer = core_producer.options(num_gpus=num_proc_per_producer).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,

View File

@ -7,6 +7,7 @@ import ray
import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
@ -15,7 +16,6 @@ from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
import wandb
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
@ -76,6 +76,7 @@ class BaseProducer:
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.profiler = CustomProfiler(f"P{self.producer_idx}")
torch.cuda.reset_peak_memory_stats()
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
@ -372,11 +373,126 @@ class BaseProducer:
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
print(f"[P{self.producer_idx}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
def __del__(self):
self.profiler.close()
@ray.remote # (runtime_env={ "nsight": "default"})
class DummyProducer(BaseProducer):
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
self.num_generations = num_generations
self.max_length = generate_config.get("max_tokens", generate_config.get("max_length", 512))
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
# generate dummy rollouts
device = get_current_device()
# self.profiler.log(f"{input_ids.size()}, {attention_mask.size()}, {self.max_length}")
num_new_tokens = self.max_length
rollouts = {
"input_ids": torch.cat(
[
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones(
(input_ids.size(0), self.num_generations, num_new_tokens),
dtype=input_ids.dtype,
).to(device),
],
dim=-1,
).to(device),
"attention_mask": torch.cat(
[
torch.repeat_interleave(attention_mask.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones(
(input_ids.size(0), self.num_generations, num_new_tokens),
dtype=attention_mask.dtype,
).to(device),
],
dim=-1,
),
"action_log_probs": torch.zeros(
(input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.float32
).to(device),
"action_mask": torch.ones((input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.bool).to(
device
),
"response_idx": torch.tensor(
[[[input_ids.size(-1), input_ids.size(-1) + num_new_tokens]] * self.num_generations] * input_ids.size(0)
)
.to(device)
.to(torch.int),
}
if "gt_answer" in kwargs:
rollouts["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
rollouts["test_cases"] = kwargs["test_cases"]
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
@ray.remote # (runtime_env={ "nsight": "default"})
class SimpleProducer(BaseProducer):
def __init__(

View File

@ -152,16 +152,21 @@ class CustomProfiler:
self.pid = os.getpid()
self.file = open(f"{name}.prof", "w")
def log(self, message):
def _log(self, message):
current_time = time.time()
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def log(self, message):
current_time = time.time()
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def enter(self, event_name):
self.log(f"Enter {event_name}")
self._log(f"Enter {event_name}")
def exit(self, event_name):
self.log(f"Exit {event_name}")
self._log(f"Exit {event_name}")
def close(self):
self.file.close()

View File

@ -0,0 +1,86 @@
# Re-import required libraries due to kernel reset
import argparse
from collections import defaultdict
from datetime import datetime
import matplotlib.cm as cm
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
# Argument parser for command line arguments
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
args = parser.parse_args()
# Raw log lines
log_lines = []
import glob
files = glob.glob("*.prof")
for file in files:
with open(file, "r") as f:
log_lines += f.readlines()
# Parse logs and collect function intervals grouped by actor
actors = defaultdict(lambda: defaultdict(list))
current_entries = {}
for line in log_lines:
if line.startswith("[Log]"):
continue
parts = line.split()
timestamp = float(parts[0])
actor = parts[1]
action = parts[3]
func_name = parts[4]
key = (actor, func_name)
if action == "Enter":
current_entries[key] = timestamp
elif action == "Exit":
start_time = current_entries.pop(key, None)
if start_time is not None:
actors[actor][func_name].append((start_time, timestamp))
# Plotting setup
fig, ax = plt.subplots(figsize=(12, 6))
colors = cm.get_cmap("tab10", len(actors))
actor_offsets = {}
base_offset = 0
function_spacing = 0.9
yticks = []
yticklabels = []
for idx, (actor, func_dict) in enumerate(actors.items()):
actor_offsets[actor] = base_offset
color = colors(idx)
for j, (func, intervals) in enumerate(func_dict.items()):
y_val = base_offset + j * function_spacing
yticks.append(y_val)
yticklabels.append(f"{actor}:{func}")
for start, end in intervals:
ax.plot(
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
[y_val, y_val],
color=color,
linewidth=4,
label=actor if j == 0 else "",
)
base_offset += len(func_dict) * function_spacing + 1
# Formatting
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time")
ax.set_title("Timeline per Actor")
# Remove duplicate labels in legend
handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles))
ax.legend(unique.values(), unique.keys())
plt.tight_layout()
plt.grid(True)
plt.savefig(args.visualization)
print(f"Plot saved as {args.visualization}")

View File

@ -1,82 +0,0 @@
from collections import defaultdict
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
def parse_logs(log_file_path):
logs_by_actor = defaultdict(list)
with open(log_file_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) < 5:
continue
timestamp = float(parts[0])
actor = parts[1]
event = parts[3]
function_name = " ".join(parts[4:])
logs_by_actor[actor].append((timestamp, event, function_name))
return logs_by_actor
def build_intervals(logs_by_actor):
actor_intervals = defaultdict(list)
for actor, events in logs_by_actor.items():
func_stack = {}
for timestamp, event, func in events:
(actor, func)
if event == "Enter":
func_stack[func] = timestamp
elif event == "Exit" and func in func_stack:
start = func_stack.pop(func)
actor_intervals[actor].append((func, start, timestamp))
return actor_intervals
def plot_actor_timelines(actor_intervals):
fig, ax = plt.subplots(figsize=(12, 6))
ytick_labels = []
yticks = []
color_map = plt.get_cmap("tab10")
color_lookup = {}
y = 0
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
color_lookup[actor] = color_map(idx % 10)
for func, start, end in intervals:
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
ax.text(start, y + 0.1, func, fontsize=8, color="black")
yticks.append(y)
ytick_labels.append(actor)
y += 1
ax.set_yticks(yticks)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Unix Timestamp")
ax.set_title("Ray Actor Function Timeline")
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
# Unique legend
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
# ==== Usage ====
# Replace with your actual log file path
import glob
files = glob.glob("*.prof")
logs = {}
for file in files:
print(f"Processing file: {file}")
logs_by_actor = parse_logs(log_file_path)
logs.update(logs_by_actor)
actor_intervals = build_intervals(logs)
plot_actor_timelines(actor_intervals)

View File

@ -0,0 +1,28 @@
# profile under different setups
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
# zero2 ibs32 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# # zero2 ibs64 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# # zero2 ibs96 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 24 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# 4K
# MAX_NEW_TOKENS=$((4096-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mpt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
# # 32K
# MAX_NEW_TOKENS=$((32768-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 4 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
# 16K
MAX_NEW_TOKENS=$((16384-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png

View File

@ -93,7 +93,7 @@ if __name__ == "__main__":
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "GRPO-DUMMY-TEST"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
@ -172,7 +172,7 @@ if __name__ == "__main__":
namespace="ray-example",
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
@ -243,7 +243,7 @@ if __name__ == "__main__":
else:
raise ValueError(f"Unsupported backend: {args.backend}")
if args.algo == "GRPO":
if "GRPO" in args.algo:
# Default Settings
grpo_config = {
"lr": args.learning_rate,
@ -318,6 +318,8 @@ if __name__ == "__main__":
plugin_config={
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
# "num_layers_per_stage": [12,12,2,2],
"num_layers_per_stage": [20, 8],
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size

View File

@ -0,0 +1,4 @@
MAX_NEW_TOKENS=$((4096-512))
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.png