add profile

This commit is contained in:
YeAnbang 2025-06-12 15:03:44 +08:00
parent 8992def757
commit 43a0e99ae1
7 changed files with 139 additions and 10 deletions

View File

@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
from .utils import CustomProfiler, bind_batch, post_recv, unbind_batch
class BaseConsumer:
@ -94,6 +94,7 @@ class BaseConsumer:
self.buffer = []
self.recv_cnt = 0
self.profiler = CustomProfiler(f"C{self.rank}")
def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@ -117,9 +118,11 @@ class BaseConsumer:
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.profiler.enter(f"recv_broadcast_data_P{r}")
raw_batch = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
self.profiler.exit(f"recv_broadcast_data_P{r}")
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@ -192,7 +195,9 @@ class BaseConsumer:
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
@ -228,6 +233,7 @@ class BaseConsumer:
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
@ -245,9 +251,14 @@ class BaseConsumer:
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
def __del__(self):
if hasattr(self, "profiler"):
self.profiler.close()
@ray.remote
@ray.remote # (runtime_env={ "nsight": "default"})
class SimpleConsumer(BaseConsumer):
def __init__(
self,

View File

@ -3,18 +3,18 @@ from typing import Any, Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@ray.remote
@ray.remote # (runtime_env={ "nsight": "default"})
class GRPOConsumer(BaseConsumer):
def __init__(
self,

View File

@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines) - 1
return len(lines)
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:

View File

@ -7,7 +7,6 @@ import ray
import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
@ -16,11 +15,12 @@ from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
import wandb
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_append_to_jsonl_file
from .utils import CustomProfiler, pre_send, safe_append_to_jsonl_file
try:
from vllm import SamplingParams
@ -75,6 +75,7 @@ class BaseProducer:
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.profiler = CustomProfiler(f"P{self.producer_idx}")
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
@ -268,11 +269,14 @@ class BaseProducer:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
@ -310,12 +314,15 @@ class BaseProducer:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
self.profiler.exit("send_broadcast_data")
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
):
@ -324,6 +331,7 @@ class BaseProducer:
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
if self.consumer_pp_size > 1:
@ -349,6 +357,7 @@ class BaseProducer:
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
@ -364,8 +373,11 @@ class BaseProducer:
"temperature"
] + ratio * 0.9
def __del__(self):
self.profiler.close()
@ray.remote
@ray.remote # (runtime_env={ "nsight": "default"})
class SimpleProducer(BaseProducer):
def __init__(
self,

View File

@ -1,5 +1,6 @@
import json
import os
import time
from typing import Any, Dict, List
import torch
@ -143,3 +144,25 @@ def safe_append_to_jsonl_file(file_path, data):
for entry in data:
json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n")
class CustomProfiler:
def __init__(self, name):
self.name = name
self.pid = os.getpid()
self.file = open(f"{name}.prof", "w")
def log(self, message):
current_time = time.time()
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
self.file.flush()
def enter(self, event_name):
self.log(f"Enter {event_name}")
def exit(self, event_name):
self.log(f"Exit {event_name}")
def close(self):
self.file.close()
print(f"Profiler data written to {self.name}.prof")

View File

@ -0,0 +1,82 @@
from collections import defaultdict
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
def parse_logs(log_file_path):
logs_by_actor = defaultdict(list)
with open(log_file_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) < 5:
continue
timestamp = float(parts[0])
actor = parts[1]
event = parts[3]
function_name = " ".join(parts[4:])
logs_by_actor[actor].append((timestamp, event, function_name))
return logs_by_actor
def build_intervals(logs_by_actor):
actor_intervals = defaultdict(list)
for actor, events in logs_by_actor.items():
func_stack = {}
for timestamp, event, func in events:
(actor, func)
if event == "Enter":
func_stack[func] = timestamp
elif event == "Exit" and func in func_stack:
start = func_stack.pop(func)
actor_intervals[actor].append((func, start, timestamp))
return actor_intervals
def plot_actor_timelines(actor_intervals):
fig, ax = plt.subplots(figsize=(12, 6))
ytick_labels = []
yticks = []
color_map = plt.get_cmap("tab10")
color_lookup = {}
y = 0
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
color_lookup[actor] = color_map(idx % 10)
for func, start, end in intervals:
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
ax.text(start, y + 0.1, func, fontsize=8, color="black")
yticks.append(y)
ytick_labels.append(actor)
y += 1
ax.set_yticks(yticks)
ax.set_yticklabels(ytick_labels)
ax.set_xlabel("Unix Timestamp")
ax.set_title("Ray Actor Function Timeline")
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
# Unique legend
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
# ==== Usage ====
# Replace with your actual log file path
import glob
files = glob.glob("*.prof")
logs = {}
for file in files:
print(f"Processing file: {file}")
logs_by_actor = parse_logs(log_file_path)
logs.update(logs_by_actor)
actor_intervals = build_intervals(logs)
plot_actor_timelines(actor_intervals)

View File

@ -233,9 +233,10 @@ if __name__ == "__main__":
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
# for profiling, we let vllm generate till max_new_tokens is reached
stop=None, # ["</answer>"] if args.reward_type == "think_answer_tags" else None,
ignore_eos=True,
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation