add profiling

This commit is contained in:
YeAnbang 2025-06-13 18:00:31 +08:00
parent 43a0e99ae1
commit ff1689b69a
11 changed files with 325 additions and 104 deletions

2
.gitignore vendored
View File

@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
applications/ColossalChat/*.db applications/ColossalChat/*.db
applications/ColossalChat/stdin applications/ColossalChat/stdin
applications/ColossalChat/*.zip applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png

View File

@ -1,4 +1,5 @@
import os import os
import time
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -79,6 +80,8 @@ class BaseConsumer:
self.tp_size = dist.get_world_size(self.plugin.tp_group) self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group) self.pp_size = dist.get_world_size(self.plugin.pp_group)
torch.cuda.reset_peak_memory_stats()
# Init Hybrid ray process group # Init Hybrid ray process group
for i in range(self.num_producers): for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
@ -106,6 +109,8 @@ class BaseConsumer:
print( print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
) )
start_time = time.time()
total_step = 0
for episode in range(self.num_episodes): for episode in range(self.num_episodes):
with tqdm( with tqdm(
range(self.num_update_per_episode), range(self.num_update_per_episode),
@ -197,6 +202,7 @@ class BaseConsumer:
batch = post_recv(batch) batch = post_recv(batch)
self.profiler.enter("step") self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
total_step += 1
self.profiler.exit("step") self.profiler.exit("step")
self.buffer = self.buffer[ self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
@ -252,6 +258,8 @@ class BaseConsumer:
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.exit("sync_model") self.profiler.exit("sync_model")
print(f"[T{self.rank}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
print(f"Average running time per step: {(time.time() - start_time) / total_step:.2f} seconds")
def __del__(self): def __del__(self):
if hasattr(self, "profiler"): if hasattr(self, "profiler"):

View File

@ -3,13 +3,13 @@ from typing import Any, Optional
import ray import ray
import torch import torch
import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -258,9 +258,11 @@ class GRPOConsumer(BaseConsumer):
] ]
if self.plugin.pp_size > 1: if self.plugin.pp_size > 1:
# torch.cuda.empty_cache()
# Support training with PP. # Support training with PP.
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
with torch.no_grad(): with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
reference_model_outputs = self.booster.execute_pipeline( reference_model_outputs = self.booster.execute_pipeline(
iter( iter(
[ [
@ -278,14 +280,32 @@ class GRPOConsumer(BaseConsumer):
return_loss=False, return_loss=False,
return_outputs=True, return_outputs=True,
) )
self.profiler.log(
f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"] # breakpoint()
reference_action_log_probs = calc_action_log_probs( torch.cuda.reset_peak_memory_stats()
reference_model_logits / self.generate_config["temperature"], reference_action_log_probs = torch.zeros(
input_ids_forward_micro_batch, (input_ids_forward_micro_batch.size(0), num_action),
num_action, device=input_ids_forward_micro_batch.device,
self.plugin.shard_config, )
for i in range(reference_action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
reference_action_log_probs[i, :] += calc_action_log_probs(
reference_model_outputs["outputs"]["logits"][i : i + 1]
/ self.generate_config["temperature"],
input_ids_forward_micro_batch[i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# breakpoint()
# torch.cuda.empty_cache()
self.profiler.log(
f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
) )
else: else:
# Dummy reference logprobs for data iterator. # Dummy reference logprobs for data iterator.
@ -308,12 +328,26 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs): def _criterion(outputs, inputs):
action_logits = outputs.logits action_logits = outputs.logits
action_log_probs = calc_action_log_probs( action_log_probs = torch.zeros(
action_logits / self.generate_config["temperature"], (inputs["input_ids"].size(0), num_action), device=action_logits.device
inputs["input_ids"],
num_action,
self.plugin.shard_config,
) )
# breakpoint()
torch.cuda.reset_peak_memory_stats()
for i in range(action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
action_log_probs[i, :] += calc_action_log_probs(
action_logits[i : i + 1] / self.generate_config["temperature"],
inputs["input_ids"][i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# torch.cuda.empty_cache()
self.profiler.log(
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
# breakpoint()
if "reference_action_log_probs" in inputs: if "reference_action_log_probs" in inputs:
per_token_kl = ( per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs) torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@ -347,6 +381,9 @@ class GRPOConsumer(BaseConsumer):
return_loss=True, return_loss=True,
return_outputs=False, return_outputs=False,
) )
self.profiler.log(
f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
loss = policy_model_outputs["loss"] loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
@ -373,6 +410,8 @@ class GRPOConsumer(BaseConsumer):
input_ids=input_ids_forward_micro_batch, input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch,
).logits ).logits
# torch.cuda.empty_cache()
reference_action_log_probs = calc_action_log_probs( reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"], reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch, input_ids_forward_micro_batch,
@ -422,6 +461,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.add_(advantages.data) self.accum_advantages.add_(advantages.data)
self.accum_count += 1 self.accum_count += 1
if need_update: if need_update:
# breakpoint()
# torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1
@ -429,6 +471,9 @@ class GRPOConsumer(BaseConsumer):
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.effective_sample_count = 0 self.effective_sample_count = 0
self.profiler.log(
f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0

View File

@ -7,9 +7,15 @@ import ray
from .consumer import SimpleConsumer from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer from .producer import DummyProducer, SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "GRPO-DUMMY-TEST": GRPOConsumer}
PRODUCER_MAP = {
"Simple": SimpleProducer,
"GRPO": SimpleProducer,
"DAPO": SimpleProducer,
"GRPO-DUMMY-TEST": DummyProducer,
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@ -62,6 +68,7 @@ def launch_distributed(
raise NotImplementedError(f"{core_algo} is not supported yet.") raise NotImplementedError(f"{core_algo} is not supported yet.")
else: else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
core_producer = PRODUCER_MAP.get(core_algo, SimpleProducer)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@ -81,7 +88,7 @@ def launch_distributed(
procs = [] procs = []
for i in range(num_producers): for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( producer = core_producer.options(num_gpus=num_proc_per_producer).remote(
producer_idx=i, producer_idx=i,
num_producers=num_producers, num_producers=num_producers,
num_consumer_procs=num_consumer_procs, num_consumer_procs=num_consumer_procs,

View File

@ -7,6 +7,7 @@ import ray
import ray.util.collective as cc import ray.util.collective as cc
import torch import torch
import tqdm import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
@ -15,7 +16,6 @@ from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
import wandb
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
@ -76,6 +76,7 @@ class BaseProducer:
self.latest_rollout_log_step = -1 self.latest_rollout_log_step = -1
self.grpo_config = grpo_config self.grpo_config = grpo_config
self.profiler = CustomProfiler(f"P{self.producer_idx}") self.profiler = CustomProfiler(f"P{self.producer_idx}")
torch.cuda.reset_peak_memory_stats()
reward_model_kwargs = { reward_model_kwargs = {
k: v k: v
for k, v in grpo_config.items() for k, v in grpo_config.items()
@ -372,11 +373,126 @@ class BaseProducer:
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature" "temperature"
] + ratio * 0.9 ] + ratio * 0.9
print(f"[P{self.producer_idx}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
def __del__(self): def __del__(self):
self.profiler.close() self.profiler.close()
@ray.remote # (runtime_env={ "nsight": "default"})
class DummyProducer(BaseProducer):
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
self.num_generations = num_generations
self.max_length = generate_config.get("max_tokens", generate_config.get("max_length", 512))
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
# generate dummy rollouts
device = get_current_device()
# self.profiler.log(f"{input_ids.size()}, {attention_mask.size()}, {self.max_length}")
num_new_tokens = self.max_length
rollouts = {
"input_ids": torch.cat(
[
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones(
(input_ids.size(0), self.num_generations, num_new_tokens),
dtype=input_ids.dtype,
).to(device),
],
dim=-1,
).to(device),
"attention_mask": torch.cat(
[
torch.repeat_interleave(attention_mask.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones(
(input_ids.size(0), self.num_generations, num_new_tokens),
dtype=attention_mask.dtype,
).to(device),
],
dim=-1,
),
"action_log_probs": torch.zeros(
(input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.float32
).to(device),
"action_mask": torch.ones((input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.bool).to(
device
),
"response_idx": torch.tensor(
[[[input_ids.size(-1), input_ids.size(-1) + num_new_tokens]] * self.num_generations] * input_ids.size(0)
)
.to(device)
.to(torch.int),
}
if "gt_answer" in kwargs:
rollouts["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
rollouts["test_cases"] = kwargs["test_cases"]
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
@ray.remote # (runtime_env={ "nsight": "default"}) @ray.remote # (runtime_env={ "nsight": "default"})
class SimpleProducer(BaseProducer): class SimpleProducer(BaseProducer):
def __init__( def __init__(

View File

@ -152,16 +152,21 @@ class CustomProfiler:
self.pid = os.getpid() self.pid = os.getpid()
self.file = open(f"{name}.prof", "w") self.file = open(f"{name}.prof", "w")
def log(self, message): def _log(self, message):
current_time = time.time() current_time = time.time()
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n") self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush() self.file.flush()
def log(self, message):
current_time = time.time()
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def enter(self, event_name): def enter(self, event_name):
self.log(f"Enter {event_name}") self._log(f"Enter {event_name}")
def exit(self, event_name): def exit(self, event_name):
self.log(f"Exit {event_name}") self._log(f"Exit {event_name}")
def close(self): def close(self):
self.file.close() self.file.close()

View File

@ -0,0 +1,86 @@
# Re-import required libraries due to kernel reset
import argparse
from collections import defaultdict
from datetime import datetime
import matplotlib.cm as cm
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
# Argument parser for command line arguments
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
args = parser.parse_args()
# Raw log lines
log_lines = []
import glob
files = glob.glob("*.prof")
for file in files:
with open(file, "r") as f:
log_lines += f.readlines()
# Parse logs and collect function intervals grouped by actor
actors = defaultdict(lambda: defaultdict(list))
current_entries = {}
for line in log_lines:
if line.startswith("[Log]"):
continue
parts = line.split()
timestamp = float(parts[0])
actor = parts[1]
action = parts[3]
func_name = parts[4]
key = (actor, func_name)
if action == "Enter":
current_entries[key] = timestamp
elif action == "Exit":
start_time = current_entries.pop(key, None)
if start_time is not None:
actors[actor][func_name].append((start_time, timestamp))
# Plotting setup
fig, ax = plt.subplots(figsize=(12, 6))
colors = cm.get_cmap("tab10", len(actors))
actor_offsets = {}
base_offset = 0
function_spacing = 0.9
yticks = []
yticklabels = []
for idx, (actor, func_dict) in enumerate(actors.items()):
actor_offsets[actor] = base_offset
color = colors(idx)
for j, (func, intervals) in enumerate(func_dict.items()):
y_val = base_offset + j * function_spacing
yticks.append(y_val)
yticklabels.append(f"{actor}:{func}")
for start, end in intervals:
ax.plot(
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
[y_val, y_val],
color=color,
linewidth=4,
label=actor if j == 0 else "",
)
base_offset += len(func_dict) * function_spacing + 1
# Formatting
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time")
ax.set_title("Timeline per Actor")
# Remove duplicate labels in legend
handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles))
ax.legend(unique.values(), unique.keys())
plt.tight_layout()
plt.grid(True)
plt.savefig(args.visualization)
print(f"Plot saved as {args.visualization}")

View File

@ -1,82 +0,0 @@
from collections import defaultdict
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
def parse_logs(log_file_path):
logs_by_actor = defaultdict(list)
with open(log_file_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) < 5:
continue
timestamp = float(parts[0])
actor = parts[1]
event = parts[3]
function_name = " ".join(parts[4:])
logs_by_actor[actor].append((timestamp, event, function_name))
return logs_by_actor
def build_intervals(logs_by_actor):
actor_intervals = defaultdict(list)
for actor, events in logs_by_actor.items():
func_stack = {}
for timestamp, event, func in events:
(actor, func)
if event == "Enter":
func_stack[func] = timestamp
elif event == "Exit" and func in func_stack:
start = func_stack.pop(func)
actor_intervals[actor].append((func, start, timestamp))
return actor_intervals
def plot_actor_timelines(actor_intervals):
fig, ax = plt.subplots(figsize=(12, 6))
ytick_labels = []
yticks = []
color_map = plt.get_cmap("tab10")
color_lookup = {}
y = 0
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
color_lookup[actor] = color_map(idx % 10)
for func, start, end in intervals:
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
ax.text(start, y + 0.1, func, fontsize=8, color="black")
yticks.append(y)
ytick_labels.append(actor)
y += 1
ax.set_yticks(yticks)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Unix Timestamp")
ax.set_title("Ray Actor Function Timeline")
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
# Unique legend
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
# ==== Usage ====
# Replace with your actual log file path
import glob
files = glob.glob("*.prof")
logs = {}
for file in files:
print(f"Processing file: {file}")
logs_by_actor = parse_logs(log_file_path)
logs.update(logs_by_actor)
actor_intervals = build_intervals(logs)
plot_actor_timelines(actor_intervals)

View File

@ -0,0 +1,28 @@
# profile under different setups
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
# zero2 ibs32 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# # zero2 ibs64 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# # zero2 ibs96 tbs32
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 24 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.png
# 4K
# MAX_NEW_TOKENS=$((4096-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mpt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
# # 32K
# MAX_NEW_TOKENS=$((32768-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 4 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
# 16K
MAX_NEW_TOKENS=$((16384-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png

View File

@ -93,7 +93,7 @@ if __name__ == "__main__":
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters # GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "GRPO-DUMMY-TEST"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument( parser.add_argument(
@ -172,7 +172,7 @@ if __name__ == "__main__":
namespace="ray-example", namespace="ray-example",
runtime_env={ runtime_env={
"env_vars": { "env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false" "TOKENIZERS_PARALLELISM": "false"
}, },
}, },
@ -243,7 +243,7 @@ if __name__ == "__main__":
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
if args.algo == "GRPO": if "GRPO" in args.algo:
# Default Settings # Default Settings
grpo_config = { grpo_config = {
"lr": args.learning_rate, "lr": args.learning_rate,
@ -318,6 +318,8 @@ if __name__ == "__main__":
plugin_config={ plugin_config={
"tp_size": args.tensor_parallel_size, "tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size, "pp_size": args.pipeline_parallel_size,
# "num_layers_per_stage": [12,12,2,2],
"num_layers_per_stage": [20, 8],
"microbatch_size": max( "microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size 1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size ), # microbatch size should be set to train_microbatch_size // pp_size

View File

@ -0,0 +1,4 @@
MAX_NEW_TOKENS=$((4096-512))
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.png