From 43a0e99ae1074a6b44ef0f268e9b33c4f7a0d23b Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 12 Jun 2025 15:03:44 +0800 Subject: [PATCH] add profile --- .../coati/distributed/consumer.py | 15 +++- .../coati/distributed/grpo_consumer.py | 4 +- .../ColossalChat/coati/distributed/launch.py | 2 +- .../coati/distributed/producer.py | 18 +++- .../ColossalChat/coati/distributed/utils.py | 23 ++++++ applications/ColossalChat/profile_utils.py | 82 +++++++++++++++++++ applications/ColossalChat/rl_example.py | 5 +- 7 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 applications/ColossalChat/profile_utils.py diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a7abb1588..a1eac136a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict -from .utils import bind_batch, post_recv, unbind_batch +from .utils import CustomProfiler, bind_batch, post_recv, unbind_batch class BaseConsumer: @@ -94,6 +94,7 @@ class BaseConsumer: self.buffer = [] self.recv_cnt = 0 + self.profiler = CustomProfiler(f"C{self.rank}") def state_dict(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -117,9 +118,11 @@ class BaseConsumer: # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") + self.profiler.enter(f"recv_broadcast_data_P{r}") raw_batch = ray_broadcast_tensor_dict( None, src=0, device=self.device, group_name=f"sync_data_{r}" ) + self.profiler.exit(f"recv_broadcast_data_P{r}") # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] @@ -192,7 +195,9 @@ class BaseConsumer: } batch = bind_batch([t[0] for t in batches]) batch = post_recv(batch) + self.profiler.enter("step") loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") self.buffer = self.buffer[ effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : ] @@ -228,6 +233,7 @@ class BaseConsumer: ) else: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + self.profiler.enter("sync_model") torch.cuda.empty_cache() state_dict = self.state_dict() if self.pp_size > 1: @@ -245,9 +251,14 @@ class BaseConsumer: ) del state_dict torch.cuda.empty_cache() + self.profiler.exit("sync_model") + + def __del__(self): + if hasattr(self, "profiler"): + self.profiler.close() -@ray.remote +@ray.remote # (runtime_env={ "nsight": "default"}) class SimpleConsumer(BaseConsumer): def __init__( self, diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 51cdcf322..2b846b669 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -3,18 +3,18 @@ from typing import Any, Optional import ray import torch -import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer +import wandb from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -@ray.remote +@ray.remote # (runtime_env={ "nsight": "default"}) class GRPOConsumer(BaseConsumer): def __init__( self, diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 41ba8ea55..b666d083a 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int: with open(path) as f: lines = f.readlines() lines = [line for line in lines if line.strip()] - return len(lines) - 1 + return len(lines) def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 854c2fcc2..830ebcba6 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -7,7 +7,6 @@ import ray import ray.util.collective as cc import torch import tqdm -import wandb from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward @@ -16,11 +15,12 @@ from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer +import wandb from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send, safe_append_to_jsonl_file +from .utils import CustomProfiler, pre_send, safe_append_to_jsonl_file try: from vllm import SamplingParams @@ -75,6 +75,7 @@ class BaseProducer: self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 self.grpo_config = grpo_config + self.profiler = CustomProfiler(f"P{self.producer_idx}") reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -268,11 +269,14 @@ class BaseProducer: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.eval_mode = False self.latest_eval_step = self.consumer_global_step + self.profiler.enter("rollout") outputs = self.rollout(**batch) + self.profiler.exit("rollout") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) + self.profiler.enter("calculate_reward") if self.grpo_config["reward_fn_type"] == "code": test_cases = [] for prompt_id in range(bs): @@ -310,12 +314,15 @@ class BaseProducer: outputs.pop("gt_answer") if "test_cases" in outputs: outputs.pop("test_cases") + self.profiler.exit("calculate_reward") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) + self.profiler.enter("send_broadcast_data") ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) + self.profiler.exit("send_broadcast_data") if (i + 1) % self.num_microbatches == 0 and ( episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 ): @@ -324,6 +331,7 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration + self.profiler.enter("sync_model") torch.cuda.empty_cache() if self.consumer_pp_size > 1: @@ -349,6 +357,7 @@ class BaseProducer: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() + self.profiler.exit("sync_model") if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): @@ -364,8 +373,11 @@ class BaseProducer: "temperature" ] + ratio * 0.9 + def __del__(self): + self.profiler.close() -@ray.remote + +@ray.remote # (runtime_env={ "nsight": "default"}) class SimpleProducer(BaseProducer): def __init__( self, diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index a40ebbcfb..6d4169ee0 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,5 +1,6 @@ import json import os +import time from typing import Any, Dict, List import torch @@ -143,3 +144,25 @@ def safe_append_to_jsonl_file(file_path, data): for entry in data: json_line = json.dumps(entry, ensure_ascii=False) f.write(json_line + "\n") + + +class CustomProfiler: + def __init__(self, name): + self.name = name + self.pid = os.getpid() + self.file = open(f"{name}.prof", "w") + + def log(self, message): + current_time = time.time() + self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n") + self.file.flush() + + def enter(self, event_name): + self.log(f"Enter {event_name}") + + def exit(self, event_name): + self.log(f"Exit {event_name}") + + def close(self): + self.file.close() + print(f"Profiler data written to {self.name}.prof") diff --git a/applications/ColossalChat/profile_utils.py b/applications/ColossalChat/profile_utils.py new file mode 100644 index 000000000..a0a2ee4da --- /dev/null +++ b/applications/ColossalChat/profile_utils.py @@ -0,0 +1,82 @@ +from collections import defaultdict + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt + + +def parse_logs(log_file_path): + logs_by_actor = defaultdict(list) + + with open(log_file_path, "r") as f: + for line in f: + parts = line.strip().split() + if len(parts) < 5: + continue + timestamp = float(parts[0]) + actor = parts[1] + event = parts[3] + function_name = " ".join(parts[4:]) + logs_by_actor[actor].append((timestamp, event, function_name)) + + return logs_by_actor + + +def build_intervals(logs_by_actor): + actor_intervals = defaultdict(list) + + for actor, events in logs_by_actor.items(): + func_stack = {} + for timestamp, event, func in events: + (actor, func) + if event == "Enter": + func_stack[func] = timestamp + elif event == "Exit" and func in func_stack: + start = func_stack.pop(func) + actor_intervals[actor].append((func, start, timestamp)) + + return actor_intervals + + +def plot_actor_timelines(actor_intervals): + fig, ax = plt.subplots(figsize=(12, 6)) + ytick_labels = [] + yticks = [] + color_map = plt.get_cmap("tab10") + color_lookup = {} + + y = 0 + for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())): + color_lookup[actor] = color_map(idx % 10) + for func, start, end in intervals: + ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor) + ax.text(start, y + 0.1, func, fontsize=8, color="black") + yticks.append(y) + ytick_labels.append(actor) + y += 1 + + ax.set_yticks(yticks) + ax.set_yticklabels(ytick_labels) + ax.set_xlabel("Unix Timestamp") + ax.set_title("Ray Actor Function Timeline") + ax.grid(True, axis="x", linestyle="--", alpha=0.6) + + # Unique legend + handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup] + ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1)) + + plt.tight_layout() + plt.show() + + +# ==== Usage ==== +# Replace with your actual log file path +import glob + +files = glob.glob("*.prof") +logs = {} +for file in files: + print(f"Processing file: {file}") + logs_by_actor = parse_logs(log_file_path) + logs.update(logs_by_actor) +actor_intervals = build_intervals(logs) +plot_actor_timelines(actor_intervals) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 0536a746d..1a2db4063 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -233,9 +233,10 @@ if __name__ == "__main__": generate_config.update( dict( max_tokens=args.max_new_tokens, # max new tokens - ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, - stop=[""] if args.reward_type == "think_answer_tags" else None, + # for profiling, we let vllm generate till max_new_tokens is reached + stop=None, # [""] if args.reward_type == "think_answer_tags" else None, + ignore_eos=True, ) ) eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation