mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
add profiling
This commit is contained in:
parent
43a0e99ae1
commit
ff1689b69a
2
.gitignore
vendored
2
.gitignore
vendored
@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
|
||||
applications/ColossalChat/*.db
|
||||
applications/ColossalChat/stdin
|
||||
applications/ColossalChat/*.zip
|
||||
applications/ColossalChat/*.prof
|
||||
applications/ColossalChat/*.png
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -79,6 +80,8 @@ class BaseConsumer:
|
||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Init Hybrid ray process group
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
@ -106,6 +109,8 @@ class BaseConsumer:
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
start_time = time.time()
|
||||
total_step = 0
|
||||
for episode in range(self.num_episodes):
|
||||
with tqdm(
|
||||
range(self.num_update_per_episode),
|
||||
@ -197,6 +202,7 @@ class BaseConsumer:
|
||||
batch = post_recv(batch)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
total_step += 1
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
@ -252,6 +258,8 @@ class BaseConsumer:
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
print(f"[T{self.rank}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
print(f"Average running time per step: {(time.time() - start_time) / total_step:.2f} seconds")
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "profiler"):
|
||||
|
@ -3,13 +3,13 @@ from typing import Any, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import wandb
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
@ -258,9 +258,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
]
|
||||
|
||||
if self.plugin.pp_size > 1:
|
||||
# torch.cuda.empty_cache()
|
||||
# Support training with PP.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
@ -278,14 +280,32 @@ class GRPOConsumer(BaseConsumer):
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
self.profiler.log(
|
||||
f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
# breakpoint()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
reference_action_log_probs = torch.zeros(
|
||||
(input_ids_forward_micro_batch.size(0), num_action),
|
||||
device=input_ids_forward_micro_batch.device,
|
||||
)
|
||||
for i in range(reference_action_log_probs.size(0)):
|
||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||
reference_action_log_probs[i, :] += calc_action_log_probs(
|
||||
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
||||
/ self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch[i : i + 1],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)[0]
|
||||
# breakpoint()
|
||||
# torch.cuda.empty_cache()
|
||||
self.profiler.log(
|
||||
f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
@ -308,12 +328,26 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
action_log_probs = torch.zeros(
|
||||
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
||||
)
|
||||
# breakpoint()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
for i in range(action_log_probs.size(0)):
|
||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||
action_log_probs[i, :] += calc_action_log_probs(
|
||||
action_logits[i : i + 1] / self.generate_config["temperature"],
|
||||
inputs["input_ids"][i : i + 1],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)[0]
|
||||
# torch.cuda.empty_cache()
|
||||
self.profiler.log(
|
||||
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
# breakpoint()
|
||||
if "reference_action_log_probs" in inputs:
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
@ -347,6 +381,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
return_loss=True,
|
||||
return_outputs=False,
|
||||
)
|
||||
self.profiler.log(
|
||||
f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
loss = policy_model_outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
@ -373,6 +410,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
@ -422,6 +461,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
# breakpoint()
|
||||
# torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
@ -429,6 +471,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.profiler.log(
|
||||
f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
|
@ -7,9 +7,15 @@ import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
from .producer import DummyProducer, SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "GRPO-DUMMY-TEST": GRPOConsumer}
|
||||
PRODUCER_MAP = {
|
||||
"Simple": SimpleProducer,
|
||||
"GRPO": SimpleProducer,
|
||||
"DAPO": SimpleProducer,
|
||||
"GRPO-DUMMY-TEST": DummyProducer,
|
||||
}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
@ -62,6 +68,7 @@ def launch_distributed(
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||
core_producer = PRODUCER_MAP.get(core_algo, SimpleProducer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
@ -81,7 +88,7 @@ def launch_distributed(
|
||||
|
||||
procs = []
|
||||
for i in range(num_producers):
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
producer = core_producer.options(num_gpus=num_proc_per_producer).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
|
@ -7,6 +7,7 @@ import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
@ -15,7 +16,6 @@ from ray.util.collective.types import Backend, ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import wandb
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
@ -76,6 +76,7 @@ class BaseProducer:
|
||||
self.latest_rollout_log_step = -1
|
||||
self.grpo_config = grpo_config
|
||||
self.profiler = CustomProfiler(f"P{self.producer_idx}")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
@ -372,11 +373,126 @@ class BaseProducer:
|
||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
print(f"[P{self.producer_idx}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
|
||||
def __del__(self):
|
||||
self.profiler.close()
|
||||
|
||||
|
||||
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||
class DummyProducer(BaseProducer):
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
)
|
||||
self.num_generations = num_generations
|
||||
self.max_length = generate_config.get("max_tokens", generate_config.get("max_length", 512))
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
# generate dummy rollouts
|
||||
device = get_current_device()
|
||||
# self.profiler.log(f"{input_ids.size()}, {attention_mask.size()}, {self.max_length}")
|
||||
num_new_tokens = self.max_length
|
||||
rollouts = {
|
||||
"input_ids": torch.cat(
|
||||
[
|
||||
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||
torch.ones(
|
||||
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||
dtype=input_ids.dtype,
|
||||
).to(device),
|
||||
],
|
||||
dim=-1,
|
||||
).to(device),
|
||||
"attention_mask": torch.cat(
|
||||
[
|
||||
torch.repeat_interleave(attention_mask.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||
torch.ones(
|
||||
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||
dtype=attention_mask.dtype,
|
||||
).to(device),
|
||||
],
|
||||
dim=-1,
|
||||
),
|
||||
"action_log_probs": torch.zeros(
|
||||
(input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.float32
|
||||
).to(device),
|
||||
"action_mask": torch.ones((input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.bool).to(
|
||||
device
|
||||
),
|
||||
"response_idx": torch.tensor(
|
||||
[[[input_ids.size(-1), input_ids.size(-1) + num_new_tokens]] * self.num_generations] * input_ids.size(0)
|
||||
)
|
||||
.to(device)
|
||||
.to(torch.int),
|
||||
}
|
||||
if "gt_answer" in kwargs:
|
||||
rollouts["gt_answer"] = kwargs["gt_answer"]
|
||||
if "test_cases" in kwargs:
|
||||
rollouts["test_cases"] = kwargs["test_cases"]
|
||||
return rollouts
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
if hasattr(self, "rollout_log_file"):
|
||||
self.rollout_log_file.close()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||
class SimpleProducer(BaseProducer):
|
||||
def __init__(
|
||||
|
@ -152,16 +152,21 @@ class CustomProfiler:
|
||||
self.pid = os.getpid()
|
||||
self.file = open(f"{name}.prof", "w")
|
||||
|
||||
def log(self, message):
|
||||
def _log(self, message):
|
||||
current_time = time.time()
|
||||
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
|
||||
self.file.flush()
|
||||
|
||||
def log(self, message):
|
||||
current_time = time.time()
|
||||
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
|
||||
self.file.flush()
|
||||
|
||||
def enter(self, event_name):
|
||||
self.log(f"Enter {event_name}")
|
||||
self._log(f"Enter {event_name}")
|
||||
|
||||
def exit(self, event_name):
|
||||
self.log(f"Exit {event_name}")
|
||||
self._log(f"Exit {event_name}")
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
86
applications/ColossalChat/profile_grpo.py
Normal file
86
applications/ColossalChat/profile_grpo.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Re-import required libraries due to kernel reset
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.dates as mdates
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Argument parser for command line arguments
|
||||
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
|
||||
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Raw log lines
|
||||
log_lines = []
|
||||
|
||||
import glob
|
||||
|
||||
files = glob.glob("*.prof")
|
||||
for file in files:
|
||||
with open(file, "r") as f:
|
||||
log_lines += f.readlines()
|
||||
|
||||
# Parse logs and collect function intervals grouped by actor
|
||||
actors = defaultdict(lambda: defaultdict(list))
|
||||
current_entries = {}
|
||||
|
||||
for line in log_lines:
|
||||
if line.startswith("[Log]"):
|
||||
continue
|
||||
parts = line.split()
|
||||
timestamp = float(parts[0])
|
||||
actor = parts[1]
|
||||
action = parts[3]
|
||||
func_name = parts[4]
|
||||
key = (actor, func_name)
|
||||
if action == "Enter":
|
||||
current_entries[key] = timestamp
|
||||
elif action == "Exit":
|
||||
start_time = current_entries.pop(key, None)
|
||||
if start_time is not None:
|
||||
actors[actor][func_name].append((start_time, timestamp))
|
||||
|
||||
# Plotting setup
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
colors = cm.get_cmap("tab10", len(actors))
|
||||
|
||||
actor_offsets = {}
|
||||
base_offset = 0
|
||||
function_spacing = 0.9
|
||||
|
||||
yticks = []
|
||||
yticklabels = []
|
||||
|
||||
for idx, (actor, func_dict) in enumerate(actors.items()):
|
||||
actor_offsets[actor] = base_offset
|
||||
color = colors(idx)
|
||||
for j, (func, intervals) in enumerate(func_dict.items()):
|
||||
y_val = base_offset + j * function_spacing
|
||||
yticks.append(y_val)
|
||||
yticklabels.append(f"{actor}:{func}")
|
||||
for start, end in intervals:
|
||||
ax.plot(
|
||||
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
|
||||
[y_val, y_val],
|
||||
color=color,
|
||||
linewidth=4,
|
||||
label=actor if j == 0 else "",
|
||||
)
|
||||
base_offset += len(func_dict) * function_spacing + 1
|
||||
|
||||
# Formatting
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
|
||||
ax.set_yticks(yticks)
|
||||
ax.set_yticklabels(yticklabels)
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_title("Timeline per Actor")
|
||||
# Remove duplicate labels in legend
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
unique = dict(zip(labels, handles))
|
||||
ax.legend(unique.values(), unique.keys())
|
||||
plt.tight_layout()
|
||||
plt.grid(True)
|
||||
plt.savefig(args.visualization)
|
||||
print(f"Plot saved as {args.visualization}")
|
@ -1,82 +0,0 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def parse_logs(log_file_path):
|
||||
logs_by_actor = defaultdict(list)
|
||||
|
||||
with open(log_file_path, "r") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
timestamp = float(parts[0])
|
||||
actor = parts[1]
|
||||
event = parts[3]
|
||||
function_name = " ".join(parts[4:])
|
||||
logs_by_actor[actor].append((timestamp, event, function_name))
|
||||
|
||||
return logs_by_actor
|
||||
|
||||
|
||||
def build_intervals(logs_by_actor):
|
||||
actor_intervals = defaultdict(list)
|
||||
|
||||
for actor, events in logs_by_actor.items():
|
||||
func_stack = {}
|
||||
for timestamp, event, func in events:
|
||||
(actor, func)
|
||||
if event == "Enter":
|
||||
func_stack[func] = timestamp
|
||||
elif event == "Exit" and func in func_stack:
|
||||
start = func_stack.pop(func)
|
||||
actor_intervals[actor].append((func, start, timestamp))
|
||||
|
||||
return actor_intervals
|
||||
|
||||
|
||||
def plot_actor_timelines(actor_intervals):
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
ytick_labels = []
|
||||
yticks = []
|
||||
color_map = plt.get_cmap("tab10")
|
||||
color_lookup = {}
|
||||
|
||||
y = 0
|
||||
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
|
||||
color_lookup[actor] = color_map(idx % 10)
|
||||
for func, start, end in intervals:
|
||||
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
|
||||
ax.text(start, y + 0.1, func, fontsize=8, color="black")
|
||||
yticks.append(y)
|
||||
ytick_labels.append(actor)
|
||||
y += 1
|
||||
|
||||
ax.set_yticks(yticks)
|
||||
ax.set_yticklabels(ytick_labels)
|
||||
ax.set_xlabel("Unix Timestamp")
|
||||
ax.set_title("Ray Actor Function Timeline")
|
||||
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
|
||||
|
||||
# Unique legend
|
||||
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
|
||||
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
# ==== Usage ====
|
||||
# Replace with your actual log file path
|
||||
import glob
|
||||
|
||||
files = glob.glob("*.prof")
|
||||
logs = {}
|
||||
for file in files:
|
||||
print(f"Processing file: {file}")
|
||||
logs_by_actor = parse_logs(log_file_path)
|
||||
logs.update(logs_by_actor)
|
||||
actor_intervals = build_intervals(logs)
|
||||
plot_actor_timelines(actor_intervals)
|
28
applications/ColossalChat/profiling.sh
Executable file
28
applications/ColossalChat/profiling.sh
Executable file
@ -0,0 +1,28 @@
|
||||
# profile under different setups
|
||||
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||
# zero2 ibs32 tbs32
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||
|
||||
# # zero2 ibs64 tbs32
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||
|
||||
# # zero2 ibs96 tbs32
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 24 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||
|
||||
# 4K
|
||||
# MAX_NEW_TOKENS=$((4096-512))
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mpt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
||||
|
||||
# # 32K
|
||||
# MAX_NEW_TOKENS=$((32768-512))
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 4 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
|
||||
|
||||
# 16K
|
||||
MAX_NEW_TOKENS=$((16384-512))
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
|
||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "GRPO-DUMMY-TEST"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
parser.add_argument(
|
||||
@ -172,7 +172,7 @@ if __name__ == "__main__":
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
@ -243,7 +243,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
if args.algo == "GRPO":
|
||||
if "GRPO" in args.algo:
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"lr": args.learning_rate,
|
||||
@ -318,6 +318,8 @@ if __name__ == "__main__":
|
||||
plugin_config={
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
# "num_layers_per_stage": [12,12,2,2],
|
||||
"num_layers_per_stage": [20, 8],
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
|
4
applications/ColossalChat/test_profiling.sh
Executable file
4
applications/ColossalChat/test_profiling.sh
Executable file
@ -0,0 +1,4 @@
|
||||
MAX_NEW_TOKENS=$((4096-512))
|
||||
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
|
||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.png
|
Loading…
Reference in New Issue
Block a user