add profile

This commit is contained in:
YeAnbang 2025-06-12 15:03:44 +08:00
parent 8992def757
commit 43a0e99ae1
7 changed files with 139 additions and 10 deletions

View File

@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch from .utils import CustomProfiler, bind_batch, post_recv, unbind_batch
class BaseConsumer: class BaseConsumer:
@ -94,6 +94,7 @@ class BaseConsumer:
self.buffer = [] self.buffer = []
self.recv_cnt = 0 self.recv_cnt = 0
self.profiler = CustomProfiler(f"C{self.rank}")
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
@ -117,9 +118,11 @@ class BaseConsumer:
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.profiler.enter(f"recv_broadcast_data_P{r}")
raw_batch = ray_broadcast_tensor_dict( raw_batch = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}" None, src=0, device=self.device, group_name=f"sync_data_{r}"
) )
self.profiler.exit(f"recv_broadcast_data_P{r}")
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging # we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@ -192,7 +195,9 @@ class BaseConsumer:
} }
batch = bind_batch([t[0] for t in batches]) batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch) batch = post_recv(batch)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[ self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
] ]
@ -228,6 +233,7 @@ class BaseConsumer:
) )
else: else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
self.profiler.enter("sync_model")
torch.cuda.empty_cache() torch.cuda.empty_cache()
state_dict = self.state_dict() state_dict = self.state_dict()
if self.pp_size > 1: if self.pp_size > 1:
@ -245,9 +251,14 @@ class BaseConsumer:
) )
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.exit("sync_model")
def __del__(self):
if hasattr(self, "profiler"):
self.profiler.close()
@ray.remote @ray.remote # (runtime_env={ "nsight": "default"})
class SimpleConsumer(BaseConsumer): class SimpleConsumer(BaseConsumer):
def __init__( def __init__(
self, self,

View File

@ -3,18 +3,18 @@ from typing import Any, Optional
import ray import ray
import torch import torch
import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ray.remote @ray.remote # (runtime_env={ "nsight": "default"})
class GRPOConsumer(BaseConsumer): class GRPOConsumer(BaseConsumer):
def __init__( def __init__(
self, self,

View File

@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
with open(path) as f: with open(path) as f:
lines = f.readlines() lines = f.readlines()
lines = [line for line in lines if line.strip()] lines = [line for line in lines if line.strip()]
return len(lines) - 1 return len(lines)
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:

View File

@ -7,7 +7,6 @@ import ray
import ray.util.collective as cc import ray.util.collective as cc
import torch import torch
import tqdm import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
@ -16,11 +15,12 @@ from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
import wandb
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_append_to_jsonl_file from .utils import CustomProfiler, pre_send, safe_append_to_jsonl_file
try: try:
from vllm import SamplingParams from vllm import SamplingParams
@ -75,6 +75,7 @@ class BaseProducer:
self.log_rollout_interval = log_rollout_interval self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1 self.latest_rollout_log_step = -1
self.grpo_config = grpo_config self.grpo_config = grpo_config
self.profiler = CustomProfiler(f"P{self.producer_idx}")
reward_model_kwargs = { reward_model_kwargs = {
k: v k: v
for k, v in grpo_config.items() for k, v in grpo_config.items()
@ -268,11 +269,14 @@ class BaseProducer:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False self.eval_mode = False
self.latest_eval_step = self.consumer_global_step self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code": if self.grpo_config["reward_fn_type"] == "code":
test_cases = [] test_cases = []
for prompt_id in range(bs): for prompt_id in range(bs):
@ -310,12 +314,15 @@ class BaseProducer:
outputs.pop("gt_answer") outputs.pop("gt_answer")
if "test_cases" in outputs: if "test_cases" in outputs:
outputs.pop("test_cases") outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs) outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
) )
self.profiler.exit("send_broadcast_data")
if (i + 1) % self.num_microbatches == 0 and ( if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
): ):
@ -324,6 +331,7 @@ class BaseProducer:
): ):
self.model.llm.sleep() # revict KV_cache to avoid OOM self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration # don't sync model for last iteration
self.profiler.enter("sync_model")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.consumer_pp_size > 1: if self.consumer_pp_size > 1:
@ -349,6 +357,7 @@ class BaseProducer:
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.exit("sync_model")
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False "enable_sleep_mode", False
): ):
@ -364,8 +373,11 @@ class BaseProducer:
"temperature" "temperature"
] + ratio * 0.9 ] + ratio * 0.9
def __del__(self):
self.profiler.close()
@ray.remote
@ray.remote # (runtime_env={ "nsight": "default"})
class SimpleProducer(BaseProducer): class SimpleProducer(BaseProducer):
def __init__( def __init__(
self, self,

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import time
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
@ -143,3 +144,25 @@ def safe_append_to_jsonl_file(file_path, data):
for entry in data: for entry in data:
json_line = json.dumps(entry, ensure_ascii=False) json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n") f.write(json_line + "\n")
class CustomProfiler:
def __init__(self, name):
self.name = name
self.pid = os.getpid()
self.file = open(f"{name}.prof", "w")
def log(self, message):
current_time = time.time()
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def enter(self, event_name):
self.log(f"Enter {event_name}")
def exit(self, event_name):
self.log(f"Exit {event_name}")
def close(self):
self.file.close()
print(f"Profiler data written to {self.name}.prof")

View File

@ -0,0 +1,82 @@
from collections import defaultdict
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
def parse_logs(log_file_path):
logs_by_actor = defaultdict(list)
with open(log_file_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) < 5:
continue
timestamp = float(parts[0])
actor = parts[1]
event = parts[3]
function_name = " ".join(parts[4:])
logs_by_actor[actor].append((timestamp, event, function_name))
return logs_by_actor
def build_intervals(logs_by_actor):
actor_intervals = defaultdict(list)
for actor, events in logs_by_actor.items():
func_stack = {}
for timestamp, event, func in events:
(actor, func)
if event == "Enter":
func_stack[func] = timestamp
elif event == "Exit" and func in func_stack:
start = func_stack.pop(func)
actor_intervals[actor].append((func, start, timestamp))
return actor_intervals
def plot_actor_timelines(actor_intervals):
fig, ax = plt.subplots(figsize=(12, 6))
ytick_labels = []
yticks = []
color_map = plt.get_cmap("tab10")
color_lookup = {}
y = 0
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
color_lookup[actor] = color_map(idx % 10)
for func, start, end in intervals:
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
ax.text(start, y + 0.1, func, fontsize=8, color="black")
yticks.append(y)
ytick_labels.append(actor)
y += 1
ax.set_yticks(yticks)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Unix Timestamp")
ax.set_title("Ray Actor Function Timeline")
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
# Unique legend
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
# ==== Usage ====
# Replace with your actual log file path
import glob
files = glob.glob("*.prof")
logs = {}
for file in files:
print(f"Processing file: {file}")
logs_by_actor = parse_logs(log_file_path)
logs.update(logs_by_actor)
actor_intervals = build_intervals(logs)
plot_actor_timelines(actor_intervals)

View File

@ -233,9 +233,10 @@ if __name__ == "__main__":
generate_config.update( generate_config.update(
dict( dict(
max_tokens=args.max_new_tokens, # max new tokens max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None, # for profiling, we let vllm generate till max_new_tokens is reached
stop=None, # ["</answer>"] if args.reward_type == "think_answer_tags" else None,
ignore_eos=True,
) )
) )
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation