mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-25 20:46:13 +00:00
add profiling
This commit is contained in:
parent
43a0e99ae1
commit
ff1689b69a
2
.gitignore
vendored
2
.gitignore
vendored
@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
|
|||||||
applications/ColossalChat/*.db
|
applications/ColossalChat/*.db
|
||||||
applications/ColossalChat/stdin
|
applications/ColossalChat/stdin
|
||||||
applications/ColossalChat/*.zip
|
applications/ColossalChat/*.zip
|
||||||
|
applications/ColossalChat/*.prof
|
||||||
|
applications/ColossalChat/*.png
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -79,6 +80,8 @@ class BaseConsumer:
|
|||||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
# Init Hybrid ray process group
|
# Init Hybrid ray process group
|
||||||
for i in range(self.num_producers):
|
for i in range(self.num_producers):
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||||
@ -106,6 +109,8 @@ class BaseConsumer:
|
|||||||
print(
|
print(
|
||||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||||
)
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
total_step = 0
|
||||||
for episode in range(self.num_episodes):
|
for episode in range(self.num_episodes):
|
||||||
with tqdm(
|
with tqdm(
|
||||||
range(self.num_update_per_episode),
|
range(self.num_update_per_episode),
|
||||||
@ -197,6 +202,7 @@ class BaseConsumer:
|
|||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
self.profiler.enter("step")
|
self.profiler.enter("step")
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
|
total_step += 1
|
||||||
self.profiler.exit("step")
|
self.profiler.exit("step")
|
||||||
self.buffer = self.buffer[
|
self.buffer = self.buffer[
|
||||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
@ -252,6 +258,8 @@ class BaseConsumer:
|
|||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.profiler.exit("sync_model")
|
self.profiler.exit("sync_model")
|
||||||
|
print(f"[T{self.rank}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
print(f"Average running time per step: {(time.time() - start_time) / total_step:.2f} seconds")
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, "profiler"):
|
if hasattr(self, "profiler"):
|
||||||
|
@ -3,13 +3,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import wandb
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
@ -258,9 +258,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if self.plugin.pp_size > 1:
|
if self.plugin.pp_size > 1:
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
# Support training with PP.
|
# Support training with PP.
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
reference_model_outputs = self.booster.execute_pipeline(
|
reference_model_outputs = self.booster.execute_pipeline(
|
||||||
iter(
|
iter(
|
||||||
[
|
[
|
||||||
@ -278,14 +280,32 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
return_loss=False,
|
return_loss=False,
|
||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
|
self.profiler.log(
|
||||||
|
f"reference_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
|
)
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
# breakpoint()
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
torch.cuda.reset_peak_memory_stats()
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_action_log_probs = torch.zeros(
|
||||||
input_ids_forward_micro_batch,
|
(input_ids_forward_micro_batch.size(0), num_action),
|
||||||
num_action,
|
device=input_ids_forward_micro_batch.device,
|
||||||
self.plugin.shard_config,
|
)
|
||||||
|
for i in range(reference_action_log_probs.size(0)):
|
||||||
|
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||||
|
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||||
|
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||||
|
reference_action_log_probs[i, :] += calc_action_log_probs(
|
||||||
|
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
||||||
|
/ self.generate_config["temperature"],
|
||||||
|
input_ids_forward_micro_batch[i : i + 1],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)[0]
|
||||||
|
# breakpoint()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
self.profiler.log(
|
||||||
|
f"reference_action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Dummy reference logprobs for data iterator.
|
# Dummy reference logprobs for data iterator.
|
||||||
@ -308,12 +328,26 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = torch.zeros(
|
||||||
action_logits / self.generate_config["temperature"],
|
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
||||||
inputs["input_ids"],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
)
|
||||||
|
# breakpoint()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
for i in range(action_log_probs.size(0)):
|
||||||
|
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||||
|
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||||
|
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||||
|
action_log_probs[i, :] += calc_action_log_probs(
|
||||||
|
action_logits[i : i + 1] / self.generate_config["temperature"],
|
||||||
|
inputs["input_ids"][i : i + 1],
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)[0]
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
self.profiler.log(
|
||||||
|
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
|
)
|
||||||
|
# breakpoint()
|
||||||
if "reference_action_log_probs" in inputs:
|
if "reference_action_log_probs" in inputs:
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
@ -347,6 +381,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
return_loss=True,
|
return_loss=True,
|
||||||
return_outputs=False,
|
return_outputs=False,
|
||||||
)
|
)
|
||||||
|
self.profiler.log(
|
||||||
|
f"policy_model_forward_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
|
)
|
||||||
loss = policy_model_outputs["loss"]
|
loss = policy_model_outputs["loss"]
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
@ -373,6 +410,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
attention_mask=attention_mask_forward_micro_batch,
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
).logits
|
).logits
|
||||||
|
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
@ -422,6 +461,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
self.accum_count += 1
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
|
# breakpoint()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
@ -429,6 +471,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
|
self.profiler.log(
|
||||||
|
f"optimizer_step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
|
)
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
|
@ -7,9 +7,15 @@ import ray
|
|||||||
|
|
||||||
from .consumer import SimpleConsumer
|
from .consumer import SimpleConsumer
|
||||||
from .grpo_consumer import GRPOConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import DummyProducer, SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "GRPO-DUMMY-TEST": GRPOConsumer}
|
||||||
|
PRODUCER_MAP = {
|
||||||
|
"Simple": SimpleProducer,
|
||||||
|
"GRPO": SimpleProducer,
|
||||||
|
"DAPO": SimpleProducer,
|
||||||
|
"GRPO-DUMMY-TEST": DummyProducer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
@ -62,6 +68,7 @@ def launch_distributed(
|
|||||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||||
else:
|
else:
|
||||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||||
|
core_producer = PRODUCER_MAP.get(core_algo, SimpleProducer)
|
||||||
|
|
||||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||||
@ -81,7 +88,7 @@ def launch_distributed(
|
|||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
producer = core_producer.options(num_gpus=num_proc_per_producer).remote(
|
||||||
producer_idx=i,
|
producer_idx=i,
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_consumer_procs=num_consumer_procs,
|
num_consumer_procs=num_consumer_procs,
|
||||||
|
@ -7,6 +7,7 @@ import ray
|
|||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import wandb
|
||||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
@ -15,7 +16,6 @@ from ray.util.collective.types import Backend, ReduceOp
|
|||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
import wandb
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
@ -76,6 +76,7 @@ class BaseProducer:
|
|||||||
self.latest_rollout_log_step = -1
|
self.latest_rollout_log_step = -1
|
||||||
self.grpo_config = grpo_config
|
self.grpo_config = grpo_config
|
||||||
self.profiler = CustomProfiler(f"P{self.producer_idx}")
|
self.profiler = CustomProfiler(f"P{self.producer_idx}")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in grpo_config.items()
|
for k, v in grpo_config.items()
|
||||||
@ -372,11 +373,126 @@ class BaseProducer:
|
|||||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.9
|
] + ratio * 0.9
|
||||||
|
print(f"[P{self.producer_idx}] Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.profiler.close()
|
self.profiler.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||||
|
class DummyProducer(BaseProducer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
producer_idx,
|
||||||
|
num_producers,
|
||||||
|
num_consumer_procs,
|
||||||
|
num_episodes,
|
||||||
|
batch_size,
|
||||||
|
train_dataset_config,
|
||||||
|
model_config,
|
||||||
|
generate_config,
|
||||||
|
tokenizer_config=None,
|
||||||
|
microbatch_size=1,
|
||||||
|
backend="transformers",
|
||||||
|
num_generations: int = 8,
|
||||||
|
consumer_plugin_config=None,
|
||||||
|
eval_dataset_config=None,
|
||||||
|
eval_interval=-1, # disable evaluation
|
||||||
|
grpo_config: Dict[str, Any] = None,
|
||||||
|
eval_save_dir: str = "./eval",
|
||||||
|
eval_generation_config={},
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
|
log_rollout_interval: int = 20,
|
||||||
|
rollout_log_file: str = "./rollout_log.jsonl",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
producer_idx,
|
||||||
|
num_producers,
|
||||||
|
num_consumer_procs,
|
||||||
|
num_episodes,
|
||||||
|
batch_size,
|
||||||
|
train_dataset_config,
|
||||||
|
model_config,
|
||||||
|
generate_config,
|
||||||
|
tokenizer_config,
|
||||||
|
microbatch_size,
|
||||||
|
backend,
|
||||||
|
consumer_plugin_config,
|
||||||
|
eval_dataset_config=eval_dataset_config,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
grpo_config=grpo_config,
|
||||||
|
eval_save_dir=eval_save_dir,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
|
log_rollout_interval=log_rollout_interval,
|
||||||
|
rollout_log_file=rollout_log_file,
|
||||||
|
)
|
||||||
|
self.num_generations = num_generations
|
||||||
|
self.max_length = generate_config.get("max_tokens", generate_config.get("max_length", 512))
|
||||||
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
|
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||||
|
self.eval_generation_config.update(eval_generation_config)
|
||||||
|
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
|
# generate dummy rollouts
|
||||||
|
device = get_current_device()
|
||||||
|
# self.profiler.log(f"{input_ids.size()}, {attention_mask.size()}, {self.max_length}")
|
||||||
|
num_new_tokens = self.max_length
|
||||||
|
rollouts = {
|
||||||
|
"input_ids": torch.cat(
|
||||||
|
[
|
||||||
|
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||||
|
torch.ones(
|
||||||
|
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
).to(device),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).to(device),
|
||||||
|
"attention_mask": torch.cat(
|
||||||
|
[
|
||||||
|
torch.repeat_interleave(attention_mask.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||||
|
torch.ones(
|
||||||
|
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
).to(device),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
"action_log_probs": torch.zeros(
|
||||||
|
(input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.float32
|
||||||
|
).to(device),
|
||||||
|
"action_mask": torch.ones((input_ids.size(0), self.num_generations, num_new_tokens), dtype=torch.bool).to(
|
||||||
|
device
|
||||||
|
),
|
||||||
|
"response_idx": torch.tensor(
|
||||||
|
[[[input_ids.size(-1), input_ids.size(-1) + num_new_tokens]] * self.num_generations] * input_ids.size(0)
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.to(torch.int),
|
||||||
|
}
|
||||||
|
if "gt_answer" in kwargs:
|
||||||
|
rollouts["gt_answer"] = kwargs["gt_answer"]
|
||||||
|
if "test_cases" in kwargs:
|
||||||
|
rollouts["test_cases"] = kwargs["test_cases"]
|
||||||
|
return rollouts
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
self.wandb_run.finish()
|
||||||
|
if hasattr(self, "rollout_log_file"):
|
||||||
|
self.rollout_log_file.close()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote # (runtime_env={ "nsight": "default"})
|
@ray.remote # (runtime_env={ "nsight": "default"})
|
||||||
class SimpleProducer(BaseProducer):
|
class SimpleProducer(BaseProducer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -152,16 +152,21 @@ class CustomProfiler:
|
|||||||
self.pid = os.getpid()
|
self.pid = os.getpid()
|
||||||
self.file = open(f"{name}.prof", "w")
|
self.file = open(f"{name}.prof", "w")
|
||||||
|
|
||||||
def log(self, message):
|
def _log(self, message):
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
|
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
|
def log(self, message):
|
||||||
|
current_time = time.time()
|
||||||
|
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
|
||||||
|
self.file.flush()
|
||||||
|
|
||||||
def enter(self, event_name):
|
def enter(self, event_name):
|
||||||
self.log(f"Enter {event_name}")
|
self._log(f"Enter {event_name}")
|
||||||
|
|
||||||
def exit(self, event_name):
|
def exit(self, event_name):
|
||||||
self.log(f"Exit {event_name}")
|
self._log(f"Exit {event_name}")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.file.close()
|
self.file.close()
|
||||||
|
86
applications/ColossalChat/profile_grpo.py
Normal file
86
applications/ColossalChat/profile_grpo.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Re-import required libraries due to kernel reset
|
||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib.dates as mdates
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Argument parser for command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
|
||||||
|
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Raw log lines
|
||||||
|
log_lines = []
|
||||||
|
|
||||||
|
import glob
|
||||||
|
|
||||||
|
files = glob.glob("*.prof")
|
||||||
|
for file in files:
|
||||||
|
with open(file, "r") as f:
|
||||||
|
log_lines += f.readlines()
|
||||||
|
|
||||||
|
# Parse logs and collect function intervals grouped by actor
|
||||||
|
actors = defaultdict(lambda: defaultdict(list))
|
||||||
|
current_entries = {}
|
||||||
|
|
||||||
|
for line in log_lines:
|
||||||
|
if line.startswith("[Log]"):
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
timestamp = float(parts[0])
|
||||||
|
actor = parts[1]
|
||||||
|
action = parts[3]
|
||||||
|
func_name = parts[4]
|
||||||
|
key = (actor, func_name)
|
||||||
|
if action == "Enter":
|
||||||
|
current_entries[key] = timestamp
|
||||||
|
elif action == "Exit":
|
||||||
|
start_time = current_entries.pop(key, None)
|
||||||
|
if start_time is not None:
|
||||||
|
actors[actor][func_name].append((start_time, timestamp))
|
||||||
|
|
||||||
|
# Plotting setup
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
|
colors = cm.get_cmap("tab10", len(actors))
|
||||||
|
|
||||||
|
actor_offsets = {}
|
||||||
|
base_offset = 0
|
||||||
|
function_spacing = 0.9
|
||||||
|
|
||||||
|
yticks = []
|
||||||
|
yticklabels = []
|
||||||
|
|
||||||
|
for idx, (actor, func_dict) in enumerate(actors.items()):
|
||||||
|
actor_offsets[actor] = base_offset
|
||||||
|
color = colors(idx)
|
||||||
|
for j, (func, intervals) in enumerate(func_dict.items()):
|
||||||
|
y_val = base_offset + j * function_spacing
|
||||||
|
yticks.append(y_val)
|
||||||
|
yticklabels.append(f"{actor}:{func}")
|
||||||
|
for start, end in intervals:
|
||||||
|
ax.plot(
|
||||||
|
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
|
||||||
|
[y_val, y_val],
|
||||||
|
color=color,
|
||||||
|
linewidth=4,
|
||||||
|
label=actor if j == 0 else "",
|
||||||
|
)
|
||||||
|
base_offset += len(func_dict) * function_spacing + 1
|
||||||
|
|
||||||
|
# Formatting
|
||||||
|
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
|
||||||
|
ax.set_yticks(yticks)
|
||||||
|
ax.set_yticklabels(yticklabels)
|
||||||
|
ax.set_xlabel("Time")
|
||||||
|
ax.set_title("Timeline per Actor")
|
||||||
|
# Remove duplicate labels in legend
|
||||||
|
handles, labels = ax.get_legend_handles_labels()
|
||||||
|
unique = dict(zip(labels, handles))
|
||||||
|
ax.legend(unique.values(), unique.keys())
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.savefig(args.visualization)
|
||||||
|
print(f"Plot saved as {args.visualization}")
|
@ -1,82 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import matplotlib.patches as mpatches
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def parse_logs(log_file_path):
|
|
||||||
logs_by_actor = defaultdict(list)
|
|
||||||
|
|
||||||
with open(log_file_path, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
parts = line.strip().split()
|
|
||||||
if len(parts) < 5:
|
|
||||||
continue
|
|
||||||
timestamp = float(parts[0])
|
|
||||||
actor = parts[1]
|
|
||||||
event = parts[3]
|
|
||||||
function_name = " ".join(parts[4:])
|
|
||||||
logs_by_actor[actor].append((timestamp, event, function_name))
|
|
||||||
|
|
||||||
return logs_by_actor
|
|
||||||
|
|
||||||
|
|
||||||
def build_intervals(logs_by_actor):
|
|
||||||
actor_intervals = defaultdict(list)
|
|
||||||
|
|
||||||
for actor, events in logs_by_actor.items():
|
|
||||||
func_stack = {}
|
|
||||||
for timestamp, event, func in events:
|
|
||||||
(actor, func)
|
|
||||||
if event == "Enter":
|
|
||||||
func_stack[func] = timestamp
|
|
||||||
elif event == "Exit" and func in func_stack:
|
|
||||||
start = func_stack.pop(func)
|
|
||||||
actor_intervals[actor].append((func, start, timestamp))
|
|
||||||
|
|
||||||
return actor_intervals
|
|
||||||
|
|
||||||
|
|
||||||
def plot_actor_timelines(actor_intervals):
|
|
||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
|
||||||
ytick_labels = []
|
|
||||||
yticks = []
|
|
||||||
color_map = plt.get_cmap("tab10")
|
|
||||||
color_lookup = {}
|
|
||||||
|
|
||||||
y = 0
|
|
||||||
for idx, (actor, intervals) in enumerate(sorted(actor_intervals.items())):
|
|
||||||
color_lookup[actor] = color_map(idx % 10)
|
|
||||||
for func, start, end in intervals:
|
|
||||||
ax.barh(y, end - start, left=start, height=0.3, color=color_lookup[actor], label=actor)
|
|
||||||
ax.text(start, y + 0.1, func, fontsize=8, color="black")
|
|
||||||
yticks.append(y)
|
|
||||||
ytick_labels.append(actor)
|
|
||||||
y += 1
|
|
||||||
|
|
||||||
ax.set_yticks(yticks)
|
|
||||||
ax.set_yticklabels(ytick_labels)
|
|
||||||
ax.set_xlabel("Unix Timestamp")
|
|
||||||
ax.set_title("Ray Actor Function Timeline")
|
|
||||||
ax.grid(True, axis="x", linestyle="--", alpha=0.6)
|
|
||||||
|
|
||||||
# Unique legend
|
|
||||||
handles = [mpatches.Patch(color=color_lookup[a], label=a) for a in color_lookup]
|
|
||||||
ax.legend(handles=handles, title="Actors", loc="upper left", bbox_to_anchor=(1, 1))
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# ==== Usage ====
|
|
||||||
# Replace with your actual log file path
|
|
||||||
import glob
|
|
||||||
|
|
||||||
files = glob.glob("*.prof")
|
|
||||||
logs = {}
|
|
||||||
for file in files:
|
|
||||||
print(f"Processing file: {file}")
|
|
||||||
logs_by_actor = parse_logs(log_file_path)
|
|
||||||
logs.update(logs_by_actor)
|
|
||||||
actor_intervals = build_intervals(logs)
|
|
||||||
plot_actor_timelines(actor_intervals)
|
|
28
applications/ColossalChat/profiling.sh
Executable file
28
applications/ColossalChat/profiling.sh
Executable file
@ -0,0 +1,28 @@
|
|||||||
|
# profile under different setups
|
||||||
|
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||||
|
# zero2 ibs32 tbs32
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||||
|
|
||||||
|
# # zero2 ibs64 tbs32
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||||
|
|
||||||
|
# # zero2 ibs96 tbs32
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 24 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 2 2>&1| tee ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_96_tbs_32_tmbs_2_zero_2_GRPO_profile.png
|
||||||
|
|
||||||
|
# 4K
|
||||||
|
# MAX_NEW_TOKENS=$((4096-512))
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mpt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
||||||
|
|
||||||
|
# # 32K
|
||||||
|
# MAX_NEW_TOKENS=$((32768-512))
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 4 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
|
||||||
|
|
||||||
|
# 16K
|
||||||
|
MAX_NEW_TOKENS=$((16384-512))
|
||||||
|
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
|
||||||
|
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
@ -93,7 +93,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||||
|
|
||||||
# GRPO parameters
|
# GRPO parameters
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "GRPO-DUMMY-TEST"])
|
||||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -172,7 +172,7 @@ if __name__ == "__main__":
|
|||||||
namespace="ray-example",
|
namespace="ray-example",
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||||
"TOKENIZERS_PARALLELISM": "false"
|
"TOKENIZERS_PARALLELISM": "false"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -243,7 +243,7 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
|
||||||
if args.algo == "GRPO":
|
if "GRPO" in args.algo:
|
||||||
# Default Settings
|
# Default Settings
|
||||||
grpo_config = {
|
grpo_config = {
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
@ -318,6 +318,8 @@ if __name__ == "__main__":
|
|||||||
plugin_config={
|
plugin_config={
|
||||||
"tp_size": args.tensor_parallel_size,
|
"tp_size": args.tensor_parallel_size,
|
||||||
"pp_size": args.pipeline_parallel_size,
|
"pp_size": args.pipeline_parallel_size,
|
||||||
|
# "num_layers_per_stage": [12,12,2,2],
|
||||||
|
"num_layers_per_stage": [20, 8],
|
||||||
"microbatch_size": max(
|
"microbatch_size": max(
|
||||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
4
applications/ColossalChat/test_profiling.sh
Executable file
4
applications/ColossalChat/test_profiling.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
MAX_NEW_TOKENS=$((4096-512))
|
||||||
|
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||||
|
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO -ibs 32 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.txt
|
||||||
|
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_4096_GRPO_profile.png
|
Loading…
Reference in New Issue
Block a user