mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-22 03:03:23 +00:00
[feat[ Support one-behind to reduce bubble time. Add profiling code (#6353)
* support n_behind, add profiling * fix bugs * fix visualization * fix behind * fix loop issue * add profiling * fix update * update assert * remove assert --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
8880b83791
commit
b1f646c7e7
@ -6,6 +6,7 @@ import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
@ -36,6 +37,8 @@ class BaseConsumer:
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
@ -49,6 +52,7 @@ class BaseConsumer:
|
||||
self.minibatch_size = minibatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
self.enable_profiling = enable_profiling
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
|
||||
@ -57,6 +61,7 @@ class BaseConsumer:
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
self.n_behind = n_behind
|
||||
|
||||
def setup(self) -> None:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
@ -94,6 +99,7 @@ class BaseConsumer:
|
||||
|
||||
self.buffer = []
|
||||
self.recv_cnt = 0
|
||||
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@ -101,6 +107,41 @@ class BaseConsumer:
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare a mini-batch from the effective group to raw group mapping.
|
||||
This method is used to create a mini-batch for training.
|
||||
"""
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
return batch, raw_mini_batches_metric_dict
|
||||
|
||||
def calculate_effective_group_to_raw_group_mapping(self, step):
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
if self.n_behind == 0:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||
else:
|
||||
if self.buffer[buffer_idx][-1] <= step - self.n_behind:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
|
||||
return effective_group_to_raw_group_mapping
|
||||
|
||||
def loop(self) -> None:
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
@ -112,14 +153,53 @@ class BaseConsumer:
|
||||
disable=self.rank != 0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
i = 0
|
||||
|
||||
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
|
||||
for _ in range(self.num_recv_per_update):
|
||||
if self.n_behind > 0:
|
||||
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||
step=step
|
||||
)
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
self.profiler.log(
|
||||
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||
)
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
effective_group_to_raw_group_mapping = (
|
||||
self.calculate_effective_group_to_raw_group_mapping(step=step)
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before
|
||||
- self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.profiler.enter(f"recv_broadcast_data_P{r}")
|
||||
raw_batch = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
self.profiler.exit(f"recv_broadcast_data_P{r}")
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
# we need to calculate the metrics before filtering here for logging
|
||||
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
||||
@ -154,6 +234,7 @@ class BaseConsumer:
|
||||
format_acc[group_idx],
|
||||
ans_acc[group_idx],
|
||||
response_len[group_idx],
|
||||
step,
|
||||
]
|
||||
)
|
||||
if effective_group_mask is not None:
|
||||
@ -161,56 +242,44 @@ class BaseConsumer:
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||
)
|
||||
# mapping the effective group to the raw group for indexing
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||
step=step
|
||||
)
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||
)
|
||||
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(
|
||||
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||
if self.n_behind == 0:
|
||||
# If n_behind is 0, we start training after receiving data from producers.
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
self.profiler.log(
|
||||
f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||
)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
effective_group_to_raw_group_mapping = (
|
||||
self.calculate_effective_group_to_raw_group_mapping(step=step)
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before
|
||||
- self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
@ -221,13 +290,16 @@ class BaseConsumer:
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
|
||||
episode != 0 or step >= self.n_behind
|
||||
):
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
self.profiler.enter("sync_model")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
@ -245,6 +317,13 @@ class BaseConsumer:
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "profiler"):
|
||||
self.profiler.close()
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -38,6 +38,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if (
|
||||
@ -63,6 +65,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
minibatch_size,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
|
@ -57,6 +57,8 @@ def launch_distributed(
|
||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_save_dir: str = "./rollout",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
@ -132,6 +134,8 @@ def launch_distributed(
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
@ -171,6 +175,8 @@ def launch_distributed(
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
consumer_procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in consumer_procs])
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.profiling_utils import CustomProfiler
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from ray.util.collective import allreduce
|
||||
@ -52,6 +53,8 @@ class BaseProducer:
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@ -62,6 +65,7 @@ class BaseProducer:
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
self.latest_eval_step = -1
|
||||
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.model_config = model_config
|
||||
@ -75,6 +79,7 @@ class BaseProducer:
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
self.grpo_config = grpo_config
|
||||
self.n_behind = n_behind
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
@ -268,11 +273,14 @@ class BaseProducer:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
self.profiler.enter("rollout")
|
||||
outputs = self.rollout(**batch)
|
||||
self.profiler.exit("rollout")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
self.profiler.enter("calculate_reward")
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
test_cases = []
|
||||
for prompt_id in range(bs):
|
||||
@ -310,14 +318,19 @@ class BaseProducer:
|
||||
outputs.pop("gt_answer")
|
||||
if "test_cases" in outputs:
|
||||
outputs.pop("test_cases")
|
||||
self.profiler.exit("calculate_reward")
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
self.profiler.enter("send_broadcast_data")
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
self.profiler.exit("send_broadcast_data")
|
||||
if (
|
||||
(i + 1) % self.num_microbatches == 0
|
||||
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
|
||||
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
|
||||
):
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
@ -325,7 +338,7 @@ class BaseProducer:
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.profiler.enter("sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
print(
|
||||
@ -347,6 +360,7 @@ class BaseProducer:
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
self.profiler.exit("sync_model")
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
@ -364,6 +378,9 @@ class BaseProducer:
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
def __del__(self):
|
||||
self.profiler.close()
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SimpleProducer(BaseProducer):
|
||||
@ -392,6 +409,8 @@ class SimpleProducer(BaseProducer):
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@ -415,6 +434,8 @@ class SimpleProducer(BaseProducer):
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
|
@ -0,0 +1,37 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
class CustomProfiler:
|
||||
def __init__(self, name, disabled=True):
|
||||
self.disabled = disabled
|
||||
if not disabled:
|
||||
self.name = name
|
||||
self.pid = os.getpid()
|
||||
self.file = open(f"{name}.prof", "w")
|
||||
|
||||
def _log(self, message):
|
||||
if self.disabled:
|
||||
return
|
||||
current_time = time.time()
|
||||
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
|
||||
self.file.flush()
|
||||
|
||||
def log(self, message):
|
||||
if self.disabled:
|
||||
return
|
||||
current_time = time.time()
|
||||
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
|
||||
self.file.flush()
|
||||
|
||||
def enter(self, event_name):
|
||||
self._log(f"Enter {event_name}")
|
||||
|
||||
def exit(self, event_name):
|
||||
self._log(f"Exit {event_name}")
|
||||
|
||||
def close(self):
|
||||
if self.disabled:
|
||||
return
|
||||
self.file.close()
|
||||
print(f"Profiler data written to {self.name}.prof")
|
13
applications/ColossalChat/profiling.sh
Executable file
13
applications/ColossalChat/profiling.sh
Executable file
@ -0,0 +1,13 @@
|
||||
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||
|
||||
# 8K context length
|
||||
# rm -rf *.prof
|
||||
# MAX_NEW_TOKENS=$((8192-512))
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
|
||||
|
||||
# 4K context length
|
||||
rm -rf *.prof
|
||||
MAX_NEW_TOKENS=$((4096-512))
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
|
||||
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png
|
@ -67,6 +67,27 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
@ -97,6 +118,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
||||
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
@ -117,6 +145,13 @@ if __name__ == "__main__":
|
||||
default=100,
|
||||
help="Interval for evaluation. Evaluate every ei training steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-nb",
|
||||
"--n-behind",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@ -128,32 +163,7 @@ if __name__ == "__main__":
|
||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -236,14 +246,25 @@ if __name__ == "__main__":
|
||||
tensor_parallel_size=args.producer_tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
if args.enable_profiling:
|
||||
# If profiling is enabled, we force model to generate to max_new_tokens
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
@ -353,4 +374,6 @@ if __name__ == "__main__":
|
||||
eval_generation_config=eval_generation_config,
|
||||
log_rollout_interval=20,
|
||||
rollout_save_dir=args.rollout_save_dir,
|
||||
enable_profiling=args.enable_profiling,
|
||||
n_behind=args.n_behind,
|
||||
)
|
||||
|
100
applications/ColossalChat/visualization.py
Normal file
100
applications/ColossalChat/visualization.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Re-import required libraries due to kernel reset
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Argument parser for command line arguments
|
||||
parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
|
||||
parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Raw log lines
|
||||
log_lines = []
|
||||
|
||||
import glob
|
||||
|
||||
files = glob.glob("*.prof")
|
||||
for file in files:
|
||||
with open(file, "r") as f:
|
||||
log_lines += f.readlines()
|
||||
|
||||
# Parse logs and collect function intervals grouped by actor
|
||||
actors = defaultdict(lambda: defaultdict(list))
|
||||
current_entries = {}
|
||||
|
||||
# First, collect all timestamps to find the minimum
|
||||
all_timestamps = []
|
||||
parsed_lines = []
|
||||
|
||||
for line in log_lines:
|
||||
if line.startswith("[Log]"):
|
||||
continue
|
||||
parts = line.split()
|
||||
timestamp = float(parts[0])
|
||||
actor = parts[1]
|
||||
action = parts[3]
|
||||
func_name = parts[4]
|
||||
parsed_lines.append((timestamp, actor, action, func_name))
|
||||
all_timestamps.append(timestamp)
|
||||
|
||||
if not all_timestamps:
|
||||
raise ValueError("No valid log entries found.")
|
||||
|
||||
min_timestamp = min(all_timestamps)
|
||||
|
||||
for timestamp, actor, action, func_name in parsed_lines:
|
||||
rel_timestamp = timestamp - min_timestamp
|
||||
key = (actor, func_name)
|
||||
if action == "Enter":
|
||||
current_entries[key] = rel_timestamp
|
||||
elif action == "Exit":
|
||||
start_time = current_entries.pop(key, None)
|
||||
if start_time is not None:
|
||||
actors[actor][func_name].append((start_time, rel_timestamp))
|
||||
|
||||
# Plotting setup
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
colors = cm.get_cmap("tab10", len(actors))
|
||||
|
||||
actor_offsets = {}
|
||||
base_offset = 0
|
||||
function_spacing = 0.9
|
||||
|
||||
yticks = []
|
||||
yticklabels = []
|
||||
|
||||
for idx, (actor, func_dict) in enumerate(actors.items()):
|
||||
actor_offsets[actor] = base_offset
|
||||
color = colors(idx)
|
||||
for j, (func, intervals) in enumerate(func_dict.items()):
|
||||
print(actor, func, intervals)
|
||||
y_val = base_offset + j * function_spacing
|
||||
yticks.append(y_val)
|
||||
yticklabels.append(f"{actor}:{func}")
|
||||
for start, end in intervals:
|
||||
if end - start < 1:
|
||||
end = start + 1 # Ensure all lines are at least 3 units long
|
||||
ax.plot(
|
||||
[start, end],
|
||||
[y_val, y_val],
|
||||
color=color,
|
||||
linewidth=2,
|
||||
label=actor if j == 0 else "",
|
||||
)
|
||||
base_offset += len(func_dict) * function_spacing + 1
|
||||
|
||||
# Formatting
|
||||
ax.set_yticks(yticks)
|
||||
ax.set_yticklabels(yticklabels)
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_title("Timeline per Actor")
|
||||
# Remove duplicate labels in legend
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
unique = dict(zip(labels, handles))
|
||||
ax.legend(unique.values(), unique.keys())
|
||||
plt.tight_layout()
|
||||
plt.grid(True)
|
||||
plt.savefig(args.visualization, dpi=600) # Increase dpi for higher resolution
|
||||
print(f"Plot saved as {args.visualization}")
|
Loading…
Reference in New Issue
Block a user