mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-24 20:20:53 +00:00
add profile
This commit is contained in:
parent
8992def757
commit
43a0e99ae1
@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import CustomProfiler, bind_batch, post_recv, unbind_batch
|
||||||
|
|
||||||
|
|
||||||
class BaseConsumer:
|
class BaseConsumer:
|
||||||
@ -94,6 +94,7 @@ class BaseConsumer:
|
|||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
self.profiler = CustomProfiler(f"C{self.rank}")
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -117,9 +118,11 @@ class BaseConsumer:
|
|||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
|
self.profiler.enter(f"recv_broadcast_data_P{r}")
|
||||||
raw_batch = ray_broadcast_tensor_dict(
|
raw_batch = ray_broadcast_tensor_dict(
|
||||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||||
)
|
)
|
||||||
|
self.profiler.exit(f"recv_broadcast_data_P{r}")
|
||||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
# we need to calculate the metrics before filtering here for logging
|
# we need to calculate the metrics before filtering here for logging
|
||||||
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
||||||
@ -192,7 +195,9 @@ class BaseConsumer:
|
|||||||
}
|
}
|
||||||
batch = bind_batch([t[0] for t in batches])
|
batch = bind_batch([t[0] for t in batches])
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
|
self.profiler.enter("step")
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
|
self.profiler.exit("step")
|
||||||
self.buffer = self.buffer[
|
self.buffer = self.buffer[
|
||||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
]
|
]
|
||||||
@ -228,6 +233,7 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||||
|
self.profiler.enter("sync_model")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
@ -245,9 +251,14 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
self.profiler.exit("sync_model")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, "profiler"):
|
||||||
|
self.profiler.close()
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||||
class SimpleConsumer(BaseConsumer):
|
class SimpleConsumer(BaseConsumer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -3,18 +3,18 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
import wandb
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||||
class GRPOConsumer(BaseConsumer):
|
class GRPOConsumer(BaseConsumer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
|
|||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
lines = [line for line in lines if line.strip()]
|
lines = [line for line in lines if line.strip()]
|
||||||
return len(lines) - 1
|
return len(lines)
|
||||||
|
|
||||||
|
|
||||||
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
||||||
|
@ -7,7 +7,6 @@ import ray
|
|||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import wandb
|
|
||||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
@ -16,11 +15,12 @@ from ray.util.collective.types import Backend, ReduceOp
|
|||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import wandb
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import pre_send, safe_append_to_jsonl_file
|
from .utils import CustomProfiler, pre_send, safe_append_to_jsonl_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
@ -75,6 +75,7 @@ class BaseProducer:
|
|||||||
self.log_rollout_interval = log_rollout_interval
|
self.log_rollout_interval = log_rollout_interval
|
||||||
self.latest_rollout_log_step = -1
|
self.latest_rollout_log_step = -1
|
||||||
self.grpo_config = grpo_config
|
self.grpo_config = grpo_config
|
||||||
|
self.profiler = CustomProfiler(f"P{self.producer_idx}")
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in grpo_config.items()
|
for k, v in grpo_config.items()
|
||||||
@ -268,11 +269,14 @@ class BaseProducer:
|
|||||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||||
self.eval_mode = False
|
self.eval_mode = False
|
||||||
self.latest_eval_step = self.consumer_global_step
|
self.latest_eval_step = self.consumer_global_step
|
||||||
|
self.profiler.enter("rollout")
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
self.profiler.exit("rollout")
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||||
|
self.profiler.enter("calculate_reward")
|
||||||
if self.grpo_config["reward_fn_type"] == "code":
|
if self.grpo_config["reward_fn_type"] == "code":
|
||||||
test_cases = []
|
test_cases = []
|
||||||
for prompt_id in range(bs):
|
for prompt_id in range(bs):
|
||||||
@ -310,12 +314,15 @@ class BaseProducer:
|
|||||||
outputs.pop("gt_answer")
|
outputs.pop("gt_answer")
|
||||||
if "test_cases" in outputs:
|
if "test_cases" in outputs:
|
||||||
outputs.pop("test_cases")
|
outputs.pop("test_cases")
|
||||||
|
self.profiler.exit("calculate_reward")
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs = pre_send(outputs)
|
outputs = pre_send(outputs)
|
||||||
|
self.profiler.enter("send_broadcast_data")
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||||
)
|
)
|
||||||
|
self.profiler.exit("send_broadcast_data")
|
||||||
if (i + 1) % self.num_microbatches == 0 and (
|
if (i + 1) % self.num_microbatches == 0 and (
|
||||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||||
):
|
):
|
||||||
@ -324,6 +331,7 @@ class BaseProducer:
|
|||||||
):
|
):
|
||||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||||
# don't sync model for last iteration
|
# don't sync model for last iteration
|
||||||
|
self.profiler.enter("sync_model")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
@ -349,6 +357,7 @@ class BaseProducer:
|
|||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
self.profiler.exit("sync_model")
|
||||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
"enable_sleep_mode", False
|
"enable_sleep_mode", False
|
||||||
):
|
):
|
||||||
@ -364,8 +373,11 @@ class BaseProducer:
|
|||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.9
|
] + ratio * 0.9
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.profiler.close()
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||||
class SimpleProducer(BaseProducer):
|
class SimpleProducer(BaseProducer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -143,3 +144,25 @@ def safe_append_to_jsonl_file(file_path, data):
|
|||||||
for entry in data:
|
for entry in data:
|
||||||
json_line = json.dumps(entry, ensure_ascii=False)
|
json_line = json.dumps(entry, ensure_ascii=False)
|
||||||
f.write(json_line + "\n")
|
f.write(json_line + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomProfiler:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.pid = os.getpid()
|
||||||
|
self.file = open(f"{name}.prof", "w")
|
||||||
|
|
||||||
|
def log(self, message):
|
||||||
|
current_time = time.time()
|
||||||
|
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
|
||||||
|
self.file.flush()
|
||||||
|
|
||||||
|
def enter(self, event_name):
|
||||||
|
self.log(f"Enter {event_name}")
|
||||||
|
|
||||||
|
def exit(self, event_name):
|
||||||
|
self.log(f"Exit {event_name}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.file.close()
|
||||||
|
print(f"Profiler data written to {self.name}.prof")
|
||||||
|
82
applications/ColossalChat/profile_utils.py
Normal file
82
applications/ColossalChat/profile_utils.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def parse_logs(log_file_path):
|
||||||
|
logs_by_actor = defaultdict(list)
|
||||||
|
|
||||||
|
with open(log_file_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) < 5:
|
||||||
|
continue
|
||||||
|
timestamp = float(parts[0])
|
||||||
|
actor = parts[1]
|
||||||
|
event = parts[3]
|
||||||
|
function_name = " ".join(parts[4:])
|
||||||
|
logs_by_actor[actor].append((timestamp, event, function_name))
|
||||||
|
|
||||||
|
return logs_by_actor
|
||||||
|
|
||||||
|
|
||||||
|
def build_intervals(logs_by_actor):
|
||||||
|
actor_intervals = defaultdict(list)
|
||||||
|
|
||||||
|
for actor, events in logs_by_actor.items():
|
||||||
|
func_stack = {}
|
||||||
|
for timestamp, event, func in events:
|
||||||
|
(actor, func)
|
||||||
|
if event == "Enter":
|
||||||
|
func_stack[func] = timestamp
|
||||||
|
elif event == "Exit" and func in func_stack:
|
||||||
|
start = func_stack.pop(func)
|
||||||
|
actor_intervals[actor].append((func, start, timestamp))
|
||||||
|
|
||||||
|
return actor_intervals
|
||||||
|
|
||||||
|
|
||||||
|
def plot_actor_timelines(actor_intervals):
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
|
ytick_labels = []
|
||||||
|
yticks = []
|
||||||
|
color_map = plt.get_cmap("tab10")
|
||||||
|
color_lookup = {}
|
||||||
|
|
||||||
|
y = 0
|
||||||
|
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
|
||||||
|
color_lookup[actor] = color_map(idx % 10)
|
||||||
|
for func, start, end in intervals:
|
||||||
|
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
|
||||||
|
ax.text(start, y + 0.1, func, fontsize=8, color="black")
|
||||||
|
yticks.append(y)
|
||||||
|
ytick_labels.append(actor)
|
||||||
|
y += 1
|
||||||
|
|
||||||
|
ax.set_yticks(yticks)
|
||||||
|
ax.set_yticklabels(ytick_labels)
|
||||||
|
ax.set_xlabel("Unix Timestamp")
|
||||||
|
ax.set_title("Ray Actor Function Timeline")
|
||||||
|
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
|
||||||
|
|
||||||
|
# Unique legend
|
||||||
|
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
|
||||||
|
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# ==== Usage ====
|
||||||
|
# Replace with your actual log file path
|
||||||
|
import glob
|
||||||
|
|
||||||
|
files = glob.glob("*.prof")
|
||||||
|
logs = {}
|
||||||
|
for file in files:
|
||||||
|
print(f"Processing file: {file}")
|
||||||
|
logs_by_actor = parse_logs(log_file_path)
|
||||||
|
logs.update(logs_by_actor)
|
||||||
|
actor_intervals = build_intervals(logs)
|
||||||
|
plot_actor_timelines(actor_intervals)
|
@ -233,9 +233,10 @@ if __name__ == "__main__":
|
|||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=args.max_new_tokens, # max new tokens
|
max_tokens=args.max_new_tokens, # max new tokens
|
||||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
# for profiling, we let vllm generate till max_new_tokens is reached
|
||||||
|
stop=None, # ["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
|
ignore_eos=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||||
|
Loading…
Reference in New Issue
Block a user